4次修改
This commit is contained in:
@@ -95,17 +95,23 @@ class Message(MessageBase, metaclass=ABCMeta):
|
|||||||
class MessageRecv(Message):
|
class MessageRecv(Message):
|
||||||
"""接收消息类,用于处理从MessageCQ序列化的消息"""
|
"""接收消息类,用于处理从MessageCQ序列化的消息"""
|
||||||
|
|
||||||
def __init__(self, message_dict: dict[str, Any], message_id: str, chat_stream: "ChatStream", user_info: UserInfo):
|
def __init__(self, message_dict: dict[str, Any]):
|
||||||
"""从MessageCQ的字典初始化
|
"""从MessageCQ的字典初始化
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_dict: MessageCQ序列化后的字典
|
message_dict: MessageCQ序列化后的字典
|
||||||
"""
|
"""
|
||||||
super().__init__(message_id, chat_stream, user_info)
|
# Manually initialize attributes from MessageBase and Message
|
||||||
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||||
self.raw_message = message_dict.get("raw_message")
|
self.raw_message = message_dict.get("raw_message")
|
||||||
|
|
||||||
|
self.chat_stream = None
|
||||||
|
self.reply = None
|
||||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||||
|
self.memorized_times = 0
|
||||||
|
|
||||||
|
# MessageRecv specific attributes
|
||||||
self.is_emoji = False
|
self.is_emoji = False
|
||||||
self.has_emoji = False
|
self.has_emoji = False
|
||||||
self.is_picid = False
|
self.is_picid = False
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class MessageStorage:
|
|||||||
is_picid=is_picid,
|
is_picid=is_picid,
|
||||||
)
|
)
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
await session.add(new_message)
|
session.add(new_message)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储消息失败")
|
logger.exception("存储消息失败")
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ class PlanFilter:
|
|||||||
if plan.mode == ChatMode.PROACTIVE:
|
if plan.mode == ChatMode.PROACTIVE:
|
||||||
long_term_memory_block = await self._get_long_term_memory_context()
|
long_term_memory_block = await self._get_long_term_memory_context()
|
||||||
|
|
||||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
||||||
messages=[msg.flatten() for msg in plan.chat_history],
|
messages=[msg.flatten() for msg in plan.chat_history],
|
||||||
timestamp_mode="normal",
|
timestamp_mode="normal",
|
||||||
truncate=False,
|
truncate=False,
|
||||||
@@ -160,7 +160,7 @@ class PlanFilter:
|
|||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||||
chat_id=plan.chat_id,
|
chat_id=plan.chat_id,
|
||||||
timestamp_start=time.time() - 3600,
|
timestamp_start=time.time() - 3600,
|
||||||
timestamp_end=time.time(),
|
timestamp_end=time.time(),
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class ImageManager:
|
|||||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||||
"""从数据库获取图片描述
|
"""从数据库获取图片描述
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -80,22 +80,22 @@ class ImageManager:
|
|||||||
Optional[str]: 描述文本,如果不存在则返回None
|
Optional[str]: 描述文本,如果不存在则返回None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
record = session.execute(
|
record = (await session.execute(
|
||||||
select(ImageDescriptions).where(
|
select(ImageDescriptions).where(
|
||||||
and_(
|
and_(
|
||||||
ImageDescriptions.image_description_hash == image_hash,
|
ImageDescriptions.image_description_hash == image_hash,
|
||||||
ImageDescriptions.type == description_type,
|
ImageDescriptions.type == description_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
).scalar()
|
)).scalar()
|
||||||
return record.description if record else None
|
return record.description if record else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
async def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||||
"""保存图片描述到数据库
|
"""保存图片描述到数据库
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -105,16 +105,16 @@ class ImageManager:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
current_timestamp = time.time()
|
current_timestamp = time.time()
|
||||||
with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 查找现有记录
|
# 查找现有记录
|
||||||
existing = session.execute(
|
existing = (await session.execute(
|
||||||
select(ImageDescriptions).where(
|
select(ImageDescriptions).where(
|
||||||
and_(
|
and_(
|
||||||
ImageDescriptions.image_description_hash == image_hash,
|
ImageDescriptions.image_description_hash == image_hash,
|
||||||
ImageDescriptions.type == description_type,
|
ImageDescriptions.type == description_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
).scalar()
|
)).scalar()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# 更新现有记录
|
# 更新现有记录
|
||||||
@@ -129,7 +129,7 @@ class ImageManager:
|
|||||||
timestamp=current_timestamp,
|
timestamp=current_timestamp,
|
||||||
)
|
)
|
||||||
session.add(new_desc)
|
session.add(new_desc)
|
||||||
session.commit()
|
await session.commit()
|
||||||
# 会在上下文管理器中自动调用
|
# 会在上下文管理器中自动调用
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||||
@@ -175,7 +175,7 @@ class ImageManager:
|
|||||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||||
|
|
||||||
# 查询ImageDescriptions表的缓存描述
|
# 查询ImageDescriptions表的缓存描述
|
||||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
|
||||||
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
||||||
return f"[表情包:{cached_description}]"
|
return f"[表情包:{cached_description}]"
|
||||||
|
|
||||||
@@ -239,7 +239,7 @@ class ImageManager:
|
|||||||
|
|
||||||
logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}")
|
logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}")
|
||||||
|
|
||||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
|
||||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||||
return f"[表情包:{cached_description}]"
|
return f"[表情包:{cached_description}]"
|
||||||
|
|
||||||
@@ -261,10 +261,10 @@ class ImageManager:
|
|||||||
try:
|
try:
|
||||||
from src.common.database.sqlalchemy_models import get_db_session
|
from src.common.database.sqlalchemy_models import get_db_session
|
||||||
|
|
||||||
with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
existing_img = session.execute(
|
existing_img = (await session.execute(
|
||||||
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
|
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
|
||||||
).scalar()
|
)).scalar()
|
||||||
|
|
||||||
if existing_img:
|
if existing_img:
|
||||||
existing_img.path = file_path
|
existing_img.path = file_path
|
||||||
@@ -279,7 +279,7 @@ class ImageManager:
|
|||||||
timestamp=current_timestamp,
|
timestamp=current_timestamp,
|
||||||
)
|
)
|
||||||
session.add(new_img)
|
session.add(new_img)
|
||||||
session.commit()
|
await session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存到Images表失败: {str(e)}")
|
logger.error(f"保存到Images表失败: {str(e)}")
|
||||||
|
|
||||||
@@ -289,7 +289,7 @@ class ImageManager:
|
|||||||
logger.debug("偷取表情包功能已关闭,跳过保存。")
|
logger.debug("偷取表情包功能已关闭,跳过保存。")
|
||||||
|
|
||||||
# 保存最终的情感标签到缓存 (ImageDescriptions表)
|
# 保存最终的情感标签到缓存 (ImageDescriptions表)
|
||||||
self._save_description_to_db(image_hash, final_emotion, "emoji")
|
await self._save_description_to_db(image_hash, final_emotion, "emoji")
|
||||||
|
|
||||||
return f"[表情包:{final_emotion}]"
|
return f"[表情包:{final_emotion}]"
|
||||||
|
|
||||||
@@ -306,9 +306,9 @@ class ImageManager:
|
|||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
# 优先检查Images表中是否已有完整的描述
|
async with get_db_session() as session:
|
||||||
with get_db_session() as session:
|
# 优先检查Images表中是否已有完整的描述
|
||||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
|
||||||
if existing_image:
|
if existing_image:
|
||||||
# 更新计数
|
# 更新计数
|
||||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||||
@@ -318,34 +318,34 @@ class ImageManager:
|
|||||||
|
|
||||||
# 如果已有描述,直接返回
|
# 如果已有描述,直接返回
|
||||||
if existing_image.description:
|
if existing_image.description:
|
||||||
|
await session.commit()
|
||||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...")
|
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...")
|
||||||
return f"[图片:{existing_image.description}]"
|
return f"[图片:{existing_image.description}]"
|
||||||
|
|
||||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
# 如果没有描述,继续在当前会话中操作
|
||||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
if cached_description := await self._get_description_from_db(image_hash, "image"):
|
||||||
return f"[图片:{cached_description}]"
|
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
||||||
|
return f"[图片:{cached_description}]"
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
prompt = global_config.custom_prompt.image_prompt
|
prompt = global_config.custom_prompt.image_prompt
|
||||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||||
description, _ = await self.vlm.generate_response_for_image(
|
description, _ = await self.vlm.generate_response_for_image(
|
||||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||||
)
|
)
|
||||||
|
|
||||||
if description is None:
|
if description is None:
|
||||||
logger.warning("AI未能生成图片描述")
|
logger.warning("AI未能生成图片描述")
|
||||||
return "[图片(描述生成失败)]"
|
return "[图片(描述生成失败)]"
|
||||||
|
|
||||||
# 保存图片和描述
|
# 保存图片和描述
|
||||||
current_timestamp = time.time()
|
current_timestamp = time.time()
|
||||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||||
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
||||||
os.makedirs(image_dir, exist_ok=True)
|
os.makedirs(image_dir, exist_ok=True)
|
||||||
file_path = os.path.join(image_dir, filename)
|
file_path = os.path.join(image_dir, filename)
|
||||||
|
|
||||||
try:
|
|
||||||
# 保存文件
|
|
||||||
with open(file_path, "wb") as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
|
|
||||||
@@ -358,7 +358,6 @@ class ImageManager:
|
|||||||
existing_image.image_id = str(uuid.uuid4())
|
existing_image.image_id = str(uuid.uuid4())
|
||||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||||
existing_image.vlm_processed = True
|
existing_image.vlm_processed = True
|
||||||
|
|
||||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||||
else:
|
else:
|
||||||
new_img = Images(
|
new_img = Images(
|
||||||
@@ -372,13 +371,15 @@ class ImageManager:
|
|||||||
count=1,
|
count=1,
|
||||||
)
|
)
|
||||||
session.add(new_img)
|
session.add(new_img)
|
||||||
|
|
||||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
|
||||||
|
|
||||||
# 保存描述到ImageDescriptions表作为备用缓存
|
await session.commit()
|
||||||
self._save_description_to_db(image_hash, description, "image")
|
|
||||||
|
# 保存描述到ImageDescriptions表作为备用缓存
|
||||||
|
await self._save_description_to_db(image_hash, description, "image")
|
||||||
|
|
||||||
|
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
||||||
|
return f"[图片:{description}]"
|
||||||
|
|
||||||
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
||||||
return f"[图片:{description}]"
|
return f"[图片:{description}]"
|
||||||
@@ -525,8 +526,8 @@ class ImageManager:
|
|||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
|
||||||
if existing_image:
|
if existing_image:
|
||||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||||
if (
|
if (
|
||||||
@@ -546,6 +547,7 @@ class ImageManager:
|
|||||||
existing_image.vlm_processed = False
|
existing_image.vlm_processed = False
|
||||||
|
|
||||||
existing_image.count += 1
|
existing_image.count += 1
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
# 如果已有描述,直接返回
|
# 如果已有描述,直接返回
|
||||||
if existing_image.description and existing_image.description.strip():
|
if existing_image.description and existing_image.description.strip():
|
||||||
@@ -556,6 +558,7 @@ class ImageManager:
|
|||||||
# 更新数据库中的描述
|
# 更新数据库中的描述
|
||||||
existing_image.description = description.replace("[图片:", "").replace("]", "")
|
existing_image.description = description.replace("[图片:", "").replace("]", "")
|
||||||
existing_image.vlm_processed = True
|
existing_image.vlm_processed = True
|
||||||
|
await session.commit()
|
||||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||||
|
|
||||||
# print(f"图片不存在: {image_hash}")
|
# print(f"图片不存在: {image_hash}")
|
||||||
@@ -588,7 +591,7 @@ class ImageManager:
|
|||||||
count=1,
|
count=1,
|
||||||
)
|
)
|
||||||
session.add(new_img)
|
session.add(new_img)
|
||||||
session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return image_id, f"[picid:{image_id}]"
|
return image_id, f"[picid:{image_id}]"
|
||||||
|
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ class LLMUsageRecorder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def record_usage_to_database(
|
async def record_usage_to_database(
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
model_usage: UsageRecord,
|
model_usage: UsageRecord,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -161,7 +161,7 @@ class LLMUsageRecorder:
|
|||||||
session = None
|
session = None
|
||||||
try:
|
try:
|
||||||
# 使用 SQLAlchemy 会话创建记录
|
# 使用 SQLAlchemy 会话创建记录
|
||||||
with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
usage_record = LLMUsage(
|
usage_record = LLMUsage(
|
||||||
model_name=model_info.model_identifier,
|
model_name=model_info.model_identifier,
|
||||||
model_assign_name=model_info.name,
|
model_assign_name=model_info.name,
|
||||||
@@ -179,7 +179,7 @@ class LLMUsageRecorder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
session.add(usage_record)
|
session.add(usage_record)
|
||||||
session.commit()
|
await session.commit()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class LLMRequest:
|
|||||||
content, extracted_reasoning = self._extract_reasoning(content)
|
content, extracted_reasoning = self._extract_reasoning(content)
|
||||||
reasoning_content = extracted_reasoning
|
reasoning_content = extracted_reasoning
|
||||||
if usage := response.usage:
|
if usage := response.usage:
|
||||||
llm_usage_recorder.record_usage_to_database(
|
await llm_usage_recorder.record_usage_to_database(
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
model_usage=usage,
|
model_usage=usage,
|
||||||
user_id="system",
|
user_id="system",
|
||||||
@@ -368,7 +368,7 @@ class LLMRequest:
|
|||||||
|
|
||||||
# 成功获取响应
|
# 成功获取响应
|
||||||
if usage := response.usage:
|
if usage := response.usage:
|
||||||
llm_usage_recorder.record_usage_to_database(
|
await llm_usage_recorder.record_usage_to_database(
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
model_usage=usage,
|
model_usage=usage,
|
||||||
time_cost=time.time() - start_time,
|
time_cost=time.time() - start_time,
|
||||||
@@ -443,7 +443,7 @@ class LLMRequest:
|
|||||||
embedding = response.embedding
|
embedding = response.embedding
|
||||||
|
|
||||||
if usage := response.usage:
|
if usage := response.usage:
|
||||||
llm_usage_recorder.record_usage_to_database(
|
await llm_usage_recorder.record_usage_to_database(
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
time_cost=time.time() - start_time,
|
time_cost=time.time() - start_time,
|
||||||
model_usage=usage,
|
model_usage=usage,
|
||||||
|
|||||||
Reference in New Issue
Block a user