初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
committed by Windpicker-owo
parent ef7a3aee23
commit 23ee3767ef
77 changed files with 10000 additions and 7525 deletions

View File

@@ -12,9 +12,10 @@ import binascii
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db
from sqlalchemy import select
from src.common.database.database import db
from src.common.database.sqlalchemy_database_api import get_session
from src.common.database.sqlalchemy_models import Emoji, Images
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
@@ -29,6 +30,8 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
session = get_session()
"""
还没经过测试,有些地方数据库和内存数据同步可能不完全
@@ -151,7 +154,7 @@ class MaiEmoji:
# 准备数据库记录 for emoji collection
emotion_str = ",".join(self.emotion) if self.emotion else ""
Emoji.create(
emoji = Emoji(
emoji_hash=self.hash,
full_path=self.full_path,
format=self.format,
@@ -165,6 +168,8 @@ class MaiEmoji:
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
session.add(emoji)
session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -200,7 +205,7 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none()
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist: # type: ignore
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
@@ -248,7 +253,6 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = []
load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data)
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
@@ -393,12 +397,17 @@ class EmojiManager:
def initialize(self) -> None:
"""初始化数据库连接和表情目录"""
peewee_db.connect(reuse_if_open=True)
if peewee_db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
Emoji.create_table(safe=True) # Ensures table exists
self._initialized = True
try:
db.connect(reuse_if_open=True)
if db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
self._initialized = True # 标记为已初始化
logger.info("EmojiManager初始化成功")
except Exception as e:
logger.error(f"EmojiManager初始化失败: {e}")
self._initialized = False
raise
def _ensure_db(self) -> None:
"""确保数据库已初始化"""
@@ -410,7 +419,7 @@ class EmojiManager:
def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
@@ -644,10 +653,10 @@ class EmojiManager:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
self._ensure_db()
logger.debug("[数据库] 开始加载所有表情包记录 (Peewee)...")
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
emoji_instances = session.execute(stmt = select(Emoji)).scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量
self.emoji_objects = emoji_objects
@@ -675,15 +684,15 @@ class EmojiManager:
self._ensure_db()
if emoji_hash:
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = Emoji.select()
query = session.execute(select(Emoji)).scalars().all()
emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -760,7 +769,7 @@ class EmojiManager:
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description
@@ -921,9 +930,10 @@ class EmojiManager:
# 尝试从Images表获取已有的详细描述可能在收到表情包时已生成
existing_description = None
try:
from src.common.database.database_model import Images
# from src.common.database.database_model_compat import Images
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
existing_image = session.execute(stmt).scalar_one_or_none()
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -7,7 +7,9 @@ from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
@@ -20,7 +22,7 @@ DECAY_DAYS = 30 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
logger = get_logger("expressor")
session = get_session()
def format_create_date(timestamp: float) -> str:
"""
@@ -168,30 +170,50 @@ class ExpressionLearner:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
# """
# 获取指定chat_id的style表达方式(已禁用grammar的获取)
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
# """
# learnt_style_expressions = []
def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
"""
获取指定chat_id的stylegrammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
"""
learnt_style_expressions = []
learnt_grammar_expressions = []
# 直接从数据库查询
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
for expr in style_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "style",
"create_date": create_date,
}
)
grammar_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")))
for expr in grammar_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "grammar",
"create_date": create_date,
}
)
return learnt_style_expressions, learnt_grammar_expressions
# # 直接从数据库查询
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
# for expr in style_query:
# # 确保create_date存在如果不存在则使用last_active_time
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
# learnt_style_expressions.append(
# {
# "situation": expr.situation,
# "style": expr.style,
# "count": expr.count,
# "last_active_time": expr.last_active_time,
# "source_id": self.chat_id,
# "type": "style",
# "create_date": create_date,
# }
# )
# return learnt_style_expressions
@@ -201,7 +223,7 @@ class ExpressionLearner:
"""
try:
# 获取所有表达方式
all_expressions = Expression.select()
all_expressions = session.execute(select(Expression)).scalars()
updated_count = 0
deleted_count = 0
@@ -217,18 +239,20 @@ class ExpressionLearner:
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
session.delete(expr)
deleted_count += 1
else:
# 更新count
expr.count = new_count
expr.save()
updated_count += 1
session.commit()
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
session.rollback()
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
@@ -297,23 +321,22 @@ class ExpressionLearner:
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == "style")
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
)).scalar()
if query:
expr_obj = query
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
new_expression = Expression(
situation=new_expr["situation"],
style=new_expr["style"],
count=1,
@@ -322,16 +345,18 @@ class ExpressionLearner:
type="style",
create_date=current_time, # 手动设置创建日期
)
session.add(new_expression)
# 限制最大数量
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
.order_by(Expression.count.asc())
session.execute(select(Expression)
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())).scalars()
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
session.delete(expr)
session.commit()
return learnt_expressions
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
@@ -509,54 +534,35 @@ class ExpressionLearnerManager:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
# 查重同chat_id+type+situation+style
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)).scalar()
if query:
expr_obj = query
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
else:
new_expression = Expression(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
if query.exists():
expr_obj = query.get()
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
done_parent_dir = os.path.dirname(done_flag)
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")
except OSError as e:
logger.error(f"文件系统错误无法写入done.done标记文件: {e}")
except Exception as e:
logger.error(f"写入done.done标记文件失败: {e}")
session.add(new_expression)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 检查并处理grammar表达删除
if not os.path.exists(done_flag2):
@@ -581,18 +587,20 @@ class ExpressionLearnerManager:
"""
try:
# 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null())
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
session.commit()
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
session.rollback()
logger.error(f"迁移老数据创建日期失败: {e}")
def delete_all_grammar_expressions(self) -> int:

View File

@@ -9,8 +9,11 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_session
session = get_session()
logger = get_logger("expression_selector")
@@ -131,9 +134,12 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
style_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
))
grammar_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
))
style_exprs = [
{
@@ -146,9 +152,24 @@ class ExpressionSelector:
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
for expr in style_query.scalars()
]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in grammar_query.scalars()
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
@@ -174,19 +195,19 @@ class ExpressionSelector:
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
)).scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.save()
session.commit()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)

View File

@@ -1,33 +0,0 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
from .utils.dyn_topk import dyn_select_top_k
class MemoryActiveManager:
def __init__(
self,
embed_manager: EmbeddingManager,
llm_client_embedding: LLMClient,
):
self.embed_manager = embed_manager
self.embedding_client = llm_client_embedding
def get_activation(self, question: str) -> float:
"""获取记忆激活度"""
# 生成问题的Embedding
question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
# 查询关系库中的相似度
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
# 动态过滤阈值
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
if rel_scores[0][1] < global_config["qa"]["params"]["relation_threshold"]:
# 未找到相关关系
return 0.0
# 计算激活度
activation = sum([item[2] for item in rel_scores]) * 10
return activation

View File

@@ -16,8 +16,10 @@ from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from sqlalchemy import select,insert,update,text,delete
from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入
from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_session
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
@@ -37,7 +39,7 @@ def cosine_similarity(v1, v2):
install(extra_lines=3)
session = get_session()
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
@@ -731,13 +733,14 @@ class Hippocampus:
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容
if memory_items:
logger.debug("节点包含完整记忆")
# 计算记忆与关键词的相似度
memory_words = set(jieba.cut(memory_items))
text_words = set(keywords)
all_words = memory_words | text_words
if all_words:
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
logger.debug(f"节点包含 {len(memory_items)}记忆")
# 计算每条记忆与输入文本的相似度
memory_similarities = []
for memory in memory_items:
# 计算与输入文本的相似度
memory_words = set(jieba.cut(memory))
text_words = set(jieba.cut(text))
all_words = memory_words | text_words
v1 = [1 if word in memory_words else 0 for word in all_words]
v2 = [1 if word in text_words else 0 for word in all_words]
_ = cosine_similarity(v1, v2) # 计算但不使用用_表示
@@ -844,11 +847,6 @@ class Hippocampus:
else:
activate_map[node] = activation_value
# 输出激活映射
# logger.info("激活映射统计:")
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
# 计算激活节点数与总节点数的比值
total_activation = sum(activate_map.values())
# logger.debug(f"总激活值: {total_activation:.2f}")
@@ -942,10 +940,13 @@ class EntorhinalCortex:
for message in messages:
# 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
# 使用 Peewee 更新记录
Messages.update(memorized_times=current_memorized_times + 1).where(
Messages.message_id == message["message_id"]
).execute()
# 使用 SQLAlchemy 2.0 更新记录
session.execute(
update(Messages)
.where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1)
)
session.commit()
return messages # 直接返回原始的消息列表
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
@@ -959,7 +960,7 @@ class EntorhinalCortex:
current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点
db_nodes = {node.concept: node for node in GraphNodes.select()}
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据
@@ -1025,22 +1026,27 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size]
GraphNodes.insert_many(batch).execute()
session.execute(insert(GraphNodes), batch)
session.commit()
if nodes_to_update:
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i : i + batch_size]
for node_data in batch:
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
GraphNodes.concept == node_data["concept"]
).execute()
session.execute(
update(GraphNodes)
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
)
session.commit()
if nodes_to_delete:
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
session.commit()
# 处理边的信息
db_edges = list(GraphEdges.select())
db_edges = list(session.execute(select(GraphEdges)).scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -1092,20 +1098,29 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
session.execute(insert(GraphEdges), batch)
session.commit()
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i : i + batch_size]
for edge_data in batch:
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
).execute()
session.execute(
update(GraphEdges)
.where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
)
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
)
session.commit()
if edges_to_delete:
for source, target in edges_to_delete:
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
session.execute(
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
)
session.commit()
end_time = time.time()
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}")
@@ -1118,8 +1133,9 @@ class EntorhinalCortex:
# 清空数据库
clear_start = time.time()
GraphNodes.delete().execute()
GraphEdges.delete().execute()
session.execute(delete(GraphNodes))
session.execute(delete(GraphEdges))
session.commit()
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1186,12 +1202,27 @@ class EntorhinalCortex:
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
continue
# 批量插入边
# 批量写入节点
node_start = time.time()
if nodes_data:
batch_size = 500 # 增加批量大小
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
session.commit()
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
# 批量写入边
edge_start = time.time()
if edges_data:
batch_size = 100
batch_size = 500 # 增加批量大小
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
session.execute(insert(GraphEdges), batch)
session.commit()
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")
end_time = time.time()
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}")
@@ -1211,9 +1242,7 @@ class EntorhinalCortex:
skipped_nodes = 0
# 从数据库加载所有节点
nodes = list(GraphNodes.select())
total_nodes = len(nodes)
nodes = list(session.execute(select(GraphNodes)).scalars())
for node in nodes:
concept = node.concept
try:
@@ -1235,8 +1264,10 @@ class EntorhinalCortex:
if not node.last_modified:
update_data["last_modified"] = current_time
if update_data:
GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
)
session.commit()
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time
@@ -1256,7 +1287,7 @@ class EntorhinalCortex:
continue
# 从数据库加载所有边
edges = list(GraphEdges.select())
edges = list(session.execute(select(GraphEdges)).scalars())
for edge in edges:
source = edge.source
target = edge.target
@@ -1272,9 +1303,12 @@ class EntorhinalCortex:
if not edge.last_modified:
update_data["last_modified"] = current_time
GraphEdges.update(**update_data).where(
(GraphEdges.source == source) & (GraphEdges.target == target)
).execute()
session.execute(
update(GraphEdges)
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
)
session.commit()
# 获取时间信息(如果不存在则使用当前时间)
created_time = edge.created_time or current_time
@@ -1398,7 +1432,6 @@ class ParahippocampalGyrus:
all_words = topic_words | existing_words
v1 = [1 if word in topic_words else 0 for word in all_words]
v2 = [1 if word in existing_words else 0 for word in all_words]
similarity = cosine_similarity(v1, v2)
if similarity >= 0.7:
@@ -1502,7 +1535,7 @@ class ParahippocampalGyrus:
check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage)))
check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage)))
# 只有在有足够的节点和边时进行采样
# 只有在有足够的节点和边时进行采样
if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count:
try:
nodes_to_check = random.sample(all_nodes, check_nodes_count)
@@ -1548,6 +1581,11 @@ class ParahippocampalGyrus:
logger.info("[遗忘] 开始检查节点...")
node_check_start = time.time()
# 初始化整合相关变量
merged_count = 0
nodes_modified = set()
for node in nodes_to_check:
# 检查节点是否存在,以防在迭代中被移除(例如边移除导致)
if node not in self.memory_graph.G:
@@ -1567,64 +1605,91 @@ class ParahippocampalGyrus:
logger.warning(f"[遗忘] 移除空节点 {node} 时发生错误(可能已被移除): {e}")
continue # 处理下一个节点
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
# 检查节点的最后修改时间,如果太旧则尝试遗忘
last_modified = node_data.get("last_modified", current_time)
node_weight = node_data.get("weight", 1.0)
# 条件1检查是否长时间未修改 (使用配置的遗忘时间)
time_threshold = 3600 * global_config.memory.memory_forget_time
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
# 权重为1时使用默认阈值权重越高阈值越大越难遗忘
adjusted_threshold = time_threshold * node_weight
if current_time - last_modified > adjusted_threshold and memory_items:
# 既然每个节点现在是完整记忆,直接删除整个节点
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(长时间未修改,权重{node_weight:.1f})")
logger.debug(f"[遗忘] 移除了长时间未修改的节点: {node} (权重: {node_weight:.1f})")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误(可能已被移除): {e}")
continue
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
# 随机遗忘一条记忆
if len(memory_items) > 1:
removed_item = self.memory_graph.forget_topic(node)
if removed_item:
node_changes["reduced"].append(f"{node} (移除: {removed_item[:50]}...)")
elif len(memory_items) == 1:
# 如果只有一条记忆,检查是否应该完全移除节点
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node} (最后记忆)")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误: {e}")
# 检查节点内是否有相似的记忆项需要整合
if len(memory_items) > 1:
merged_in_this_node = False
items_to_remove = []
for i in range(len(memory_items)):
for j in range(i + 1, len(memory_items)):
similarity = self._calculate_item_similarity(memory_items[i], memory_items[j])
if similarity > 0.8: # 相似度阈值
# 合并相似记忆项
longer_item = memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
shorter_item = memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
# 保留更长的记忆项,标记短的用于删除
if shorter_item not in items_to_remove:
items_to_remove.append(shorter_item)
merged_count += 1
merged_in_this_node = True
logger.debug(f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}...")
# 移除被合并的记忆项
if items_to_remove:
for item in items_to_remove:
if item in memory_items:
memory_items.remove(item)
nodes_modified.add(node)
# 更新节点的记忆项
self.memory_graph.G.nodes[node]["memory_items"] = memory_items
self.memory_graph.G.nodes[node]["last_modified"] = current_time
node_check_end = time.time()
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}")
if any(edge_changes.values()) or any(node_changes.values()):
# 输出变化统计
if edge_changes["weakened"]:
logger.info(f"[遗忘] 减弱了 {len(edge_changes['weakened'])} 个连接")
if edge_changes["removed"]:
logger.info(f"[遗忘] 移除了 {len(edge_changes['removed'])} 个连接")
if node_changes["reduced"]:
logger.info(f"[遗忘] 减少了 {len(node_changes['reduced'])} 个节点的记忆")
if node_changes["removed"]:
logger.info(f"[遗忘] 移除了 {len(node_changes['removed'])} 个节点")
# 检查是否有变化需要同步到数据库
has_changes = (
edge_changes["weakened"] or
edge_changes["removed"] or
node_changes["reduced"] or
node_changes["removed"] or
merged_count > 0
)
if has_changes:
logger.info("[遗忘] 开始将变更同步到数据库...")
sync_start = time.time()
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
sync_end = time.time()
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}")
# 汇总输出所有变化
logger.info("[遗忘] 遗忘操作统计:")
if edge_changes["weakened"]:
logger.info(
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
)
if edge_changes["removed"]:
logger.info(
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
)
if node_changes["reduced"]:
logger.info(
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
)
if node_changes["removed"]:
logger.info(
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
)
if merged_count > 0:
logger.info(f"[整合] 共合并了 {merged_count} 对相似记忆项,分布在 {len(nodes_modified)} 个节点中。")
sync_start = time.time()
logger.info("[整合] 开始将变更同步到数据库...")
# 使用 resync 更安全地处理删除和添加
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
sync_end = time.time()
logger.info(f"[整合] 数据库同步耗时: {sync_end - sync_start:.2f}")
else:
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
end_time = time.time()
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
logger.info("[整合] 本次检查未发现需要合并的记忆项。")
@@ -1734,10 +1799,7 @@ class HippocampusManager:
"""获取所有节点名称的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return self._hippocampus.get_all_node_names()
# 创建全局实例
hippocampus_manager = HippocampusManager()

View File

@@ -10,12 +10,13 @@ from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.common.database.database_model import Memory # Peewee Models导入
from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入
from src.common.database.sqlalchemy_database_api import get_session
from src.config.config import model_config
from sqlalchemy import select
logger = get_logger(__name__)
session = get_session()
class MemoryItem:
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
@@ -120,7 +121,8 @@ class InstantMemory:
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time,
)
memory.save()
session.add(memory)
session.commit()
async def get_memory(self, target: str):
from json_repair import repair_json
@@ -166,13 +168,13 @@ class InstantMemory:
if start_time and end_time:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = Memory.select().where(
query = session.execute(select(Memory).where(
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) # type: ignore
& (Memory.create_time < end_ts) # type: ignore
)
& (Memory.create_time >= start_ts)
& (Memory.create_time < end_ts)
)).scalars()
else:
query = Memory.select().where(Memory.chat_id == self.chat_id)
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
for mem in query:
# 对每条记忆

View File

@@ -8,8 +8,12 @@ from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
from sqlalchemy import select, text
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.database.sqlalchemy_database_api import get_session
from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
@@ -19,7 +23,7 @@ install(extra_lines=3)
logger = get_logger("chat_stream")
session = get_session()
class ChatMessageContext:
"""聊天消息上下文,存储消息的上下文信息"""
@@ -131,7 +135,8 @@ class ChatManager:
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
session.commit()
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
@@ -231,7 +236,7 @@ class ChatManager:
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
@@ -342,7 +347,28 @@ class ChatManager:
"group_name": group_info_d["group_name"] if group_info_d else "",
}
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
# 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'],
set_=fields_to_save
)
elif global_config.database.database_type == "mysql":
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
else:
# 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(
index_elements=['stream_id'],
set_=fields_to_save
)
session.execute(stmt)
session.commit()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
@@ -361,7 +387,7 @@ class ChatManager:
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
for model_instance in session.execute(select(ChatStreams)).scalars():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,

View File

@@ -1,16 +1,18 @@
import re
import json
import traceback
import json
from typing import Union
from src.common.database.database_model import Messages, Images
from src.common.database.sqlalchemy_models import Messages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select, update, desc
logger = get_logger("message_storage")
class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
@@ -33,15 +35,11 @@ class MessageStorage:
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 莫越权 救世啊
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
# print(message)
processed_plain_text = message.processed_plain_text
# print(processed_plain_text)
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
@@ -93,11 +91,16 @@ class MessageStorage:
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
Messages.create(
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = json.dumps(priority_info) if priority_info else None
# 获取数据库会话
session = get_session()
new_message = Messages(
message_id=msg_id,
time=float(message.message_info.time), # type: ignore
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
# Flattened chat_info
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_dict.get("stream_id"),
@@ -111,18 +114,16 @@ class MessageStorage:
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=message.memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info,
priority_info=priority_info_json,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
@@ -131,35 +132,44 @@ class MessageStorage:
key_words_lite=key_words_lite,
selected_expressions=selected_expressions,
)
session.add(new_message)
session.commit()
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
traceback.print_exc()
# 如果需要其他存储相关的函数,可以在这里添加
@staticmethod
async def update_message(
message: MessageRecv,
) -> None: # 用于实时更新数据库的自身发送消息ID目前能处理text,reply,image和emoji
"""更新最新一条匹配消息的message_id"""
async def update_message(message):
"""更新消息ID"""
try:
if message.message_segment.type == "notify":
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
mmc_message_id = message.message_info.message_id # 修复正确访问message_id
if message.message_segment.type == "text":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "reply":
qq_message_id = message.message_segment.data.get("id")
else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
).scalar()
if matched_message:
session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
# session.commit() 会在上下文管理器中自动调用
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
@@ -178,10 +188,12 @@ class MessageStorage:
def replace_match(match):
description = match.group(1).strip()
try:
image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
)
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
).scalar()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)

View File

@@ -7,13 +7,14 @@ from rich.traceback import install
from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import Person,get_person_id
from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select, and_
install(extra_lines=3)
session = get_session()
def replace_user_references_sync(
content: str,
@@ -254,50 +255,90 @@ def get_actions_by_timestamp_with_chat(
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) # type: ignore
)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
))
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.desc()).limit(limit))
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.asc()).limit(limit))
else:
query = query.order_by(ActionRecords.time.asc())
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
)
).order_by(ActionRecords.time.asc()))
actions = list(query)
return [action.__data__ for action in actions]
actions = list(query.scalars())
return [action.__dict__ for action in actions]
def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) # type: ignore
)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
))
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.desc()).limit(limit))
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.asc()).limit(limit))
else:
query = query.order_by(ActionRecords.time.asc())
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.asc()))
actions = list(query)
return [action.__data__ for action in actions]
actions = list(query.scalars())
return [action.__dict__ for action in actions]
def get_raw_msg_by_timestamp_random(
@@ -700,7 +741,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar()
if image and image.description:
description = image.description
except Exception:
@@ -813,7 +854,7 @@ def build_readable_messages(
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
show_actions: bool = True,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> str: # sourcery skip: extract-method
@@ -846,21 +887,21 @@ def build_readable_messages(
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (
ActionRecords.select()
.where(
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
actions_in_range = session.execute(select(ActionRecords).where(
and_(
ActionRecords.time >= min_time,
ActionRecords.time <= max_time,
ActionRecords.chat_id == chat_id
)
.order_by(ActionRecords.time)
)
).order_by(ActionRecords.time)).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = (
ActionRecords.select()
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
action_after_latest = session.execute(select(ActionRecords).where(
and_(
ActionRecords.time > max_time,
ActionRecords.chat_id == chat_id
)
).order_by(ActionRecords.time).limit(1)).scalars()
# 合并两部分动作记录
actions = list(actions_in_range) + list(action_after_latest)

View File

@@ -6,13 +6,52 @@ from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
from src.common.database.sqlalchemy_database_api import get_db_session, db_query, db_save, db_get
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
# 同步包装器函数用于在非异步环境中调用异步数据库API
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
"""同步版本的db_get用于在线程池中调用"""
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建新的事件循环
import threading
result = None
exception = None
def run_in_thread():
nonlocal result, exception
try:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
result = new_loop.run_until_complete(
db_get(model_class, filters, limit, order_by, single_result)
)
new_loop.close()
except Exception as e:
exception = e
thread = threading.Thread(target=run_in_thread)
thread.start()
thread.join()
if exception:
raise exception
return result
else:
return loop.run_until_complete(
db_get(model_class, filters, limit, order_by, single_result)
)
except RuntimeError:
# 没有事件循环,创建一个新的
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 统计数据的键
TOTAL_REQ_CNT = "total_requests"
TOTAL_COST = "total_cost"
@@ -59,17 +98,9 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: int | None = None # Changed to int for Peewee's default ID
self.record_id: int | None = None
"""记录ID"""
self._init_database() # 初始化数据库
@staticmethod
def _init_database():
"""初始化数据库"""
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self): # sourcery skip: use-named-expression
try:
current_time = datetime.now()
@@ -77,36 +108,50 @@ class OnlineTimeRecordTask(AsyncTask):
if self.record_id:
# 如果有记录,则更新结束时间
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
updated_rows = query.execute()
updated_rows = await db_query(
model_class=OnlineTime,
query_type="update",
filters={"id": self.record_id},
data={"end_timestamp": extended_end_time}
)
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
self.record_id = None # Reset record_id to trigger find/create logic below
self.record_id = None
if not self.record_id: # Check again if record_id was reset or initially None
# 如果没有记录,检查一分钟以内是否已有记录
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
.order_by(OnlineTime.end_timestamp.desc())
.first()
if not self.record_id:
# 查找最近一分钟内的记录
recent_threshold = current_time - timedelta(minutes=1)
recent_records = await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": recent_threshold}},
order_by="-end_timestamp",
limit=1,
single_result=True
)
if recent_record:
# 如果有记录,更新结束时间
self.record_id = recent_record.id
recent_record.end_timestamp = extended_end_time
recent_record.save()
else:
# 若没有记录,则插入新的在线时间记录
new_record = OnlineTime.create(
timestamp=current_time.timestamp(), # 添加此行
start_timestamp=current_time,
end_timestamp=extended_end_time,
duration=5, # 初始时长为5分钟
if recent_records:
# 找到近期记录,更新
self.record_id = recent_records['id']
await db_query(
model_class=OnlineTime,
query_type="update",
filters={"id": self.record_id},
data={"end_timestamp": extended_end_time}
)
self.record_id = new_record.id
else:
# 创建新记录
new_record = await db_save(
model_class=OnlineTime,
data={
"timestamp": str(current_time),
"duration": 5, # 初始时长为5分钟
"start_timestamp": current_time,
"end_timestamp": extended_end_time,
}
)
if new_record:
self.record_id = new_record['id']
except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -322,18 +367,23 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_timestamp = record.timestamp # This is already a datetime object
records = _sync_db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp"
)
for record in records:
record_timestamp = record['timestamp'] # 从字典中获取
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
request_type = record.get('request_type') or "unknown"
user_id = record.get('user_id') or "unknown"
model_name = record.get('model_name') or "unknown"
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
module_name = request_type.split(".")[0] if "." in request_type else request_type
@@ -343,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
prompt_tokens = record.get('prompt_tokens') or 0
completion_tokens = record.get('completion_tokens') or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
@@ -362,7 +412,7 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
cost = record.cost or 0.0
cost = record.get('cost') or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
@@ -425,11 +475,15 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
records = _sync_db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp"
)
for record in records:
record_end_timestamp = record['end_timestamp']
record_start_timestamp = record['start_timestamp']
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
@@ -466,24 +520,30 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp
records = _sync_db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time"
)
for message in records:
message_time_ts = message['time'] # This is a float timestamp
chat_id = None
chat_name = None
# Logic based on Peewee model structure, aiming to replicate original intent
if message.chat_info_group_id:
chat_id = f"g{message.chat_info_group_id}"
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
if message.get('chat_info_group_id'):
chat_id = f"g{message['chat_info_group_id']}"
chat_name = message.get('chat_info_group_name') or f"{message['chat_info_group_id']}"
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
chat_id = f"u{message['user_id']}" # SENDER's user_id
chat_name = message.get('user_nickname') # SENDER's nickname
else:
# If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
)
continue
@@ -1025,8 +1085,14 @@ class StatisticOutputTask(AsyncTask):
# 查询LLM使用记录
query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp
records = _sync_db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp"
)
for record in records:
record_time = record['timestamp']
# 找到对应的时间间隔索引
time_diff = (record_time - start_time).total_seconds()
@@ -1034,17 +1100,17 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 累加总花费数据
cost = record.cost or 0.0
cost = record.get('cost') or 0.0
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
model_name = record.model_name or "unknown"
model_name = record.get('model_name') or "unknown"
if model_name not in cost_by_model:
cost_by_model[model_name] = [0] * len(time_points)
cost_by_model[model_name][interval_index] += cost
# 累加按模块分类的花费
request_type = record.request_type or "unknown"
request_type = record.get('request_type') or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
if module_name not in cost_by_module:
cost_by_module[module_name] = [0] * len(time_points)
@@ -1052,8 +1118,14 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time
records = _sync_db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time"
)
for message in records:
message_time_ts = message['time']
# 找到对应的时间间隔索引
time_diff = message_time_ts - query_start_timestamp
@@ -1062,10 +1134,10 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 确定聊天流名称
chat_name = None
if message.chat_info_group_id:
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id:
chat_name = message.user_nickname or f"用户{message.user_id}"
if message.get('chat_info_group_id'):
chat_name = message.get('chat_info_group_name') or f"{message['chat_info_group_id']}"
elif message.get('user_id'):
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
else:
continue

View File

@@ -13,10 +13,12 @@ from rich.traceback import install
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import get_db_session
from sqlalchemy import select, and_
install(extra_lines=3)
logger = get_logger("chat_image")
@@ -41,9 +43,10 @@ class ImageManager:
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
# 使用SQLAlchemy创建表已在初始化时完成
logger.debug("使用SQLAlchemy进行表管理")
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
logger.error(f"数据库连接失败: {e}")
self._initialized = True
@@ -63,12 +66,13 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
with get_db_session() as session:
record = session.execute(select(ImageDescriptions).where(
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
)).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
return None
@staticmethod
@@ -82,16 +86,28 @@ class ImageManager:
"""
try:
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
image_description_hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
with get_db_session() as session:
# 查找现有记录
existing = session.execute(select(ImageDescriptions).where(
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
)).scalar()
if existing:
# 更新现有记录
existing.description = description
existing.timestamp = current_timestamp
else:
# 创建新记录
new_desc = ImageDescriptions(
image_description_hash=image_hash,
type=description_type,
description=description,
timestamp=current_timestamp
)
session.add(new_desc)
# session.commit() 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -214,19 +230,29 @@ class ImageManager:
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
img_obj.path = file_path
img_obj.description = detailed_description # 保存详细描述
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist: # type: ignore
Images.create(
emoji_hash=image_hash,
path=file_path,
type="emoji",
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
)
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
existing_img = session.execute(select(Images).where(
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
)).scalar()
if existing_img:
existing_img.path = file_path
existing_img.description = detailed_description # 保存详细描述
existing_img.timestamp = current_timestamp
else:
new_img = Images(
emoji_hash=image_hash,
path=file_path,
type="emoji",
description=detailed_description, # 保存详细描述
timestamp=current_timestamp,
)
session.add(new_img)
# session.commit() 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}")
except Exception as e:
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
@@ -249,19 +275,19 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
existing_image.save()
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
# 如果已有描述,直接返回
if existing_image.description:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
# 如果已有描述,直接返回
if existing_image.description:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
@@ -300,10 +326,10 @@ class ImageManager:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
existing_image.save()
session.commit()
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
Images.create(
new_img = Images(
image_id=str(uuid.uuid4()),
emoji_hash=image_hash,
path=file_path,
@@ -313,6 +339,8 @@ class ImageManager:
vlm_processed=True,
count=1,
)
session.add(new_img)
session.commit()
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
@@ -465,31 +493,32 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
existing_image.count += 1
session.commit()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
# 保存新图片
current_timestamp = time.time()
@@ -503,7 +532,7 @@ class ImageManager:
f.write(image_bytes)
# 保存到数据库
Images.create(
new_img = Images(
image_id=image_id,
emoji_hash=image_hash,
path=file_path,
@@ -512,6 +541,8 @@ class ImageManager:
vlm_processed=False,
count=1,
)
session.add(new_img)
session.commit()
# 启动异步VLM处理
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
@@ -536,60 +567,64 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session:
# 获取当前图片记录
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
# 获取当前图片记录
image = Images.get(Images.image_id == image_id)
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = session.execute(select(Images).where(
and_(
Images.emoji_hash == image_hash,
Images.description.isnot(None),
Images.description != "",
Images.id != image.id
)
)).scalar()
if existing_with_description:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
image.vlm_processed = True
session.commit()
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
# 检查ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
session.commit()
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = global_config.custom_prompt.image_prompt
# 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("VLM未能生成图片描述")
description = "无法生成描述"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
# 检查ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
image.save()
return
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = global_config.custom_prompt.image_prompt
# 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("VLM未能生成图片描述")
description = "无法生成描述"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
except Exception as e:
logger.error(f"VLM处理图片失败: {str(e)}")