4次修改

This commit is contained in:
tt-P607
2025-09-20 11:57:22 +08:00
parent 898208f425
commit a8992cdd51
6 changed files with 69 additions and 60 deletions

View File

@@ -95,17 +95,23 @@ class Message(MessageBase, metaclass=ABCMeta):
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any], message_id: str, chat_stream: "ChatStream", user_info: UserInfo):
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
super().__init__(message_id, chat_stream, user_info)
# Manually initialize attributes from MessageBase and Message
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.chat_stream = None
self.reply = None
self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.memorized_times = 0
# MessageRecv specific attributes
self.is_emoji = False
self.has_emoji = False
self.is_picid = False

View File

@@ -122,7 +122,7 @@ class MessageStorage:
is_picid=is_picid,
)
async with get_db_session() as session:
await session.add(new_message)
session.add(new_message)
except Exception:
logger.exception("存储消息失败")

View File

@@ -124,7 +124,7 @@ class PlanFilter:
if plan.mode == ChatMode.PROACTIVE:
long_term_memory_block = await self._get_long_term_memory_context()
chat_content_block, message_id_list = build_readable_messages_with_id(
chat_content_block, message_id_list = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal",
truncate=False,
@@ -160,7 +160,7 @@ class PlanFilter:
show_actions=True,
)
actions_before_now = get_actions_by_timestamp_with_chat(
actions_before_now = await get_actions_by_timestamp_with_chat(
chat_id=plan.chat_id,
timestamp_start=time.time() - 3600,
timestamp_end=time.time(),

View File

@@ -69,7 +69,7 @@ class ImageManager:
os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
Args:
@@ -80,22 +80,22 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
with get_db_session() as session:
record = session.execute(
async with get_db_session() as session:
record = (await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
).scalar()
)).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
return None
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
async def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
@@ -105,16 +105,16 @@ class ImageManager:
"""
try:
current_timestamp = time.time()
with get_db_session() as session:
async with get_db_session() as session:
# 查找现有记录
existing = session.execute(
existing = (await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
).scalar()
)).scalar()
if existing:
# 更新现有记录
@@ -129,7 +129,7 @@ class ImageManager:
timestamp=current_timestamp,
)
session.add(new_desc)
session.commit()
await session.commit()
# 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
@@ -175,7 +175,7 @@ class ImageManager:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "emoji"):
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
return f"[表情包:{cached_description}]"
@@ -239,7 +239,7 @@ class ImageManager:
logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}")
if cached_description := self._get_description_from_db(image_hash, "emoji"):
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
@@ -261,10 +261,10 @@ class ImageManager:
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
existing_img = session.execute(
async with get_db_session() as session:
existing_img = (await session.execute(
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
).scalar()
)).scalar()
if existing_img:
existing_img.path = file_path
@@ -279,7 +279,7 @@ class ImageManager:
timestamp=current_timestamp,
)
session.add(new_img)
session.commit()
await session.commit()
except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}")
@@ -289,7 +289,7 @@ class ImageManager:
logger.debug("偷取表情包功能已关闭,跳过保存。")
# 保存最终的情感标签到缓存 (ImageDescriptions表)
self._save_description_to_db(image_hash, final_emotion, "emoji")
await self._save_description_to_db(image_hash, final_emotion, "emoji")
return f"[表情包:{final_emotion}]"
@@ -306,9 +306,9 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
async with get_db_session() as session:
# 优先检查Images表中是否已有完整的描述
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
@@ -318,34 +318,34 @@ class ImageManager:
# 如果已有描述,直接返回
if existing_image.description:
await session.commit()
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...")
return f"[图片:{existing_image.description}]"
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
return f"[图片:{cached_description}]"
# 如果没有描述,继续在当前会话中操作
if cached_description := await self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
# 保存图片和描述
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
# 保存图片和描述
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
@@ -358,7 +358,6 @@ class ImageManager:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
new_img = Images(
@@ -372,13 +371,15 @@ class ImageManager:
count=1,
)
session.add(new_img)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
await session.commit()
# 保存描述到ImageDescriptions表作为备用缓存
await self._save_description_to_db(image_hash, description, "image")
logger.info(f"[VLM完成] 图片描述生成: {description}...")
return f"[图片:{description}]"
logger.info(f"[VLM完成] 图片描述生成: {description}...")
return f"[图片:{description}]"
@@ -525,8 +526,8 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
async with get_db_session() as session:
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
@@ -546,6 +547,7 @@ class ImageManager:
existing_image.vlm_processed = False
existing_image.count += 1
await session.commit()
# 如果已有描述,直接返回
if existing_image.description and existing_image.description.strip():
@@ -556,6 +558,7 @@ class ImageManager:
# 更新数据库中的描述
existing_image.description = description.replace("[图片:", "").replace("]", "")
existing_image.vlm_processed = True
await session.commit()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
# print(f"图片不存在: {image_hash}")
@@ -588,7 +591,7 @@ class ImageManager:
count=1,
)
session.add(new_img)
session.commit()
await session.commit()
return image_id, f"[picid:{image_id}]"

View File

@@ -146,7 +146,7 @@ class LLMUsageRecorder:
"""
@staticmethod
def record_usage_to_database(
async def record_usage_to_database(
model_info: ModelInfo,
model_usage: UsageRecord,
user_id: str,
@@ -161,7 +161,7 @@ class LLMUsageRecorder:
session = None
try:
# 使用 SQLAlchemy 会话创建记录
with get_db_session() as session:
async with get_db_session() as session:
usage_record = LLMUsage(
model_name=model_info.model_identifier,
model_assign_name=model_info.name,
@@ -179,7 +179,7 @@ class LLMUsageRecorder:
)
session.add(usage_record)
session.commit()
await session.commit()
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "

View File

@@ -202,7 +202,7 @@ class LLMRequest:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
await llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
@@ -367,7 +367,7 @@ class LLMRequest:
# 成功获取响应
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
await llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
time_cost=time.time() - start_time,
@@ -442,7 +442,7 @@ class LLMRequest:
embedding = response.embedding
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
await llm_usage_recorder.record_usage_to_database(
model_info=model_info,
time_cost=time.time() - start_time,
model_usage=usage,