修复代码格式和文件名大小写问题
This commit is contained in:
@@ -260,89 +260,105 @@ def get_actions_by_timestamp_with_chat(
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
||||
|
||||
# 记录函数调用参数
|
||||
logger.debug(f"[get_actions_by_timestamp_with_chat] 调用参数: chat_id={chat_id}, "
|
||||
f"timestamp_start={timestamp_start}, timestamp_end={timestamp_end}, "
|
||||
f"limit={limit}, limit_mode={limit_mode}")
|
||||
|
||||
logger.debug(
|
||||
f"[get_actions_by_timestamp_with_chat] 调用参数: chat_id={chat_id}, "
|
||||
f"timestamp_start={timestamp_start}, timestamp_end={timestamp_end}, "
|
||||
f"limit={limit}, limit_mode={limit_mode}"
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
).order_by(ActionRecords.time.desc()).limit(limit))
|
||||
.order_by(ActionRecords.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions_result = []
|
||||
for action in reversed(actions):
|
||||
action_dict = {
|
||||
'id': action.id,
|
||||
'action_id': action.action_id,
|
||||
'time': action.time,
|
||||
'action_name': action.action_name,
|
||||
'action_data': action.action_data,
|
||||
'action_done': action.action_done,
|
||||
'action_build_into_prompt': action.action_build_into_prompt,
|
||||
'action_prompt_display': action.action_prompt_display,
|
||||
'chat_id': action.chat_id,
|
||||
'chat_info_stream_id': action.chat_info_stream_id,
|
||||
'chat_info_platform': action.chat_info_platform,
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()).limit(limit))
|
||||
.order_by(ActionRecords.time.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
'id': action.id,
|
||||
'action_id': action.action_id,
|
||||
'time': action.time,
|
||||
'action_name': action.action_name,
|
||||
'action_data': action.action_data,
|
||||
'action_done': action.action_done,
|
||||
'action_build_into_prompt': action.action_build_into_prompt,
|
||||
'action_prompt_display': action.action_prompt_display,
|
||||
'chat_id': action.chat_id,
|
||||
'chat_info_stream_id': action.chat_info_stream_id,
|
||||
'chat_info_platform': action.chat_info_platform,
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else:
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end,
|
||||
)
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()))
|
||||
.order_by(ActionRecords.time.asc())
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
actions_result = []
|
||||
for action in actions:
|
||||
action_dict = {
|
||||
'id': action.id,
|
||||
'action_id': action.action_id,
|
||||
'time': action.time,
|
||||
'action_name': action.action_name,
|
||||
'action_data': action.action_data,
|
||||
'action_done': action.action_done,
|
||||
'action_build_into_prompt': action.action_build_into_prompt,
|
||||
'action_prompt_display': action.action_prompt_display,
|
||||
'chat_id': action.chat_id,
|
||||
'chat_info_stream_id': action.chat_info_stream_id,
|
||||
'chat_info_platform': action.chat_info_platform,
|
||||
"id": action.id,
|
||||
"action_id": action.action_id,
|
||||
"time": action.time,
|
||||
"action_name": action.action_name,
|
||||
"action_data": action.action_data,
|
||||
"action_done": action.action_done,
|
||||
"action_build_into_prompt": action.action_build_into_prompt,
|
||||
"action_prompt_display": action.action_prompt_display,
|
||||
"chat_id": action.chat_id,
|
||||
"chat_info_stream_id": action.chat_info_stream_id,
|
||||
"chat_info_platform": action.chat_info_platform,
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
return actions_result
|
||||
@@ -355,31 +371,45 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end,
|
||||
)
|
||||
)
|
||||
).order_by(ActionRecords.time.desc()).limit(limit))
|
||||
.order_by(ActionRecords.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end,
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
query = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
ActionRecords.time <= timestamp_end,
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()).limit(limit))
|
||||
else:
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()))
|
||||
.order_by(ActionRecords.time.asc())
|
||||
)
|
||||
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in actions]
|
||||
@@ -782,7 +812,6 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 按图片编号排序
|
||||
sorted_items = sorted(pic_id_mapping.items(), key=lambda x: int(x[1].replace("图片", "")))
|
||||
|
||||
|
||||
for pic_id, display_name in sorted_items:
|
||||
# 从数据库中获取图片描述
|
||||
description = "内容正在阅读,请稍等"
|
||||
@@ -791,7 +820,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar()
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception: ...
|
||||
except Exception:
|
||||
...
|
||||
# 如果查询失败,保持默认描述
|
||||
|
||||
mapping_lines.append(f"[{display_name}] 的内容:{description}")
|
||||
@@ -811,17 +841,18 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
格式化的动作字符串。
|
||||
"""
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
||||
|
||||
logger.debug(f"[build_readable_actions] 开始处理 {len(actions) if actions else 0} 条动作记录")
|
||||
|
||||
|
||||
if not actions:
|
||||
logger.debug("[build_readable_actions] 动作记录为空,返回空字符串")
|
||||
return ""
|
||||
|
||||
output_lines = []
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
logger.debug(f"[build_readable_actions] 当前时间戳: {current_time}")
|
||||
|
||||
# The get functions return actions sorted ascending by time. Let's reverse it to show newest first.
|
||||
@@ -830,12 +861,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
for i, action in enumerate(actions):
|
||||
logger.debug(f"[build_readable_actions] === 处理第 {i} 条动作记录 ===")
|
||||
logger.debug(f"[build_readable_actions] 原始动作数据: {action}")
|
||||
|
||||
|
||||
action_time = action.get("time", current_time)
|
||||
action_name = action.get("action_name", "未知动作")
|
||||
|
||||
|
||||
logger.debug(f"[build_readable_actions] 动作时间戳: {action_time}, 动作名称: '{action_name}'")
|
||||
|
||||
|
||||
# 检查是否是原始的 action_name 值
|
||||
original_action_name = action.get("action_name")
|
||||
if original_action_name is None:
|
||||
@@ -844,7 +875,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
logger.error(f"[build_readable_actions] 动作 #{i}: action_name 为空字符串!")
|
||||
elif original_action_name == "未知动作":
|
||||
logger.error(f"[build_readable_actions] 动作 #{i}: action_name 已经是'未知动作'!")
|
||||
|
||||
|
||||
if action_name in ["no_action", "no_reply"]:
|
||||
logger.debug(f"[build_readable_actions] 跳过动作 #{i}: {action_name} (在跳过列表中)")
|
||||
continue
|
||||
@@ -863,7 +894,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
|
||||
logger.debug(f"[build_readable_actions] 时间描述: '{time_ago_str}'")
|
||||
|
||||
line = f"{time_ago_str},你使用了\"{action_name}\",具体内容是:\"{action_prompt_display}\""
|
||||
line = f'{time_ago_str},你使用了"{action_name}",具体内容是:"{action_prompt_display}"'
|
||||
logger.debug(f"[build_readable_actions] 生成的行: '{line}'")
|
||||
output_lines.append(line)
|
||||
|
||||
@@ -964,23 +995,26 @@ def build_readable_messages(
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time,
|
||||
ActionRecords.time <= max_time,
|
||||
ActionRecords.chat_id == chat_id
|
||||
actions_in_range = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
|
||||
)
|
||||
)
|
||||
).order_by(ActionRecords.time)).scalars()
|
||||
.order_by(ActionRecords.time)
|
||||
).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.time > max_time,
|
||||
ActionRecords.chat_id == chat_id
|
||||
)
|
||||
).order_by(ActionRecords.time).limit(1)).scalars()
|
||||
action_after_latest = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
).scalars()
|
||||
|
||||
# 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError
|
||||
actions = [
|
||||
|
||||
@@ -12,6 +12,7 @@ install(extra_lines=3)
|
||||
|
||||
logger = get_logger("prompt_build")
|
||||
|
||||
|
||||
class PromptContext:
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
@@ -27,7 +28,7 @@ class PromptContext:
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
@@ -51,7 +52,7 @@ class PromptContext:
|
||||
# 保存当前协程的上下文值,不影响其他协程
|
||||
previous_context = self._current_context
|
||||
# 设置当前协程的新上下文
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
else:
|
||||
# 如果没有提供新上下文,保持当前上下文不变
|
||||
previous_context = self._current_context
|
||||
@@ -69,7 +70,8 @@ class PromptContext:
|
||||
# 如果reset失败,尝试直接设置
|
||||
try:
|
||||
self._current_context = previous_context
|
||||
except Exception: ...
|
||||
except Exception:
|
||||
...
|
||||
# 静默忽略恢复失败
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
@@ -174,7 +176,9 @@ class Prompt(str):
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def __new__(cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs):
|
||||
def __new__(
|
||||
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
|
||||
):
|
||||
# 如果传入的是元组,转换为列表
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
@@ -219,7 +223,9 @@ class Prompt(str):
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def _format_template(cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None) -> str:
|
||||
def _format_template(
|
||||
cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# 预处理模板中的转义花括号
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
智能提示词参数模块 - 优化参数结构
|
||||
简化SmartPromptParameters,减少冗余和重复
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Literal
|
||||
|
||||
@@ -9,6 +10,7 @@ from typing import Dict, Any, Optional, List, Literal
|
||||
@dataclass
|
||||
class SmartPromptParameters:
|
||||
"""简化的智能提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
@@ -17,7 +19,7 @@ class SmartPromptParameters:
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
@@ -25,20 +27,20 @@ class SmartPromptParameters:
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
@@ -46,7 +48,7 @@ class SmartPromptParameters:
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
@@ -57,7 +59,10 @@ class SmartPromptParameters:
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""统一的参数验证"""
|
||||
errors = []
|
||||
@@ -68,39 +73,39 @@ class SmartPromptParameters:
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
|
||||
|
||||
def get_needed_build_tasks(self) -> List[str]:
|
||||
"""获取需要执行的任务列表"""
|
||||
tasks = []
|
||||
|
||||
|
||||
if self.enable_expression and not self.expression_habits_block:
|
||||
tasks.append("expression_habits")
|
||||
|
||||
|
||||
if self.enable_memory and not self.memory_block:
|
||||
tasks.append("memory_block")
|
||||
|
||||
|
||||
if self.enable_relation and not self.relation_info_block:
|
||||
tasks.append("relation_info")
|
||||
|
||||
|
||||
if self.enable_tool and not self.tool_info_block:
|
||||
tasks.append("tool_info")
|
||||
|
||||
|
||||
if self.enable_knowledge and not self.knowledge_prompt:
|
||||
tasks.append("knowledge_info")
|
||||
|
||||
|
||||
if self.enable_cross_context and not self.cross_context_block:
|
||||
tasks.append("cross_context")
|
||||
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_legacy_params(cls, **kwargs) -> 'SmartPromptParameters':
|
||||
def from_legacy_params(cls, **kwargs) -> "SmartPromptParameters":
|
||||
"""
|
||||
从旧版参数创建新参数对象
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: 旧版参数
|
||||
|
||||
|
||||
Returns:
|
||||
SmartPromptParameters: 新参数对象
|
||||
"""
|
||||
@@ -113,7 +118,6 @@ class SmartPromptParameters:
|
||||
reply_to=kwargs.get("reply_to", ""),
|
||||
extra_info=kwargs.get("extra_info", ""),
|
||||
prompt_mode=kwargs.get("current_prompt_mode", "s4u"),
|
||||
|
||||
# 功能开关
|
||||
enable_tool=kwargs.get("enable_tool", True),
|
||||
enable_memory=kwargs.get("enable_memory", True),
|
||||
@@ -121,18 +125,15 @@ class SmartPromptParameters:
|
||||
enable_relation=kwargs.get("enable_relation", True),
|
||||
enable_cross_context=kwargs.get("enable_cross_context", True),
|
||||
enable_knowledge=kwargs.get("enable_knowledge", True),
|
||||
|
||||
# 性能控制
|
||||
max_context_messages=kwargs.get("max_context_messages", 50),
|
||||
debug_mode=kwargs.get("debug_mode", False),
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info=kwargs.get("chat_target_info"),
|
||||
message_list_before_now_long=kwargs.get("message_list_before_now_long", []),
|
||||
message_list_before_short=kwargs.get("message_list_before_short", []),
|
||||
chat_talking_prompt_short=kwargs.get("chat_talking_prompt_short", ""),
|
||||
target_user_info=kwargs.get("target_user_info"),
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block=kwargs.get("expression_habits_block", ""),
|
||||
relation_info_block=kwargs.get("relation_info", ""),
|
||||
@@ -140,7 +141,6 @@ class SmartPromptParameters:
|
||||
tool_info_block=kwargs.get("tool_info", ""),
|
||||
knowledge_prompt=kwargs.get("knowledge_prompt", ""),
|
||||
cross_context_block=kwargs.get("cross_context_block", ""),
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt=kwargs.get("keywords_reaction_prompt", ""),
|
||||
extra_info_block=kwargs.get("extra_info_block", ""),
|
||||
@@ -151,4 +151,6 @@ class SmartPromptParameters:
|
||||
reply_target_block=kwargs.get("reply_target_block", ""),
|
||||
mood_prompt=kwargs.get("mood_prompt", ""),
|
||||
action_descriptions=kwargs.get("action_descriptions", ""),
|
||||
)
|
||||
# 可用动作信息
|
||||
available_actions=kwargs.get("available_actions", None),
|
||||
)
|
||||
|
||||
@@ -2,16 +2,14 @@
|
||||
共享提示词工具模块 - 消除重复代码
|
||||
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
build_readable_messages_with_id,
|
||||
)
|
||||
@@ -23,25 +21,25 @@ logger = get_logger("prompt_utils")
|
||||
|
||||
class PromptUtils:
|
||||
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
|
||||
Args:
|
||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (发送者名称, 消息内容)
|
||||
"""
|
||||
sender = ""
|
||||
target = ""
|
||||
|
||||
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
|
||||
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
@@ -49,16 +47,16 @@ class PromptUtils:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||
"""
|
||||
构建关系信息 - 统一实现
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
|
||||
Returns:
|
||||
str: 关系信息字符串
|
||||
"""
|
||||
@@ -66,8 +64,9 @@ class PromptUtils:
|
||||
return ""
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
|
||||
if not reply_to:
|
||||
return ""
|
||||
sender, text = PromptUtils.parse_reply_target(reply_to)
|
||||
@@ -82,21 +81,19 @@ class PromptUtils:
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str,
|
||||
target_user_info: Optional[Dict[str, Any]],
|
||||
current_prompt_mode: str
|
||||
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 当前聊天ID
|
||||
target_user_info: 目标用户信息
|
||||
current_prompt_mode: 当前提示模式
|
||||
|
||||
|
||||
Returns:
|
||||
str: 跨群上下文块
|
||||
"""
|
||||
@@ -108,7 +105,7 @@ class PromptUtils:
|
||||
current_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not current_stream or not current_stream.group_info:
|
||||
return ""
|
||||
|
||||
|
||||
try:
|
||||
current_chat_raw_id = current_stream.group_info.group_id
|
||||
except Exception as e:
|
||||
@@ -144,7 +141,7 @@ class PromptUtils:
|
||||
if messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f"[以下是来自\"{chat_name}\"的近期消息]\n{formatted_messages}")
|
||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||
except Exception as e:
|
||||
logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}")
|
||||
continue
|
||||
@@ -175,14 +172,15 @@ class PromptUtils:
|
||||
if user_messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id
|
||||
user_name = (
|
||||
target_user_info.get("person_name") or
|
||||
target_user_info.get("user_nickname") or user_id
|
||||
target_user_info.get("person_name")
|
||||
or target_user_info.get("user_nickname")
|
||||
or user_id
|
||||
)
|
||||
formatted_messages, _ = build_readable_messages_with_id(
|
||||
user_messages, timestamp_mode="relative"
|
||||
)
|
||||
cross_context_messages.append(
|
||||
f"[以下是\"{user_name}\"在\"{chat_name}\"的近期发言]\n{formatted_messages}"
|
||||
f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}")
|
||||
@@ -192,31 +190,31 @@ class PromptUtils:
|
||||
return ""
|
||||
|
||||
return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target_id(reply_to: str) -> str:
|
||||
"""
|
||||
解析回复目标中的用户ID
|
||||
|
||||
|
||||
Args:
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
|
||||
Returns:
|
||||
str: 用户ID
|
||||
"""
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
|
||||
# 复用parse_reply_target方法的逻辑
|
||||
sender, _ = PromptUtils.parse_reply_target(reply_to)
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
|
||||
return ""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,45 +13,45 @@ from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
|
||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
||||
"""同步版本的db_get,用于在线程池中调用"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建新的事件循环
|
||||
import threading
|
||||
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
|
||||
def run_in_thread():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result = new_loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
result = new_loop.run_until_complete(db_get(model_class, filters, limit, order_by, single_result))
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
else:
|
||||
return loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
return loop.run_until_complete(db_get(model_class, filters, limit, order_by, single_result))
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建一个新的
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
|
||||
# 统计数据的键
|
||||
TOTAL_REQ_CNT = "total_requests"
|
||||
TOTAL_COST = "total_cost"
|
||||
@@ -112,7 +112,7 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
data={"end_timestamp": extended_end_time},
|
||||
)
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
@@ -126,17 +126,17 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
filters={"end_timestamp": {"$gte": recent_threshold}},
|
||||
order_by="-end_timestamp",
|
||||
limit=1,
|
||||
single_result=True
|
||||
single_result=True,
|
||||
)
|
||||
|
||||
|
||||
if recent_records:
|
||||
# 找到近期记录,更新它
|
||||
self.record_id = recent_records['id']
|
||||
self.record_id = recent_records["id"]
|
||||
await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
data={"end_timestamp": extended_end_time},
|
||||
)
|
||||
else:
|
||||
# 创建新记录
|
||||
@@ -147,10 +147,10 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
"duration": 5, # 初始时长为5分钟
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": extended_end_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
if new_record:
|
||||
self.record_id = new_record['id']
|
||||
self.record_id = new_record["id"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
@@ -368,20 +368,19 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
) or []
|
||||
|
||||
records = (
|
||||
_sync_db_get(model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp")
|
||||
or []
|
||||
)
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
record_timestamp = record.get('timestamp')
|
||||
|
||||
record_timestamp = record.get("timestamp")
|
||||
if isinstance(record_timestamp, str):
|
||||
record_timestamp = datetime.fromisoformat(record_timestamp)
|
||||
|
||||
|
||||
if not record_timestamp:
|
||||
continue
|
||||
|
||||
@@ -390,9 +389,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
user_id = record.get('user_id') or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
user_id = record.get("user_id") or "unknown"
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
@@ -402,8 +401,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||
|
||||
prompt_tokens = record.get('prompt_tokens') or 0
|
||||
completion_tokens = record.get('completion_tokens') or 0
|
||||
prompt_tokens = record.get("prompt_tokens") or 0
|
||||
completion_tokens = record.get("completion_tokens") or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
@@ -421,40 +420,40 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
|
||||
cost = record.get('cost') or 0.0
|
||||
cost = record.get("cost") or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||
stats[period_key][COST_BY_MODULE][module_name] += cost
|
||||
|
||||
|
||||
# 收集time_cost数据
|
||||
time_cost = record.get('time_cost') or 0.0
|
||||
time_cost = record.get("time_cost") or 0.0
|
||||
if time_cost > 0: # 只记录有效的time_cost
|
||||
stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost)
|
||||
stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost)
|
||||
break
|
||||
|
||||
# 计算平均耗时和标准差
|
||||
|
||||
# 计算平均耗时和标准差
|
||||
for period_key in stats:
|
||||
for category in [REQ_CNT_BY_TYPE, REQ_CNT_BY_USER, REQ_CNT_BY_MODEL, REQ_CNT_BY_MODULE]:
|
||||
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
|
||||
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
|
||||
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
|
||||
|
||||
|
||||
for item_name in stats[period_key][category]:
|
||||
time_costs = stats[period_key][time_cost_key].get(item_name, [])
|
||||
if time_costs:
|
||||
# 计算平均耗时
|
||||
avg_time_cost = sum(time_costs) / len(time_costs)
|
||||
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
|
||||
|
||||
|
||||
# 计算标准差
|
||||
if len(time_costs) > 1:
|
||||
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
|
||||
std_time_cost = variance ** 0.5
|
||||
std_time_cost = variance**0.5
|
||||
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
|
||||
else:
|
||||
stats[period_key][std_key][item_name] = 0.0
|
||||
@@ -483,21 +482,22 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = _sync_db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp"
|
||||
) or []
|
||||
|
||||
records = (
|
||||
_sync_db_get(
|
||||
model_class=OnlineTime, filters={"end_timestamp": {"$gte": query_start_time}}, order_by="-end_timestamp"
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
record_end_timestamp = record.get('end_timestamp')
|
||||
record_end_timestamp = record.get("end_timestamp")
|
||||
if isinstance(record_end_timestamp, str):
|
||||
record_end_timestamp = datetime.fromisoformat(record_end_timestamp)
|
||||
|
||||
record_start_timestamp = record.get('start_timestamp')
|
||||
record_start_timestamp = record.get("start_timestamp")
|
||||
if isinstance(record_start_timestamp, str):
|
||||
record_start_timestamp = datetime.fromisoformat(record_start_timestamp)
|
||||
|
||||
@@ -539,16 +539,15 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
) or []
|
||||
|
||||
records = (
|
||||
_sync_db_get(model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time")
|
||||
or []
|
||||
)
|
||||
|
||||
for message in records:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
message_time_ts = message.get('time') # This is a float timestamp
|
||||
message_time_ts = message.get("time") # This is a float timestamp
|
||||
|
||||
if not message_time_ts:
|
||||
continue
|
||||
@@ -557,18 +556,16 @@ class StatisticOutputTask(AsyncTask):
|
||||
chat_name = None
|
||||
|
||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||
if message.get('chat_info_group_id'):
|
||||
if message.get("chat_info_group_id"):
|
||||
chat_id = f"g{message['chat_info_group_id']}"
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
||||
elif message.get("user_id"): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||
chat_name = message.get('user_nickname') # SENDER's nickname
|
||||
chat_name = message.get("user_nickname") # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(
|
||||
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
|
||||
)
|
||||
logger.warning(f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats.")
|
||||
continue
|
||||
|
||||
if not chat_id: # Should not happen if above logic is correct
|
||||
@@ -589,8 +586,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
@@ -721,7 +716,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
|
||||
output.append(
|
||||
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
|
||||
)
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
@@ -1109,13 +1106,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
model_class=LLMUsage, filters={"timestamp": {"$gte": query_start_time}}, order_by="-timestamp"
|
||||
)
|
||||
|
||||
|
||||
for record in records:
|
||||
record_time = record['timestamp']
|
||||
record_time = record["timestamp"]
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
@@ -1123,17 +1118,17 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 累加总花费数据
|
||||
cost = record.get('cost') or 0.0
|
||||
cost = record.get("cost") or 0.0
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
model_name = record.get("model_name") or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0] * len(time_points)
|
||||
cost_by_model[model_name][interval_index] += cost
|
||||
|
||||
# 累加按模块分类的花费
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0] * len(time_points)
|
||||
@@ -1142,13 +1137,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
model_class=Messages, filters={"time": {"$gte": query_start_timestamp}}, order_by="-time"
|
||||
)
|
||||
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time']
|
||||
message_time_ts = message["time"]
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = message_time_ts - query_start_timestamp
|
||||
@@ -1157,10 +1150,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 确定聊天流名称
|
||||
chat_name = None
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'):
|
||||
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
|
||||
if message.get("chat_info_group_id"):
|
||||
chat_name = message.get("chat_info_group_name") or f"群{message['chat_info_group_id']}"
|
||||
elif message.get("user_id"):
|
||||
chat_name = message.get("user_nickname") or f"用户{message['user_id']}"
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@@ -73,9 +73,7 @@ class ChineseTypoGenerator:
|
||||
|
||||
# 保存到缓存文件
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(
|
||||
normalized_freq, option=orjson.OPT_INDENT_2).decode('utf-8')
|
||||
)
|
||||
f.write(orjson.dumps(normalized_freq, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
return normalized_freq
|
||||
|
||||
|
||||
@@ -669,10 +669,10 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
|
||||
Returns:
|
||||
包含 {'id': str, 'message': any} 格式的字典列表
|
||||
"""
|
||||
@@ -685,47 +685,41 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
else:
|
||||
a = 1
|
||||
b = 9
|
||||
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# 生成唯一的简短ID
|
||||
while True:
|
||||
# 使用索引+随机数生成简短ID
|
||||
random_suffix = random.randint(a, b)
|
||||
message_id = f"m{i+1}{random_suffix}"
|
||||
|
||||
message_id = f"m{i + 1}{random_suffix}"
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
result.append({
|
||||
'id': message_id,
|
||||
'message': message
|
||||
})
|
||||
|
||||
|
||||
result.append({"id": message_id, "message": message})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def assign_message_ids_flexible(
|
||||
messages: list,
|
||||
prefix: str = "msg",
|
||||
id_length: int = 6,
|
||||
use_timestamp: bool = False
|
||||
messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
|
||||
) -> list:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID(增强版)
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
prefix: ID前缀,默认为"msg"
|
||||
id_length: ID的总长度(不包括前缀),默认为6
|
||||
use_timestamp: 是否在ID中包含时间戳,默认为False
|
||||
|
||||
|
||||
Returns:
|
||||
包含 {'id': str, 'message': any} 格式的字典列表
|
||||
"""
|
||||
result = []
|
||||
used_ids = set()
|
||||
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# 生成唯一的ID
|
||||
while True:
|
||||
@@ -733,38 +727,35 @@ def assign_message_ids_flexible(
|
||||
# 使用时间戳的后几位 + 随机字符
|
||||
timestamp_suffix = str(int(time.time() * 1000))[-3:]
|
||||
remaining_length = id_length - 3
|
||||
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
message_id = f"{prefix}{timestamp_suffix}{random_chars}"
|
||||
else:
|
||||
# 使用索引 + 随机字符
|
||||
index_str = str(i + 1)
|
||||
remaining_length = max(1, id_length - len(index_str))
|
||||
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
message_id = f"{prefix}{index_str}{random_chars}"
|
||||
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
result.append({
|
||||
'id': message_id,
|
||||
'message': message
|
||||
})
|
||||
|
||||
|
||||
result.append({"id": message_id, "message": message})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 使用示例:
|
||||
# messages = ["Hello", "World", "Test message"]
|
||||
#
|
||||
#
|
||||
# # 基础版本
|
||||
# result1 = assign_message_ids(messages)
|
||||
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
|
||||
#
|
||||
#
|
||||
# # 增强版本 - 自定义前缀和长度
|
||||
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
|
||||
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
|
||||
#
|
||||
#
|
||||
# # 增强版本 - 使用时间戳
|
||||
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
|
||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||
|
||||
@@ -18,6 +18,7 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
@@ -66,9 +67,14 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
record = session.execute(
|
||||
select(ImageDescriptions).where(
|
||||
and_(
|
||||
ImageDescriptions.image_description_hash == image_hash,
|
||||
ImageDescriptions.type == description_type,
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
@@ -87,9 +93,14 @@ class ImageManager:
|
||||
current_timestamp = time.time()
|
||||
with get_db_session() as session:
|
||||
# 查找现有记录
|
||||
existing = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
existing = session.execute(
|
||||
select(ImageDescriptions).where(
|
||||
and_(
|
||||
ImageDescriptions.image_description_hash == image_hash,
|
||||
ImageDescriptions.type == description_type,
|
||||
)
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
@@ -101,16 +112,17 @@ class ImageManager:
|
||||
image_description_hash=image_hash,
|
||||
type=description_type,
|
||||
description=description,
|
||||
timestamp=current_timestamp
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_desc)
|
||||
session.commit()
|
||||
# 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
@@ -135,6 +147,7 @@ class ImageManager:
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
|
||||
if cached_emoji_description:
|
||||
@@ -228,10 +241,11 @@ class ImageManager:
|
||||
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
existing_img = session.execute(select(Images).where(
|
||||
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
|
||||
)).scalar()
|
||||
existing_img = session.execute(
|
||||
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
|
||||
).scalar()
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
@@ -324,7 +338,7 @@ class ImageManager:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
|
||||
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
new_img = Images(
|
||||
@@ -338,7 +352,7 @@ class ImageManager:
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
|
||||
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
@@ -381,7 +395,8 @@ class ImageManager:
|
||||
# 确保是RGB格式方便比较
|
||||
frame = gif.convert("RGB")
|
||||
all_frames.append(frame.copy())
|
||||
except EOFError: ... # 读完啦
|
||||
except EOFError:
|
||||
... # 读完啦
|
||||
|
||||
if not all_frames:
|
||||
logger.warning("GIF中没有找到任何帧")
|
||||
@@ -511,7 +526,7 @@ class ImageManager:
|
||||
existing_image.vlm_processed = False
|
||||
|
||||
existing_image.count += 1
|
||||
|
||||
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
@@ -569,19 +584,23 @@ class ImageManager:
|
||||
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = session.execute(select(Images).where(
|
||||
and_(
|
||||
Images.emoji_hash == image_hash,
|
||||
Images.description.isnot(None),
|
||||
Images.description != "",
|
||||
Images.id != image.id
|
||||
existing_with_description = session.execute(
|
||||
select(Images).where(
|
||||
and_(
|
||||
Images.emoji_hash == image_hash,
|
||||
Images.description.isnot(None),
|
||||
Images.description != "",
|
||||
Images.id != image.id,
|
||||
)
|
||||
)
|
||||
)).scalar()
|
||||
).scalar()
|
||||
if existing_with_description:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
logger.debug(
|
||||
f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}..."
|
||||
)
|
||||
image.description = existing_with_description.description
|
||||
image.vlm_processed = True
|
||||
|
||||
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
@@ -591,7 +610,7 @@ class ImageManager:
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
|
||||
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,32 +8,30 @@
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import tempfile
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from typing import List, Tuple, Optional
|
||||
import io
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import get_db_session, Videos
|
||||
|
||||
logger = get_logger("utils_video_legacy")
|
||||
|
||||
def _extract_frames_worker(video_path: str,
|
||||
max_frames: int,
|
||||
frame_quality: int,
|
||||
max_image_size: int,
|
||||
frame_extraction_mode: str,
|
||||
frame_interval_seconds: Optional[float]) -> List[Tuple[str, float]]:
|
||||
|
||||
def _extract_frames_worker(
|
||||
video_path: str,
|
||||
max_frames: int,
|
||||
frame_quality: int,
|
||||
max_image_size: int,
|
||||
frame_extraction_mode: str,
|
||||
frame_interval_seconds: Optional[float],
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""线程池中提取视频帧的工作函数"""
|
||||
frames = []
|
||||
try:
|
||||
@@ -41,42 +39,42 @@ def _extract_frames_worker(video_path: str,
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
|
||||
|
||||
if frame_extraction_mode == "time_interval":
|
||||
# 新模式:按时间间隔抽帧
|
||||
time_interval = frame_interval_seconds
|
||||
next_frame_time = 0.0
|
||||
extracted_count = 0 # 初始化提取帧计数器
|
||||
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
|
||||
current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
|
||||
|
||||
|
||||
if current_time >= next_frame_time:
|
||||
# 转换为PIL图像并压缩
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(frame_rgb)
|
||||
|
||||
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > max_image_size:
|
||||
ratio = max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='JPEG', quality=frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
pil_image.save(buffer, format="JPEG", quality=frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
frames.append((frame_base64, current_time))
|
||||
extracted_count += 1
|
||||
|
||||
|
||||
# 注意:这里不能使用logger,因为在线程池中
|
||||
# logger.debug(f"提取第{extracted_count}帧 (时间: {current_time:.2f}s)")
|
||||
|
||||
|
||||
next_frame_time += time_interval
|
||||
else:
|
||||
# 使用numpy优化帧间隔计算
|
||||
@@ -84,49 +82,49 @@ def _extract_frames_worker(video_path: str,
|
||||
frame_interval = max(1, int(duration / max_frames * fps))
|
||||
else:
|
||||
frame_interval = 30 # 默认间隔
|
||||
|
||||
|
||||
# 使用numpy计算目标帧位置
|
||||
target_frames = np.arange(0, min(max_frames, total_frames // frame_interval + 1)) * frame_interval
|
||||
target_frames = target_frames[target_frames < total_frames].astype(int)
|
||||
|
||||
|
||||
for target_frame in target_frames:
|
||||
# 跳转到目标帧
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
continue
|
||||
|
||||
|
||||
# 使用numpy优化图像处理
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
# 转换为PIL图像并使用numpy进行尺寸计算
|
||||
height, width = frame_rgb.shape[:2]
|
||||
max_dim = max(height, width)
|
||||
|
||||
|
||||
if max_dim > max_image_size:
|
||||
# 使用numpy计算缩放比例
|
||||
ratio = max_image_size / max_dim
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
|
||||
|
||||
# 使用opencv进行高效缩放
|
||||
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
|
||||
pil_image = Image.fromarray(frame_resized)
|
||||
else:
|
||||
pil_image = Image.fromarray(frame_rgb)
|
||||
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='JPEG', quality=frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
pil_image.save(buffer, format="JPEG", quality=frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# 计算时间戳
|
||||
timestamp = target_frame / fps if fps > 0 else 0
|
||||
frames.append((frame_base64, timestamp))
|
||||
|
||||
|
||||
cap.release()
|
||||
return frames
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 返回错误信息
|
||||
return [("ERROR", str(e))]
|
||||
@@ -140,38 +138,39 @@ class LegacyVideoAnalyzer:
|
||||
# 使用专用的视频分析配置
|
||||
try:
|
||||
self.video_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.video_analysis,
|
||||
request_type="video_analysis"
|
||||
model_set=model_config.model_task_config.video_analysis, request_type="video_analysis"
|
||||
)
|
||||
logger.info("✅ 使用video_analysis模型配置")
|
||||
except (AttributeError, KeyError) as e:
|
||||
# 如果video_analysis不存在,使用vlm配置
|
||||
self.video_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.vlm,
|
||||
request_type="vlm"
|
||||
)
|
||||
self.video_llm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="vlm")
|
||||
logger.warning(f"video_analysis配置不可用({e}),回退使用vlm配置")
|
||||
|
||||
|
||||
# 从配置文件读取参数,如果配置不存在则使用默认值
|
||||
config = global_config.video_analysis
|
||||
|
||||
# 使用 getattr 统一获取配置参数,如果配置不存在则使用默认值
|
||||
self.max_frames = getattr(config, 'max_frames', 6)
|
||||
self.frame_quality = getattr(config, 'frame_quality', 85)
|
||||
self.max_image_size = getattr(config, 'max_image_size', 600)
|
||||
self.enable_frame_timing = getattr(config, 'enable_frame_timing', True)
|
||||
|
||||
self.max_frames = getattr(config, "max_frames", 6)
|
||||
self.frame_quality = getattr(config, "frame_quality", 85)
|
||||
self.max_image_size = getattr(config, "max_image_size", 600)
|
||||
self.enable_frame_timing = getattr(config, "enable_frame_timing", True)
|
||||
|
||||
# 从personality配置中获取人格信息
|
||||
try:
|
||||
personality_config = global_config.personality
|
||||
self.personality_core = getattr(personality_config, 'personality_core', "是一个积极向上的女大学生")
|
||||
self.personality_side = getattr(personality_config, 'personality_side', "用一句话或几句话描述人格的侧面特点")
|
||||
self.personality_core = getattr(personality_config, "personality_core", "是一个积极向上的女大学生")
|
||||
self.personality_side = getattr(
|
||||
personality_config, "personality_side", "用一句话或几句话描述人格的侧面特点"
|
||||
)
|
||||
except AttributeError:
|
||||
# 如果没有personality配置,使用默认值
|
||||
self.personality_core = "是一个积极向上的女大学生"
|
||||
self.personality_side = "用一句话或几句话描述人格的侧面特点"
|
||||
|
||||
self.batch_analysis_prompt = getattr(config, 'batch_analysis_prompt', """请以第一人称的视角来观看这一个视频,你看到的这些是从视频中按时间顺序提取的关键帧。
|
||||
|
||||
self.batch_analysis_prompt = getattr(
|
||||
config,
|
||||
"batch_analysis_prompt",
|
||||
"""请以第一人称的视角来观看这一个视频,你看到的这些是从视频中按时间顺序提取的关键帧。
|
||||
|
||||
你的核心人设是:{personality_core}。
|
||||
你的人格细节是:{personality_side}。
|
||||
@@ -184,16 +183,17 @@ class LegacyVideoAnalyzer:
|
||||
5. 整体氛围和情感表达
|
||||
6. 任何特殊的视觉效果或文字内容
|
||||
|
||||
请用中文回答,结果要详细准确。""")
|
||||
|
||||
请用中文回答,结果要详细准确。""",
|
||||
)
|
||||
|
||||
# 新增的线程池配置
|
||||
self.use_multiprocessing = getattr(config, 'use_multiprocessing', True)
|
||||
self.max_workers = getattr(config, 'max_workers', 2)
|
||||
self.frame_extraction_mode = getattr(config, 'frame_extraction_mode', 'fixed_number')
|
||||
self.frame_interval_seconds = getattr(config, 'frame_interval_seconds', 2.0)
|
||||
|
||||
self.use_multiprocessing = getattr(config, "use_multiprocessing", True)
|
||||
self.max_workers = getattr(config, "max_workers", 2)
|
||||
self.frame_extraction_mode = getattr(config, "frame_extraction_mode", "fixed_number")
|
||||
self.frame_interval_seconds = getattr(config, "frame_interval_seconds", 2.0)
|
||||
|
||||
# 将配置文件中的模式映射到内部使用的模式名称
|
||||
config_mode = getattr(config, 'analysis_mode', 'auto')
|
||||
config_mode = getattr(config, "analysis_mode", "auto")
|
||||
if config_mode == "batch_frames":
|
||||
self.analysis_mode = "batch"
|
||||
elif config_mode == "frame_by_frame":
|
||||
@@ -203,21 +203,23 @@ class LegacyVideoAnalyzer:
|
||||
else:
|
||||
logger.warning(f"无效的分析模式: {config_mode},使用默认的auto模式")
|
||||
self.analysis_mode = "auto"
|
||||
|
||||
|
||||
self.frame_analysis_delay = 0.3 # API调用间隔(秒)
|
||||
self.frame_interval = 1.0 # 抽帧时间间隔(秒)
|
||||
self.batch_size = 3 # 批处理时每批处理的帧数
|
||||
self.timeout = 60.0 # 分析超时时间(秒)
|
||||
|
||||
|
||||
if config:
|
||||
logger.info("✅ 从配置文件读取视频分析参数")
|
||||
else:
|
||||
logger.warning("配置文件中缺少video_analysis配置,使用默认值")
|
||||
|
||||
|
||||
# 系统提示词
|
||||
self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。"
|
||||
|
||||
logger.info(f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}")
|
||||
|
||||
logger.info(
|
||||
f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}"
|
||||
)
|
||||
|
||||
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
"""提取视频帧 - 支持多进程和单线程模式"""
|
||||
@@ -227,18 +229,18 @@ class LegacyVideoAnalyzer:
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
cap.release()
|
||||
|
||||
|
||||
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒")
|
||||
|
||||
|
||||
# 估算提取帧数
|
||||
if duration > 0:
|
||||
frame_interval = max(1, int(duration / self.max_frames * fps))
|
||||
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
|
||||
else:
|
||||
estimated_frames = self.max_frames
|
||||
|
||||
|
||||
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
|
||||
|
||||
|
||||
# 根据配置选择处理方式
|
||||
if self.use_multiprocessing:
|
||||
return await self._extract_frames_multiprocess(video_path)
|
||||
@@ -248,7 +250,7 @@ class LegacyVideoAnalyzer:
|
||||
async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
"""线程池版本的帧提取"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
try:
|
||||
logger.info("🔄 启动线程池帧提取...")
|
||||
# 使用线程池,避免进程间的导入问题
|
||||
@@ -261,19 +263,19 @@ class LegacyVideoAnalyzer:
|
||||
self.frame_quality,
|
||||
self.max_image_size,
|
||||
self.frame_extraction_mode,
|
||||
self.frame_interval_seconds
|
||||
self.frame_interval_seconds,
|
||||
)
|
||||
|
||||
|
||||
# 检查是否有错误
|
||||
if frames and frames[0][0] == "ERROR":
|
||||
logger.error(f"线程池帧提取失败: {frames[0][1]}")
|
||||
# 降级到单线程模式
|
||||
logger.info("🔄 降级到单线程模式...")
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
|
||||
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
|
||||
return frames
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"线程池帧提取失败: {e}")
|
||||
# 降级到原始方法
|
||||
@@ -288,43 +290,42 @@ class LegacyVideoAnalyzer:
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
|
||||
|
||||
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒")
|
||||
|
||||
|
||||
if self.frame_extraction_mode == "time_interval":
|
||||
# 新模式:按时间间隔抽帧
|
||||
time_interval = self.frame_interval_seconds
|
||||
next_frame_time = 0.0
|
||||
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
|
||||
current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
|
||||
|
||||
|
||||
if current_time >= next_frame_time:
|
||||
# 转换为PIL图像并压缩
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(frame_rgb)
|
||||
|
||||
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > self.max_image_size:
|
||||
ratio = self.max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='JPEG', quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
frames.append((frame_base64, current_time))
|
||||
extracted_count += 1
|
||||
|
||||
|
||||
logger.debug(f"提取第{extracted_count}帧 (时间: {current_time:.2f}s)")
|
||||
|
||||
|
||||
next_frame_time += time_interval
|
||||
else:
|
||||
# 使用numpy优化帧间隔计算
|
||||
@@ -332,53 +333,55 @@ class LegacyVideoAnalyzer:
|
||||
frame_interval = max(1, int(duration / self.max_frames * fps))
|
||||
else:
|
||||
frame_interval = 30 # 默认间隔
|
||||
|
||||
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{min(self.max_frames, total_frames // frame_interval + 1)}帧)")
|
||||
|
||||
logger.info(
|
||||
f"计算得出帧间隔: {frame_interval} (将提取约{min(self.max_frames, total_frames // frame_interval + 1)}帧)"
|
||||
)
|
||||
|
||||
# 使用numpy计算目标帧位置
|
||||
target_frames = np.arange(0, min(self.max_frames, total_frames // frame_interval + 1)) * frame_interval
|
||||
target_frames = target_frames[target_frames < total_frames].astype(int)
|
||||
|
||||
|
||||
extracted_count = 0
|
||||
|
||||
|
||||
for target_frame in target_frames:
|
||||
# 跳转到目标帧
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
continue
|
||||
|
||||
|
||||
# 使用numpy优化图像处理
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
# 转换为PIL图像并使用numpy进行尺寸计算
|
||||
height, width = frame_rgb.shape[:2]
|
||||
max_dim = max(height, width)
|
||||
|
||||
|
||||
if max_dim > self.max_image_size:
|
||||
# 使用numpy计算缩放比例
|
||||
ratio = self.max_image_size / max_dim
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
|
||||
|
||||
# 使用opencv进行高效缩放
|
||||
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
|
||||
pil_image = Image.fromarray(frame_resized)
|
||||
else:
|
||||
pil_image = Image.fromarray(frame_rgb)
|
||||
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='JPEG', quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# 计算时间戳
|
||||
timestamp = target_frame / fps if fps > 0 else 0
|
||||
frames.append((frame_base64, timestamp))
|
||||
extracted_count += 1
|
||||
|
||||
|
||||
logger.debug(f"提取第{extracted_count}帧 (时间: {timestamp:.2f}s, 帧号: {target_frame})")
|
||||
|
||||
|
||||
# 每提取一帧让步一次
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
@@ -389,48 +392,48 @@ class LegacyVideoAnalyzer:
|
||||
async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
|
||||
"""批量分析所有帧"""
|
||||
logger.info(f"开始批量分析{len(frames)}帧")
|
||||
|
||||
|
||||
if not frames:
|
||||
return "❌ 没有可分析的帧"
|
||||
|
||||
|
||||
# 构建提示词并格式化人格信息,要不然占位符的那个会爆炸
|
||||
prompt = self.batch_analysis_prompt.format(
|
||||
personality_core=self.personality_core,
|
||||
personality_side=self.personality_side
|
||||
personality_core=self.personality_core, personality_side=self.personality_side
|
||||
)
|
||||
|
||||
|
||||
if user_question:
|
||||
prompt += f"\n\n用户问题: {user_question}"
|
||||
|
||||
|
||||
# 添加帧信息到提示词
|
||||
frame_info = []
|
||||
for i, (_frame_base64, timestamp) in enumerate(frames):
|
||||
if self.enable_frame_timing:
|
||||
frame_info.append(f"第{i+1}帧 (时间: {timestamp:.2f}s)")
|
||||
frame_info.append(f"第{i + 1}帧 (时间: {timestamp:.2f}s)")
|
||||
else:
|
||||
frame_info.append(f"第{i+1}帧")
|
||||
|
||||
frame_info.append(f"第{i + 1}帧")
|
||||
|
||||
prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}"
|
||||
prompt += "\n\n请基于所有提供的帧图像进行综合分析,关注并描述视频的完整内容和故事发展。"
|
||||
|
||||
|
||||
try:
|
||||
# 尝试使用多图片分析
|
||||
response = await self._analyze_multiple_frames(frames, prompt)
|
||||
logger.info("✅ 视频识别完成")
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 视频识别失败: {e}")
|
||||
# 降级到单帧分析
|
||||
logger.warning("降级到单帧分析模式")
|
||||
try:
|
||||
frame_base64, timestamp = frames[0]
|
||||
fallback_prompt = prompt + f"\n\n注意:由于技术限制,当前仅显示第1帧 (时间: {timestamp:.2f}s),视频共有{len(frames)}帧。请基于这一帧进行分析。"
|
||||
|
||||
fallback_prompt = (
|
||||
prompt
|
||||
+ f"\n\n注意:由于技术限制,当前仅显示第1帧 (时间: {timestamp:.2f}s),视频共有{len(frames)}帧。请基于这一帧进行分析。"
|
||||
)
|
||||
|
||||
response, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=fallback_prompt,
|
||||
image_base64=frame_base64,
|
||||
image_format="jpeg"
|
||||
prompt=fallback_prompt, image_base64=frame_base64, image_format="jpeg"
|
||||
)
|
||||
logger.info("✅ 降级的单帧分析完成")
|
||||
return response
|
||||
@@ -441,22 +444,22 @@ class LegacyVideoAnalyzer:
|
||||
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
|
||||
"""使用多图片分析方法"""
|
||||
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
|
||||
|
||||
|
||||
# 导入MessageBuilder用于构建多图片消息
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.utils_model import RequestType
|
||||
|
||||
|
||||
# 构建包含多张图片的消息
|
||||
message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
|
||||
|
||||
|
||||
# 添加所有帧图像
|
||||
for _i, (frame_base64, _timestamp) in enumerate(frames):
|
||||
message_builder.add_image_content("jpeg", frame_base64)
|
||||
# logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)")
|
||||
|
||||
|
||||
message = message_builder.build()
|
||||
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
||||
|
||||
|
||||
# 获取模型信息和客户端
|
||||
model_info, api_provider, client = self.video_llm._select_model()
|
||||
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
||||
@@ -469,45 +472,43 @@ class LegacyVideoAnalyzer:
|
||||
model_info=model_info,
|
||||
message_list=[message],
|
||||
temperature=None,
|
||||
max_tokens=None
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
|
||||
return api_response.content or "❌ 未获得响应内容"
|
||||
|
||||
async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
|
||||
"""逐帧分析并汇总"""
|
||||
logger.info(f"开始逐帧分析{len(frames)}帧")
|
||||
|
||||
|
||||
frame_analyses = []
|
||||
|
||||
|
||||
for i, (frame_base64, timestamp) in enumerate(frames):
|
||||
try:
|
||||
prompt = f"请分析这个视频的第{i+1}帧"
|
||||
prompt = f"请分析这个视频的第{i + 1}帧"
|
||||
if self.enable_frame_timing:
|
||||
prompt += f" (时间: {timestamp:.2f}s)"
|
||||
prompt += "。描述你看到的内容,包括人物、动作、场景、文字等。"
|
||||
|
||||
|
||||
if user_question:
|
||||
prompt += f"\n特别关注: {user_question}"
|
||||
|
||||
|
||||
response, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=prompt,
|
||||
image_base64=frame_base64,
|
||||
image_format="jpeg"
|
||||
prompt=prompt, image_base64=frame_base64, image_format="jpeg"
|
||||
)
|
||||
|
||||
frame_analyses.append(f"第{i+1}帧 ({timestamp:.2f}s): {response}")
|
||||
logger.debug(f"✅ 第{i+1}帧分析完成")
|
||||
|
||||
|
||||
frame_analyses.append(f"第{i + 1}帧 ({timestamp:.2f}s): {response}")
|
||||
logger.debug(f"✅ 第{i + 1}帧分析完成")
|
||||
|
||||
# API调用间隔
|
||||
if i < len(frames) - 1:
|
||||
await asyncio.sleep(self.frame_analysis_delay)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 第{i+1}帧分析失败: {e}")
|
||||
frame_analyses.append(f"第{i+1}帧: 分析失败 - {e}")
|
||||
|
||||
logger.error(f"❌ 第{i + 1}帧分析失败: {e}")
|
||||
frame_analyses.append(f"第{i + 1}帧: 分析失败 - {e}")
|
||||
|
||||
# 生成汇总
|
||||
logger.info("开始生成汇总分析")
|
||||
summary_prompt = f"""基于以下各帧的分析结果,请提供一个完整的视频内容总结:
|
||||
@@ -518,15 +519,13 @@ class LegacyVideoAnalyzer:
|
||||
|
||||
if user_question:
|
||||
summary_prompt += f"\n特别回答用户的问题: {user_question}"
|
||||
|
||||
|
||||
try:
|
||||
# 使用最后一帧进行汇总分析
|
||||
if frames:
|
||||
last_frame_base64, _ = frames[-1]
|
||||
summary, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=summary_prompt,
|
||||
image_base64=last_frame_base64,
|
||||
image_format="jpeg"
|
||||
prompt=summary_prompt, image_base64=last_frame_base64, image_format="jpeg"
|
||||
)
|
||||
logger.info("✅ 逐帧分析和汇总完成")
|
||||
return summary
|
||||
@@ -541,12 +540,12 @@ class LegacyVideoAnalyzer:
|
||||
"""分析视频的主要方法"""
|
||||
try:
|
||||
logger.info(f"开始分析视频: {os.path.basename(video_path)}")
|
||||
|
||||
|
||||
# 提取帧
|
||||
frames = await self.extract_frames(video_path)
|
||||
if not frames:
|
||||
return "❌ 无法从视频中提取有效帧"
|
||||
|
||||
|
||||
# 根据模式选择分析方法
|
||||
if self.analysis_mode == "auto":
|
||||
# 智能选择:少于等于3帧用批量,否则用逐帧
|
||||
@@ -554,16 +553,16 @@ class LegacyVideoAnalyzer:
|
||||
logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)")
|
||||
else:
|
||||
mode = self.analysis_mode
|
||||
|
||||
|
||||
# 执行分析
|
||||
if mode == "batch":
|
||||
result = await self.analyze_frames_batch(frames, user_question)
|
||||
else: # sequential
|
||||
result = await self.analyze_frames_sequential(frames, user_question)
|
||||
|
||||
|
||||
logger.info("✅ 视频分析完成")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ 视频分析失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
@@ -571,16 +570,17 @@ class LegacyVideoAnalyzer:
|
||||
|
||||
def is_supported_video(self, file_path: str) -> bool:
|
||||
"""检查是否为支持的视频格式"""
|
||||
supported_formats = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.m4v', '.3gp', '.webm'}
|
||||
supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"}
|
||||
return Path(file_path).suffix.lower() in supported_formats
|
||||
|
||||
|
||||
# 全局实例
|
||||
_legacy_video_analyzer = None
|
||||
|
||||
|
||||
def get_legacy_video_analyzer() -> LegacyVideoAnalyzer:
|
||||
"""获取旧版本视频分析器实例(单例模式)"""
|
||||
global _legacy_video_analyzer
|
||||
if _legacy_video_analyzer is None:
|
||||
_legacy_video_analyzer = LegacyVideoAnalyzer()
|
||||
return _legacy_video_analyzer
|
||||
return _legacy_video_analyzer
|
||||
|
||||
Reference in New Issue
Block a user