From d53238dfc931e0130f74dc8bbe59c71d79a19197 Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 11:57:22 +0800 Subject: [PATCH] =?UTF-8?q?4=E6=AC=A1=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/message.py | 10 ++- src/chat/message_receive/storage.py | 2 +- src/chat/planner_actions/plan_filter.py | 4 +- src/chat/utils/utils_image.py | 101 ++++++++++++------------ src/llm_models/utils.py | 6 +- src/llm_models/utils_model.py | 6 +- 6 files changed, 69 insertions(+), 60 deletions(-) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 3bd522c8b..4a7608761 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -95,17 +95,23 @@ class Message(MessageBase, metaclass=ABCMeta): class MessageRecv(Message): """接收消息类,用于处理从MessageCQ序列化的消息""" - def __init__(self, message_dict: dict[str, Any], message_id: str, chat_stream: "ChatStream", user_info: UserInfo): + def __init__(self, message_dict: dict[str, Any]): """从MessageCQ的字典初始化 Args: message_dict: MessageCQ序列化后的字典 """ - super().__init__(message_id, chat_stream, user_info) + # Manually initialize attributes from MessageBase and Message self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.raw_message = message_dict.get("raw_message") + + self.chat_stream = None + self.reply = None self.processed_plain_text = message_dict.get("processed_plain_text", "") + self.memorized_times = 0 + + # MessageRecv specific attributes self.is_emoji = False self.has_emoji = False self.is_picid = False diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index edb007238..4666fae67 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -123,7 +123,7 @@ class MessageStorage: is_picid=is_picid, ) async with get_db_session() as session: - await session.add(new_message) + session.add(new_message) except Exception: logger.exception("存储消息失败") diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 91237c9cb..19d11bc4e 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -124,7 +124,7 @@ class PlanFilter: if plan.mode == ChatMode.PROACTIVE: long_term_memory_block = await self._get_long_term_memory_context() - chat_content_block, message_id_list = build_readable_messages_with_id( + chat_content_block, message_id_list = await build_readable_messages_with_id( messages=[msg.flatten() for msg in plan.chat_history], timestamp_mode="normal", truncate=False, @@ -160,7 +160,7 @@ class PlanFilter: show_actions=True, ) - actions_before_now = get_actions_by_timestamp_with_chat( + actions_before_now = await get_actions_by_timestamp_with_chat( chat_id=plan.chat_id, timestamp_start=time.time() - 3600, timestamp_end=time.time(), diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 18adc8a9b..93ec14957 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -69,7 +69,7 @@ class ImageManager: os.makedirs(self.IMAGE_DIR, exist_ok=True) @staticmethod - def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: + async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 Args: @@ -80,22 +80,22 @@ class ImageManager: Optional[str]: 描述文本,如果不存在则返回None """ try: - with get_db_session() as session: - record = session.execute( + async with get_db_session() as session: + record = (await session.execute( select(ImageDescriptions).where( and_( ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type, ) ) - ).scalar() + )).scalar() return record.description if record else None except Exception as e: logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}") return None @staticmethod - def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: + async def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: """保存图片描述到数据库 Args: @@ -105,16 +105,16 @@ class ImageManager: """ try: current_timestamp = time.time() - with get_db_session() as session: + async with get_db_session() as session: # 查找现有记录 - existing = session.execute( + existing = (await session.execute( select(ImageDescriptions).where( and_( ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type, ) ) - ).scalar() + )).scalar() if existing: # 更新现有记录 @@ -129,7 +129,7 @@ class ImageManager: timestamp=current_timestamp, ) session.add(new_desc) - session.commit() + await session.commit() # 会在上下文管理器中自动调用 except Exception as e: logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}") @@ -175,7 +175,7 @@ class ImageManager: logger.debug(f"查询EmojiManager时出错: {e}") # 查询ImageDescriptions表的缓存描述 - if cached_description := self._get_description_from_db(image_hash, "emoji"): + if cached_description := await self._get_description_from_db(image_hash, "emoji"): logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...") return f"[表情包:{cached_description}]" @@ -239,7 +239,7 @@ class ImageManager: logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}") - if cached_description := self._get_description_from_db(image_hash, "emoji"): + if cached_description := await self._get_description_from_db(image_hash, "emoji"): logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") return f"[表情包:{cached_description}]" @@ -261,10 +261,10 @@ class ImageManager: try: from src.common.database.sqlalchemy_models import get_db_session - with get_db_session() as session: - existing_img = session.execute( + async with get_db_session() as session: + existing_img = (await session.execute( select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji")) - ).scalar() + )).scalar() if existing_img: existing_img.path = file_path @@ -279,7 +279,7 @@ class ImageManager: timestamp=current_timestamp, ) session.add(new_img) - session.commit() + await session.commit() except Exception as e: logger.error(f"保存到Images表失败: {str(e)}") @@ -289,7 +289,7 @@ class ImageManager: logger.debug("偷取表情包功能已关闭,跳过保存。") # 保存最终的情感标签到缓存 (ImageDescriptions表) - self._save_description_to_db(image_hash, final_emotion, "emoji") + await self._save_description_to_db(image_hash, final_emotion, "emoji") return f"[表情包:{final_emotion}]" @@ -306,9 +306,9 @@ class ImageManager: image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - # 优先检查Images表中是否已有完整的描述 - with get_db_session() as session: - existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar() + async with get_db_session() as session: + # 优先检查Images表中是否已有完整的描述 + existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar() if existing_image: # 更新计数 if hasattr(existing_image, "count") and existing_image.count is not None: @@ -318,34 +318,34 @@ class ImageManager: # 如果已有描述,直接返回 if existing_image.description: + await session.commit() logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...") return f"[图片:{existing_image.description}]" - if cached_description := self._get_description_from_db(image_hash, "image"): - logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...") - return f"[图片:{cached_description}]" + # 如果没有描述,继续在当前会话中操作 + if cached_description := await self._get_description_from_db(image_hash, "image"): + logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...") + return f"[图片:{cached_description}]" - # 调用AI获取描述 - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore - prompt = global_config.custom_prompt.image_prompt - logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") - description, _ = await self.vlm.generate_response_for_image( - prompt, image_base64, image_format, temperature=0.4, max_tokens=300 - ) + # 调用AI获取描述 + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore + prompt = global_config.custom_prompt.image_prompt + logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.4, max_tokens=300 + ) - if description is None: - logger.warning("AI未能生成图片描述") - return "[图片(描述生成失败)]" + if description is None: + logger.warning("AI未能生成图片描述") + return "[图片(描述生成失败)]" - # 保存图片和描述 - current_timestamp = time.time() - filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" - image_dir = os.path.join(self.IMAGE_DIR, "image") - os.makedirs(image_dir, exist_ok=True) - file_path = os.path.join(image_dir, filename) + # 保存图片和描述 + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + image_dir = os.path.join(self.IMAGE_DIR, "image") + os.makedirs(image_dir, exist_ok=True) + file_path = os.path.join(image_dir, filename) - try: - # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) @@ -358,7 +358,6 @@ class ImageManager: existing_image.image_id = str(uuid.uuid4()) if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None: existing_image.vlm_processed = True - logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...") else: new_img = Images( @@ -372,13 +371,15 @@ class ImageManager: count=1, ) session.add(new_img) - logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") - except Exception as e: - logger.error(f"保存图片文件或元数据失败: {str(e)}") - # 保存描述到ImageDescriptions表作为备用缓存 - self._save_description_to_db(image_hash, description, "image") + await session.commit() + + # 保存描述到ImageDescriptions表作为备用缓存 + await self._save_description_to_db(image_hash, description, "image") + + logger.info(f"[VLM完成] 图片描述生成: {description}...") + return f"[图片:{description}]" logger.info(f"[VLM完成] 图片描述生成: {description}...") return f"[图片:{description}]" @@ -525,8 +526,8 @@ class ImageManager: image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - with get_db_session() as session: - existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar() + async with get_db_session() as session: + existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar() if existing_image: # 检查是否缺少必要字段,如果缺少则创建新记录 if ( @@ -546,6 +547,7 @@ class ImageManager: existing_image.vlm_processed = False existing_image.count += 1 + await session.commit() # 如果已有描述,直接返回 if existing_image.description and existing_image.description.strip(): @@ -556,6 +558,7 @@ class ImageManager: # 更新数据库中的描述 existing_image.description = description.replace("[图片:", "").replace("]", "") existing_image.vlm_processed = True + await session.commit() return existing_image.image_id, f"[picid:{existing_image.image_id}]" # print(f"图片不存在: {image_hash}") @@ -588,7 +591,7 @@ class ImageManager: count=1, ) session.add(new_img) - session.commit() + await session.commit() return image_id, f"[picid:{image_id}]" diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 48f7be5f2..bf23f144a 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -146,7 +146,7 @@ class LLMUsageRecorder: """ @staticmethod - def record_usage_to_database( + async def record_usage_to_database( model_info: ModelInfo, model_usage: UsageRecord, user_id: str, @@ -161,7 +161,7 @@ class LLMUsageRecorder: session = None try: # 使用 SQLAlchemy 会话创建记录 - with get_db_session() as session: + async with get_db_session() as session: usage_record = LLMUsage( model_name=model_info.model_identifier, model_assign_name=model_info.name, @@ -179,7 +179,7 @@ class LLMUsageRecorder: ) session.add(usage_record) - session.commit() + await session.commit() logger.debug( f"Token使用情况 - 模型: {model_usage.model_name}, " diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index f223d1913..7662d0fe9 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -202,7 +202,7 @@ class LLMRequest: content, extracted_reasoning = self._extract_reasoning(content) reasoning_content = extracted_reasoning if usage := response.usage: - llm_usage_recorder.record_usage_to_database( + await llm_usage_recorder.record_usage_to_database( model_info=model_info, model_usage=usage, user_id="system", @@ -368,7 +368,7 @@ class LLMRequest: # 成功获取响应 if usage := response.usage: - llm_usage_recorder.record_usage_to_database( + await llm_usage_recorder.record_usage_to_database( model_info=model_info, model_usage=usage, time_cost=time.time() - start_time, @@ -443,7 +443,7 @@ class LLMRequest: embedding = response.embedding if usage := response.usage: - llm_usage_recorder.record_usage_to_database( + await llm_usage_recorder.record_usage_to_database( model_info=model_info, time_cost=time.time() - start_time, model_usage=usage,