修复代码格式和文件名大小写问题

This commit is contained in:
Windpicker-owo
2025-08-31 20:50:17 +08:00
parent df29014e41
commit 8149731925
218 changed files with 6913 additions and 8257 deletions

View File

@@ -260,89 +260,105 @@ def get_actions_by_timestamp_with_chat(
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
from src.common.logger import get_logger
logger = get_logger("chat_message_builder")
# 记录函数调用参数
logger.debug(f"[get_actions_by_timestamp_with_chat] 调用参数: chat_id={chat_id}, "
f"timestamp_start={timestamp_start}, timestamp_end={timestamp_end}, "
f"limit={limit}, limit_mode={limit_mode}")
logger.debug(
f"[get_actions_by_timestamp_with_chat] 调用参数: chat_id={chat_id}, "
f"timestamp_start={timestamp_start}, timestamp_end={timestamp_end}, "
f"limit={limit}, limit_mode={limit_mode}"
)
with get_db_session() as session:
if limit > 0:
if limit_mode == "latest":
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
query = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
).order_by(ActionRecords.time.desc()).limit(limit))
.order_by(ActionRecords.time.desc())
.limit(limit)
)
actions = list(query.scalars())
actions_result = []
for action in reversed(actions):
action_dict = {
'id': action.id,
'action_id': action.action_id,
'time': action.time,
'action_name': action.action_name,
'action_data': action.action_data,
'action_done': action.action_done,
'action_build_into_prompt': action.action_build_into_prompt,
'action_prompt_display': action.action_prompt_display,
'chat_id': action.chat_id,
'chat_info_stream_id': action.chat_info_stream_id,
'chat_info_platform': action.chat_info_platform,
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
else: # earliest
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
query = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
).order_by(ActionRecords.time.asc()).limit(limit))
.order_by(ActionRecords.time.asc())
.limit(limit)
)
actions = list(query.scalars())
actions_result = []
for action in actions:
action_dict = {
'id': action.id,
'action_id': action.action_id,
'time': action.time,
'action_name': action.action_name,
'action_data': action.action_data,
'action_done': action.action_done,
'action_build_into_prompt': action.action_build_into_prompt,
'action_prompt_display': action.action_prompt_display,
'chat_id': action.chat_id,
'chat_info_stream_id': action.chat_info_stream_id,
'chat_info_platform': action.chat_info_platform,
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
else:
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end
query = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
).order_by(ActionRecords.time.asc()))
.order_by(ActionRecords.time.asc())
)
actions = list(query.scalars())
actions_result = []
for action in actions:
action_dict = {
'id': action.id,
'action_id': action.action_id,
'time': action.time,
'action_name': action.action_name,
'action_data': action.action_data,
'action_done': action.action_done,
'action_build_into_prompt': action.action_build_into_prompt,
'action_prompt_display': action.action_prompt_display,
'chat_id': action.chat_id,
'chat_info_stream_id': action.chat_info_stream_id,
'chat_info_platform': action.chat_info_platform,
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
return actions_result
@@ -355,31 +371,45 @@ def get_actions_by_timestamp_with_chat_inclusive(
with get_db_session() as session:
if limit > 0:
if limit_mode == "latest":
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
query = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end,
)
)
).order_by(ActionRecords.time.desc()).limit(limit))
.order_by(ActionRecords.time.desc())
.limit(limit)
)
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = session.execute(select(ActionRecords).where(
query = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end,
)
)
.order_by(ActionRecords.time.asc())
.limit(limit)
)
else:
query = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
ActionRecords.time <= timestamp_end,
)
).order_by(ActionRecords.time.asc()).limit(limit))
else:
query = session.execute(select(ActionRecords).where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end
)
).order_by(ActionRecords.time.asc()))
.order_by(ActionRecords.time.asc())
)
actions = list(query.scalars())
return [action.__dict__ for action in actions]
@@ -782,7 +812,6 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 按图片编号排序
sorted_items = sorted(pic_id_mapping.items(), key=lambda x: int(x[1].replace("图片", "")))
for pic_id, display_name in sorted_items:
# 从数据库中获取图片描述
description = "内容正在阅读,请稍等"
@@ -791,7 +820,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar()
if image and image.description:
description = image.description
except Exception: ...
except Exception:
...
# 如果查询失败,保持默认描述
mapping_lines.append(f"[{display_name}] 的内容:{description}")
@@ -811,17 +841,18 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
格式化的动作字符串。
"""
from src.common.logger import get_logger
logger = get_logger("chat_message_builder")
logger.debug(f"[build_readable_actions] 开始处理 {len(actions) if actions else 0} 条动作记录")
if not actions:
logger.debug("[build_readable_actions] 动作记录为空,返回空字符串")
return ""
output_lines = []
current_time = time.time()
logger.debug(f"[build_readable_actions] 当前时间戳: {current_time}")
# The get functions return actions sorted ascending by time. Let's reverse it to show newest first.
@@ -830,12 +861,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
for i, action in enumerate(actions):
logger.debug(f"[build_readable_actions] === 处理第 {i} 条动作记录 ===")
logger.debug(f"[build_readable_actions] 原始动作数据: {action}")
action_time = action.get("time", current_time)
action_name = action.get("action_name", "未知动作")
logger.debug(f"[build_readable_actions] 动作时间戳: {action_time}, 动作名称: '{action_name}'")
# 检查是否是原始的 action_name 值
original_action_name = action.get("action_name")
if original_action_name is None:
@@ -844,7 +875,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
logger.error(f"[build_readable_actions] 动作 #{i}: action_name 为空字符串!")
elif original_action_name == "未知动作":
logger.error(f"[build_readable_actions] 动作 #{i}: action_name 已经是'未知动作'!")
if action_name in ["no_action", "no_reply"]:
logger.debug(f"[build_readable_actions] 跳过动作 #{i}: {action_name} (在跳过列表中)")
continue
@@ -863,7 +894,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
logger.debug(f"[build_readable_actions] 时间描述: '{time_ago_str}'")
line = f"{time_ago_str},你使用了\"{action_name}\",具体内容是:\"{action_prompt_display}\""
line = f'{time_ago_str},你使用了"{action_name}",具体内容是:"{action_prompt_display}"'
logger.debug(f"[build_readable_actions] 生成的行: '{line}'")
output_lines.append(line)
@@ -964,23 +995,26 @@ def build_readable_messages(
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
from src.common.database.sqlalchemy_database_api import get_db_session
with get_db_session() as session:
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = session.execute(select(ActionRecords).where(
and_(
ActionRecords.time >= min_time,
ActionRecords.time <= max_time,
ActionRecords.chat_id == chat_id
actions_in_range = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
)
)
).order_by(ActionRecords.time)).scalars()
.order_by(ActionRecords.time)
).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = session.execute(select(ActionRecords).where(
and_(
ActionRecords.time > max_time,
ActionRecords.chat_id == chat_id
)
).order_by(ActionRecords.time).limit(1)).scalars()
action_after_latest = session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
).scalars()
# 合并两部分动作记录,并转为 dict避免 DetachedInstanceError
actions = [

View File

@@ -12,6 +12,7 @@ install(extra_lines=3)
logger = get_logger("prompt_build")
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
@@ -27,7 +28,7 @@ class PromptContext:
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
@@ -51,7 +52,7 @@ class PromptContext:
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context
# 设置当前协程的新上下文
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
@@ -69,7 +70,8 @@ class PromptContext:
# 如果reset失败尝试直接设置
try:
self._current_context = previous_context
except Exception: ...
except Exception:
...
# 静默忽略恢复失败
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
@@ -174,7 +176,9 @@ class Prompt(str):
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def __new__(cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs):
def __new__(
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
):
# 如果传入的是元组,转换为列表
if isinstance(args, tuple):
args = list(args)
@@ -219,7 +223,9 @@ class Prompt(str):
return prompt
@classmethod
def _format_template(cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None) -> str:
def _format_template(
cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None
) -> str:
if kwargs is None:
kwargs = {}
# 预处理模板中的转义花括号

View File

@@ -2,6 +2,7 @@
智能提示词参数模块 - 优化参数结构
简化SmartPromptParameters减少冗余和重复
"""
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal
@@ -9,6 +10,7 @@ from typing import Dict, Any, Optional, List, Literal
@dataclass
class SmartPromptParameters:
"""简化的智能提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
@@ -17,7 +19,7 @@ class SmartPromptParameters:
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
@@ -25,20 +27,20 @@ class SmartPromptParameters:
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
@@ -46,7 +48,7 @@ class SmartPromptParameters:
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
@@ -57,7 +59,10 @@ class SmartPromptParameters:
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
def validate(self) -> List[str]:
"""统一的参数验证"""
errors = []
@@ -68,39 +73,39 @@ class SmartPromptParameters:
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
def get_needed_build_tasks(self) -> List[str]:
"""获取需要执行的任务列表"""
tasks = []
if self.enable_expression and not self.expression_habits_block:
tasks.append("expression_habits")
if self.enable_memory and not self.memory_block:
tasks.append("memory_block")
if self.enable_relation and not self.relation_info_block:
tasks.append("relation_info")
if self.enable_tool and not self.tool_info_block:
tasks.append("tool_info")
if self.enable_knowledge and not self.knowledge_prompt:
tasks.append("knowledge_info")
if self.enable_cross_context and not self.cross_context_block:
tasks.append("cross_context")
return tasks
@classmethod
def from_legacy_params(cls, **kwargs) -> 'SmartPromptParameters':
def from_legacy_params(cls, **kwargs) -> "SmartPromptParameters":
"""
从旧版参数创建新参数对象
Args:
**kwargs: 旧版参数
Returns:
SmartPromptParameters: 新参数对象
"""
@@ -113,7 +118,6 @@ class SmartPromptParameters:
reply_to=kwargs.get("reply_to", ""),
extra_info=kwargs.get("extra_info", ""),
prompt_mode=kwargs.get("current_prompt_mode", "s4u"),
# 功能开关
enable_tool=kwargs.get("enable_tool", True),
enable_memory=kwargs.get("enable_memory", True),
@@ -121,18 +125,15 @@ class SmartPromptParameters:
enable_relation=kwargs.get("enable_relation", True),
enable_cross_context=kwargs.get("enable_cross_context", True),
enable_knowledge=kwargs.get("enable_knowledge", True),
# 性能控制
max_context_messages=kwargs.get("max_context_messages", 50),
debug_mode=kwargs.get("debug_mode", False),
# 聊天历史和上下文
chat_target_info=kwargs.get("chat_target_info"),
message_list_before_now_long=kwargs.get("message_list_before_now_long", []),
message_list_before_short=kwargs.get("message_list_before_short", []),
chat_talking_prompt_short=kwargs.get("chat_talking_prompt_short", ""),
target_user_info=kwargs.get("target_user_info"),
# 已构建的内容块
expression_habits_block=kwargs.get("expression_habits_block", ""),
relation_info_block=kwargs.get("relation_info", ""),
@@ -140,7 +141,6 @@ class SmartPromptParameters:
tool_info_block=kwargs.get("tool_info", ""),
knowledge_prompt=kwargs.get("knowledge_prompt", ""),
cross_context_block=kwargs.get("cross_context_block", ""),
# 其他内容块
keywords_reaction_prompt=kwargs.get("keywords_reaction_prompt", ""),
extra_info_block=kwargs.get("extra_info_block", ""),
@@ -151,4 +151,6 @@ class SmartPromptParameters:
reply_target_block=kwargs.get("reply_target_block", ""),
mood_prompt=kwargs.get("mood_prompt", ""),
action_descriptions=kwargs.get("action_descriptions", ""),
)
# 可用动作信息
available_actions=kwargs.get("available_actions", None),
)

View File

@@ -2,16 +2,14 @@
共享提示词工具模块 - 消除重复代码
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
"""
import re
import time
import asyncio
from typing import Dict, Any, List, Optional, Tuple, Union
from datetime import datetime
from typing import Dict, Any, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
build_readable_messages_with_id,
)
@@ -23,25 +21,25 @@ logger = get_logger("prompt_utils")
class PromptUtils:
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
"""
解析回复目标消息 - 统一实现
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = ""
target = ""
# 添加None检查防止NoneType错误
if target_message is None:
return sender, target
if ":" in target_message or "" in target_message:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r"[:]", string=target_message, maxsplit=1)
@@ -49,16 +47,16 @@ class PromptUtils:
sender = parts[0].strip()
target = parts[1].strip()
return sender, target
@staticmethod
async def build_relation_info(chat_id: str, reply_to: str) -> str:
"""
构建关系信息 - 统一实现
Args:
chat_id: 聊天ID
reply_to: 回复目标字符串
Returns:
str: 关系信息字符串
"""
@@ -66,8 +64,9 @@ class PromptUtils:
return ""
from src.person_info.relationship_fetcher import relationship_fetcher_manager
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
if not reply_to:
return ""
sender, text = PromptUtils.parse_reply_target(reply_to)
@@ -82,21 +81,19 @@ class PromptUtils:
return f"你完全不认识{sender}不理解ta的相关信息。"
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str,
target_user_info: Optional[Dict[str, Any]],
current_prompt_mode: str
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
) -> str:
"""
构建跨群聊上下文 - 统一实现完全继承DefaultReplyer功能
Args:
chat_id: 当前聊天ID
target_user_info: 目标用户信息
current_prompt_mode: 当前提示模式
Returns:
str: 跨群上下文块
"""
@@ -108,7 +105,7 @@ class PromptUtils:
current_stream = get_chat_manager().get_stream(chat_id)
if not current_stream or not current_stream.group_info:
return ""
try:
current_chat_raw_id = current_stream.group_info.group_id
except Exception as e:
@@ -144,7 +141,7 @@ class PromptUtils:
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f"[以下是来自\"{chat_name}\"的近期消息]\n{formatted_messages}")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}")
continue
@@ -175,14 +172,15 @@ class PromptUtils:
if user_messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
user_name = (
target_user_info.get("person_name") or
target_user_info.get("user_nickname") or user_id
target_user_info.get("person_name")
or target_user_info.get("user_nickname")
or user_id
)
formatted_messages, _ = build_readable_messages_with_id(
user_messages, timestamp_mode="relative"
)
cross_context_messages.append(
f"[以下是\"{user_name}\"\"{chat_name}\"的近期发言]\n{formatted_messages}"
f'[以下是"{user_name}""{chat_name}"的近期发言]\n{formatted_messages}'
)
except Exception as e:
logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}")
@@ -192,31 +190,31 @@ class PromptUtils:
return ""
return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
@staticmethod
def parse_reply_target_id(reply_to: str) -> str:
"""
解析回复目标中的用户ID
Args:
reply_to: 回复目标字符串
Returns:
str: 用户ID
"""
if not reply_to:
return ""
# 复用parse_reply_target方法的逻辑
sender, _ = PromptUtils.parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
return ""
return ""

File diff suppressed because it is too large Load Diff

View File

@@ -13,45 +13,45 @@ from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
# 同步包装器函数用于在非异步环境中调用异步数据库API
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
"""同步版本的db_get用于在线程池中调用"""
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建新的事件循环
import threading
result = None
exception = None
def run_in_thread():
nonlocal result, exception
try:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
result = new_loop.run_until_complete(
db_get(model_class, filters, limit, order_by, single_result)
)
result = new_loop.run_until_complete(db_get(model_class, filters, limit, order_by, single_result))
new_loop.close()
except Exception as e:
exception = e
thread = threading.Thread(target=run_in_thread)
thread.start()
thread.join()
if exception:
raise exception
return result
else:
return loop.run_until_complete(
db_get(model_class, filters, limit, order_by, single_result)
)
return loop.run_until_complete(db_get(model_class, filters, limit, order_by, single_result))
except RuntimeError:
# 没有事件循环,创建一个新的
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
# 统计数据的键
TOTAL_REQ_CNT = "total_requests"
TOTAL_COST = "total_cost"
@@ -112,7 +112,7 @@ class OnlineTimeRecordTask(AsyncTask):
model_class=OnlineTime,
query_type="update",
filters={"id": self.record_id},
data={"end_timestamp": extended_end_time}
data={"end_timestamp": extended_end_time},
)
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
@@ -126,17 +126,17 @@ class OnlineTimeRecordTask(AsyncTask):
filters={"end_timestamp": {"$gte": recent_threshold}},
order_by="-end_timestamp",
limit=1,
single_result=True
single_result=True,
)
if recent_records:
# 找到近期记录,更新它
self.record_id = recent_records['id']
self.record_id = recent_records["id"]
await db_query(
model_class=OnlineTime,
query_type="update",
filters={"id": self.record_id},
data={"end_timestamp": extended_end_time}
data={"end_timestamp": extended_end_time},
)
else:
# 创建新记录
@@ -147,10 +147,10 @@ class OnlineTimeRecordTask(AsyncTask):
"duration": 5, # 初始时长为5分钟
"start_timestamp": current_time,
"end_timestamp": extended_end_time,
}
},
)
if new_record:
self.record_id = new_record['id']
self.record_id = new_record["id"]
except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -368,20 +368,19 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录
query_start_time = collect_period[-1][1]
records = _sync_db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp"
) or []
records = (
_sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp")
or []
)
for record in records:
if not isinstance(record, dict):
continue
record_timestamp = record.get('timestamp')
record_timestamp = record.get("timestamp")
if isinstance(record_timestamp, str):
record_timestamp = datetime.fromisoformat(record_timestamp)
if not record_timestamp:
continue
@@ -390,9 +389,9 @@ class StatisticOutputTask(AsyncTask):
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get('request_type') or "unknown"
user_id = record.get('user_id') or "unknown"
model_name = record.get('model_name') or "unknown"
request_type = record.get("request_type") or "unknown"
user_id = record.get("user_id") or "unknown"
model_name = record.get("model_name") or "unknown"
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
module_name = request_type.split(".")[0] if "." in request_type else request_type
@@ -402,8 +401,8 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
prompt_tokens = record.get('prompt_tokens') or 0
completion_tokens = record.get('completion_tokens') or 0
prompt_tokens = record.get("prompt_tokens") or 0
completion_tokens = record.get("completion_tokens") or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
@@ -421,40 +420,40 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
cost = record.get('cost') or 0.0
cost = record.get("cost") or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
stats[period_key][COST_BY_MODULE][module_name] += cost
# 收集time_cost数据
time_cost = record.get('time_cost') or 0.0
time_cost = record.get("time_cost") or 0.0
if time_cost > 0: # 只记录有效的time_cost
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
break
# 计算平均耗时和标准差
# 计算平均耗时和标准差
for period_key in stats:
for category in [REQ_CNT_BY_TYPE, REQ_CNT_BY_USER, REQ_CNT_BY_MODEL, REQ_CNT_BY_MODULE]:
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
for item_name in stats[period_key][category]:
time_costs = stats[period_key][time_cost_key].get(item_name, [])
if time_costs:
# 计算平均耗时
avg_time_cost = sum(time_costs) / len(time_costs)
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
# 计算标准差
if len(time_costs) > 1:
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
std_time_cost = variance ** 0.5
std_time_cost = variance**0.5
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
else:
stats[period_key][std_key][item_name] = 0.0
@@ -483,21 +482,22 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
records = _sync_db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp"
) or []
records = (
_sync_db_get(
model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp"
)
or []
)
for record in records:
if not isinstance(record, dict):
continue
record_end_timestamp = record.get('end_timestamp')
record_end_timestamp = record.get("end_timestamp")
if isinstance(record_end_timestamp, str):
record_end_timestamp = datetime.fromisoformat(record_end_timestamp)
record_start_timestamp = record.get('start_timestamp')
record_start_timestamp = record.get("start_timestamp")
if isinstance(record_start_timestamp, str):
record_start_timestamp = datetime.fromisoformat(record_start_timestamp)
@@ -539,16 +539,15 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = _sync_db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time"
) or []
records = (
_sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time")
or []
)
for message in records:
if not isinstance(message, dict):
continue
message_time_ts = message.get('time') # This is a float timestamp
message_time_ts = message.get("time") # This is a float timestamp
if not message_time_ts:
continue
@@ -557,18 +556,16 @@ class StatisticOutputTask(AsyncTask):
chat_name = None
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
if message.get('chat_info_group_id'):
if message.get("chat_info_group_id"):
chat_id = f"g{message['chat_info_group_id']}"
chat_name = message.get('chat_info_group_name') or f"{message['chat_info_group_id']}"
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
chat_name = message.get("chat_info_group_name") or f"{message['chat_info_group_id']}"
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message['user_id']}" # SENDER's user_id
chat_name = message.get('user_nickname') # SENDER's nickname
chat_name = message.get("user_nickname") # SENDER's nickname
else:
# If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
)
logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.")
continue
if not chat_id: # Should not happen if above logic is correct
@@ -589,8 +586,6 @@ class StatisticOutputTask(AsyncTask):
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
"""
收集各时间段的统计数据
@@ -721,7 +716,9 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
output.append(
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
)
output.append("")
return "\n".join(output)
@@ -1109,13 +1106,11 @@ class StatisticOutputTask(AsyncTask):
# 查询LLM使用记录
query_start_time = start_time
records = _sync_db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp"
model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp"
)
for record in records:
record_time = record['timestamp']
record_time = record["timestamp"]
# 找到对应的时间间隔索引
time_diff = (record_time - start_time).total_seconds()
@@ -1123,17 +1118,17 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 累加总花费数据
cost = record.get('cost') or 0.0
cost = record.get("cost") or 0.0
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
model_name = record.get('model_name') or "unknown"
model_name = record.get("model_name") or "unknown"
if model_name not in cost_by_model:
cost_by_model[model_name] = [0] * len(time_points)
cost_by_model[model_name][interval_index] += cost
# 累加按模块分类的花费
request_type = record.get('request_type') or "unknown"
request_type = record.get("request_type") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
if module_name not in cost_by_module:
cost_by_module[module_name] = [0] * len(time_points)
@@ -1142,13 +1137,11 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
records = _sync_db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time"
model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time"
)
for message in records:
message_time_ts = message['time']
message_time_ts = message["time"]
# 找到对应的时间间隔索引
time_diff = message_time_ts - query_start_timestamp
@@ -1157,10 +1150,10 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 确定聊天流名称
chat_name = None
if message.get('chat_info_group_id'):
chat_name = message.get('chat_info_group_name') or f"{message['chat_info_group_id']}"
elif message.get('user_id'):
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
if message.get("chat_info_group_id"):
chat_name = message.get("chat_info_group_name") or f"{message['chat_info_group_id']}"
elif message.get("user_id"):
chat_name = message.get("user_nickname") or f"用户{message['user_id']}"
else:
continue

View File

@@ -73,9 +73,7 @@ class ChineseTypoGenerator:
# 保存到缓存文件
with open(cache_file, "w", encoding="utf-8") as f:
f.write(orjson.dumps(
normalized_freq, option=orjson.OPT_INDENT_2).decode('utf-8')
)
f.write(orjson.dumps(normalized_freq, option=orjson.OPT_INDENT_2).decode("utf-8"))
return normalized_freq

View File

@@ -669,10 +669,10 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
"""
为消息列表中的每个消息分配唯一的简短随机ID
Args:
messages: 消息列表
Returns:
包含 {'id': str, 'message': any} 格式的字典列表
"""
@@ -685,47 +685,41 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
else:
a = 1
b = 9
for i, message in enumerate(messages):
# 生成唯一的简短ID
while True:
# 使用索引+随机数生成简短ID
random_suffix = random.randint(a, b)
message_id = f"m{i+1}{random_suffix}"
message_id = f"m{i + 1}{random_suffix}"
if message_id not in used_ids:
used_ids.add(message_id)
break
result.append({
'id': message_id,
'message': message
})
result.append({"id": message_id, "message": message})
return result
def assign_message_ids_flexible(
messages: list,
prefix: str = "msg",
id_length: int = 6,
use_timestamp: bool = False
messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
) -> list:
"""
为消息列表中的每个消息分配唯一的简短随机ID增强版
Args:
messages: 消息列表
prefix: ID前缀默认为"msg"
id_length: ID的总长度不包括前缀默认为6
use_timestamp: 是否在ID中包含时间戳默认为False
Returns:
包含 {'id': str, 'message': any} 格式的字典列表
"""
result = []
used_ids = set()
for i, message in enumerate(messages):
# 生成唯一的ID
while True:
@@ -733,38 +727,35 @@ def assign_message_ids_flexible(
# 使用时间戳的后几位 + 随机字符
timestamp_suffix = str(int(time.time() * 1000))[-3:]
remaining_length = id_length - 3
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
message_id = f"{prefix}{timestamp_suffix}{random_chars}"
else:
# 使用索引 + 随机字符
index_str = str(i + 1)
remaining_length = max(1, id_length - len(index_str))
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
message_id = f"{prefix}{index_str}{random_chars}"
if message_id not in used_ids:
used_ids.add(message_id)
break
result.append({
'id': message_id,
'message': message
})
result.append({"id": message_id, "message": message})
return result
# 使用示例:
# messages = ["Hello", "World", "Test message"]
#
#
# # 基础版本
# result1 = assign_message_ids(messages)
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
#
#
# # 增强版本 - 自定义前缀和长度
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
#
#
# # 增强版本 - 使用时间戳
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]

View File

@@ -18,6 +18,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import get_db_session
from sqlalchemy import select, and_
install(extra_lines=3)
logger = get_logger("chat_image")
@@ -66,9 +67,14 @@ class ImageManager:
"""
try:
with get_db_session() as session:
record = session.execute(select(ImageDescriptions).where(
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
)).scalar()
record = session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
@@ -87,9 +93,14 @@ class ImageManager:
current_timestamp = time.time()
with get_db_session() as session:
# 查找现有记录
existing = session.execute(select(ImageDescriptions).where(
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
)).scalar()
existing = session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
).scalar()
if existing:
# 更新现有记录
@@ -101,16 +112,17 @@ class ImageManager:
image_description_hash=image_hash,
type=description_type,
description=description,
timestamp=current_timestamp
timestamp=current_timestamp,
)
session.add(new_desc)
session.commit()
# 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
@@ -135,6 +147,7 @@ class ImageManager:
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
if cached_emoji_description:
@@ -228,10 +241,11 @@ class ImageManager:
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
existing_img = session.execute(select(Images).where(
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
)).scalar()
existing_img = session.execute(
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
).scalar()
if existing_img:
existing_img.path = file_path
@@ -324,7 +338,7 @@ class ImageManager:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
new_img = Images(
@@ -338,7 +352,7 @@ class ImageManager:
count=1,
)
session.add(new_img)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
@@ -381,7 +395,8 @@ class ImageManager:
# 确保是RGB格式方便比较
frame = gif.convert("RGB")
all_frames.append(frame.copy())
except EOFError: ... # 读完啦
except EOFError:
... # 读完啦
if not all_frames:
logger.warning("GIF中没有找到任何帧")
@@ -511,7 +526,7 @@ class ImageManager:
existing_image.vlm_processed = False
existing_image.count += 1
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
# print(f"图片不存在: {image_hash}")
@@ -569,19 +584,23 @@ class ImageManager:
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = session.execute(select(Images).where(
and_(
Images.emoji_hash == image_hash,
Images.description.isnot(None),
Images.description != "",
Images.id != image.id
existing_with_description = session.execute(
select(Images).where(
and_(
Images.emoji_hash == image_hash,
Images.description.isnot(None),
Images.description != "",
Images.id != image.id,
)
)
)).scalar()
).scalar()
if existing_with_description:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
logger.debug(
f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}..."
)
image.description = existing_with_description.description
image.vlm_processed = True
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
@@ -591,7 +610,7 @@ class ImageManager:
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
return
# 获取图片格式

File diff suppressed because it is too large Load Diff

View File

@@ -8,32 +8,30 @@
import os
import cv2
import tempfile
import asyncio
import base64
import hashlib
import time
import numpy as np
from PIL import Image
from pathlib import Path
from typing import List, Tuple, Optional, Dict
from typing import List, Tuple, Optional
import io
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import get_db_session, Videos
logger = get_logger("utils_video_legacy")
def _extract_frames_worker(video_path: str,
max_frames: int,
frame_quality: int,
max_image_size: int,
frame_extraction_mode: str,
frame_interval_seconds: Optional[float]) -> List[Tuple[str, float]]:
def _extract_frames_worker(
video_path: str,
max_frames: int,
frame_quality: int,
max_image_size: int,
frame_extraction_mode: str,
frame_interval_seconds: Optional[float],
) -> List[Tuple[str, float]]:
"""线程池中提取视频帧的工作函数"""
frames = []
try:
@@ -41,42 +39,42 @@ def _extract_frames_worker(video_path: str,
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
if frame_extraction_mode == "time_interval":
# 新模式:按时间间隔抽帧
time_interval = frame_interval_seconds
next_frame_time = 0.0
extracted_count = 0 # 初始化提取帧计数器
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
if current_time >= next_frame_time:
# 转换为PIL图像并压缩
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
# 调整图像大小
if max(pil_image.size) > max_image_size:
ratio = max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为base64
buffer = io.BytesIO()
pil_image.save(buffer, format='JPEG', quality=frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
pil_image.save(buffer, format="JPEG", quality=frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
frames.append((frame_base64, current_time))
extracted_count += 1
# 注意这里不能使用logger因为在线程池中
# logger.debug(f"提取第{extracted_count}帧 (时间: {current_time:.2f}s)")
next_frame_time += time_interval
else:
# 使用numpy优化帧间隔计算
@@ -84,49 +82,49 @@ def _extract_frames_worker(video_path: str,
frame_interval = max(1, int(duration / max_frames * fps))
else:
frame_interval = 30 # 默认间隔
# 使用numpy计算目标帧位置
target_frames = np.arange(0, min(max_frames, total_frames // frame_interval + 1)) * frame_interval
target_frames = target_frames[target_frames < total_frames].astype(int)
for target_frame in target_frames:
# 跳转到目标帧
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
ret, frame = cap.read()
if not ret:
continue
# 使用numpy优化图像处理
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转换为PIL图像并使用numpy进行尺寸计算
height, width = frame_rgb.shape[:2]
max_dim = max(height, width)
if max_dim > max_image_size:
# 使用numpy计算缩放比例
ratio = max_image_size / max_dim
new_width = int(width * ratio)
new_height = int(height * ratio)
# 使用opencv进行高效缩放
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
pil_image = Image.fromarray(frame_resized)
else:
pil_image = Image.fromarray(frame_rgb)
# 转换为base64
buffer = io.BytesIO()
pil_image.save(buffer, format='JPEG', quality=frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
pil_image.save(buffer, format="JPEG", quality=frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# 计算时间戳
timestamp = target_frame / fps if fps > 0 else 0
frames.append((frame_base64, timestamp))
cap.release()
return frames
except Exception as e:
# 返回错误信息
return [("ERROR", str(e))]
@@ -140,38 +138,39 @@ class LegacyVideoAnalyzer:
# 使用专用的视频分析配置
try:
self.video_llm = LLMRequest(
model_set=model_config.model_task_config.video_analysis,
request_type="video_analysis"
model_set=model_config.model_task_config.video_analysis, request_type="video_analysis"
)
logger.info("✅ 使用video_analysis模型配置")
except (AttributeError, KeyError) as e:
# 如果video_analysis不存在使用vlm配置
self.video_llm = LLMRequest(
model_set=model_config.model_task_config.vlm,
request_type="vlm"
)
self.video_llm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="vlm")
logger.warning(f"video_analysis配置不可用({e})回退使用vlm配置")
# 从配置文件读取参数,如果配置不存在则使用默认值
config = global_config.video_analysis
# 使用 getattr 统一获取配置参数,如果配置不存在则使用默认值
self.max_frames = getattr(config, 'max_frames', 6)
self.frame_quality = getattr(config, 'frame_quality', 85)
self.max_image_size = getattr(config, 'max_image_size', 600)
self.enable_frame_timing = getattr(config, 'enable_frame_timing', True)
self.max_frames = getattr(config, "max_frames", 6)
self.frame_quality = getattr(config, "frame_quality", 85)
self.max_image_size = getattr(config, "max_image_size", 600)
self.enable_frame_timing = getattr(config, "enable_frame_timing", True)
# 从personality配置中获取人格信息
try:
personality_config = global_config.personality
self.personality_core = getattr(personality_config, 'personality_core', "是一个积极向上的女大学生")
self.personality_side = getattr(personality_config, 'personality_side', "用一句话或几句话描述人格的侧面特点")
self.personality_core = getattr(personality_config, "personality_core", "是一个积极向上的女大学生")
self.personality_side = getattr(
personality_config, "personality_side", "用一句话或几句话描述人格的侧面特点"
)
except AttributeError:
# 如果没有personality配置使用默认值
self.personality_core = "是一个积极向上的女大学生"
self.personality_side = "用一句话或几句话描述人格的侧面特点"
self.batch_analysis_prompt = getattr(config, 'batch_analysis_prompt', """请以第一人称的视角来观看这一个视频,你看到的这些是从视频中按时间顺序提取的关键帧。
self.batch_analysis_prompt = getattr(
config,
"batch_analysis_prompt",
"""请以第一人称的视角来观看这一个视频,你看到的这些是从视频中按时间顺序提取的关键帧。
你的核心人设是:{personality_core}
你的人格细节是:{personality_side}
@@ -184,16 +183,17 @@ class LegacyVideoAnalyzer:
5. 整体氛围和情感表达
6. 任何特殊的视觉效果或文字内容
请用中文回答,结果要详细准确。""")
请用中文回答,结果要详细准确。""",
)
# 新增的线程池配置
self.use_multiprocessing = getattr(config, 'use_multiprocessing', True)
self.max_workers = getattr(config, 'max_workers', 2)
self.frame_extraction_mode = getattr(config, 'frame_extraction_mode', 'fixed_number')
self.frame_interval_seconds = getattr(config, 'frame_interval_seconds', 2.0)
self.use_multiprocessing = getattr(config, "use_multiprocessing", True)
self.max_workers = getattr(config, "max_workers", 2)
self.frame_extraction_mode = getattr(config, "frame_extraction_mode", "fixed_number")
self.frame_interval_seconds = getattr(config, "frame_interval_seconds", 2.0)
# 将配置文件中的模式映射到内部使用的模式名称
config_mode = getattr(config, 'analysis_mode', 'auto')
config_mode = getattr(config, "analysis_mode", "auto")
if config_mode == "batch_frames":
self.analysis_mode = "batch"
elif config_mode == "frame_by_frame":
@@ -203,21 +203,23 @@ class LegacyVideoAnalyzer:
else:
logger.warning(f"无效的分析模式: {config_mode}使用默认的auto模式")
self.analysis_mode = "auto"
self.frame_analysis_delay = 0.3 # API调用间隔
self.frame_interval = 1.0 # 抽帧时间间隔(秒)
self.batch_size = 3 # 批处理时每批处理的帧数
self.timeout = 60.0 # 分析超时时间(秒)
if config:
logger.info("✅ 从配置文件读取视频分析参数")
else:
logger.warning("配置文件中缺少video_analysis配置使用默认值")
# 系统提示词
self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。"
logger.info(f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}")
logger.info(
f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}"
)
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
"""提取视频帧 - 支持多进程和单线程模式"""
@@ -227,18 +229,18 @@ class LegacyVideoAnalyzer:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
cap.release()
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}")
# 估算提取帧数
if duration > 0:
frame_interval = max(1, int(duration / self.max_frames * fps))
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
else:
estimated_frames = self.max_frames
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
# 根据配置选择处理方式
if self.use_multiprocessing:
return await self._extract_frames_multiprocess(video_path)
@@ -248,7 +250,7 @@ class LegacyVideoAnalyzer:
async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]:
"""线程池版本的帧提取"""
loop = asyncio.get_event_loop()
try:
logger.info("🔄 启动线程池帧提取...")
# 使用线程池,避免进程间的导入问题
@@ -261,19 +263,19 @@ class LegacyVideoAnalyzer:
self.frame_quality,
self.max_image_size,
self.frame_extraction_mode,
self.frame_interval_seconds
self.frame_interval_seconds,
)
# 检查是否有错误
if frames and frames[0][0] == "ERROR":
logger.error(f"线程池帧提取失败: {frames[0][1]}")
# 降级到单线程模式
logger.info("🔄 降级到单线程模式...")
return await self._extract_frames_fallback(video_path)
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
return frames
except Exception as e:
logger.error(f"线程池帧提取失败: {e}")
# 降级到原始方法
@@ -288,43 +290,42 @@ class LegacyVideoAnalyzer:
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}")
if self.frame_extraction_mode == "time_interval":
# 新模式:按时间间隔抽帧
time_interval = self.frame_interval_seconds
next_frame_time = 0.0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
if current_time >= next_frame_time:
# 转换为PIL图像并压缩
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
# 调整图像大小
if max(pil_image.size) > self.max_image_size:
ratio = self.max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为base64
buffer = io.BytesIO()
pil_image.save(buffer, format='JPEG', quality=self.frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
frames.append((frame_base64, current_time))
extracted_count += 1
logger.debug(f"提取第{extracted_count}帧 (时间: {current_time:.2f}s)")
next_frame_time += time_interval
else:
# 使用numpy优化帧间隔计算
@@ -332,53 +333,55 @@ class LegacyVideoAnalyzer:
frame_interval = max(1, int(duration / self.max_frames * fps))
else:
frame_interval = 30 # 默认间隔
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{min(self.max_frames, total_frames // frame_interval + 1)}帧)")
logger.info(
f"计算得出帧间隔: {frame_interval} (将提取约{min(self.max_frames, total_frames // frame_interval + 1)}帧)"
)
# 使用numpy计算目标帧位置
target_frames = np.arange(0, min(self.max_frames, total_frames // frame_interval + 1)) * frame_interval
target_frames = target_frames[target_frames < total_frames].astype(int)
extracted_count = 0
for target_frame in target_frames:
# 跳转到目标帧
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
ret, frame = cap.read()
if not ret:
continue
# 使用numpy优化图像处理
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转换为PIL图像并使用numpy进行尺寸计算
height, width = frame_rgb.shape[:2]
max_dim = max(height, width)
if max_dim > self.max_image_size:
# 使用numpy计算缩放比例
ratio = self.max_image_size / max_dim
new_width = int(width * ratio)
new_height = int(height * ratio)
# 使用opencv进行高效缩放
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
pil_image = Image.fromarray(frame_resized)
else:
pil_image = Image.fromarray(frame_rgb)
# 转换为base64
buffer = io.BytesIO()
pil_image.save(buffer, format='JPEG', quality=self.frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# 计算时间戳
timestamp = target_frame / fps if fps > 0 else 0
frames.append((frame_base64, timestamp))
extracted_count += 1
logger.debug(f"提取第{extracted_count}帧 (时间: {timestamp:.2f}s, 帧号: {target_frame})")
# 每提取一帧让步一次
await asyncio.sleep(0.001)
@@ -389,48 +392,48 @@ class LegacyVideoAnalyzer:
async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
"""批量分析所有帧"""
logger.info(f"开始批量分析{len(frames)}")
if not frames:
return "❌ 没有可分析的帧"
# 构建提示词并格式化人格信息,要不然占位符的那个会爆炸
prompt = self.batch_analysis_prompt.format(
personality_core=self.personality_core,
personality_side=self.personality_side
personality_core=self.personality_core, personality_side=self.personality_side
)
if user_question:
prompt += f"\n\n用户问题: {user_question}"
# 添加帧信息到提示词
frame_info = []
for i, (_frame_base64, timestamp) in enumerate(frames):
if self.enable_frame_timing:
frame_info.append(f"{i+1}帧 (时间: {timestamp:.2f}s)")
frame_info.append(f"{i + 1}帧 (时间: {timestamp:.2f}s)")
else:
frame_info.append(f"{i+1}")
frame_info.append(f"{i + 1}")
prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}"
prompt += "\n\n请基于所有提供的帧图像进行综合分析,关注并描述视频的完整内容和故事发展。"
try:
# 尝试使用多图片分析
response = await self._analyze_multiple_frames(frames, prompt)
logger.info("✅ 视频识别完成")
return response
except Exception as e:
logger.error(f"❌ 视频识别失败: {e}")
# 降级到单帧分析
logger.warning("降级到单帧分析模式")
try:
frame_base64, timestamp = frames[0]
fallback_prompt = prompt + f"\n\n注意由于技术限制当前仅显示第1帧 (时间: {timestamp:.2f}s),视频共有{len(frames)}帧。请基于这一帧进行分析。"
fallback_prompt = (
prompt
+ f"\n\n注意由于技术限制当前仅显示第1帧 (时间: {timestamp:.2f}s),视频共有{len(frames)}帧。请基于这一帧进行分析。"
)
response, _ = await self.video_llm.generate_response_for_image(
prompt=fallback_prompt,
image_base64=frame_base64,
image_format="jpeg"
prompt=fallback_prompt, image_base64=frame_base64, image_format="jpeg"
)
logger.info("✅ 降级的单帧分析完成")
return response
@@ -441,22 +444,22 @@ class LegacyVideoAnalyzer:
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
"""使用多图片分析方法"""
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
# 导入MessageBuilder用于构建多图片消息
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
# 构建包含多张图片的消息
message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
# 添加所有帧图像
for _i, (frame_base64, _timestamp) in enumerate(frames):
message_builder.add_image_content("jpeg", frame_base64)
# logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)")
message = message_builder.build()
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
# 获取模型信息和客户端
model_info, api_provider, client = self.video_llm._select_model()
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
@@ -469,45 +472,43 @@ class LegacyVideoAnalyzer:
model_info=model_info,
message_list=[message],
temperature=None,
max_tokens=None
max_tokens=None,
)
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
return api_response.content or "❌ 未获得响应内容"
async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
"""逐帧分析并汇总"""
logger.info(f"开始逐帧分析{len(frames)}")
frame_analyses = []
for i, (frame_base64, timestamp) in enumerate(frames):
try:
prompt = f"请分析这个视频的第{i+1}"
prompt = f"请分析这个视频的第{i + 1}"
if self.enable_frame_timing:
prompt += f" (时间: {timestamp:.2f}s)"
prompt += "。描述你看到的内容,包括人物、动作、场景、文字等。"
if user_question:
prompt += f"\n特别关注: {user_question}"
response, _ = await self.video_llm.generate_response_for_image(
prompt=prompt,
image_base64=frame_base64,
image_format="jpeg"
prompt=prompt, image_base64=frame_base64, image_format="jpeg"
)
frame_analyses.append(f"{i+1}帧 ({timestamp:.2f}s): {response}")
logger.debug(f"✅ 第{i+1}帧分析完成")
frame_analyses.append(f"{i + 1}帧 ({timestamp:.2f}s): {response}")
logger.debug(f"✅ 第{i + 1}帧分析完成")
# API调用间隔
if i < len(frames) - 1:
await asyncio.sleep(self.frame_analysis_delay)
except Exception as e:
logger.error(f"❌ 第{i+1}帧分析失败: {e}")
frame_analyses.append(f"{i+1}帧: 分析失败 - {e}")
logger.error(f"❌ 第{i + 1}帧分析失败: {e}")
frame_analyses.append(f"{i + 1}帧: 分析失败 - {e}")
# 生成汇总
logger.info("开始生成汇总分析")
summary_prompt = f"""基于以下各帧的分析结果,请提供一个完整的视频内容总结:
@@ -518,15 +519,13 @@ class LegacyVideoAnalyzer:
if user_question:
summary_prompt += f"\n特别回答用户的问题: {user_question}"
try:
# 使用最后一帧进行汇总分析
if frames:
last_frame_base64, _ = frames[-1]
summary, _ = await self.video_llm.generate_response_for_image(
prompt=summary_prompt,
image_base64=last_frame_base64,
image_format="jpeg"
prompt=summary_prompt, image_base64=last_frame_base64, image_format="jpeg"
)
logger.info("✅ 逐帧分析和汇总完成")
return summary
@@ -541,12 +540,12 @@ class LegacyVideoAnalyzer:
"""分析视频的主要方法"""
try:
logger.info(f"开始分析视频: {os.path.basename(video_path)}")
# 提取帧
frames = await self.extract_frames(video_path)
if not frames:
return "❌ 无法从视频中提取有效帧"
# 根据模式选择分析方法
if self.analysis_mode == "auto":
# 智能选择少于等于3帧用批量否则用逐帧
@@ -554,16 +553,16 @@ class LegacyVideoAnalyzer:
logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)")
else:
mode = self.analysis_mode
# 执行分析
if mode == "batch":
result = await self.analyze_frames_batch(frames, user_question)
else: # sequential
result = await self.analyze_frames_sequential(frames, user_question)
logger.info("✅ 视频分析完成")
return result
except Exception as e:
error_msg = f"❌ 视频分析失败: {str(e)}"
logger.error(error_msg)
@@ -571,16 +570,17 @@ class LegacyVideoAnalyzer:
def is_supported_video(self, file_path: str) -> bool:
"""检查是否为支持的视频格式"""
supported_formats = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.m4v', '.3gp', '.webm'}
supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"}
return Path(file_path).suffix.lower() in supported_formats
# 全局实例
_legacy_video_analyzer = None
def get_legacy_video_analyzer() -> LegacyVideoAnalyzer:
"""获取旧版本视频分析器实例(单例模式)"""
global _legacy_video_analyzer
if _legacy_video_analyzer is None:
_legacy_video_analyzer = LegacyVideoAnalyzer()
return _legacy_video_analyzer
return _legacy_video_analyzer