4次修改

This commit is contained in:
tt-P607
2025-09-20 11:57:22 +08:00
committed by Windpicker-owo
parent 93542cadef
commit d53238dfc9
6 changed files with 69 additions and 60 deletions

View File

@@ -95,17 +95,23 @@ class Message(MessageBase, metaclass=ABCMeta):
class MessageRecv(Message): class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息""" """接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any], message_id: str, chat_stream: "ChatStream", user_info: UserInfo): def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化 """从MessageCQ的字典初始化
Args: Args:
message_dict: MessageCQ序列化后的字典 message_dict: MessageCQ序列化后的字典
""" """
super().__init__(message_id, chat_stream, user_info) # Manually initialize attributes from MessageBase and Message
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message") self.raw_message = message_dict.get("raw_message")
self.chat_stream = None
self.reply = None
self.processed_plain_text = message_dict.get("processed_plain_text", "") self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.memorized_times = 0
# MessageRecv specific attributes
self.is_emoji = False self.is_emoji = False
self.has_emoji = False self.has_emoji = False
self.is_picid = False self.is_picid = False

View File

@@ -123,7 +123,7 @@ class MessageStorage:
is_picid=is_picid, is_picid=is_picid,
) )
async with get_db_session() as session: async with get_db_session() as session:
await session.add(new_message) session.add(new_message)
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")

View File

@@ -124,7 +124,7 @@ class PlanFilter:
if plan.mode == ChatMode.PROACTIVE: if plan.mode == ChatMode.PROACTIVE:
long_term_memory_block = await self._get_long_term_memory_context() long_term_memory_block = await self._get_long_term_memory_context()
chat_content_block, message_id_list = build_readable_messages_with_id( chat_content_block, message_id_list = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history], messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal", timestamp_mode="normal",
truncate=False, truncate=False,
@@ -160,7 +160,7 @@ class PlanFilter:
show_actions=True, show_actions=True,
) )
actions_before_now = get_actions_by_timestamp_with_chat( actions_before_now = await get_actions_by_timestamp_with_chat(
chat_id=plan.chat_id, chat_id=plan.chat_id,
timestamp_start=time.time() - 3600, timestamp_start=time.time() - 3600,
timestamp_end=time.time(), timestamp_end=time.time(),

View File

@@ -69,7 +69,7 @@ class ImageManager:
os.makedirs(self.IMAGE_DIR, exist_ok=True) os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod @staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
Args: Args:
@@ -80,22 +80,22 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
try: try:
with get_db_session() as session: async with get_db_session() as session:
record = session.execute( record = (await session.execute(
select(ImageDescriptions).where( select(ImageDescriptions).where(
and_( and_(
ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type, ImageDescriptions.type == description_type,
) )
) )
).scalar() )).scalar()
return record.description if record else None return record.description if record else None
except Exception as e: except Exception as e:
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}") logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
return None return None
@staticmethod @staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: async def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库 """保存图片描述到数据库
Args: Args:
@@ -105,16 +105,16 @@ class ImageManager:
""" """
try: try:
current_timestamp = time.time() current_timestamp = time.time()
with get_db_session() as session: async with get_db_session() as session:
# 查找现有记录 # 查找现有记录
existing = session.execute( existing = (await session.execute(
select(ImageDescriptions).where( select(ImageDescriptions).where(
and_( and_(
ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type, ImageDescriptions.type == description_type,
) )
) )
).scalar() )).scalar()
if existing: if existing:
# 更新现有记录 # 更新现有记录
@@ -129,7 +129,7 @@ class ImageManager:
timestamp=current_timestamp, timestamp=current_timestamp,
) )
session.add(new_desc) session.add(new_desc)
session.commit() await session.commit()
# 会在上下文管理器中自动调用 # 会在上下文管理器中自动调用
except Exception as e: except Exception as e:
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}") logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
@@ -175,7 +175,7 @@ class ImageManager:
logger.debug(f"查询EmojiManager时出错: {e}") logger.debug(f"查询EmojiManager时出错: {e}")
# 查询ImageDescriptions表的缓存描述 # 查询ImageDescriptions表的缓存描述
if cached_description := self._get_description_from_db(image_hash, "emoji"): if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...") logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
return f"[表情包:{cached_description}]" return f"[表情包:{cached_description}]"
@@ -239,7 +239,7 @@ class ImageManager:
logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}") logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}")
if cached_description := self._get_description_from_db(image_hash, "emoji"): if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]" return f"[表情包:{cached_description}]"
@@ -261,10 +261,10 @@ class ImageManager:
try: try:
from src.common.database.sqlalchemy_models import get_db_session from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session: async with get_db_session() as session:
existing_img = session.execute( existing_img = (await session.execute(
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji")) select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
).scalar() )).scalar()
if existing_img: if existing_img:
existing_img.path = file_path existing_img.path = file_path
@@ -279,7 +279,7 @@ class ImageManager:
timestamp=current_timestamp, timestamp=current_timestamp,
) )
session.add(new_img) session.add(new_img)
session.commit() await session.commit()
except Exception as e: except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}") logger.error(f"保存到Images表失败: {str(e)}")
@@ -289,7 +289,7 @@ class ImageManager:
logger.debug("偷取表情包功能已关闭,跳过保存。") logger.debug("偷取表情包功能已关闭,跳过保存。")
# 保存最终的情感标签到缓存 (ImageDescriptions表) # 保存最终的情感标签到缓存 (ImageDescriptions表)
self._save_description_to_db(image_hash, final_emotion, "emoji") await self._save_description_to_db(image_hash, final_emotion, "emoji")
return f"[表情包:{final_emotion}]" return f"[表情包:{final_emotion}]"
@@ -306,9 +306,9 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述 async with get_db_session() as session:
with get_db_session() as session: # 优先检查Images表中是否已有完整的描述
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar() existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
if existing_image: if existing_image:
# 更新计数 # 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None: if hasattr(existing_image, "count") and existing_image.count is not None:
@@ -318,34 +318,34 @@ class ImageManager:
# 如果已有描述,直接返回 # 如果已有描述,直接返回
if existing_image.description: if existing_image.description:
await session.commit()
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...") logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...")
return f"[图片:{existing_image.description}]" return f"[图片:{existing_image.description}]"
if cached_description := self._get_description_from_db(image_hash, "image"): # 如果没有描述,继续在当前会话中操作
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...") if cached_description := await self._get_description_from_db(image_hash, "image"):
return f"[图片:{cached_description}]" logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
return f"[图片:{cached_description}]"
# 调用AI获取描述 # 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image( description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300 prompt, image_base64, image_format, temperature=0.4, max_tokens=300
) )
if description is None: if description is None:
logger.warning("AI未能生成图片描述") logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]" return "[图片(描述生成失败)]"
# 保存图片和描述 # 保存图片和描述
current_timestamp = time.time() current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "image") image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(image_dir, exist_ok=True) os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename) file_path = os.path.join(image_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
@@ -358,7 +358,6 @@ class ImageManager:
existing_image.image_id = str(uuid.uuid4()) existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None: if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True existing_image.vlm_processed = True
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...") logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else: else:
new_img = Images( new_img = Images(
@@ -372,13 +371,15 @@ class ImageManager:
count=1, count=1,
) )
session.add(new_img) session.add(new_img)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到ImageDescriptions表作为备用缓存 await session.commit()
self._save_description_to_db(image_hash, description, "image")
# 保存描述到ImageDescriptions表作为备用缓存
await self._save_description_to_db(image_hash, description, "image")
logger.info(f"[VLM完成] 图片描述生成: {description}...")
return f"[图片:{description}]"
logger.info(f"[VLM完成] 图片描述生成: {description}...") logger.info(f"[VLM完成] 图片描述生成: {description}...")
return f"[图片:{description}]" return f"[图片:{description}]"
@@ -525,8 +526,8 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
with get_db_session() as session: async with get_db_session() as session:
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar() existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
if existing_image: if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录 # 检查是否缺少必要字段,如果缺少则创建新记录
if ( if (
@@ -546,6 +547,7 @@ class ImageManager:
existing_image.vlm_processed = False existing_image.vlm_processed = False
existing_image.count += 1 existing_image.count += 1
await session.commit()
# 如果已有描述,直接返回 # 如果已有描述,直接返回
if existing_image.description and existing_image.description.strip(): if existing_image.description and existing_image.description.strip():
@@ -556,6 +558,7 @@ class ImageManager:
# 更新数据库中的描述 # 更新数据库中的描述
existing_image.description = description.replace("[图片:", "").replace("]", "") existing_image.description = description.replace("[图片:", "").replace("]", "")
existing_image.vlm_processed = True existing_image.vlm_processed = True
await session.commit()
return existing_image.image_id, f"[picid:{existing_image.image_id}]" return existing_image.image_id, f"[picid:{existing_image.image_id}]"
# print(f"图片不存在: {image_hash}") # print(f"图片不存在: {image_hash}")
@@ -588,7 +591,7 @@ class ImageManager:
count=1, count=1,
) )
session.add(new_img) session.add(new_img)
session.commit() await session.commit()
return image_id, f"[picid:{image_id}]" return image_id, f"[picid:{image_id}]"

View File

@@ -146,7 +146,7 @@ class LLMUsageRecorder:
""" """
@staticmethod @staticmethod
def record_usage_to_database( async def record_usage_to_database(
model_info: ModelInfo, model_info: ModelInfo,
model_usage: UsageRecord, model_usage: UsageRecord,
user_id: str, user_id: str,
@@ -161,7 +161,7 @@ class LLMUsageRecorder:
session = None session = None
try: try:
# 使用 SQLAlchemy 会话创建记录 # 使用 SQLAlchemy 会话创建记录
with get_db_session() as session: async with get_db_session() as session:
usage_record = LLMUsage( usage_record = LLMUsage(
model_name=model_info.model_identifier, model_name=model_info.model_identifier,
model_assign_name=model_info.name, model_assign_name=model_info.name,
@@ -179,7 +179,7 @@ class LLMUsageRecorder:
) )
session.add(usage_record) session.add(usage_record)
session.commit() await session.commit()
logger.debug( logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, " f"Token使用情况 - 模型: {model_usage.model_name}, "

View File

@@ -202,7 +202,7 @@ class LLMRequest:
content, extracted_reasoning = self._extract_reasoning(content) content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning reasoning_content = extracted_reasoning
if usage := response.usage: if usage := response.usage:
llm_usage_recorder.record_usage_to_database( await llm_usage_recorder.record_usage_to_database(
model_info=model_info, model_info=model_info,
model_usage=usage, model_usage=usage,
user_id="system", user_id="system",
@@ -368,7 +368,7 @@ class LLMRequest:
# 成功获取响应 # 成功获取响应
if usage := response.usage: if usage := response.usage:
llm_usage_recorder.record_usage_to_database( await llm_usage_recorder.record_usage_to_database(
model_info=model_info, model_info=model_info,
model_usage=usage, model_usage=usage,
time_cost=time.time() - start_time, time_cost=time.time() - start_time,
@@ -443,7 +443,7 @@ class LLMRequest:
embedding = response.embedding embedding = response.embedding
if usage := response.usage: if usage := response.usage:
llm_usage_recorder.record_usage_to_database( await llm_usage_recorder.record_usage_to_database(
model_info=model_info, model_info=model_info,
time_cost=time.time() - start_time, time_cost=time.time() - start_time,
model_usage=usage, model_usage=usage,