Merge branch 'dev' of https://github.com/mcn1630/MoFox-Core into patch
This commit is contained in:
@@ -1107,28 +1107,22 @@ class EmojiManager:
|
|||||||
if emoji_base64 is None: # 再次检查读取
|
if emoji_base64 is None: # 再次检查读取
|
||||||
logger.error(f"[注册失败] 无法读取图片以生成描述: {filename}")
|
logger.error(f"[注册失败] 无法读取图片以生成描述: {filename}")
|
||||||
return False
|
return False
|
||||||
task = asyncio.create_task(self.build_emoji_description(emoji_base64))
|
|
||||||
|
|
||||||
def after_built_description(fut: asyncio.Future):
|
# 等待描述生成完成
|
||||||
if fut.cancelled():
|
description, emotions = await self.build_emoji_description(emoji_base64)
|
||||||
logger.error(f"[注册失败] 描述生成任务被取消: {filename}")
|
|
||||||
elif fut.exception():
|
|
||||||
logger.error(f"[注册失败] 描述生成任务出错 ({filename}): {fut.exception()}")
|
|
||||||
else:
|
|
||||||
description, emotions = fut.result()
|
|
||||||
|
|
||||||
if not description: # 检查描述是否成功生成或审核通过
|
if not description: # 检查描述是否成功生成或审核通过
|
||||||
logger.warning(f"[注册失败] 未能生成有效描述或审核未通过: {filename}")
|
logger.warning(f"[注册失败] 未能生成有效描述或审核未通过: {filename}")
|
||||||
# 删除未能生成描述的文件
|
# 删除未能生成描述的文件
|
||||||
try:
|
try:
|
||||||
os.remove(file_full_path)
|
os.remove(file_full_path)
|
||||||
logger.info(f"[清理] 删除描述生成失败的文件: {filename}")
|
logger.info(f"[清理] 删除描述生成失败的文件: {filename}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}")
|
logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}")
|
||||||
return False
|
return False
|
||||||
new_emoji.description = description
|
|
||||||
new_emoji.emotion = emotions
|
new_emoji.description = description
|
||||||
task.add_done_callback(after_built_description)
|
new_emoji.emotion = emotions
|
||||||
except Exception as build_desc_error:
|
except Exception as build_desc_error:
|
||||||
logger.error(f"[注册失败] 生成描述/情感时出错 ({filename}): {build_desc_error}")
|
logger.error(f"[注册失败] 生成描述/情感时出错 ({filename}): {build_desc_error}")
|
||||||
# 同样考虑删除文件
|
# 同样考虑删除文件
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class MessageHandler:
|
|||||||
predicate=_is_adapter_response,
|
predicate=_is_adapter_response,
|
||||||
handler=self._handle_adapter_response_route,
|
handler=self._handle_adapter_response_route,
|
||||||
name="adapter_response_handler",
|
name="adapter_response_handler",
|
||||||
message_type="adapter_response",
|
priority=100
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册 notice 消息处理器(处理通知消息,如戳一戳、禁言等)
|
# 注册 notice 消息处理器(处理通知消息,如戳一戳、禁言等)
|
||||||
@@ -152,7 +152,7 @@ class MessageHandler:
|
|||||||
predicate=_is_notice_message,
|
predicate=_is_notice_message,
|
||||||
handler=self._handle_notice_message,
|
handler=self._handle_notice_message,
|
||||||
name="notice_message_handler",
|
name="notice_message_handler",
|
||||||
message_type="notice",
|
priority=90
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册默认消息处理器(处理所有其他消息)
|
# 注册默认消息处理器(处理所有其他消息)
|
||||||
@@ -160,6 +160,7 @@ class MessageHandler:
|
|||||||
predicate=lambda _: True, # 匹配所有消息
|
predicate=lambda _: True, # 匹配所有消息
|
||||||
handler=self._handle_normal_message,
|
handler=self._handle_normal_message,
|
||||||
name="default_message_handler",
|
name="default_message_handler",
|
||||||
|
priority=50
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("MessageHandler 已向 MessageRuntime 注册处理器和钩子")
|
logger.info("MessageHandler 已向 MessageRuntime 注册处理器和钩子")
|
||||||
@@ -314,7 +315,7 @@ class MessageHandler:
|
|||||||
# 触发 notice 事件(可供插件监听)
|
# 触发 notice 事件(可供插件监听)
|
||||||
await event_manager.trigger_event(
|
await event_manager.trigger_event(
|
||||||
EventType.ON_NOTICE_RECEIVED,
|
EventType.ON_NOTICE_RECEIVED,
|
||||||
permission_group="USER",
|
permission_group="SYSTEM",
|
||||||
message=message,
|
message=message,
|
||||||
notice_type=notice_type,
|
notice_type=notice_type,
|
||||||
chat_stream=chat,
|
chat_stream=chat,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from mofox_wire import MessageEnvelope
|
|||||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
@@ -27,13 +27,13 @@ logger = get_logger("sender")
|
|||||||
|
|
||||||
async def send_envelope(
|
async def send_envelope(
|
||||||
envelope: MessageEnvelope,
|
envelope: MessageEnvelope,
|
||||||
chat_stream: "ChatStream" | None = None,
|
chat_stream: ChatStream | None = None,
|
||||||
db_message: DatabaseMessages | None = None,
|
db_message: DatabaseMessages | None = None,
|
||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送消息"""
|
"""发送消息"""
|
||||||
message_preview = truncate_message(
|
message_preview = truncate_message(
|
||||||
(db_message.processed_plain_text if db_message else str(envelope.get("message_segment", ""))),
|
(db_message.processed_plain_text or "" if db_message else str(envelope.get("message_segment", ""))),
|
||||||
max_length=120,
|
max_length=120,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,6 +81,7 @@ class HeartFCSender:
|
|||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
thinking_start_time: float = 0.0,
|
thinking_start_time: float = 0.0,
|
||||||
display_message: str | None = None,
|
display_message: str | None = None,
|
||||||
|
storage_user_info: "DatabaseUserInfo | None" = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
logger.error("消息缺少 chat_stream,无法发送")
|
logger.error("消息缺少 chat_stream,无法发送")
|
||||||
@@ -93,6 +94,13 @@ class HeartFCSender:
|
|||||||
platform=chat_stream.platform,
|
platform=chat_stream.platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 如果提供了用于存储的用户信息,则覆盖
|
||||||
|
if storage_message and storage_user_info:
|
||||||
|
db_message.user_info.user_id = storage_user_info.user_id
|
||||||
|
db_message.user_info.user_nickname = storage_user_info.user_nickname
|
||||||
|
db_message.user_info.user_cardname = storage_user_info.user_cardname
|
||||||
|
db_message.user_info.platform = storage_user_info.platform
|
||||||
|
|
||||||
# 使用调用方指定的展示文本
|
# 使用调用方指定的展示文本
|
||||||
if display_message:
|
if display_message:
|
||||||
db_message.display_message = display_message
|
db_message.display_message = display_message
|
||||||
@@ -125,9 +133,13 @@ class HeartFCSender:
|
|||||||
|
|
||||||
# 将发送的消息写入上下文历史
|
# 将发送的消息写入上下文历史
|
||||||
try:
|
try:
|
||||||
if chat_stream.context:
|
if chat_stream and chat_stream.context and global_config.chat:
|
||||||
context = chat_stream.context
|
context = chat_stream.context
|
||||||
max_context_size = getattr(global_config.chat, "max_context_size", 40)
|
chat_config = global_config.chat
|
||||||
|
if chat_config:
|
||||||
|
max_context_size = getattr(chat_config, "max_context_size", 40)
|
||||||
|
else:
|
||||||
|
max_context_size = 40
|
||||||
|
|
||||||
if len(context.history_messages) >= max_context_size:
|
if len(context.history_messages) >= max_context_size:
|
||||||
context.history_messages = context.history_messages[1:]
|
context.history_messages = context.history_messages[1:]
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class ImageManager:
|
|||||||
self._ensure_image_dir()
|
self._ensure_image_dir()
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
assert model_config is not None
|
||||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
@@ -189,7 +190,7 @@ class ImageManager:
|
|||||||
return "[表情包(描述生成失败)]"
|
return "[表情包(描述生成失败)]"
|
||||||
|
|
||||||
# 4. (可选) 如果启用了“偷表情包”,则将图片和完整描述存入待注册区
|
# 4. (可选) 如果启用了“偷表情包”,则将图片和完整描述存入待注册区
|
||||||
if global_config.emoji.steal_emoji:
|
if global_config and global_config.emoji and global_config.emoji.steal_emoji:
|
||||||
logger.debug(f"偷取表情包功能已开启,保存待注册表情包: {image_hash}")
|
logger.debug(f"偷取表情包功能已开启,保存待注册表情包: {image_hash}")
|
||||||
try:
|
try:
|
||||||
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
|
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
|
||||||
@@ -226,6 +227,22 @@ class ImageManager:
|
|||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
|
# 1.5. 如果是GIF,先转换为JPG
|
||||||
|
try:
|
||||||
|
image_format_check = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
|
||||||
|
if image_format_check == "gif":
|
||||||
|
logger.info(f"检测到GIF图片 (Hash: {image_hash[:8]}...),正在转换为JPG...")
|
||||||
|
if transformed_b64 := self.transform_gif(image_base64):
|
||||||
|
image_base64 = transformed_b64
|
||||||
|
image_bytes = base64.b64decode(image_base64)
|
||||||
|
logger.info("GIF转换成功,将使用转换后的图片进行描述")
|
||||||
|
else:
|
||||||
|
logger.error("GIF转换失败,无法生成描述")
|
||||||
|
return "[图片(GIF转换失败)]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"图片格式检测失败: {e!s},将按原格式处理")
|
||||||
|
|
||||||
|
|
||||||
# 2. 优先查询 Images 表缓存
|
# 2. 优先查询 Images 表缓存
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||||
@@ -242,6 +259,8 @@ class ImageManager:
|
|||||||
# 4. 如果都未命中,则同步调用VLM生成新描述
|
# 4. 如果都未命中,则同步调用VLM生成新描述
|
||||||
logger.info(f"[新图片识别] 无缓存 (Hash: {image_hash[:8]}...),调用VLM生成描述")
|
logger.info(f"[新图片识别] 无缓存 (Hash: {image_hash[:8]}...),调用VLM生成描述")
|
||||||
description = None
|
description = None
|
||||||
|
assert global_config is not None
|
||||||
|
assert global_config.custom_prompt is not None
|
||||||
prompt = global_config.custom_prompt.image_prompt
|
prompt = global_config.custom_prompt.image_prompt
|
||||||
logger.info(f"[识图VLM调用] Prompt: {prompt}")
|
logger.info(f"[识图VLM调用] Prompt: {prompt}")
|
||||||
for i in range(3): # 重试3次
|
for i in range(3): # 重试3次
|
||||||
|
|||||||
@@ -126,6 +126,12 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
用于特殊场景,如需要完全独立的连接时。
|
用于特殊场景,如需要完全独立的连接时。
|
||||||
一般情况下应使用 get_db_session()。
|
一般情况下应使用 get_db_session()。
|
||||||
|
|
||||||
|
事务管理说明:
|
||||||
|
- 正常退出时自动提交事务
|
||||||
|
- 发生异常时自动回滚事务
|
||||||
|
- 如果用户代码已手动调用 commit/rollback,再次调用是安全的
|
||||||
|
- 适用于所有数据库类型(SQLite, MySQL, PostgreSQL)
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
AsyncSession: SQLAlchemy异步会话对象
|
AsyncSession: SQLAlchemy异步会话对象
|
||||||
"""
|
"""
|
||||||
@@ -139,8 +145,16 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
await _apply_session_settings(session, global_config.database.database_type)
|
await _apply_session_settings(session, global_config.database.database_type)
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
# 正常退出时提交事务
|
||||||
|
# 这对所有数据库都很重要,因为 SQLAlchemy 默认不是 autocommit 模式
|
||||||
|
# 检查事务是否活动,避免在已回滚的事务上提交
|
||||||
|
if session.is_active:
|
||||||
|
await session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
await session.rollback()
|
# 检查是否需要回滚(事务是否活动)
|
||||||
|
if session.is_active:
|
||||||
|
await session.rollback()
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from typing import Any, TypeVar
|
|||||||
|
|
||||||
from sqlalchemy import delete, insert, select, update
|
from sqlalchemy import delete, insert, select, update
|
||||||
|
|
||||||
from src.common.database.core.session import get_db_session
|
from src.common.database.core.session import get_db_session_direct
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.memory_utils import estimate_size_smart
|
from src.common.memory_utils import estimate_size_smart
|
||||||
|
|
||||||
@@ -330,7 +330,7 @@ class AdaptiveBatchScheduler:
|
|||||||
operations: list[BatchOperation],
|
operations: list[BatchOperation],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""批量执行查询操作"""
|
"""批量执行查询操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session_direct() as session:
|
||||||
for op in operations:
|
for op in operations:
|
||||||
try:
|
try:
|
||||||
# 构建查询
|
# 构建查询
|
||||||
@@ -371,7 +371,7 @@ class AdaptiveBatchScheduler:
|
|||||||
operations: list[BatchOperation],
|
operations: list[BatchOperation],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""批量执行插入操作"""
|
"""批量执行插入操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session_direct() as session:
|
||||||
try:
|
try:
|
||||||
# 收集数据,并过滤掉 id=None 的情况(让数据库自动生成)
|
# 收集数据,并过滤掉 id=None 的情况(让数据库自动生成)
|
||||||
all_data = []
|
all_data = []
|
||||||
@@ -387,7 +387,7 @@ class AdaptiveBatchScheduler:
|
|||||||
# 批量插入
|
# 批量插入
|
||||||
stmt = insert(operations[0].model_class).values(all_data)
|
stmt = insert(operations[0].model_class).values(all_data)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
await session.commit()
|
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||||
|
|
||||||
# 设置结果
|
# 设置结果
|
||||||
for op in operations:
|
for op in operations:
|
||||||
@@ -402,20 +402,21 @@ class AdaptiveBatchScheduler:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批量插入失败: {e}")
|
logger.error(f"批量插入失败: {e}")
|
||||||
await session.rollback()
|
# 注意:rollback 由 get_db_session_direct 上下文管理器自动处理
|
||||||
for op in operations:
|
for op in operations:
|
||||||
if op.future and not op.future.done():
|
if op.future and not op.future.done():
|
||||||
op.future.set_exception(e)
|
op.future.set_exception(e)
|
||||||
|
raise # 重新抛出异常以触发 rollback
|
||||||
|
|
||||||
async def _execute_update_batch(
|
async def _execute_update_batch(
|
||||||
self,
|
self,
|
||||||
operations: list[BatchOperation],
|
operations: list[BatchOperation],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""批量执行更新操作"""
|
"""批量执行更新操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session_direct() as session:
|
||||||
results = []
|
results = []
|
||||||
try:
|
try:
|
||||||
# 🔧 修复:收集所有操作后一次性commit,而不是循环中多次commit
|
# 🔧 收集所有操作后一次性commit,而不是循环中多次commit
|
||||||
for op in operations:
|
for op in operations:
|
||||||
# 构建更新语句
|
# 构建更新语句
|
||||||
stmt = update(op.model_class)
|
stmt = update(op.model_class)
|
||||||
@@ -430,8 +431,7 @@ class AdaptiveBatchScheduler:
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
results.append((op, result.rowcount))
|
results.append((op, result.rowcount))
|
||||||
|
|
||||||
# 所有操作成功后,一次性commit
|
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
# 设置所有操作的结果
|
# 设置所有操作的结果
|
||||||
for op, rowcount in results:
|
for op, rowcount in results:
|
||||||
@@ -446,21 +446,22 @@ class AdaptiveBatchScheduler:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批量更新失败: {e}")
|
logger.error(f"批量更新失败: {e}")
|
||||||
await session.rollback()
|
# 注意:rollback 由 get_db_session_direct 上下文管理器自动处理
|
||||||
# 所有操作都失败
|
# 所有操作都失败
|
||||||
for op in operations:
|
for op in operations:
|
||||||
if op.future and not op.future.done():
|
if op.future and not op.future.done():
|
||||||
op.future.set_exception(e)
|
op.future.set_exception(e)
|
||||||
|
raise # 重新抛出异常以触发 rollback
|
||||||
|
|
||||||
async def _execute_delete_batch(
|
async def _execute_delete_batch(
|
||||||
self,
|
self,
|
||||||
operations: list[BatchOperation],
|
operations: list[BatchOperation],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""批量执行删除操作"""
|
"""批量执行删除操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session_direct() as session:
|
||||||
results = []
|
results = []
|
||||||
try:
|
try:
|
||||||
# 🔧 修复:收集所有操作后一次性commit,而不是循环中多次commit
|
# 🔧 收集所有操作后一次性commit,而不是循环中多次commit
|
||||||
for op in operations:
|
for op in operations:
|
||||||
# 构建删除语句
|
# 构建删除语句
|
||||||
stmt = delete(op.model_class)
|
stmt = delete(op.model_class)
|
||||||
@@ -472,8 +473,7 @@ class AdaptiveBatchScheduler:
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
results.append((op, result.rowcount))
|
results.append((op, result.rowcount))
|
||||||
|
|
||||||
# 所有操作成功后,一次性commit
|
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
# 设置所有操作的结果
|
# 设置所有操作的结果
|
||||||
for op, rowcount in results:
|
for op, rowcount in results:
|
||||||
@@ -488,11 +488,12 @@ class AdaptiveBatchScheduler:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批量删除失败: {e}")
|
logger.error(f"批量删除失败: {e}")
|
||||||
await session.rollback()
|
# 注意:rollback 由 get_db_session_direct 上下文管理器自动处理
|
||||||
# 所有操作都失败
|
# 所有操作都失败
|
||||||
for op in operations:
|
for op in operations:
|
||||||
if op.future and not op.future.done():
|
if op.future and not op.future.done():
|
||||||
op.future.set_exception(e)
|
op.future.set_exception(e)
|
||||||
|
raise # 重新抛出异常以触发 rollback
|
||||||
|
|
||||||
async def _adjust_parameters(self) -> None:
|
async def _adjust_parameters(self) -> None:
|
||||||
"""根据性能自适应调整参数"""
|
"""根据性能自适应调整参数"""
|
||||||
|
|||||||
@@ -123,6 +123,12 @@ class ConnectionPoolManager:
|
|||||||
"""
|
"""
|
||||||
获取数据库会话的透明包装器
|
获取数据库会话的透明包装器
|
||||||
如果有可用连接则复用,否则创建新连接
|
如果有可用连接则复用,否则创建新连接
|
||||||
|
|
||||||
|
事务管理说明:
|
||||||
|
- 正常退出时自动提交事务
|
||||||
|
- 发生异常时自动回滚事务
|
||||||
|
- 如果用户代码已手动调用 commit/rollback,再次调用是安全的(空操作)
|
||||||
|
- 支持所有数据库类型:SQLite、MySQL、PostgreSQL
|
||||||
"""
|
"""
|
||||||
connection_info = None
|
connection_info = None
|
||||||
|
|
||||||
@@ -151,21 +157,30 @@ class ConnectionPoolManager:
|
|||||||
|
|
||||||
yield connection_info.session
|
yield connection_info.session
|
||||||
|
|
||||||
# 🔧 修复:正常退出时提交事务
|
# 🔧 正常退出时提交事务
|
||||||
# 这对SQLite至关重要,因为SQLite没有autocommit
|
# 这对所有数据库(SQLite、MySQL、PostgreSQL)都很重要
|
||||||
|
# 因为 SQLAlchemy 默认使用事务模式,不会自动提交
|
||||||
|
# 注意:如果用户代码已调用 commit(),这里的 commit() 是安全的空操作
|
||||||
if connection_info and connection_info.session:
|
if connection_info and connection_info.session:
|
||||||
try:
|
try:
|
||||||
await connection_info.session.commit()
|
# 检查事务是否处于活动状态,避免在已回滚的事务上提交
|
||||||
|
if connection_info.session.is_active:
|
||||||
|
await connection_info.session.commit()
|
||||||
except Exception as commit_error:
|
except Exception as commit_error:
|
||||||
logger.warning(f"提交事务时出错: {commit_error}")
|
logger.warning(f"提交事务时出错: {commit_error}")
|
||||||
await connection_info.session.rollback()
|
try:
|
||||||
|
await connection_info.session.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass # 忽略回滚错误,因为事务可能已经结束
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# 发生错误时回滚连接
|
# 发生错误时回滚连接
|
||||||
if connection_info and connection_info.session:
|
if connection_info and connection_info.session:
|
||||||
try:
|
try:
|
||||||
await connection_info.session.rollback()
|
# 检查是否需要回滚(事务是否活动)
|
||||||
|
if connection_info.session.is_active:
|
||||||
|
await connection_info.session.rollback()
|
||||||
except Exception as rollback_error:
|
except Exception as rollback_error:
|
||||||
logger.warning(f"回滚连接时出错: {rollback_error}")
|
logger.warning(f"回滚连接时出错: {rollback_error}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ def _build_message_envelope(
|
|||||||
timestamp: float,
|
timestamp: float,
|
||||||
) -> MessageEnvelope:
|
) -> MessageEnvelope:
|
||||||
"""构建发送的 MessageEnvelope 数据结构"""
|
"""构建发送的 MessageEnvelope 数据结构"""
|
||||||
|
# 这里的 user_info 决定了消息要发给谁,所以在私聊场景下必须是目标用户
|
||||||
target_user_info = target_stream.user_info or bot_user_info
|
target_user_info = target_stream.user_info or bot_user_info
|
||||||
message_info: dict[str, Any] = {
|
message_info: dict[str, Any] = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
@@ -212,7 +213,7 @@ def _build_message_envelope(
|
|||||||
"platform": target_stream.group_info.platform,
|
"platform": target_stream.group_info.platform,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return { # type: ignore
|
||||||
"id": str(uuid.uuid4()),
|
"id": str(uuid.uuid4()),
|
||||||
"direction": "outgoing",
|
"direction": "outgoing",
|
||||||
"platform": target_stream.platform,
|
"platform": target_stream.platform,
|
||||||
@@ -257,9 +258,14 @@ async def _send_to_target(
|
|||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
message_id = f"send_api_{int(current_time * 1000)}"
|
message_id = f"send_api_{int(current_time * 1000)}"
|
||||||
|
|
||||||
|
bot_config = global_config.bot
|
||||||
|
if not bot_config:
|
||||||
|
logger.error("机器人配置丢失,无法构建机器人用户信息")
|
||||||
|
return False
|
||||||
|
|
||||||
bot_user_info = DatabaseUserInfo(
|
bot_user_info = DatabaseUserInfo(
|
||||||
user_id=str(global_config.bot.qq_account),
|
user_id=str(bot_config.qq_account),
|
||||||
user_nickname=global_config.bot.nickname,
|
user_nickname=bot_config.nickname,
|
||||||
platform=target_stream.platform,
|
platform=target_stream.platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -328,6 +334,7 @@ async def _send_to_target(
|
|||||||
show_log=show_log,
|
show_log=show_log,
|
||||||
thinking_start_time=current_time,
|
thinking_start_time=current_time,
|
||||||
display_message=display_message_for_db,
|
display_message=display_message_for_db,
|
||||||
|
storage_user_info=bot_user_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sent_msg:
|
if sent_msg:
|
||||||
|
|||||||
@@ -433,6 +433,7 @@ class EventManager:
|
|||||||
EventType.AFTER_LLM,
|
EventType.AFTER_LLM,
|
||||||
EventType.POST_SEND,
|
EventType.POST_SEND,
|
||||||
EventType.AFTER_SEND,
|
EventType.AFTER_SEND,
|
||||||
|
EventType.ON_NOTICE_RECEIVED
|
||||||
]
|
]
|
||||||
|
|
||||||
for event_name in default_events:
|
for event_name in default_events:
|
||||||
|
|||||||
@@ -49,14 +49,6 @@ class EmojiAction(BaseAction):
|
|||||||
----------------------------------------
|
----------------------------------------
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ========== 以下使用旧的激活配置(已废弃但兼容) ==========
|
|
||||||
# 激活设置
|
|
||||||
if global_config.emoji.emoji_activate_type == "llm":
|
|
||||||
activation_type = ActionActivationType.LLM_JUDGE
|
|
||||||
random_activation_probability = 0
|
|
||||||
else:
|
|
||||||
activation_type = ActionActivationType.RANDOM
|
|
||||||
random_activation_probability = global_config.emoji.emoji_chance
|
|
||||||
mode_enable = ChatMode.ALL
|
mode_enable = ChatMode.ALL
|
||||||
parallel_action = True
|
parallel_action = True
|
||||||
|
|
||||||
@@ -88,6 +80,15 @@ class EmojiAction(BaseAction):
|
|||||||
# 关联类型
|
# 关联类型
|
||||||
associated_types: ClassVar[list[str]] = ["emoji"]
|
associated_types: ClassVar[list[str]] = ["emoji"]
|
||||||
|
|
||||||
|
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
|
||||||
|
"""根据配置选择激活方式"""
|
||||||
|
assert global_config is not None
|
||||||
|
if global_config.emoji.emoji_activate_type == "llm":
|
||||||
|
return await self._llm_judge_activation(
|
||||||
|
judge_prompt=self.llm_judge_prompt, llm_judge_model=llm_judge_model
|
||||||
|
)
|
||||||
|
return await self._random_activation(global_config.emoji.emoji_chance)
|
||||||
|
|
||||||
async def execute(self) -> tuple[bool, str]:
|
async def execute(self) -> tuple[bool, str]:
|
||||||
"""执行表情动作"""
|
"""执行表情动作"""
|
||||||
logger.info(f"{self.log_prefix} 决定发送表情")
|
logger.info(f"{self.log_prefix} 决定发送表情")
|
||||||
@@ -95,6 +96,7 @@ class EmojiAction(BaseAction):
|
|||||||
try:
|
try:
|
||||||
# 1. 获取发送表情的原因
|
# 1. 获取发送表情的原因
|
||||||
reason = self.action_data.get("reason", "表达当前情绪")
|
reason = self.action_data.get("reason", "表达当前情绪")
|
||||||
|
main_reply_content = self.action_data.get("main_reply_content", "")
|
||||||
logger.info(f"{self.log_prefix} 发送表情原因: {reason}")
|
logger.info(f"{self.log_prefix} 发送表情原因: {reason}")
|
||||||
|
|
||||||
# 2. 获取所有有效的表情包对象
|
# 2. 获取所有有效的表情包对象
|
||||||
@@ -108,7 +110,7 @@ class EmojiAction(BaseAction):
|
|||||||
|
|
||||||
# 3. 根据历史记录筛选表情
|
# 3. 根据历史记录筛选表情
|
||||||
try:
|
try:
|
||||||
recent_emojis_desc = get_recent_emojis(self.chat_id, limit=10)
|
recent_emojis_desc = get_recent_emojis(self.chat_id, limit=20)
|
||||||
if recent_emojis_desc:
|
if recent_emojis_desc:
|
||||||
filtered_emojis = [emoji for emoji in all_emojis_obj if emoji.description not in recent_emojis_desc]
|
filtered_emojis = [emoji for emoji in all_emojis_obj if emoji.description not in recent_emojis_desc]
|
||||||
if filtered_emojis:
|
if filtered_emojis:
|
||||||
@@ -120,8 +122,8 @@ class EmojiAction(BaseAction):
|
|||||||
logger.error(f"{self.log_prefix} 获取或处理表情发送历史时出错: {e}")
|
logger.error(f"{self.log_prefix} 获取或处理表情发送历史时出错: {e}")
|
||||||
|
|
||||||
# 4. 准备情感数据和后备列表
|
# 4. 准备情感数据和后备列表
|
||||||
emotion_map: ClassVar = {}
|
emotion_map = {}
|
||||||
all_emojis_data: ClassVar = []
|
all_emojis_data = []
|
||||||
|
|
||||||
for emoji in all_emojis_obj:
|
for emoji in all_emojis_obj:
|
||||||
b64 = image_path_to_base64(emoji.full_path)
|
b64 = image_path_to_base64(emoji.full_path)
|
||||||
@@ -146,14 +148,15 @@ class EmojiAction(BaseAction):
|
|||||||
chosen_emotion = "表情包" # 默认描述,避免变量未定义错误
|
chosen_emotion = "表情包" # 默认描述,避免变量未定义错误
|
||||||
|
|
||||||
# 4. 根据配置选择不同的表情选择模式
|
# 4. 根据配置选择不同的表情选择模式
|
||||||
|
assert global_config is not None
|
||||||
if global_config.emoji.emoji_selection_mode == "emotion":
|
if global_config.emoji.emoji_selection_mode == "emotion":
|
||||||
# --- 情感标签选择模式 ---
|
# --- 情感标签选择模式 ---
|
||||||
if not available_emotions:
|
if not available_emotions:
|
||||||
logger.warning(f"{self.log_prefix} 获取到的表情包均无情感标签, 将随机发送")
|
logger.warning(f"{self.log_prefix} 获取到的表情包均无情感标签, 将随机发送")
|
||||||
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
||||||
else:
|
else:
|
||||||
# 获取最近的5条消息内容用于判断
|
# 获取最近的20条消息内容用于判断
|
||||||
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=20)
|
||||||
messages_text = ""
|
messages_text = ""
|
||||||
if recent_messages:
|
if recent_messages:
|
||||||
messages_text = await message_api.build_readable_messages(
|
messages_text = await message_api.build_readable_messages(
|
||||||
@@ -164,8 +167,15 @@ class EmojiAction(BaseAction):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 构建prompt让LLM选择情感
|
# 构建prompt让LLM选择情感
|
||||||
|
prompt_addition = ""
|
||||||
|
if main_reply_content:
|
||||||
|
prompt_addition = f"""
|
||||||
|
这是你刚刚生成、准备发送的消息:
|
||||||
|
"{main_reply_content}"
|
||||||
|
"""
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
你是一个正在进行聊天的网友,你需要根据一个理由和最近的聊天记录,从一个情感标签列表中选择最匹配的一个。
|
你是一个正在进行聊天的网友,你需要根据一个理由、最近的聊天记录以及你自己将要发送的消息,从一个情感标签列表中选择最匹配的一个。
|
||||||
|
{prompt_addition}
|
||||||
这是最近的聊天记录:
|
这是最近的聊天记录:
|
||||||
{messages_text}
|
{messages_text}
|
||||||
|
|
||||||
@@ -174,10 +184,8 @@ class EmojiAction(BaseAction):
|
|||||||
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
|
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
assert global_config is not None
|
||||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||||
else:
|
|
||||||
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
models = llm_api.get_available_models()
|
models = llm_api.get_available_models()
|
||||||
@@ -211,10 +219,11 @@ class EmojiAction(BaseAction):
|
|||||||
)
|
)
|
||||||
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
||||||
|
|
||||||
elif global_config.emoji.emoji_selection_mode == "description":
|
assert global_config is not None
|
||||||
|
if global_config.emoji.emoji_selection_mode == "description":
|
||||||
# --- 详细描述选择模式 ---
|
# --- 详细描述选择模式 ---
|
||||||
# 获取最近的5条消息内容用于判断
|
# 获取最近的5条消息内容用于判断
|
||||||
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=20)
|
||||||
messages_text = ""
|
messages_text = ""
|
||||||
if recent_messages:
|
if recent_messages:
|
||||||
messages_text = await message_api.build_readable_messages(
|
messages_text = await message_api.build_readable_messages(
|
||||||
@@ -234,8 +243,15 @@ class EmojiAction(BaseAction):
|
|||||||
emoji_descriptions = [extract_refined_info(desc) for _, desc in all_emojis_data]
|
emoji_descriptions = [extract_refined_info(desc) for _, desc in all_emojis_data]
|
||||||
|
|
||||||
# 构建prompt让LLM选择描述
|
# 构建prompt让LLM选择描述
|
||||||
|
prompt_addition = ""
|
||||||
|
if main_reply_content:
|
||||||
|
prompt_addition = f"""
|
||||||
|
这是你刚刚生成、准备发送的消息:
|
||||||
|
"{main_reply_content}"
|
||||||
|
"""
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
你是一个正在进行聊天的网友,你需要根据一个理由和最近的聊天记录,从一个表情包描述列表中选择最匹配的一个。
|
你是一个正在进行聊天的网友,你需要根据一个理由、最近的聊天记录以及你自己将要发送的消息,从一个表情包描述列表中选择最匹配的一个。
|
||||||
|
{prompt_addition}
|
||||||
这是最近的聊天记录:
|
这是最近的聊天记录:
|
||||||
{messages_text}
|
{messages_text}
|
||||||
|
|
||||||
@@ -264,44 +280,22 @@ class EmojiAction(BaseAction):
|
|||||||
chosen_emotion = chosen_description # 在描述模式下,用描述作为情感标签
|
chosen_emotion = chosen_description # 在描述模式下,用描述作为情感标签
|
||||||
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
|
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
|
||||||
|
|
||||||
# 优化匹配逻辑:优先在精炼描述中精确匹配,然后进行关键词匹配
|
# 使用更鲁棒的子字符串匹配逻辑
|
||||||
def extract_refined_info(full_desc: str) -> str:
|
|
||||||
return full_desc.split(" Desc:")[0].strip()
|
|
||||||
|
|
||||||
# 1. 尝试在精炼描述中找到最匹配的表情
|
|
||||||
# 我们假设LLM返回的是精炼描述的一部分或全部
|
|
||||||
matched_emoji = None
|
matched_emoji = None
|
||||||
best_match_score = 0
|
for b64, desc in all_emojis_data:
|
||||||
|
# 检查LLM返回的描述是否是数据库中某个表情完整描述的一部分
|
||||||
for item in all_emojis_data:
|
if chosen_description in desc:
|
||||||
refined_info = extract_refined_info(item[1])
|
matched_emoji = (b64, desc)
|
||||||
# 计算一个简单的匹配分数
|
break
|
||||||
score = 0
|
|
||||||
if chosen_description.lower() in refined_info.lower():
|
|
||||||
score += 2 # 包含匹配
|
|
||||||
if refined_info.lower() in chosen_description.lower():
|
|
||||||
score += 2 # 包含匹配
|
|
||||||
|
|
||||||
# 关键词匹配加分
|
|
||||||
chosen_keywords = re.findall(r"\w+", chosen_description.lower())
|
|
||||||
item_keywords = re.findall(r"\[(.*?)\]", refined_info)
|
|
||||||
if item_keywords:
|
|
||||||
item_keywords_set = {k.strip().lower() for k in item_keywords[0].split(",")}
|
|
||||||
for kw in chosen_keywords:
|
|
||||||
if kw in item_keywords_set:
|
|
||||||
score += 1
|
|
||||||
|
|
||||||
if score > best_match_score:
|
|
||||||
best_match_score = score
|
|
||||||
matched_emoji = item
|
|
||||||
|
|
||||||
if matched_emoji:
|
if matched_emoji:
|
||||||
emoji_base64, emoji_description = matched_emoji
|
emoji_base64, emoji_description = matched_emoji
|
||||||
logger.info(f"{self.log_prefix} 找到匹配描述的表情包: {extract_refined_info(emoji_description)}")
|
logger.info(f"{self.log_prefix} 找到匹配描述的表情包: {emoji_description}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix} LLM选择的描述无法匹配任何表情包, 将随机选择")
|
logger.warning(f"{self.log_prefix} LLM选择的描述无法匹配任何表情包, 将随机选择")
|
||||||
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
||||||
else:
|
else:
|
||||||
|
assert global_config is not None
|
||||||
logger.error(f"{self.log_prefix} 无效的表情选择模式: {global_config.emoji.emoji_selection_mode}")
|
logger.error(f"{self.log_prefix} 无效的表情选择模式: {global_config.emoji.emoji_selection_mode}")
|
||||||
return False, "无效的表情选择模式"
|
return False, "无效的表情选择模式"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user