初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
committed by Windpicker-owo
parent ef7a3aee23
commit 23ee3767ef
77 changed files with 10000 additions and 7525 deletions

View File

@@ -12,9 +12,10 @@ import binascii
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db
from sqlalchemy import select
from src.common.database.database import db
from src.common.database.sqlalchemy_database_api import get_session
from src.common.database.sqlalchemy_models import Emoji, Images
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
@@ -29,6 +30,8 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
session = get_session()
"""
还没经过测试,有些地方数据库和内存数据同步可能不完全
@@ -151,7 +154,7 @@ class MaiEmoji:
# 准备数据库记录 for emoji collection
emotion_str = ",".join(self.emotion) if self.emotion else ""
Emoji.create(
emoji = Emoji(
emoji_hash=self.hash,
full_path=self.full_path,
format=self.format,
@@ -165,6 +168,8 @@ class MaiEmoji:
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
session.add(emoji)
session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -200,7 +205,7 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none()
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist: # type: ignore
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
@@ -248,7 +253,6 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = []
load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data)
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
@@ -393,12 +397,17 @@ class EmojiManager:
def initialize(self) -> None:
"""初始化数据库连接和表情目录"""
peewee_db.connect(reuse_if_open=True)
if peewee_db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
Emoji.create_table(safe=True) # Ensures table exists
self._initialized = True
try:
db.connect(reuse_if_open=True)
if db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
self._initialized = True # 标记为已初始化
logger.info("EmojiManager初始化成功")
except Exception as e:
logger.error(f"EmojiManager初始化失败: {e}")
self._initialized = False
raise
def _ensure_db(self) -> None:
"""确保数据库已初始化"""
@@ -410,7 +419,7 @@ class EmojiManager:
def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
@@ -644,10 +653,10 @@ class EmojiManager:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
self._ensure_db()
logger.debug("[数据库] 开始加载所有表情包记录 (Peewee)...")
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
emoji_instances = session.execute(stmt = select(Emoji)).scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量
self.emoji_objects = emoji_objects
@@ -675,15 +684,15 @@ class EmojiManager:
self._ensure_db()
if emoji_hash:
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = Emoji.select()
query = session.execute(select(Emoji)).scalars().all()
emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -760,7 +769,7 @@ class EmojiManager:
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description
@@ -921,9 +930,10 @@ class EmojiManager:
# 尝试从Images表获取已有的详细描述可能在收到表情包时已生成
existing_description = None
try:
from src.common.database.database_model import Images
# from src.common.database.database_model_compat import Images
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
existing_image = session.execute(stmt).scalar_one_or_none()
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -7,7 +7,9 @@ from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
@@ -20,7 +22,7 @@ DECAY_DAYS = 30 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
logger = get_logger("expressor")
session = get_session()
def format_create_date(timestamp: float) -> str:
"""
@@ -168,30 +170,50 @@ class ExpressionLearner:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
# """
# 获取指定chat_id的style表达方式(已禁用grammar的获取)
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
# """
# learnt_style_expressions = []
def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
"""
获取指定chat_id的stylegrammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
"""
learnt_style_expressions = []
learnt_grammar_expressions = []
# 直接从数据库查询
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
for expr in style_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "style",
"create_date": create_date,
}
)
grammar_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")))
for expr in grammar_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "grammar",
"create_date": create_date,
}
)
return learnt_style_expressions, learnt_grammar_expressions
# # 直接从数据库查询
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
# for expr in style_query:
# # 确保create_date存在如果不存在则使用last_active_time
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
# learnt_style_expressions.append(
# {
# "situation": expr.situation,
# "style": expr.style,
# "count": expr.count,
# "last_active_time": expr.last_active_time,
# "source_id": self.chat_id,
# "type": "style",
# "create_date": create_date,
# }
# )
# return learnt_style_expressions
@@ -201,7 +223,7 @@ class ExpressionLearner:
"""
try:
# 获取所有表达方式
all_expressions = Expression.select()
all_expressions = session.execute(select(Expression)).scalars()
updated_count = 0
deleted_count = 0
@@ -217,18 +239,20 @@ class ExpressionLearner:
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
session.delete(expr)
deleted_count += 1
else:
# 更新count
expr.count = new_count
expr.save()
updated_count += 1
session.commit()
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
session.rollback()
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
@@ -297,23 +321,22 @@ class ExpressionLearner:
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == "style")
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
)).scalar()
if query:
expr_obj = query
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
new_expression = Expression(
situation=new_expr["situation"],
style=new_expr["style"],
count=1,
@@ -322,16 +345,18 @@ class ExpressionLearner:
type="style",
create_date=current_time, # 手动设置创建日期
)
session.add(new_expression)
# 限制最大数量
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
.order_by(Expression.count.asc())
session.execute(select(Expression)
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())).scalars()
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
session.delete(expr)
session.commit()
return learnt_expressions
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
@@ -509,54 +534,35 @@ class ExpressionLearnerManager:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
# 查重同chat_id+type+situation+style
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)).scalar()
if query:
expr_obj = query
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
else:
new_expression = Expression(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
if query.exists():
expr_obj = query.get()
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
done_parent_dir = os.path.dirname(done_flag)
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")
except OSError as e:
logger.error(f"文件系统错误无法写入done.done标记文件: {e}")
except Exception as e:
logger.error(f"写入done.done标记文件失败: {e}")
session.add(new_expression)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 检查并处理grammar表达删除
if not os.path.exists(done_flag2):
@@ -581,18 +587,20 @@ class ExpressionLearnerManager:
"""
try:
# 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null())
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
session.commit()
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
session.rollback()
logger.error(f"迁移老数据创建日期失败: {e}")
def delete_all_grammar_expressions(self) -> int:

View File

@@ -9,8 +9,11 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_session
session = get_session()
logger = get_logger("expression_selector")
@@ -131,9 +134,12 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
style_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
))
grammar_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
))
style_exprs = [
{
@@ -146,9 +152,24 @@ class ExpressionSelector:
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
for expr in style_query.scalars()
]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in grammar_query.scalars()
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
@@ -174,19 +195,19 @@ class ExpressionSelector:
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
)).scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.save()
session.commit()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)

View File

@@ -1,33 +0,0 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
from .utils.dyn_topk import dyn_select_top_k
class MemoryActiveManager:
def __init__(
self,
embed_manager: EmbeddingManager,
llm_client_embedding: LLMClient,
):
self.embed_manager = embed_manager
self.embedding_client = llm_client_embedding
def get_activation(self, question: str) -> float:
"""获取记忆激活度"""
# 生成问题的Embedding
question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
# 查询关系库中的相似度
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
# 动态过滤阈值
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
if rel_scores[0][1] < global_config["qa"]["params"]["relation_threshold"]:
# 未找到相关关系
return 0.0
# 计算激活度
activation = sum([item[2] for item in rel_scores]) * 10
return activation

View File

@@ -16,8 +16,10 @@ from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from sqlalchemy import select,insert,update,text,delete
from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入
from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_session
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
@@ -37,7 +39,7 @@ def cosine_similarity(v1, v2):
install(extra_lines=3)
session = get_session()
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
@@ -731,13 +733,14 @@ class Hippocampus:
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容
if memory_items:
logger.debug("节点包含完整记忆")
# 计算记忆与关键词的相似度
memory_words = set(jieba.cut(memory_items))
text_words = set(keywords)
all_words = memory_words | text_words
if all_words:
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
logger.debug(f"节点包含 {len(memory_items)}记忆")
# 计算每条记忆与输入文本的相似度
memory_similarities = []
for memory in memory_items:
# 计算与输入文本的相似度
memory_words = set(jieba.cut(memory))
text_words = set(jieba.cut(text))
all_words = memory_words | text_words
v1 = [1 if word in memory_words else 0 for word in all_words]
v2 = [1 if word in text_words else 0 for word in all_words]
_ = cosine_similarity(v1, v2) # 计算但不使用用_表示
@@ -844,11 +847,6 @@ class Hippocampus:
else:
activate_map[node] = activation_value
# 输出激活映射
# logger.info("激活映射统计:")
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
# 计算激活节点数与总节点数的比值
total_activation = sum(activate_map.values())
# logger.debug(f"总激活值: {total_activation:.2f}")
@@ -942,10 +940,13 @@ class EntorhinalCortex:
for message in messages:
# 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
# 使用 Peewee 更新记录
Messages.update(memorized_times=current_memorized_times + 1).where(
Messages.message_id == message["message_id"]
).execute()
# 使用 SQLAlchemy 2.0 更新记录
session.execute(
update(Messages)
.where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1)
)
session.commit()
return messages # 直接返回原始的消息列表
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
@@ -959,7 +960,7 @@ class EntorhinalCortex:
current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点
db_nodes = {node.concept: node for node in GraphNodes.select()}
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据
@@ -1025,22 +1026,27 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size]
GraphNodes.insert_many(batch).execute()
session.execute(insert(GraphNodes), batch)
session.commit()
if nodes_to_update:
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i : i + batch_size]
for node_data in batch:
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
GraphNodes.concept == node_data["concept"]
).execute()
session.execute(
update(GraphNodes)
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
)
session.commit()
if nodes_to_delete:
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
session.commit()
# 处理边的信息
db_edges = list(GraphEdges.select())
db_edges = list(session.execute(select(GraphEdges)).scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -1092,20 +1098,29 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
session.execute(insert(GraphEdges), batch)
session.commit()
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i : i + batch_size]
for edge_data in batch:
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
).execute()
session.execute(
update(GraphEdges)
.where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
)
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
)
session.commit()
if edges_to_delete:
for source, target in edges_to_delete:
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
session.execute(
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
)
session.commit()
end_time = time.time()
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}")
@@ -1118,8 +1133,9 @@ class EntorhinalCortex:
# 清空数据库
clear_start = time.time()
GraphNodes.delete().execute()
GraphEdges.delete().execute()
session.execute(delete(GraphNodes))
session.execute(delete(GraphEdges))
session.commit()
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1186,12 +1202,27 @@ class EntorhinalCortex:
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
continue
# 批量插入边
# 批量写入节点
node_start = time.time()
if nodes_data:
batch_size = 500 # 增加批量大小
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
session.commit()
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
# 批量写入边
edge_start = time.time()
if edges_data:
batch_size = 100
batch_size = 500 # 增加批量大小
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
session.execute(insert(GraphEdges), batch)
session.commit()
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")
end_time = time.time()
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}")
@@ -1211,9 +1242,7 @@ class EntorhinalCortex:
skipped_nodes = 0
# 从数据库加载所有节点
nodes = list(GraphNodes.select())
total_nodes = len(nodes)
nodes = list(session.execute(select(GraphNodes)).scalars())
for node in nodes:
concept = node.concept
try:
@@ -1235,8 +1264,10 @@ class EntorhinalCortex:
if not node.last_modified:
update_data["last_modified"] = current_time
if update_data:
GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
)
session.commit()
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time
@@ -1256,7 +1287,7 @@ class EntorhinalCortex:
continue
# 从数据库加载所有边
edges = list(GraphEdges.select())
edges = list(session.execute(select(GraphEdges)).scalars())
for edge in edges:
source = edge.source
target = edge.target
@@ -1272,9 +1303,12 @@ class EntorhinalCortex:
if not edge.last_modified:
update_data["last_modified"] = current_time
GraphEdges.update(**update_data).where(
(GraphEdges.source == source) & (GraphEdges.target == target)
).execute()
session.execute(
update(GraphEdges)
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
)
session.commit()
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.created_time or current_time
@@ -1398,7 +1432,6 @@ class ParahippocampalGyrus:
all_words = topic_words | existing_words
v1 = [1 if word in topic_words else 0 for word in all_words]
v2 = [1 if word in existing_words else 0 for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= 0.7:
@@ -1502,7 +1535,7 @@ class ParahippocampalGyrus:
check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage)))
check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage)))
# 只有在有足够的节点和边时进行采样
# 只有在有足够的节点和边时进行采样
if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count:
try:
nodes_to_check = random.sample(all_nodes, check_nodes_count)
@@ -1548,6 +1581,11 @@ class ParahippocampalGyrus:
logger.info("[遗忘] 开始检查节点...")
node_check_start = time.time()
# 初始化整合相关变量
merged_count = 0
nodes_modified = set()
for node in nodes_to_check:
# 检查节点是否存在,以防在迭代中被移除(例如边移除导致)
if node not in self.memory_graph.G:
@@ -1567,64 +1605,91 @@ class ParahippocampalGyrus:
logger.warning(f"[遗忘] 移除空节点 {node} 时发生错误(可能已被移除): {e}")
continue # 处理下一个节点
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
# 检查节点的最后修改时间,如果太旧则尝试遗忘
last_modified = node_data.get("last_modified", current_time)
node_weight = node_data.get("weight", 1.0)
# 条件1检查是否长时间未修改 (使用配置的遗忘时间)
time_threshold = 3600 * global_config.memory.memory_forget_time
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
# 权重为1时使用默认阈值权重越高阈值越大越难遗忘
adjusted_threshold = time_threshold * node_weight
if current_time - last_modified > adjusted_threshold and memory_items:
# 既然每个节点现在是完整记忆,直接删除整个节点
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(长时间未修改,权重{node_weight:.1f})")
logger.debug(f"[遗忘] 移除了长时间未修改的节点: {node} (权重: {node_weight:.1f})")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误(可能已被移除): {e}")
continue
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
# 随机遗忘一条记忆
if len(memory_items) > 1:
removed_item = self.memory_graph.forget_topic(node)
if removed_item:
node_changes["reduced"].append(f"{node} (移除: {removed_item[:50]}...)")
elif len(memory_items) == 1:
# 如果只有一条记忆,检查是否应该完全移除节点
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node} (最后记忆)")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误: {e}")
# 检查节点内是否有相似的记忆项需要整合
if len(memory_items) > 1:
merged_in_this_node = False
items_to_remove = []
for i in range(len(memory_items)):
for j in range(i + 1, len(memory_items)):
similarity = self._calculate_item_similarity(memory_items[i], memory_items[j])
if similarity > 0.8: # 相似度阈值
# 合并相似记忆项
longer_item = memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
shorter_item = memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
# 保留更长的记忆项,标记短的用于删除
if shorter_item not in items_to_remove:
items_to_remove.append(shorter_item)
merged_count += 1
merged_in_this_node = True
logger.debug(f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}...")
# 移除被合并的记忆项
if items_to_remove:
for item in items_to_remove:
if item in memory_items:
memory_items.remove(item)
nodes_modified.add(node)
# 更新节点的记忆项
self.memory_graph.G.nodes[node]["memory_items"] = memory_items
self.memory_graph.G.nodes[node]["last_modified"] = current_time
node_check_end = time.time()
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}")
if any(edge_changes.values()) or any(node_changes.values()):
# 输出变化统计
if edge_changes["weakened"]:
logger.info(f"[遗忘] 减弱了 {len(edge_changes['weakened'])} 个连接")
if edge_changes["removed"]:
logger.info(f"[遗忘] 移除了 {len(edge_changes['removed'])} 个连接")
if node_changes["reduced"]:
logger.info(f"[遗忘] 减少了 {len(node_changes['reduced'])} 个节点的记忆")
if node_changes["removed"]:
logger.info(f"[遗忘] 移除了 {len(node_changes['removed'])} 个节点")
# 检查是否有变化需要同步到数据库
has_changes = (
edge_changes["weakened"] or
edge_changes["removed"] or
node_changes["reduced"] or
node_changes["removed"] or
merged_count > 0
)
if has_changes:
logger.info("[遗忘] 开始将变更同步到数据库...")
sync_start = time.time()
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
sync_end = time.time()
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}")
# 汇总输出所有变化
logger.info("[遗忘] 遗忘操作统计:")
if edge_changes["weakened"]:
logger.info(
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
)
if edge_changes["removed"]:
logger.info(
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
)
if node_changes["reduced"]:
logger.info(
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
)
if node_changes["removed"]:
logger.info(
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
)
if merged_count > 0:
logger.info(f"[整合] 共合并了 {merged_count} 对相似记忆项,分布在 {len(nodes_modified)} 个节点中。")
sync_start = time.time()
logger.info("[整合] 开始将变更同步到数据库...")
# 使用 resync 更安全地处理删除和添加
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
sync_end = time.time()
logger.info(f"[整合] 数据库同步耗时: {sync_end - sync_start:.2f}")
else:
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
end_time = time.time()
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
logger.info("[整合] 本次检查未发现需要合并的记忆项。")
@@ -1734,10 +1799,7 @@ class HippocampusManager:
"""获取所有节点名称的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return self._hippocampus.get_all_node_names()
# 创建全局实例
hippocampus_manager = HippocampusManager()

View File

@@ -10,12 +10,13 @@ from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.common.database.database_model import Memory # Peewee Models导入
from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入
from src.common.database.sqlalchemy_database_api import get_session
from src.config.config import model_config
from sqlalchemy import select
logger = get_logger(__name__)
session = get_session()
class MemoryItem:
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
@@ -120,7 +121,8 @@ class InstantMemory:
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time,
)
memory.save()
session.add(memory)
session.commit()
async def get_memory(self, target: str):
from json_repair import repair_json
@@ -166,13 +168,13 @@ class InstantMemory:
if start_time and end_time:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = Memory.select().where(
query = session.execute(select(Memory).where(
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) # type: ignore
& (Memory.create_time < end_ts) # type: ignore
)
& (Memory.create_time >= start_ts)
& (Memory.create_time < end_ts)
)).scalars()
else:
query = Memory.select().where(Memory.chat_id == self.chat_id)
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
for mem in query:
# 对每条记忆

View File

@@ -8,8 +8,12 @@ from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
from sqlalchemy import select, text
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.database.sqlalchemy_database_api import get_session
from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
@@ -19,7 +23,7 @@ install(extra_lines=3)
logger = get_logger("chat_stream")
session = get_session()
class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息"""
@@ -131,7 +135,8 @@ class ChatManager:
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
session.commit()
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
@@ -231,7 +236,7 @@ class ChatManager:
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
@@ -342,7 +347,28 @@ class ChatManager:
"group_name": group_info_d["group_name"] if group_info_d else "",
}
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
# 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'],
set_=fields_to_save
)
elif global_config.database.database_type == "mysql":
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
else:
# 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'],
set_=fields_to_save
)
session.execute(stmt)
session.commit()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
@@ -361,7 +387,7 @@ class ChatManager:
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
for model_instance in session.execute(select(ChatStreams)).scalars():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,

View File

@@ -1,16 +1,18 @@
import re
import json
import traceback
import json
from typing import Union
from src.common.database.database_model import Messages, Images
from src.common.database.sqlalchemy_models import Messages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select, update, desc
logger = get_logger("message_storage")
class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
@@ -33,15 +35,11 @@ class MessageStorage:
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 莫越权 救世啊
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
# print(message)
processed_plain_text = message.processed_plain_text
# print(processed_plain_text)
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
@@ -93,11 +91,16 @@ class MessageStorage:
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
Messages.create(
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = json.dumps(priority_info) if priority_info else None
# 获取数据库会话
session = get_session()
new_message = Messages(
message_id=msg_id,
time=float(message.message_info.time), # type: ignore
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
# Flattened chat_info
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_dict.get("stream_id"),
@@ -111,18 +114,16 @@ class MessageStorage:
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=message.memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info,
priority_info=priority_info_json,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
@@ -131,35 +132,44 @@ class MessageStorage:
key_words_lite=key_words_lite,
selected_expressions=selected_expressions,
)
session.add(new_message)
session.commit()
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
traceback.print_exc()
# 如果需要其他存储相关的函数,可以在这里添加
@staticmethod
async def update_message(
message: MessageRecv,
) -> None: # 用于实时更新数据库的自身发送消息ID目前能处理text,reply,image和emoji
"""更新最新一条匹配消息的message_id"""
async def update_message(message):
"""更新消息ID"""
try:
if message.message_segment.type == "notify":
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
mmc_message_id = message.message_info.message_id # 修复正确访问message_id
if message.message_segment.type == "text":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "reply":
qq_message_id = message.message_segment.data.get("id")
else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
).scalar()
if matched_message:
session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
# session.commit() 会在上下文管理器中自动调用
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
@@ -178,10 +188,12 @@ class MessageStorage:
def replace_match(match):
description = match.group(1).strip()
try:
image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
)
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
).scalar()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)

View File

@@ -7,13 +7,14 @@ from rich.traceback import install
from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import Person,get_person_id
from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select, and_
install(extra_lines=3)
session = get_session()
def replace_user_references_sync(
content: str,
@@ -254,50 +255,90 @@ def get_actions_by_timestamp_with_chat(
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) # type: ignore
)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
))
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.desc()).limit(limit))
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.asc()).limit(limit))
else:
query = query.order_by(ActionRecords.time.asc())
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.asc()))
actions = list(query)
return [action.__data__ for action in actions]
actions = list(query.scalars())
return [action.__dict__ for action in actions]
def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) # type: ignore
)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
))
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.desc()).limit(limit))
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.asc()).limit(limit))
else:
query = query.order_by(ActionRecords.time.asc())
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.asc()))
actions = list(query)
return [action.__data__ for action in actions]
actions = list(query.scalars())
return [action.__dict__ for action in actions]
def get_raw_msg_by_timestamp_random(
@@ -700,7 +741,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar()
if image and image.description:
description = image.description
except Exception:
@@ -813,7 +854,7 @@ def build_readable_messages(
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
show_actions: bool = True,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> str: # sourcery skip: extract-method
@@ -846,21 +887,21 @@ def build_readable_messages(
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (
ActionRecords.select()
.where(
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
actions_in_range = session.execute(select(ActionRecords).where(
and_(
ActionRecords.time >= min_time,
ActionRecords.time <= max_time,
ActionRecords.chat_id == chat_id
)
.order_by(ActionRecords.time)
)
).order_by(ActionRecords.time)).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = (
ActionRecords.select()
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
action_after_latest = session.execute(select(ActionRecords).where(
and_(
ActionRecords.time > max_time,
ActionRecords.chat_id == chat_id
)
).order_by(ActionRecords.time).limit(1)).scalars()
# 合并两部分动作记录
actions = list(actions_in_range) + list(action_after_latest)

View File

@@ -6,13 +6,52 @@ from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
from src.common.database.sqlalchemy_database_api import get_db_session, db_query, db_save, db_get
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
# 同步包装器函数用于在非异步环境中调用异步数据库API
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
"""同步版本的db_get用于在线程池中调用"""
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建新的事件循环
import threading
result = None
exception = None
def run_in_thread():
nonlocal result, exception
try:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
result = new_loop.run_until_complete(
db_get(model_class, filters, limit, order_by, single_result)
)
new_loop.close()
except Exception as e:
exception = e
thread = threading.Thread(target=run_in_thread)
thread.start()
thread.join()
if exception:
raise exception
return result
else:
return loop.run_until_complete(
db_get(model_class, filters, limit, order_by, single_result)
)
except RuntimeError:
# 没有事件循环,创建一个新的
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 统计数据的键
TOTAL_REQ_CNT = "total_requests"
TOTAL_COST = "total_cost"
@@ -59,17 +98,9 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: int | None = None # Changed to int for Peewee's default ID
self.record_id: int | None = None
"""记录ID"""
self._init_database() # 初始化数据库
@staticmethod
def _init_database():
"""初始化数据库"""
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self): # sourcery skip: use-named-expression
try:
current_time = datetime.now()
@@ -77,36 +108,50 @@ class OnlineTimeRecordTask(AsyncTask):
if self.record_id:
# 如果有记录,则更新结束时间
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
updated_rows = query.execute()
updated_rows = await db_query(
model_class=OnlineTime,
query_type="update",
filters={"id": self.record_id},
data={"end_timestamp": extended_end_time}
)
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
self.record_id = None # Reset record_id to trigger find/create logic below
self.record_id = None
if not self.record_id: # Check again if record_id was reset or initially None
# 如果没有记录,检查一分钟以内是否已有记录
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
.order_by(OnlineTime.end_timestamp.desc())
.first()
if not self.record_id:
# 查找最近一分钟内的记录
recent_threshold = current_time - timedelta(minutes=1)
recent_records = await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": recent_threshold}},
order_by="-end_timestamp",
limit=1,
single_result=True
)
if recent_record:
# 如果有记录,更新结束时间
self.record_id = recent_record.id
recent_record.end_timestamp = extended_end_time
recent_record.save()
else:
# 若没有记录,则插入新的在线时间记录
new_record = OnlineTime.create(
timestamp=current_time.timestamp(), # 添加此行
start_timestamp=current_time,
end_timestamp=extended_end_time,
duration=5, # 初始时长为5分钟
if recent_records:
# 找到近期记录,更新
self.record_id = recent_records['id']
await db_query(
model_class=OnlineTime,
query_type="update",
filters={"id": self.record_id},
data={"end_timestamp": extended_end_time}
)
self.record_id = new_record.id
else:
# 创建新记录
new_record = await db_save(
model_class=OnlineTime,
data={
"timestamp": str(current_time),
"duration": 5, # 初始时长为5分钟
"start_timestamp": current_time,
"end_timestamp": extended_end_time,
}
)
if new_record:
self.record_id = new_record['id']
except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -322,18 +367,23 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_timestamp = record.timestamp # This is already a datetime object
records = _sync_db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp"
)
for record in records:
record_timestamp = record['timestamp'] # 从字典中获取
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
request_type = record.get('request_type') or "unknown"
user_id = record.get('user_id') or "unknown"
model_name = record.get('model_name') or "unknown"
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
module_name = request_type.split(".")[0] if "." in request_type else request_type
@@ -343,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
prompt_tokens = record.get('prompt_tokens') or 0
completion_tokens = record.get('completion_tokens') or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
@@ -362,7 +412,7 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
cost = record.cost or 0.0
cost = record.get('cost') or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
@@ -425,11 +475,15 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
records = _sync_db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp"
)
for record in records:
record_end_timestamp = record['end_timestamp']
record_start_timestamp = record['start_timestamp']
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
@@ -466,24 +520,30 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp
records = _sync_db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time"
)
for message in records:
message_time_ts = message['time'] # This is a float timestamp
chat_id = None
chat_name = None
# Logic based on Peewee model structure, aiming to replicate original intent
if message.chat_info_group_id:
chat_id = f"g{message.chat_info_group_id}"
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
if message.get('chat_info_group_id'):
chat_id = f"g{message['chat_info_group_id']}"
chat_name = message.get('chat_info_group_name') or f"{message['chat_info_group_id']}"
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
chat_id = f"u{message['user_id']}" # SENDER's user_id
chat_name = message.get('user_nickname') # SENDER's nickname
else:
# If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
)
continue
@@ -1025,8 +1085,14 @@ class StatisticOutputTask(AsyncTask):
# 查询LLM使用记录
query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp
records = _sync_db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp"
)
for record in records:
record_time = record['timestamp']
# 找到对应的时间间隔索引
time_diff = (record_time - start_time).total_seconds()
@@ -1034,17 +1100,17 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 累加总花费数据
cost = record.cost or 0.0
cost = record.get('cost') or 0.0
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
model_name = record.model_name or "unknown"
model_name = record.get('model_name') or "unknown"
if model_name not in cost_by_model:
cost_by_model[model_name] = [0] * len(time_points)
cost_by_model[model_name][interval_index] += cost
# 累加按模块分类的花费
request_type = record.request_type or "unknown"
request_type = record.get('request_type') or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
if module_name not in cost_by_module:
cost_by_module[module_name] = [0] * len(time_points)
@@ -1052,8 +1118,14 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time
records = _sync_db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time"
)
for message in records:
message_time_ts = message['time']
# 找到对应的时间间隔索引
time_diff = message_time_ts - query_start_timestamp
@@ -1062,10 +1134,10 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 确定聊天流名称
chat_name = None
if message.chat_info_group_id:
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id:
chat_name = message.user_nickname or f"用户{message.user_id}"
if message.get('chat_info_group_id'):
chat_name = message.get('chat_info_group_name') or f"{message['chat_info_group_id']}"
elif message.get('user_id'):
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
else:
continue

View File

@@ -13,10 +13,12 @@ from rich.traceback import install
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import get_db_session
from sqlalchemy import select, and_
install(extra_lines=3)
logger = get_logger("chat_image")
@@ -41,9 +43,10 @@ class ImageManager:
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
# 使用SQLAlchemy创建表已在初始化时完成
logger.debug("使用SQLAlchemy进行表管理")
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
logger.error(f"数据库连接失败: {e}")
self._initialized = True
@@ -63,12 +66,13 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
with get_db_session() as session:
record = session.execute(select(ImageDescriptions).where(
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
)).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
return None
@staticmethod
@@ -82,16 +86,28 @@ class ImageManager:
"""
try:
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
image_description_hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
with get_db_session() as session:
# 查找现有记录
existing = session.execute(select(ImageDescriptions).where(
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
)).scalar()
if existing:
# 更新现有记录
existing.description = description
existing.timestamp = current_timestamp
else:
# 创建新记录
new_desc = ImageDescriptions(
image_description_hash=image_hash,
type=description_type,
description=description,
timestamp=current_timestamp
)
session.add(new_desc)
# session.commit() 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -214,19 +230,29 @@ class ImageManager:
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
img_obj.path = file_path
img_obj.description = detailed_description # 保存详细描述
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist: # type: ignore
Images.create(
emoji_hash=image_hash,
path=file_path,
type="emoji",
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
)
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
existing_img = session.execute(select(Images).where(
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
)).scalar()
if existing_img:
existing_img.path = file_path
existing_img.description = detailed_description # 保存详细描述
existing_img.timestamp = current_timestamp
else:
new_img = Images(
emoji_hash=image_hash,
path=file_path,
type="emoji",
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
)
session.add(new_img)
# session.commit() 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}")
except Exception as e:
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
@@ -249,19 +275,19 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
existing_image.save()
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
# 如果已有描述,直接返回
if existing_image.description:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
# 如果已有描述,直接返回
if existing_image.description:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
@@ -300,10 +326,10 @@ class ImageManager:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
existing_image.save()
session.commit()
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
Images.create(
new_img = Images(
image_id=str(uuid.uuid4()),
emoji_hash=image_hash,
path=file_path,
@@ -313,6 +339,8 @@ class ImageManager:
vlm_processed=True,
count=1,
)
session.add(new_img)
session.commit()
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
@@ -465,31 +493,32 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
existing_image.count += 1
session.commit()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
# 保存新图片
current_timestamp = time.time()
@@ -503,7 +532,7 @@ class ImageManager:
f.write(image_bytes)
# 保存到数据库
Images.create(
new_img = Images(
image_id=image_id,
emoji_hash=image_hash,
path=file_path,
@@ -512,6 +541,8 @@ class ImageManager:
vlm_processed=False,
count=1,
)
session.add(new_img)
session.commit()
# 启动异步VLM处理
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
@@ -536,60 +567,64 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session:
# 获取当前图片记录
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
# 获取当前图片记录
image = Images.get(Images.image_id == image_id)
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = session.execute(select(Images).where(
and_(
Images.emoji_hash == image_hash,
Images.description.isnot(None),
Images.description != "",
Images.id != image.id
)
)).scalar()
if existing_with_description:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
image.vlm_processed = True
session.commit()
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
# 检查ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
session.commit()
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = global_config.custom_prompt.image_prompt
# 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("VLM未能生成图片描述")
description = "无法生成描述"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
# 检查ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
image.save()
return
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = global_config.custom_prompt.image_prompt
# 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("VLM未能生成图片描述")
description = "无法生成描述"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
except Exception as e:
logger.error(f"VLM处理图片失败: {str(e)}")

View File

@@ -1,14 +1,103 @@
import os
from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
from src.common.logger import get_logger
# SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_engine, get_session
install(extra_lines=3)
_client = None
_db = None
_sql_engine = None
logger = get_logger("database")
# 兼容性为了不破坏现有代码保留db变量但指向SQLAlchemy
class DatabaseProxy:
"""数据库代理类提供Peewee到SQLAlchemy的兼容性接口"""
def __init__(self):
self._engine = None
self._session = None
def initialize(self, *args, **kwargs):
"""初始化数据库连接"""
return initialize_database_compat()
def connect(self, reuse_if_open=True):
"""连接数据库(兼容性方法)"""
try:
self._engine = get_engine()
return True
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return False
def is_closed(self):
"""检查数据库是否关闭(兼容性方法)"""
return self._engine is None
def create_tables(self, models, safe=True):
"""创建表(兼容性方法)"""
try:
from src.common.database.sqlalchemy_models import Base
engine = get_engine()
Base.metadata.create_all(bind=engine)
return True
except Exception as e:
logger.error(f"创建表失败: {e}")
return False
def table_exists(self, model):
"""检查表是否存在(兼容性方法)"""
try:
from sqlalchemy import inspect
engine = get_engine()
inspector = inspect(engine)
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
return table_name in inspector.get_table_names()
except Exception:
return False
def execute_sql(self, sql):
"""执行SQL兼容性方法"""
try:
from sqlalchemy import text
session = get_session()
result = session.execute(text(sql))
session.close()
return result
except Exception as e:
logger.error(f"执行SQL失败: {e}")
raise
def atomic(self):
"""事务上下文管理器(兼容性方法)"""
return SQLAlchemyTransaction()
class SQLAlchemyTransaction:
"""SQLAlchemy事务上下文管理器"""
def __init__(self):
self.session = None
def __enter__(self):
self.session = get_session()
return self.session
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.session.commit()
else:
self.session.rollback()
self.session.close()
# 创建全局数据库代理实例
db = DatabaseProxy()
def __create_database_instance():
uri = os.getenv("MONGODB_URI")
@@ -39,7 +128,7 @@ def __create_database_instance():
def get_db():
"""获取数据库连接实例,延迟初始化。"""
"""获取MongoDB连接实例,延迟初始化。"""
global _client, _db
if _client is None:
_client = __create_database_instance()
@@ -47,6 +136,47 @@ def get_db():
return _db
def initialize_sql_database(database_config):
"""
根据配置初始化SQL数据库连接SQLAlchemy版本
Args:
database_config: DatabaseConfig对象
"""
global _sql_engine
try:
logger.info("使用SQLAlchemy初始化SQL数据库...")
# 记录数据库配置信息
if database_config.database_type == "mysql":
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
logger.info("MySQL数据库连接配置:")
logger.info(f" 连接信息: {connection_info}")
logger.info(f" 字符集: {database_config.mysql_charset}")
else:
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if not os.path.isabs(database_config.sqlite_path):
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
else:
db_path = database_config.sqlite_path
logger.info("SQLite数据库连接配置:")
logger.info(f" 数据库文件: {db_path}")
# 使用SQLAlchemy初始化
success = initialize_database_compat()
if success:
_sql_engine = get_engine()
logger.info("SQLAlchemy数据库初始化成功")
else:
logger.error("SQLAlchemy数据库初始化失败")
return _sql_engine
except Exception as e:
logger.error(f"初始化SQL数据库失败: {e}")
return None
class DBWrapper:
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
@@ -57,26 +187,6 @@ class DBWrapper:
return get_db()[key] # type: ignore
# 全局数据库访问点
# 全局MongoDB数据库访问点
memory_db: Database = DBWrapper() # type: ignore
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(
_DB_FILE,
pragmas={
"journal_mode": "wal", # WAL模式提高并发性能
"cache_size": -64 * 1000, # 64MB缓存
"foreign_keys": 1,
"ignore_check_constraints": 0,
"synchronous": 0, # 异步写入提高性能
"busy_timeout": 1000, # 1秒超时而不是3秒
},
)

View File

@@ -0,0 +1,420 @@
"""SQLAlchemy数据库API模块
提供基于SQLAlchemy的数据库操作替换Peewee以解决MySQL连接问题
支持自动重连、连接池管理和更好的错误处理
"""
import traceback
import time
from typing import Dict, List, Any, Union, Type, Optional
from contextlib import contextmanager
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError
from sqlalchemy import desc, asc, func, and_, or_
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import (
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
Expression, ThinkingLog, GraphNodes, GraphEdges,get_session
)
logger = get_logger("sqlalchemy_database_api")
# 模型映射表,用于通过名称获取模型类
MODEL_MAPPING = {
'Messages': Messages,
'ActionRecords': ActionRecords,
'PersonInfo': PersonInfo,
'ChatStreams': ChatStreams,
'LLMUsage': LLMUsage,
'Emoji': Emoji,
'Images': Images,
'ImageDescriptions': ImageDescriptions,
'OnlineTime': OnlineTime,
'Memory': Memory,
'Expression': Expression,
'ThinkingLog': ThinkingLog,
'GraphNodes': GraphNodes,
'GraphEdges': GraphEdges,
}
@contextmanager
def get_db_session():
"""数据库会话上下文管理器,自动处理事务和连接错误"""
session = None
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
session = get_session()
yield session
session.commit()
break
except (DisconnectionError, OperationalError) as e:
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
if session:
session.rollback()
session.close()
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
else:
raise
except Exception as e:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
"""构建查询过滤条件"""
conditions = []
for field_name, value in filters.items():
if not hasattr(model_class, field_name):
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
continue
field = getattr(model_class, field_name)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(~field.in_(op_value))
else:
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
else:
# 直接相等比较
conditions.append(field == value)
return conditions
async def db_query(
model_class: Type[Base],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
Args:
model_class: SQLAlchemy模型类
data: 用于创建或更新的数据字典
query_type: 查询类型 ("get", "create", "update", "delete", "count")
filters: 过滤条件字典
limit: 限制结果数量
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回相应结果
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
with get_db_session() as session:
if query_type == "get":
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 应用排序
if order_by:
for field_name in order_by:
if field_name.startswith("-"):
field_name = field_name[1:]
if hasattr(model_class, field_name):
query = query.order_by(desc(getattr(model_class, field_name)))
else:
if hasattr(model_class, field_name):
query = query.order_by(asc(getattr(model_class, field_name)))
# 应用限制
if limit and limit > 0:
query = query.limit(limit)
# 执行查询
results = query.all()
# 转换为字典格式
result_dicts = []
for result in results:
result_dict = {}
for column in result.__table__.columns:
result_dict[column.name] = getattr(result, column.name)
result_dicts.append(result_dict)
if single_result:
return result_dicts[0] if result_dicts else None
return result_dicts
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建新记录
new_record = model_class(**data)
session.add(new_record)
session.flush() # 获取自动生成的ID
# 转换为字典格式返回
result_dict = {}
for column in new_record.__table__.columns:
result_dict[column.name] = getattr(new_record, column.name)
return result_dict
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 执行更新
affected_rows = query.update(data)
return affected_rows
elif query_type == "delete":
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 执行删除
affected_rows = query.delete()
return affected_rows
elif query_type == "count":
query = session.query(func.count(model_class.id))
# 应用过滤条件
if filters:
base_query = session.query(model_class)
conditions = build_filters(session, model_class, filters)
if conditions:
base_query = base_query.filter(and_(*conditions))
query = session.query(func.count()).select_from(base_query.subquery())
return query.scalar()
except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
except Exception as e:
logger.error(f"[SQLAlchemy] 意外错误: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
return None
async def db_save(
model_class: Type[Base],
data: Dict[str, Any],
key_field: Optional[str] = None,
key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
"""保存数据到数据库(创建或更新)
Args:
model_class: SQLAlchemy模型类
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名
key_value: 用于查找现有记录的字段值
Returns:
保存后的记录数据或None
"""
try:
with get_db_session() as session:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
if hasattr(model_class, key_field):
existing_record = session.query(model_class).filter(
getattr(model_class, key_field) == key_value
).first()
if existing_record:
# 更新现有记录
for field, value in data.items():
if hasattr(existing_record, field):
setattr(existing_record, field, value)
session.flush()
# 转换为字典格式返回
result_dict = {}
for column in existing_record.__table__.columns:
result_dict[column.name] = getattr(existing_record, column.name)
return result_dict
# 创建新记录
new_record = model_class(**data)
session.add(new_record)
session.flush()
# 转换为字典格式返回
result_dict = {}
for column in new_record.__table__.columns:
result_dict[column.name] = getattr(new_record, column.name)
return result_dict
except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
except Exception as e:
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
traceback.print_exc()
return None
async def db_get(
model_class: Type[Base],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
Args:
model_class: SQLAlchemy模型类
filters: 过滤条件
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
记录数据或None
"""
order_by_list = [order_by] if order_by else None
return await db_query(
model_class=model_class,
query_type="get",
filters=filters,
limit=limit,
order_by=order_by_list,
single_result=single_result
)
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
Args:
chat_stream: 聊天流对象
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
保存的记录数据或None
"""
try:
import json
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update({
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
})
else:
record_data.update({
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
})
# 保存记录
saved_record = await db_save(
ActionRecords,
data=record_data,
key_field="action_id",
key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None
# 兼容性函数方便从Peewee迁移
def get_model_class(model_name: str) -> Optional[Type[Base]]:
"""根据模型名称获取模型类"""
return MODEL_MAPPING.get(model_name)

View File

@@ -0,0 +1,158 @@
"""SQLAlchemy数据库初始化模块
替换Peewee的数据库初始化逻辑
提供统一的数据库初始化接口
"""
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import (
Base, get_engine, get_session, initialize_database
)
logger = get_logger("sqlalchemy_init")
def initialize_sqlalchemy_database() -> bool:
"""
初始化SQLAlchemy数据库
创建所有表结构
Returns:
bool: 初始化是否成功
"""
try:
logger.info("开始初始化SQLAlchemy数据库...")
# 初始化数据库引擎和会话
engine, session_local = initialize_database()
if engine is None:
logger.error("数据库引擎初始化失败")
return False
logger.info("SQLAlchemy数据库初始化成功")
return True
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
return False
except Exception as e:
logger.error(f"数据库初始化过程中发生未知错误: {e}")
return False
def create_all_tables() -> bool:
"""
创建所有数据库表
Returns:
bool: 创建是否成功
"""
try:
logger.info("开始创建数据库表...")
engine = get_engine()
if engine is None:
logger.error("无法获取数据库引擎")
return False
# 创建所有表
Base.metadata.create_all(bind=engine)
logger.info("数据库表创建成功")
return True
except SQLAlchemyError as e:
logger.error(f"创建数据库表失败: {e}")
return False
except Exception as e:
logger.error(f"创建数据库表过程中发生未知错误: {e}")
return False
def check_database_connection() -> bool:
"""
检查数据库连接是否正常
Returns:
bool: 连接是否正常
"""
try:
session = get_session()
if session is None:
logger.error("无法获取数据库会话")
return False
# 检查会话是否可用(如果能获取到会话说明连接正常)
if session is None:
logger.error("数据库会话无效")
return False
session.close()
logger.info("数据库连接检查通过")
return True
except SQLAlchemyError as e:
logger.error(f"数据库连接检查失败: {e}")
return False
except Exception as e:
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
return False
def get_database_info() -> Optional[dict]:
"""
获取数据库信息
Returns:
dict: 数据库信息字典,包含引擎信息等
"""
try:
engine = get_engine()
if engine is None:
return None
info = {
'engine_name': engine.name,
'driver': engine.driver,
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
'pool_size': getattr(engine.pool, 'size', None),
'max_overflow': getattr(engine.pool, 'max_overflow', None),
}
return info
except Exception as e:
logger.error(f"获取数据库信息失败: {e}")
return None
_database_initialized = False
def initialize_database_compat() -> bool:
"""
兼容性数据库初始化函数
用于替换原有的Peewee初始化代码
Returns:
bool: 初始化是否成功
"""
global _database_initialized
if _database_initialized:
return True
success = initialize_sqlalchemy_database()
if success:
success = create_all_tables()
if success:
success = check_database_connection()
if success:
_database_initialized = True
return success

View File

@@ -0,0 +1,555 @@
"""SQLAlchemy数据库模型定义
替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力
"""
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
import os
import datetime
from src.config.config import global_config
from src.common.logger import get_logger
import threading
from contextlib import contextmanager
logger = get_logger("sqlalchemy_models")
# 创建基类
Base = declarative_base()
# MySQL兼容的字段类型辅助函数
def get_string_field(max_length=255, **kwargs):
"""
根据数据库类型返回合适的字符串字段
MySQL需要指定长度的VARCHAR用于索引SQLite可以使用Text
"""
if global_config.database.database_type == "mysql":
return String(max_length, **kwargs)
else:
return Text(**kwargs)
class SessionProxy:
"""线程安全的Session代理类自动管理session生命周期"""
def __init__(self):
self._local = threading.local()
def _get_current_session(self):
"""获取当前线程的session如果没有则创建新的"""
if not hasattr(self._local, 'session') or self._local.session is None:
_, SessionLocal = initialize_database()
self._local.session = SessionLocal()
return self._local.session
def _close_current_session(self):
"""关闭当前线程的session"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
self._local.session.close()
except:
pass
finally:
self._local.session = None
def __getattr__(self, name):
"""代理所有session方法"""
session = self._get_current_session()
attr = getattr(session, name)
# 如果是方法,需要特殊处理一些关键方法
if callable(attr):
if name in ['commit', 'rollback']:
def wrapper(*args, **kwargs):
try:
result = attr(*args, **kwargs)
if name == 'commit':
# commit后不要清除session只是刷新状态
pass # 保持session活跃
return result
except Exception as e:
try:
if session and hasattr(session, 'rollback'):
session.rollback()
except:
pass
# 发生错误时重新创建session
self._close_current_session()
raise
return wrapper
elif name == 'close':
def wrapper(*args, **kwargs):
result = attr(*args, **kwargs)
self._close_current_session()
return result
return wrapper
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
def wrapper(*args, **kwargs):
try:
return attr(*args, **kwargs)
except Exception as e:
# 如果是连接相关错误重新创建session再试一次
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
logger.warning(f"Session问题重新创建session: {e}")
self._close_current_session()
new_session = self._get_current_session()
new_attr = getattr(new_session, name)
return new_attr(*args, **kwargs)
raise
return wrapper
return attr
def new_session(self):
"""强制创建新的session关闭当前的创建新的"""
self._close_current_session()
return self._get_current_session()
def ensure_fresh_session(self):
"""确保使用新鲜的session如果当前session有问题则重新创建"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
# 测试session是否还可用
self._local.session.execute("SELECT 1")
except Exception:
# session有问题重新创建
self._close_current_session()
return self._get_current_session()
# 创建全局session代理实例
_global_session_proxy = SessionProxy()
def get_session():
"""返回线程安全的session代理自动管理生命周期"""
return _global_session_proxy
class ChatStreams(Base):
"""聊天流模型"""
__tablename__ = 'chat_streams'
id = Column(Integer, primary_key=True, autoincrement=True)
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
create_time = Column(Float, nullable=False)
group_platform = Column(Text, nullable=True)
group_id = Column(get_string_field(100), nullable=True, index=True)
group_name = Column(Text, nullable=True)
last_active_time = Column(Float, nullable=False)
platform = Column(Text, nullable=False)
user_platform = Column(Text, nullable=False)
user_id = Column(get_string_field(100), nullable=False, index=True)
user_nickname = Column(Text, nullable=False)
user_cardname = Column(Text, nullable=True)
__table_args__ = (
Index('idx_chatstreams_stream_id', 'stream_id'),
Index('idx_chatstreams_user_id', 'user_id'),
Index('idx_chatstreams_group_id', 'group_id'),
)
class LLMUsage(Base):
"""LLM使用记录模型"""
__tablename__ = 'llm_usage'
id = Column(Integer, primary_key=True, autoincrement=True)
model_name = Column(get_string_field(100), nullable=False, index=True)
user_id = Column(get_string_field(50), nullable=False, index=True)
request_type = Column(get_string_field(50), nullable=False, index=True)
endpoint = Column(Text, nullable=False)
prompt_tokens = Column(Integer, nullable=False)
completion_tokens = Column(Integer, nullable=False)
total_tokens = Column(Integer, nullable=False)
cost = Column(Float, nullable=False)
status = Column(Text, nullable=False)
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
__table_args__ = (
Index('idx_llmusage_model_name', 'model_name'),
Index('idx_llmusage_user_id', 'user_id'),
Index('idx_llmusage_request_type', 'request_type'),
Index('idx_llmusage_timestamp', 'timestamp'),
)
class Emoji(Base):
"""表情包模型"""
__tablename__ = 'emoji'
id = Column(Integer, primary_key=True, autoincrement=True)
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
format = Column(Text, nullable=False)
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
query_count = Column(Integer, nullable=False, default=0)
is_registered = Column(Boolean, nullable=False, default=False)
is_banned = Column(Boolean, nullable=False, default=False)
emotion = Column(Text, nullable=True)
record_time = Column(Float, nullable=False)
register_time = Column(Float, nullable=True)
usage_count = Column(Integer, nullable=False, default=0)
last_used_time = Column(Float, nullable=True)
__table_args__ = (
Index('idx_emoji_full_path', 'full_path'),
Index('idx_emoji_hash', 'emoji_hash'),
)
class Messages(Base):
"""消息模型"""
__tablename__ = 'messages'
id = Column(Integer, primary_key=True, autoincrement=True)
message_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
reply_to = Column(Text, nullable=True)
interest_value = Column(Float, nullable=True)
is_mentioned = Column(Boolean, nullable=True)
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
chat_info_user_platform = Column(Text, nullable=False)
chat_info_user_id = Column(Text, nullable=False)
chat_info_user_nickname = Column(Text, nullable=False)
chat_info_user_cardname = Column(Text, nullable=True)
chat_info_group_platform = Column(Text, nullable=True)
chat_info_group_id = Column(Text, nullable=True)
chat_info_group_name = Column(Text, nullable=True)
chat_info_create_time = Column(Float, nullable=False)
chat_info_last_active_time = Column(Float, nullable=False)
# 从顶层 user_info 扁平化而来的字段
user_platform = Column(Text, nullable=True)
user_id = Column(get_string_field(100), nullable=True, index=True)
user_nickname = Column(Text, nullable=True)
user_cardname = Column(Text, nullable=True)
processed_plain_text = Column(Text, nullable=True)
display_message = Column(Text, nullable=True)
memorized_times = Column(Integer, nullable=False, default=0)
priority_mode = Column(Text, nullable=True)
priority_info = Column(Text, nullable=True)
additional_config = Column(Text, nullable=True)
is_emoji = Column(Boolean, nullable=False, default=False)
is_picid = Column(Boolean, nullable=False, default=False)
is_command = Column(Boolean, nullable=False, default=False)
is_notify = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index('idx_messages_message_id', 'message_id'),
Index('idx_messages_chat_id', 'chat_id'),
Index('idx_messages_time', 'time'),
Index('idx_messages_user_id', 'user_id'),
)
class ActionRecords(Base):
"""动作记录模型"""
__tablename__ = 'action_records'
id = Column(Integer, primary_key=True, autoincrement=True)
action_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
action_name = Column(Text, nullable=False)
action_data = Column(Text, nullable=False)
action_done = Column(Boolean, nullable=False, default=False)
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
action_prompt_display = Column(Text, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
__table_args__ = (
Index('idx_actionrecords_action_id', 'action_id'),
Index('idx_actionrecords_chat_id', 'chat_id'),
Index('idx_actionrecords_time', 'time'),
)
class Images(Base):
"""图像信息模型"""
__tablename__ = 'images'
id = Column(Integer, primary_key=True, autoincrement=True)
image_id = Column(Text, nullable=False, default="")
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=True)
path = Column(get_string_field(500), nullable=False, unique=True)
count = Column(Integer, nullable=False, default=1)
timestamp = Column(Float, nullable=False)
type = Column(Text, nullable=False)
vlm_processed = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index('idx_images_emoji_hash', 'emoji_hash'),
Index('idx_images_path', 'path'),
)
class ImageDescriptions(Base):
"""图像描述信息模型"""
__tablename__ = 'image_descriptions'
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(Text, nullable=False)
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
timestamp = Column(Float, nullable=False)
__table_args__ = (
Index('idx_imagedesc_hash', 'image_description_hash'),
)
class OnlineTime(Base):
"""在线时长记录模型"""
__tablename__ = 'online_time'
id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
duration = Column(Integer, nullable=False)
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
end_timestamp = Column(DateTime, nullable=False, index=True)
__table_args__ = (
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
)
class PersonInfo(Base):
"""人物信息模型"""
__tablename__ = 'person_info'
id = Column(Integer, primary_key=True, autoincrement=True)
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
person_name = Column(Text, nullable=True)
name_reason = Column(Text, nullable=True)
platform = Column(Text, nullable=False)
user_id = Column(get_string_field(50), nullable=False, index=True)
nickname = Column(Text, nullable=True)
impression = Column(Text, nullable=True)
short_impression = Column(Text, nullable=True)
points = Column(Text, nullable=True)
forgotten_points = Column(Text, nullable=True)
info_list = Column(Text, nullable=True)
know_times = Column(Float, nullable=True)
know_since = Column(Float, nullable=True)
last_know = Column(Float, nullable=True)
attitude = Column(Integer, nullable=True, default=50)
__table_args__ = (
Index('idx_personinfo_person_id', 'person_id'),
Index('idx_personinfo_user_id', 'user_id'),
)
class Memory(Base):
"""记忆模型"""
__tablename__ = 'memory'
id = Column(Integer, primary_key=True, autoincrement=True)
memory_id = Column(get_string_field(64), nullable=False, index=True)
chat_id = Column(Text, nullable=True)
memory_text = Column(Text, nullable=True)
keywords = Column(Text, nullable=True)
create_time = Column(Float, nullable=True)
last_view_time = Column(Float, nullable=True)
__table_args__ = (
Index('idx_memory_memory_id', 'memory_id'),
)
class Expression(Base):
"""表达风格模型"""
__tablename__ = 'expression'
id = Column(Integer, primary_key=True, autoincrement=True)
situation = Column(Text, nullable=False)
style = Column(Text, nullable=False)
count = Column(Float, nullable=False)
last_active_time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
type = Column(Text, nullable=False)
create_date = Column(Float, nullable=True)
__table_args__ = (
Index('idx_expression_chat_id', 'chat_id'),
)
class ThinkingLog(Base):
"""思考日志模型"""
__tablename__ = 'thinking_logs'
id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(get_string_field(64), nullable=False, index=True)
trigger_text = Column(Text, nullable=True)
response_text = Column(Text, nullable=True)
trigger_info_json = Column(Text, nullable=True)
response_info_json = Column(Text, nullable=True)
timing_results_json = Column(Text, nullable=True)
chat_history_json = Column(Text, nullable=True)
chat_history_in_thinking_json = Column(Text, nullable=True)
chat_history_after_response_json = Column(Text, nullable=True)
heartflow_data_json = Column(Text, nullable=True)
reasoning_data_json = Column(Text, nullable=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (
Index('idx_thinkinglog_chat_id', 'chat_id'),
)
class GraphNodes(Base):
"""记忆图节点模型"""
__tablename__ = 'graph_nodes'
id = Column(Integer, primary_key=True, autoincrement=True)
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
memory_items = Column(Text, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index('idx_graphnodes_concept', 'concept'),
)
class GraphEdges(Base):
"""记忆图边模型"""
__tablename__ = 'graph_edges'
id = Column(Integer, primary_key=True, autoincrement=True)
source = Column(get_string_field(255), nullable=False, index=True)
target = Column(get_string_field(255), nullable=False, index=True)
strength = Column(Integer, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index('idx_graphedges_source', 'source'),
Index('idx_graphedges_target', 'target'),
)
# 数据库引擎和会话管理
_engine = None
_SessionLocal = None
def get_database_url():
"""获取数据库连接URL"""
config = global_config.database
if config.database_type == "mysql":
# 对用户名和密码进行URL编码处理特殊字符
from urllib.parse import quote_plus
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
return (
f"mysql+pymysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
else: # SQLite
# 如果是相对路径,则相对于项目根目录
if not os.path.isabs(config.sqlite_path):
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
return f"sqlite:///{db_path}"
def initialize_database():
"""初始化数据库引擎和会话"""
global _engine, _SessionLocal
if _engine is not None:
return _engine, _SessionLocal
database_url = get_database_url()
config = global_config.database
# 配置引擎参数
engine_kwargs = {
'echo': False, # 生产环境关闭SQL日志
'future': True,
}
if config.database_type == "mysql":
# MySQL连接池配置
engine_kwargs.update({
'poolclass': QueuePool,
'pool_size': config.connection_pool_size,
'max_overflow': config.connection_pool_size * 2,
'pool_timeout': config.connection_timeout,
'pool_recycle': 3600, # 1小时回收连接
'pool_pre_ping': True, # 连接前ping检查
'connect_args': {
'autocommit': config.mysql_autocommit,
'charset': config.mysql_charset,
'connect_timeout': config.connection_timeout,
'read_timeout': 30,
'write_timeout': 30,
}
})
else:
# SQLite配置 - 添加连接池设置以避免连接耗尽
engine_kwargs.update({
'poolclass': QueuePool,
'pool_size': 20, # 增加池大小
'max_overflow': 30, # 增加溢出连接数
'pool_timeout': 60, # 增加超时时间
'pool_recycle': 3600, # 1小时回收连接
'pool_pre_ping': True, # 连接前ping检查
'connect_args': {
'check_same_thread': False,
'timeout': 30,
}
})
_engine = create_engine(database_url, **engine_kwargs)
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
# 创建所有表
Base.metadata.create_all(bind=_engine)
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal
@contextmanager
def get_db_session():
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
session = None
try:
_, SessionLocal = initialize_database()
session = SessionLocal()
yield session
session.commit()
except Exception as e:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def get_engine():
"""获取数据库引擎"""
engine, _ = initialize_database()
return engine

View File

@@ -373,6 +373,7 @@ MODULE_COLORS = {
"base_command": "\033[38;5;208m", # 橙色
"component_registry": "\033[38;5;214m", # 橙黄色
"stream_api": "\033[38;5;220m", # 黄色
"plugin_hot_reload": "\033[38;5;226m", #品红色
"config_api": "\033[38;5;226m", # 亮黄色
"heartflow_api": "\033[38;5;154m", # 黄绿色
"action_apis": "\033[38;5;118m", # 绿色
@@ -406,6 +407,7 @@ MODULE_COLORS = {
"base_action": "\033[38;5;250m", # 浅灰色
# 数据库和消息
"database_model": "\033[38;5;94m", # 橙褐色
"database": "\033[38;5;46m", # 橙褐色
"maim_message": "\033[38;5;140m", # 紫褐色
# 日志系统
"logger": "\033[38;5;8m", # 深灰色
@@ -430,6 +432,8 @@ MODULE_ALIASES = {
"memory_activator": "记忆",
"tool_use": "工具",
"expressor": "表达方式",
"plugin_hot_reload": "热重载",
"database": "数据库",
"database_model": "数据库",
"mood": "情绪",
"memory": "记忆",

View File

@@ -1,20 +1,26 @@
import traceback
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
from typing import List, Optional, Any, Dict
from sqlalchemy import not_, select, func
from sqlalchemy.orm import DeclarativeBase
from src.config.config import global_config
from src.common.database.database_model import Messages
# from src.common.database.database_model import Messages
from src.common.database.sqlalchemy_models import Messages
from src.common.database.sqlalchemy_database_api import get_session
from src.common.logger import get_logger
logger = get_logger(__name__)
class Base(DeclarativeBase):
pass
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
def _model_to_dict(instance: Base) -> Dict[str, Any]:
"""
Peewee 模型实例转换为字典。
SQLAlchemy 模型实例转换为字典。
"""
return model_instance.__data__
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
def find_messages(
@@ -38,7 +44,8 @@ def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
query = Messages.select()
session = get_session()
query = select(Messages)
# 应用过滤器
if message_filter:
@@ -77,42 +84,57 @@ def find_messages(
query = query.where(Messages.user_id != global_config.bot.qq_account)
if filter_command:
query = query.where(not Messages.is_command)
query = query.where(not_(Messages.is_command))
if limit > 0:
# 确保limit是正整数
limit = max(1, int(limit))
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
peewee_results = list(query)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
try:
latest_results = session.execute(query).scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
logger.error(f"执行latest查询失败: {e}")
results = []
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
peewee_sort_terms = []
sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
sort_terms.append(field.asc())
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
return [_model_to_dict(msg) for msg in peewee_results]
return [_model_to_dict(msg) for msg in results]
except Exception as e:
log_message = (
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
@@ -130,7 +152,8 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。
"""
try:
query = Messages.select()
session = get_session()
query = select(func.count(Messages.id))
# 应用过滤器
if message_filter:
@@ -167,14 +190,14 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if conditions:
query = query.where(*conditions)
count = query.count()
return count
count = session.execute(query).scalar()
return count or 0
except Exception as e:
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
# 注意:对于 SQLAlchemy插入操作通常是使用 session.add() 和 session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。

View File

@@ -13,6 +13,7 @@ from typing import List, Optional
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
from src.config.official_configs import (
DatabaseConfig,
BotConfig,
PersonalityConfig,
ExpressionConfig,
@@ -340,6 +341,7 @@ class Config(ConfigBase):
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
database: DatabaseConfig
bot: BotConfig
personality: PersonalityConfig
relationship: RelationshipConfig
@@ -466,4 +468,25 @@ update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
# 初始化数据库连接
logger.info("正在初始化数据库连接...")
from src.common.database.database import initialize_sql_database
try:
initialize_sql_database(global_config.database)
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
except Exception as e:
logger.error(f"数据库连接初始化失败: {e}")
raise e
# 初始化数据库表结构
logger.info("正在初始化数据库表结构...")
from src.common.database.sqlalchemy_models import initialize_database as init_db
try:
init_db()
logger.info("数据库表结构初始化完成")
except Exception as e:
logger.error(f"数据库表结构初始化失败: {e}")
raise e
logger.info("非常的新鲜,非常的美味!")

View File

@@ -13,6 +13,65 @@ from src.config.config_base import ConfigBase
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
@dataclass
class DatabaseConfig(ConfigBase):
"""数据库配置类"""
database_type: Literal["sqlite", "mysql"] = "sqlite"
"""数据库类型,支持 sqlite 或 mysql"""
# SQLite 配置
sqlite_path: str = "data/MaiBot.db"
"""SQLite数据库文件路径"""
# MySQL 配置
mysql_host: str = "localhost"
"""MySQL服务器地址"""
mysql_port: int = 3306
"""MySQL服务器端口"""
mysql_database: str = "maibot"
"""MySQL数据库名"""
mysql_user: str = "root"
"""MySQL用户名"""
mysql_password: str = ""
"""MySQL密码"""
mysql_charset: str = "utf8mb4"
"""MySQL字符集"""
mysql_unix_socket: str = ""
"""MySQL Unix套接字路径可选用于本地连接优先于host/port"""
# MySQL SSL 配置
mysql_ssl_mode: str = "DISABLED"
"""SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY"""
mysql_ssl_ca: str = ""
"""SSL CA证书路径"""
mysql_ssl_cert: str = ""
"""SSL客户端证书路径"""
mysql_ssl_key: str = ""
"""SSL客户端密钥路径"""
# MySQL 高级配置
mysql_autocommit: bool = True
"""自动提交事务"""
mysql_sql_mode: str = "TRADITIONAL"
"""SQL模式"""
# 连接池配置
connection_pool_size: int = 10
"""连接池大小仅MySQL有效"""
connection_timeout: int = 10
"""连接超时时间(秒)"""
@dataclass
class BotConfig(ConfigBase):
@@ -72,6 +131,19 @@ class ChatConfig(ConfigBase):
max_context_size: int = 18
"""上下文长度"""
replyer_random_probability: float = 0.5
"""
发言时选择推理模型的概率0-1之间
选择普通模型的概率为 1 - reasoning_normal_model_probability
"""
thinking_timeout: int = 40
"""麦麦最长思考规划时间超过这个时间的思考会放弃往往是api反应太慢"""
talk_frequency: float = 1
"""回复频率阈值"""
mentioned_bot_inevitable_reply: bool = False
"""提及 bot 必然回复"""
@@ -93,17 +165,17 @@ class ChatConfig(ConfigBase):
"""
统一的活跃度和专注度配置
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
全局配置示例:
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
特定聊天流配置示例:
[
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
]
说明:
- 当第一个元素为空字符串""时,表示全局默认配置
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
@@ -155,72 +227,11 @@ class ChatConfig(ConfigBase):
# 检查全局时段配置(第一个元素为空字符串的配置)
global_frequency = self._get_global_frequency()
return self.talk_frequency if global_frequency is None else global_frequency
def _get_global_focus_value(self) -> Optional[float]:
"""
获取全局默认专注度配置
if global_frequency is not None:
return global_frequency
Returns:
float: 专注度值,如果没有配置则返回 None
"""
for config_item in self.focus_value_adjust:
if not config_item or len(config_item) < 2:
continue
# 检查是否为全局默认配置(第一个元素为空字符串)
if config_item[0] == "":
return self._get_time_based_focus_value(config_item[1:])
return None
def _get_time_based_focus_value(self, time_focus_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的专注度
Args:
time_focus_list: 时间专注度配置列表,格式为 ["HH:MM,focus_value", ...]
Returns:
float: 专注度值,如果没有配置则返回 None
"""
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
current_hour, current_minute = map(int, current_time.split(":"))
current_minutes = current_hour * 60 + current_minute
# 解析时间专注度配置
time_focus_pairs = []
for time_focus_str in time_focus_list:
try:
time_str, focus_str = time_focus_str.split(",")
hour, minute = map(int, time_str.split(":"))
focus_value = float(focus_str)
minutes = hour * 60 + minute
time_focus_pairs.append((minutes, focus_value))
except (ValueError, IndexError):
continue
if not time_focus_pairs:
return None
# 按时间排序
time_focus_pairs.sort(key=lambda x: x[0])
# 查找当前时间对应的专注度
current_focus_value = None
for minutes, focus_value in time_focus_pairs:
if current_minutes >= minutes:
current_focus_value = focus_value
else:
break
# 如果当前时间在所有配置时间之前,使用最后一个时间段的专注度(跨天逻辑)
if current_focus_value is None and time_focus_pairs:
current_focus_value = time_focus_pairs[-1][1]
return current_focus_value
# 如果都没有匹配,返回默认值
return self.talk_frequency
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
"""
@@ -395,6 +406,14 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表"""
@dataclass
class NormalChatConfig(ConfigBase):
"""普通聊天配置类"""
willing_mode: str = "classical"
"""意愿模式"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
@@ -403,14 +422,14 @@ class ExpressionConfig(ConfigBase):
"""
表达学习配置列表,支持按聊天流配置
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
示例:
[
["", "enable", "enable", 1.0], # 全局配置使用表达启用学习学习强度1.0
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置使用表达启用学习学习强度1.5
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置使用表达禁用学习学习强度0.5
]
说明:
- 第一位: chat_stream_id空字符串表示全局配置
- 第二位: 是否使用学到的表达 ("enable"/"disable")
@@ -475,14 +494,14 @@ class ExpressionConfig(ConfigBase):
# 优先检查聊天流特定的配置
if chat_stream_id:
specific_expression_config = self._get_stream_specific_config(chat_stream_id)
if specific_expression_config is not None:
return specific_expression_config
specific_config = self._get_stream_specific_config(chat_stream_id)
if specific_config is not None:
return specific_config
# 检查全局配置(第一个元素为空字符串的配置)
global_expression_config = self._get_global_config()
if global_expression_config is not None:
return global_expression_config
global_config = self._get_global_config()
if global_config is not None:
return global_config
# 如果都没有匹配,返回默认值
return True, True, 300
@@ -518,10 +537,10 @@ class ExpressionConfig(ConfigBase):
# 解析配置
try:
use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable"
learning_intensity: float = float(config_item[3])
return use_expression, enable_learning, learning_intensity # type: ignore
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
@@ -541,10 +560,10 @@ class ExpressionConfig(ConfigBase):
# 检查是否为全局配置(第一个元素为空字符串)
if config_item[0] == "":
try:
use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable"
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity # type: ignore
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
@@ -558,7 +577,6 @@ class ToolConfig(ConfigBase):
enable_tool: bool = False
"""是否在聊天中启用工具"""
@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
@@ -703,7 +721,6 @@ class KeywordReactionConfig(ConfigBase):
if not isinstance(rule, KeywordRuleConfig):
raise ValueError(f"规则必须是KeywordRuleConfig类型而不是{type(rule).__name__}")
@dataclass
class CustomPromptConfig(ConfigBase):
"""自定义提示词配置类"""
@@ -852,3 +869,4 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致"""

View File

@@ -5,8 +5,7 @@ from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage
from src.common.database.sqlalchemy_models import LLMUsage, get_session
from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
@@ -143,16 +142,9 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
class LLMUsageRecorder:
"""
LLM使用情况记录器
LLM使用情况记录器SQLAlchemy版本
"""
def __init__(self):
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
@@ -160,9 +152,13 @@ class LLMUsageRecorder:
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
session = None
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
# 使用 SQLAlchemy 会话创建记录
session = get_session()
usage_record = LLMUsage(
model_name=model_info.model_identifier,
model_assign_name=model_info.name,
model_api_provider=model_info.api_provider,
@@ -175,8 +171,12 @@ class LLMUsageRecorder:
cost=total_cost or 0.0,
time_cost = round(time_cost or 0.0, 3),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
)
session.add(usage_record)
session.commit()
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
@@ -184,6 +184,11 @@ class LLMUsageRecorder:
f"总计: {model_usage.total_tokens}"
)
except Exception as e:
if session:
session.rollback()
logger.error(f"记录token使用情况失败: {str(e)}")
finally:
if session:
session.close()
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -1,5 +1,7 @@
import asyncio
import time
import signal
import sys
from maim_message import MessageServer
from src.common.remote import TelemetryHeartBeatTask
@@ -17,8 +19,9 @@ from rich.traceback import install
from src.migrate_helper.migrate import check_and_run_migrations
# from src.api.main import start_api_server
# 导入新的插件管理器
# 导入新的插件管理器和热重载管理器
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
# 导入消息API和traceback模块
from src.common.message import get_global_api
@@ -48,6 +51,28 @@ class MainSystem:
self.app: MessageServer = get_global_api()
self.server: Server = get_global_server()
# 设置信号处理器用于优雅退出
self._setup_signal_handlers()
def _setup_signal_handlers(self):
"""设置信号处理器"""
def signal_handler(signum, frame):
logger.info("收到退出信号,正在优雅关闭系统...")
self._cleanup()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
def _cleanup(self):
"""清理资源"""
try:
# 停止插件热重载系统
hot_reload_manager.stop()
logger.info("🛑 插件热重载系统已停止")
except Exception as e:
logger.error(f"停止热重载系统时出错: {e}")
async def initialize(self):
"""初始化系统组件"""
logger.info(f"正在唤醒{global_config.bot.nickname}......")
@@ -58,14 +83,7 @@ class MainSystem:
logger.info(f"""
--------------------------------
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
--------------------------------
如果想要自定义{global_config.bot.nickname}的功能,请查阅https://docs.mai-mai.org/manual/usage/
或者遇到了问题,请访问我们的文档:https://docs.mai-mai.org/
--------------------------------
如果你想要编写或了解插件相关内容请访问开发文档https://docs.mai-mai.org/develop/
--------------------------------
如果你需要查阅模型的消耗以及麦麦的统计数据请访问根目录的maibot_statistics.html文件
""")
--------------------------------""")
async def _init_components(self):
"""初始化其他组件"""
@@ -87,6 +105,10 @@ class MainSystem:
# 加载所有actions包括默认的和插件的
plugin_manager.load_all_plugins()
# 启动插件热重载系统
hot_reload_manager.start()
# 初始化表情管理器
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")

View File

View File

@@ -2,17 +2,17 @@ import hashlib
import asyncio
import json
import time
import random
from json_repair import repair_json
from typing import Union
from typing import Any, Callable, Dict, Union, Optional
from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo
from src.common.database.sqlalchemy_models import PersonInfo
from src.common.database.sqlalchemy_database_api import get_session
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
session = get_session()
logger = get_logger("person_info")
@@ -380,36 +380,282 @@ class Person:
return relation_info
# 统一的会话管理函数
def with_session(func):
"""装饰器为函数自动注入session参数"""
if asyncio.iscoroutinefunction(func):
async def async_wrapper(*args, **kwargs):
return await func(session, *args, **kwargs)
return async_wrapper
else:
def sync_wrapper(*args, **kwargs):
return func(session, *args, **kwargs)
return sync_wrapper
# 全局会话获取函数用于替换所有裸露的session使用
def _get_session():
"""获取数据库会话的统一函数"""
return get_session()
class PersonInfoManager:
def __init__(self):
"""初始化PersonInfoManager"""
from src.common.database.sqlalchemy_models import PersonInfo
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try:
db.connect(reuse_if_open=True)
# 设置连接池参数
# 设置连接池参数仅对SQLite有效
if hasattr(db, "execute_sql"):
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
# 检查数据库类型只对SQLite执行PRAGMA语句
if global_config.database.database_type == "sqlite":
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
db.create_tables([PersonInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name
try:
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_null(False)
):
from src.common.database.sqlalchemy_models import PersonInfo
# 在这里获取会话
for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_not(None)
)).fetchall():
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
except Exception as e:
logger.error(f"Peewee 加载 person_name_list 失败: {e}")
logger.error(f"SQLAlchemy 加载 person_name_list 失败: {e}")
@staticmethod
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
platform = platform.split("-")[1]
components = [platform, str(user_id)]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人"""
person_id = self.get_person_id(platform, user_id)
def _db_check_known_sync(p_id: str):
# 在需要时获取会话
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
return False
def get_person_id_by_person_name(self, person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
# 在需要时获取会话
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
return ""
@staticmethod
async def create_person_info(person_id: str, data: Optional[dict] = None):
"""创建一个项"""
if not person_id:
logger.debug("创建失败person_id不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
final_data = {"person_id": person_id}
# Start with defaults for all model fields
for key, default_value in _person_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
# If it's already a string, assume it's valid JSON or a non-JSON string field
def _db_create_sync(p_data: dict):
try:
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
return True
except Exception as e:
session.rollback()
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
"""安全地创建用户信息,处理竞态条件"""
if not person_id:
logger.debug("创建失败person_id不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
final_data = {"person_id": person_id}
# Start with defaults for all model fields
for key, default_value in _person_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
def _db_safe_create_sync(p_data: dict):
try:
existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar()
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True
# 尝试创建
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
return True
except Exception as e:
session.rollback()
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
return True # 其他协程已创建,视为成功
else:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_safe_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
if field_name not in model_fields:
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义的字段。")
return
processed_value = value
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(value, (list, dict)):
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
elif value is None: # Store None as "[]" for JSON list fields
processed_value = json.dumps([], ensure_ascii=False, indent=None)
def _db_update_sync(p_id: str, f_name: str, val_to_set):
start_time = time.time()
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
query_time = time.time()
if record:
setattr(record, f_name, val_to_set)
session.commit()
save_time = time.time()
total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志
logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
)
return True, False # Found and updated, no creation needed
else:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
return False, True # Not found, needs creation
except Exception as e:
session.rollback()
total_time = time.time() - start_time
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
if needs_creation:
logger.info(f"{person_id} 不存在,将新建。")
creation_data = data if data is not None else {}
# Ensure platform and user_id are present for context if available from 'data'
# but primarily, set the field that triggered the update.
# The create_person_info will handle defaults and serialization.
creation_data[field_name] = value # Pass original value to create_person_info
# Ensure platform and user_id are in creation_data if available,
# otherwise create_person_info will use defaults.
if data and "platform" in data:
creation_data["platform"] = data["platform"]
if data and "user_id" in data:
creation_data["user_id"] = data["user_id"]
# 使用安全的创建方法,处理竞态条件
await self._safe_create_person_info(person_id, creation_data)
@staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段"""
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
if field_name not in model_fields:
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
return bool(record)
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}")
return False
@staticmethod
def _extract_json_from_text(text: str) -> dict:
@@ -513,12 +759,13 @@ class PersonInfoManager:
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
if not is_duplicate:
person.person_name = generated_nickname
person.name_reason = result.get("reason", "未提供理由")
@@ -547,4 +794,304 @@ class PersonInfoManager:
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
person_info_manager = PersonInfoManager()
@staticmethod
async def del_one_document(person_id: str):
"""删除指定 person_id 的文档"""
if not person_id:
logger.debug("删除失败person_id 不能为空")
return
def _db_delete_sync(p_id: str):
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
session.delete(record)
session.commit()
return 1
return 0
except Exception as e:
session.rollback()
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
else:
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
@staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(p_id: str, f_name: str):
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
val = getattr(record, f_name, None)
if f_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
return [] # Default for JSON fields
# If val is already a list/dict (e.g. if somehow set without serialization)
return val # Should ideally not happen if update_one_field is always used
return val
return None # Record not found
try:
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if value_from_db is not None:
return value_from_db
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None # Ultimate fallback
except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
return default_value_for_field if field_name in person_info_default else None
@staticmethod
def get_value_sync(person_id: str, field_name: str):
"""同步获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
return []
elif val is None:
return []
return val
return val
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None
@staticmethod
async def get_values(person_id: str, field_names: list) -> dict:
"""获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id:
logger.debug("get_values获取失败person_id不能为空")
return {}
result = {}
def _db_get_record_sync(p_id: str):
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
record = await asyncio.to_thread(_db_get_record_sync, person_id)
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
for field_name in field_names:
if field_name not in model_fields:
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.debug(f"字段'{field_name}'不在SQLAlchemy模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在SQLAlchemy模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
return result
@staticmethod
async def get_specific_value_list(
field_name: str,
way: Callable[[Any], bool],
) -> Dict[str, Any]:
"""
获取满足条件的字段值字典
"""
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
if field_name not in model_fields:
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义")
return {}
def _db_get_specific_sync(f_name: str):
found_results = {}
try:
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
value = getattr(record, f_name)
if way(value):
found_results[record.person_id] = value
except Exception as e_query:
logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results
try:
return await asyncio.to_thread(_db_get_specific_sync, field_name)
except Exception as e:
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
return {}
async def get_or_create_person(
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
) -> str:
"""
根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。
使用try-except处理竞态条件避免重复创建错误。
"""
person_id = self.get_person_id(platform, user_id)
def _db_get_or_create_sync(p_id: str, init_data: dict):
"""原子性的获取或创建操作"""
# 首先尝试获取现有记录
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建
try:
new_person = PersonInfo(**init_data)
session.add(new_person)
session.commit()
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功
except Exception as e:
session.rollback()
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
unique_nickname = await self._generate_unique_person_name(nickname)
initial_data = {
"person_id": person_id,
"platform": platform,
"user_id": str(user_id),
"nickname": nickname,
"person_name": unique_nickname, # 使用群昵称作为person_name
"name_reason": "从群昵称获取",
"know_times": 0,
"know_since": int(datetime.datetime.now().timestamp()),
"last_know": int(datetime.datetime.now().timestamp()),
"impression": None,
"points": [],
"forgotten_points": [],
}
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
if isinstance(initial_data[key], (list, dict)):
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False)
# 获取 SQLAlchemy 模odel的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
if was_created:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
else:
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
return person_id
async def get_person_info_by_name(self, person_name: str) -> dict | None:
"""根据 person_name 查找用户并返回基本信息 (如果找到)"""
if not person_name:
logger.debug("get_person_info_by_name 获取失败person_name 不能为空")
return None
found_person_id = None
for pid, name_in_cache in self.person_name_list.items():
if name_in_cache == person_name:
found_person_id = pid
break
if not found_person_id:
def _db_find_by_name_sync(p_name_to_find: str):
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if record:
found_person_id = record.person_id
if (
found_person_id not in self.person_name_list
or self.person_name_list[found_person_id] != person_name
):
self.person_name_list[found_person_id] = person_name
else:
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
return None
if found_person_id:
required_fields = [
"person_id",
"platform",
"user_id",
"nickname",
"user_cardname",
"user_avatar",
"person_name",
"name_reason",
]
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
valid_fields_to_get = [
f
for f in required_fields
if f in model_fields or f in person_info_default
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)
if person_data:
final_result = {key: person_data.get(key) for key in required_fields}
return final_result
else:
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
return None
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
return None
person_info_manager = None
def get_person_info_manager():
global person_info_manager
if person_info_manager is None:
person_info_manager = PersonInfoManager()
return person_info_manager

View File

@@ -5,385 +5,25 @@
from src.plugin_system.apis import database_api
records = await database_api.db_query(ActionRecords, query_type="get")
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
注意此模块现在使用SQLAlchemy实现提供更好的连接管理和错误处理
"""
import traceback
from typing import Dict, List, Any, Union, Type, Optional
from src.common.logger import get_logger
from peewee import Model, DoesNotExist
from src.common.database.sqlalchemy_database_api import (
db_query,
db_save,
db_get,
store_action_info,
get_model_class,
MODEL_MAPPING
)
logger = get_logger("database_api")
# =============================================================================
# 通用数据库查询API函数
# =============================================================================
async def db_query(
model_class: Type[Model],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
data: 用于创建或更新的数据字典
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典,键为字段名,值为要匹配的值
limit: 限制结果数量
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果(如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
"""
"""
示例:
# 查询最近10条消息
messages = await database_api.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await database_api.db_query(
ActionRecords,
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create",
)
# 更新记录
updated_count = await database_api.db_query(
ActionRecords,
data={"action_done": True},
query_type="update",
filters={"action_id": "123"},
)
# 删除记录
deleted_count = await database_api.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await database_api.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by:
for field in order_by:
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
# 更新记录
return query.update(**data).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
return None if query_type == "get" and single_result else []
except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
# sourcery skip: inline-immediately-returned-variable
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名,例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await database_api.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
if existing_records := list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
):
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
return created_record
except Exception as e:
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
model_class: Type[Model],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
Returns:
如果single_result为True返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
# 获取单个记录
record = await database_api.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by="-time",
)
"""
try:
# 构建查询
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if single_result else []
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建。
Args:
chat_stream: 聊天流对象,包含聊天相关信息
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
Dict[str, Any]: 保存的记录数据
None: 如果保存失败
示例:
record = await database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=True,
action_prompt_display="执行了回复动作",
action_done=True,
thinking_id="thinking_123",
action_data={"content": "Hello"},
action_name="reply_action"
)
"""
try:
import time
import json
from src.common.database.database_model import ActionRecords
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update(
{
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
}
)
else:
# 如果没有chat_stream设置默认值
record_data.update(
{
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
}
)
# 使用已有的db_save函数保存记录
saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None
# 保持向后兼容性
__all__ = [
'db_query',
'db_save',
'db_get',
'store_action_info',
'get_model_class',
'MODEL_MAPPING'
]

View File

@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
__all__ = [
"plugin_manager",
"component_registry",
"events_manager",
"global_announcement_manager",
"hot_reload_manager",
]

View File

@@ -237,35 +237,55 @@ class ComponentRegistry:
logger.warning(f"组件 {component_name} 未注册,无法移除")
return False
try:
# 根据组件类型进行特定的清理操作
match component_type:
case ComponentType.ACTION:
self._action_registry.pop(component_name)
self._default_actions.pop(component_name)
# 移除Action注册
self._action_registry.pop(component_name, None)
self._default_actions.pop(component_name, None)
logger.debug(f"已移除Action组件: {component_name}")
case ComponentType.COMMAND:
self._command_registry.pop(component_name)
# 移除Command注册和模式
self._command_registry.pop(component_name, None)
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key)
self._command_patterns.pop(key, None)
logger.debug(f"已移除Command组件: {component_name} (清理了 {len(keys_to_remove)} 个模式)")
case ComponentType.TOOL:
self._tool_registry.pop(component_name)
self._llm_available_tools.pop(component_name)
# 移除Tool注册
self._tool_registry.pop(component_name, None)
self._llm_available_tools.pop(component_name, None)
logger.debug(f"已移除Tool组件: {component_name}")
case ComponentType.EVENT_HANDLER:
# 移除EventHandler注册和事件订阅
from .events_manager import events_manager # 延迟导入防止循环导入问题
self._event_handler_registry.pop(component_name)
self._enabled_event_handlers.pop(component_name)
await events_manager.unregister_event_subscriber(component_name)
self._event_handler_registry.pop(component_name, None)
self._enabled_event_handlers.pop(component_name, None)
try:
await events_manager.unregister_event_subscriber(component_name)
logger.debug(f"已移除EventHandler组件: {component_name}")
except Exception as e:
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
case _:
logger.warning(f"未知的组件类型: {component_type}")
return False
# 移除通用注册信息
namespaced_name = f"{component_type}.{component_name}"
self._components.pop(namespaced_name)
self._components_by_type[component_type].pop(component_name)
self._components_classes.pop(namespaced_name)
logger.info(f"组件 {component_name} 已移除")
self._components.pop(namespaced_name, None)
self._components_by_type[component_type].pop(component_name, None)
self._components_classes.pop(namespaced_name, None)
logger.info(f"组件 {component_name} ({component_type}) 已完全移除")
return True
except KeyError as e:
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
logger.error(f"移除组件 {component_name} ({component_type}) 时发生错误: {e}")
return False
def remove_plugin_registry(self, plugin_name: str) -> bool:
@@ -615,5 +635,54 @@ class ComponentRegistry:
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
}
# === 组件移除相关 ===
async def unregister_plugin(self, plugin_name: str) -> bool:
"""卸载插件及其所有组件
Args:
plugin_name: 插件名称
Returns:
bool: 是否成功卸载
"""
plugin_info = self.get_plugin_info(plugin_name)
if not plugin_info:
logger.warning(f"插件 {plugin_name} 未注册,无法卸载")
return False
logger.info(f"开始卸载插件: {plugin_name}")
# 记录卸载失败的组件
failed_components = []
# 逐个移除插件的所有组件
for component_info in plugin_info.components:
try:
success = await self.remove_component(
component_info.name,
component_info.component_type,
plugin_name,
)
if not success:
failed_components.append(f"{component_info.component_type}.{component_info.name}")
except Exception as e:
logger.error(f"移除组件 {component_info.name} 时发生异常: {e}")
failed_components.append(f"{component_info.component_type}.{component_info.name}")
# 移除插件注册信息
plugin_removed = self.remove_plugin_registry(plugin_name)
if failed_components:
logger.warning(f"插件 {plugin_name} 部分组件卸载失败: {failed_components}")
return False
elif not plugin_removed:
logger.error(f"插件 {plugin_name} 注册信息移除失败")
return False
else:
logger.info(f"插件 {plugin_name} 卸载成功")
return True
# 创建全局组件注册中心实例
component_registry = ComponentRegistry()

View File

@@ -33,7 +33,7 @@ class EventsManager:
if handler_name in self._handler_mapping:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
return True
if not issubclass(handler_class, BaseEventHandler):
logger.error(f"{handler_class.__name__} 不是 BaseEventHandler 的子类")

View File

@@ -0,0 +1,242 @@
"""
插件热重载模块
使用 Watchdog 监听插件目录变化,自动重载插件
"""
import os
import time
from pathlib import Path
from threading import Thread
from typing import Dict, Set
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from src.common.logger import get_logger
from .plugin_manager import plugin_manager
logger = get_logger("plugin_hot_reload")
class PluginFileHandler(FileSystemEventHandler):
"""插件文件变化处理器"""
def __init__(self, hot_reload_manager):
super().__init__()
self.hot_reload_manager = hot_reload_manager
self.pending_reloads: Set[str] = set() # 待重载的插件名称
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
self.debounce_delay = 1.0 # 防抖延迟(秒)
def on_modified(self, event):
"""文件修改事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "modified")
def on_created(self, event):
"""文件创建事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "created")
def on_deleted(self, event):
"""文件删除事件"""
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
self._handle_file_change(event.src_path, "deleted")
def _handle_file_change(self, file_path: str, change_type: str):
"""处理文件变化"""
try:
# 获取插件名称
plugin_name = self._get_plugin_name_from_path(file_path)
if not plugin_name:
return
current_time = time.time()
last_time = self.last_reload_time.get(plugin_name, 0)
# 防抖处理,避免频繁重载
if current_time - last_time < self.debounce_delay:
return
file_name = Path(file_path).name
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type})")
# 如果是删除事件,处理关键文件删除
if change_type == "deleted":
if file_name == "plugin.py":
if plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name}")
self.hot_reload_manager._unload_plugin(plugin_name)
return
elif file_name == "manifest.toml":
if plugin_name in plugin_manager.loaded_plugins:
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name}")
self.hot_reload_manager._unload_plugin(plugin_name)
return
# 对于修改和创建事件,都进行重载
# 添加到待重载列表
self.pending_reloads.add(plugin_name)
self.last_reload_time[plugin_name] = current_time
# 延迟重载,避免文件正在写入时重载
reload_thread = Thread(
target=self._delayed_reload,
args=(plugin_name,),
daemon=True
)
reload_thread.start()
except Exception as e:
logger.error(f"❌ 处理文件变化时发生错误: {e}")
def _delayed_reload(self, plugin_name: str):
"""延迟重载插件"""
try:
time.sleep(self.debounce_delay)
if plugin_name in self.pending_reloads:
self.pending_reloads.remove(plugin_name)
self.hot_reload_manager._reload_plugin(plugin_name)
except Exception as e:
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}")
def _get_plugin_name_from_path(self, file_path: str) -> str:
"""从文件路径获取插件名称"""
try:
path = Path(file_path)
# 检查是否在监听的插件目录中
plugin_root = Path(self.hot_reload_manager.watch_directory)
if not path.is_relative_to(plugin_root):
return ""
# 获取插件目录名(插件名)
relative_path = path.relative_to(plugin_root)
plugin_name = relative_path.parts[0]
# 确认这是一个有效的插件目录(检查是否有 plugin.py 或 manifest.toml
plugin_dir = plugin_root / plugin_name
if plugin_dir.is_dir() and ((plugin_dir / "plugin.py").exists() or (plugin_dir / "manifest.toml").exists()):
return plugin_name
return ""
except Exception:
return ""
class PluginHotReloadManager:
"""插件热重载管理器"""
def __init__(self, watch_directory: str = None):
print("fuck")
print(os.getcwd())
self.watch_directory = os.path.join(os.getcwd(), "plugins")
self.observer = None
self.file_handler = None
self.is_running = False
# 确保监听目录存在
if not os.path.exists(self.watch_directory):
os.makedirs(self.watch_directory, exist_ok=True)
logger.info(f"创建插件监听目录: {self.watch_directory}")
def start(self):
"""启动热重载监听"""
if self.is_running:
logger.warning("插件热重载已经在运行中")
return
try:
self.observer = Observer()
self.file_handler = PluginFileHandler(self)
self.observer.schedule(
self.file_handler,
self.watch_directory,
recursive=True
)
self.observer.start()
self.is_running = True
logger.info("🚀 插件热重载已启动,监听目录: plugins")
except Exception as e:
logger.error(f"❌ 启动插件热重载失败: {e}")
self.is_running = False
def stop(self):
"""停止热重载监听"""
if not self.is_running:
return
if self.observer:
self.observer.stop()
self.observer.join()
self.is_running = False
def _reload_plugin(self, plugin_name: str):
"""重载指定插件"""
try:
logger.info(f"🔄 开始重载插件: {plugin_name}")
if plugin_manager.reload_plugin(plugin_name):
logger.info(f"✅ 插件重载成功: {plugin_name}")
else:
logger.error(f"❌ 插件重载失败: {plugin_name}")
except Exception as e:
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}")
def _unload_plugin(self, plugin_name: str):
"""卸载指定插件"""
try:
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
if plugin_manager.unload_plugin(plugin_name):
logger.info(f"✅ 插件卸载成功: {plugin_name}")
else:
logger.error(f"❌ 插件卸载失败: {plugin_name}")
except Exception as e:
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}")
def reload_all_plugins(self):
"""重载所有插件"""
try:
logger.info("🔄 开始重载所有插件...")
# 获取当前已加载的插件列表
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
success_count = 0
fail_count = 0
for plugin_name in loaded_plugins:
if plugin_manager.reload_plugin(plugin_name):
success_count += 1
else:
fail_count += 1
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count}")
except Exception as e:
logger.error(f"❌ 重载所有插件时发生错误: {e}")
def get_status(self) -> dict:
"""获取热重载状态"""
return {
"is_running": self.is_running,
"watch_directory": self.watch_directory,
"loaded_plugins": len(plugin_manager.loaded_plugins),
"failed_plugins": len(plugin_manager.failed_plugins),
}
# 全局热重载管理器实例
hot_reload_manager = PluginHotReloadManager()

View File

@@ -1,5 +1,6 @@
import os
import traceback
import sys
from typing import Dict, List, Optional, Tuple, Type, Any
from importlib.util import spec_from_file_location, module_from_spec
@@ -488,6 +489,105 @@ class PluginManager:
else:
logger.info(f"✅ 插件加载成功: {plugin_name}")
# === 插件卸载和重载管理 ===
def unload_plugin(self, plugin_name: str) -> bool:
"""卸载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 卸载是否成功
"""
if plugin_name not in self.loaded_plugins:
logger.warning(f"插件 {plugin_name} 未加载,无需卸载")
return False
try:
# 获取插件实例
plugin_instance = self.loaded_plugins[plugin_name]
# 调用插件的清理方法(如果有的话)
if hasattr(plugin_instance, 'on_unload'):
plugin_instance.on_unload()
# 从组件注册表中移除插件的所有组件
component_registry.unregister_plugin(plugin_name)
# 从已加载插件中移除
del self.loaded_plugins[plugin_name]
# 从失败列表中移除(如果存在)
if plugin_name in self.failed_plugins:
del self.failed_plugins[plugin_name]
logger.info(f"✅ 插件卸载成功: {plugin_name}")
return True
except Exception as e:
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}")
return False
def reload_plugin(self, plugin_name: str) -> bool:
"""重载指定插件
Args:
plugin_name: 插件名称
Returns:
bool: 重载是否成功
"""
try:
# 先卸载插件
if plugin_name in self.loaded_plugins:
self.unload_plugin(plugin_name)
# 清除Python模块缓存
plugin_path = self.plugin_paths.get(plugin_name)
if plugin_path:
plugin_file = os.path.join(plugin_path, "plugin.py")
if os.path.exists(plugin_file):
# 从sys.modules中移除相关模块
modules_to_remove = []
plugin_module_prefix = ".".join(Path(plugin_file).parent.parts)
for module_name in sys.modules:
if module_name.startswith(plugin_module_prefix):
modules_to_remove.append(module_name)
for module_name in modules_to_remove:
del sys.modules[module_name]
# 从插件类注册表中移除
if plugin_name in self.plugin_classes:
del self.plugin_classes[plugin_name]
# 重新加载插件模块
if self._load_plugin_module_file(plugin_file):
# 重新加载插件实例
success, _ = self.load_registered_plugin_classes(plugin_name)
if success:
logger.info(f"🔄 插件重载成功: {plugin_name}")
return True
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 模块加载失败")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件文件不存在")
return False
else:
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件路径未知")
return False
except Exception as e:
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}")
logger.debug("详细错误信息: ", exc_info=True)
return False
# 全局插件管理器实例
plugin_manager = PluginManager()

View File

@@ -1,149 +1,149 @@
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type
logger = get_logger("tts")
class TTSAction(BaseAction):
"""TTS语音转换动作处理类"""
# 激活设置
focus_activation_type = ActionActivationType.LLM_JUDGE
normal_activation_type = ActionActivationType.KEYWORD
mode_enable = ChatMode.ALL
parallel_action = False
# 动作基本信息
action_name = "tts_action"
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
# 关键词配置 - Normal模式下使用关键词触发
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"]
keyword_case_sensitive = False
# 动作参数定义
action_parameters = {
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
}
# 动作使用场景
action_require = [
"当需要发送语音信息时使用",
"当用户明确要求使用语音功能时使用",
"当表达内容更适合用语音而不是文字传达时使用",
"当用户想听到语音回答而非阅读文本时使用",
]
# 关联类型
associated_types = ["tts_text"]
async def execute(self) -> Tuple[bool, str]:
"""处理TTS文本转语音动作"""
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
# 获取要转换的文本
text = self.action_data.get("text")
if not text:
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
return False, "执行TTS动作失败未提供文本内容"
# 确保文本适合TTS使用
processed_text = self._process_text_for_tts(text)
try:
# 发送TTS消息
await self.send_custom(message_type="tts_text", content=processed_text)
# 记录动作信息
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
)
logger.info(f"{self.log_prefix} TTS动作执行成功文本长度: {len(processed_text)}")
return True, "TTS动作执行成功"
except Exception as e:
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
return False, f"执行TTS动作时出错: {e}"
def _process_text_for_tts(self, text: str) -> str:
"""
处理文本使其更适合TTS使用
- 移除不必要的特殊字符和表情符号
- 修正标点符号以提高语音质量
- 优化文本结构使语音更流畅
"""
# 这里可以添加文本处理逻辑
# 例如:移除多余的标点、表情符号,优化语句结构等
# 简单示例实现
processed_text = text
# 移除多余的标点符号
import re
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
# 确保句子结尾有合适的标点
if not any(processed_text.endswith(end) for end in [".", "?", "!", "", "", ""]):
processed_text = f"{processed_text}"
return processed_text
@register_plugin
class TTSPlugin(BasePlugin):
"""TTS插件
- 这是文字转语音插件
- Normal模式下依靠关键词触发
- Focus模式下由LLM判断触发
- 具有一定的文本预处理能力
"""
# 插件基本信息
plugin_name: str = "tts_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息配置",
"components": "组件启用控制",
"logging": "日志记录相关配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
},
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
"logging": {
"level": ConfigField(
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
),
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# 从配置获取组件启用状态
enable_tts = self.get_config("components.enable_tts", True)
components = [] # 添加Action组件
if enable_tts:
components.append((TTSAction.get_action_info(), TTSAction))
return components
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type
logger = get_logger("tts")
class TTSAction(BaseAction):
"""TTS语音转换动作处理类"""
# 激活设置
focus_activation_type = ActionActivationType.LLM_JUDGE
normal_activation_type = ActionActivationType.KEYWORD
mode_enable = ChatMode.ALL
parallel_action = False
# 动作基本信息
action_name = "tts_action"
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
# 关键词配置 - Normal模式下使用关键词触发
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"]
keyword_case_sensitive = False
# 动作参数定义
action_parameters = {
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
}
# 动作使用场景
action_require = [
"当需要发送语音信息时使用",
"当用户明确要求使用语音功能时使用",
"当表达内容更适合用语音而不是文字传达时使用",
"当用户想听到语音回答而非阅读文本时使用",
]
# 关联类型
associated_types = ["tts_text"]
async def execute(self) -> Tuple[bool, str]:
"""处理TTS文本转语音动作"""
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
# 获取要转换的文本
text = self.action_data.get("text")
if not text:
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
return False, "执行TTS动作失败未提供文本内容"
# 确保文本适合TTS使用
processed_text = self._process_text_for_tts(text)
try:
# 发送TTS消息
await self.send_custom(message_type="tts_text", content=processed_text)
# 记录动作信息
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
)
logger.info(f"{self.log_prefix} TTS动作执行成功文本长度: {len(processed_text)}")
return True, "TTS动作执行成功"
except Exception as e:
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
return False, f"执行TTS动作时出错: {e}"
def _process_text_for_tts(self, text: str) -> str:
"""
处理文本使其更适合TTS使用
- 移除不必要的特殊字符和表情符号
- 修正标点符号以提高语音质量
- 优化文本结构使语音更流畅
"""
# 这里可以添加文本处理逻辑
# 例如:移除多余的标点、表情符号,优化语句结构等
# 简单示例实现
processed_text = text
# 移除多余的标点符号
import re
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
# 确保句子结尾有合适的标点
if not any(processed_text.endswith(end) for end in [".", "?", "!", "", "", ""]):
processed_text = f"{processed_text}"
return processed_text
@register_plugin
class TTSPlugin(BasePlugin):
"""TTS插件
- 这是文字转语音插件
- Normal模式下依靠关键词触发
- Focus模式下由LLM判断触发
- 具有一定的文本预处理能力
"""
# 插件基本信息
plugin_name: str = "tts_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表
config_file_name: str = "config.toml"
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息配置",
"components": "组件启用控制",
"logging": "日志记录相关配置",
}
# 配置Schema定义
config_schema: dict = {
"plugin": {
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
},
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
"logging": {
"level": ConfigField(
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
),
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
# 从配置获取组件启用状态
enable_tts = self.get_config("components.enable_tts", True)
components = [] # 添加Action组件
if enable_tts:
components.append((TTSAction.get_action_info(), TTSAction))
return components