re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
committed by Windpicker-owo
parent 00ba07e0e1
commit a79253c714
263 changed files with 3781 additions and 3189 deletions

View File

@@ -1,18 +1,19 @@
import time # 导入 time 模块以获取当前时间
import random
import re
import time # 导入 time 模块以获取当前时间
from collections.abc import Callable
from typing import Any
from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from sqlalchemy import and_, select
from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages
from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, and_
from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
logger = get_logger("chat_message_builder")
@@ -22,7 +23,7 @@ install(extra_lines=3)
def replace_user_references_sync(
content: str,
platform: str,
name_resolver: Optional[Callable[[str, str], str]] = None,
name_resolver: Callable[[str, str], str] | None = None,
replace_bot_name: bool = True,
) -> str:
"""
@@ -98,7 +99,7 @@ def replace_user_references_sync(
async def replace_user_references_async(
content: str,
platform: str,
name_resolver: Optional[Callable[[str, str], Any]] = None,
name_resolver: Callable[[str, str], Any] | None = None,
replace_bot_name: bool = True,
) -> str:
"""
@@ -171,7 +172,7 @@ async def replace_user_references_async(
async def get_raw_msg_by_timestamp(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@@ -191,7 +192,7 @@ async def get_raw_msg_by_timestamp_with_chat(
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
@@ -217,7 +218,7 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive(
limit: int = 0,
limit_mode: str = "latest",
filter_bot=False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
@@ -236,10 +237,10 @@ async def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
person_ids: list[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
@@ -260,7 +261,7 @@ async def get_actions_by_timestamp_with_chat(
timestamp_end: float = time.time(),
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
from src.common.logger import get_logger
@@ -369,7 +370,7 @@ async def get_actions_by_timestamp_with_chat(
async def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
async with get_db_session() as session:
if limit > 0:
@@ -420,7 +421,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
async def get_raw_msg_by_timestamp_random(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
"""
@@ -438,7 +439,7 @@ async def get_raw_msg_by_timestamp_random(
async def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
@@ -449,7 +450,7 @@ async def get_raw_msg_by_timestamp_with_users(
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> list[dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -460,7 +461,7 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List
async def get_raw_msg_before_timestamp_with_chat(
chat_id: str, timestamp: float, limit: int = 0
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -471,7 +472,7 @@ async def get_raw_msg_before_timestamp_with_chat(
async def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -480,9 +481,7 @@ async def get_raw_msg_before_timestamp_with_users(
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def num_new_messages_since(
chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None
) -> int:
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float | None = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -514,17 +513,16 @@ async def num_new_messages_since_with_users(
async def _build_readable_messages_internal(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_id_mapping: dict[str, str] | None = None,
pic_counter: int = 1,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
read_mark: float = 0.0,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
message_id_list: list[dict[str, Any]] | None = None,
) -> tuple[str, list[tuple[float, str, str]], dict[str, str], int]:
"""
内部辅助函数,构建可读消息字符串和原始消息详情列表。
@@ -543,7 +541,7 @@ async def _build_readable_messages_internal(
if not messages:
return "", [], pic_id_mapping or {}, pic_counter
message_details_raw: List[Tuple[float, str, str, bool]] = []
message_details_raw: list[tuple[float, str, str, bool]] = []
# 使用传入的映射字典,如果没有则创建新的
if pic_id_mapping is None:
@@ -669,7 +667,7 @@ async def _build_readable_messages_internal(
message_details_with_flags.append((timestamp, name, content, is_action))
# 应用截断逻辑 (如果 truncate 为 True)
message_details: List[Tuple[float, str, str, bool]] = []
message_details: list[tuple[float, str, str, bool]] = []
n_messages = len(message_details_with_flags)
if truncate and n_messages > 0:
for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags):
@@ -811,7 +809,7 @@ async def _build_readable_messages_internal(
)
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
"""
构建图片映射信息字符串,显示图片的具体描述内容
@@ -849,7 +847,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
return "\n".join(mapping_lines)
def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
def build_readable_actions(actions: list[dict[str, Any]]) -> str:
"""
将动作列表转换为可读的文本格式。
格式: 在()分钟前,你使用了(action_name)具体内容是action_prompt_display
@@ -924,12 +922,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
async def build_readable_messages_with_list(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
) -> tuple[str, list[tuple[float, str, str]]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
@@ -945,7 +943,7 @@ async def build_readable_messages_with_list(
async def build_readable_messages_with_id(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
@@ -953,7 +951,7 @@ async def build_readable_messages_with_id(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
) -> Tuple[str, List[Dict[str, Any]]]:
) -> tuple[str, list[dict[str, Any]]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
@@ -982,7 +980,7 @@ async def build_readable_messages_with_id(
async def build_readable_messages(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
@@ -990,7 +988,7 @@ async def build_readable_messages(
truncate: bool = False,
show_actions: bool = True,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
message_id_list: list[dict[str, Any]] | None = None,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。
@@ -1150,7 +1148,7 @@ async def build_readable_messages(
return "".join(result_parts)
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
"""
构建匿名可读消息将不同人的名称转为唯一占位符A、B、C...bot自己用SELF。
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段将bbb映射为匿名占位符。
@@ -1263,7 +1261,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]:
"""
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
记忆系统相关的映射表和工具函数
提供记忆类型、置信度、重要性等的中文标签映射

View File

@@ -3,19 +3,20 @@
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
"""
import re
import asyncio
import time
import contextvars
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal, Tuple
import re
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from rich.traceback import install
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_readable_messages
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
install(extra_lines=3)
@@ -50,11 +51,11 @@ class PromptParameters:
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_target_info: dict[str, Any] | None = None
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
target_user_info: dict[str, Any] | None = None
# 已构建的内容块
expression_habits_block: str = ""
@@ -77,12 +78,12 @@ class PromptParameters:
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
available_actions: dict[str, Any] | None = None
# 动态生成的聊天场景提示
chat_scene: str = ""
def validate(self) -> List[str]:
def validate(self) -> list[str]:
"""参数验证"""
errors = []
if not self.chat_id:
@@ -98,22 +99,22 @@ class PromptContext:
"""提示词上下文管理器"""
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._context_prompts: dict[str, dict[str, "Prompt"]] = {}
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock()
@property
def _current_context(self) -> Optional[str]:
def _current_context(self) -> str | None:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
def _current_context(self, value: str | None):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
async def async_scope(self, context_id: str | None = None):
"""创建一个异步的临时提示模板作用域"""
if context_id is not None:
try:
@@ -159,7 +160,7 @@ class PromptContext:
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
@@ -177,7 +178,7 @@ class PromptManager:
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
async def async_message_scope(self, message_id: str | None = None):
"""为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id):
yield self
@@ -240,8 +241,8 @@ class Prompt:
def __init__(
self,
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
name: str | None = None,
parameters: PromptParameters | None = None,
should_register: bool = True,
):
"""
@@ -281,7 +282,7 @@ class Prompt:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def _parse_template_args(self, template: str) -> List[str]:
def _parse_template_args(self, template: str) -> list[str]:
"""解析模板参数"""
template_args = []
processed_template = self._process_escaped_braces(template)
@@ -325,7 +326,7 @@ class Prompt:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}") from e
async def _build_context_data(self) -> Dict[str, Any]:
async def _build_context_data(self) -> dict[str, Any]:
"""构建智能上下文数据"""
# 并行执行所有构建任务
start_time = time.time()
@@ -405,7 +406,7 @@ class Prompt:
default_result = self._get_default_result_for_task(task_name)
results.append(default_result)
except Exception as e:
logger.error(f"构建任务{task_name}失败: {str(e)}")
logger.error(f"构建任务{task_name}失败: {e!s}")
default_result = self._get_default_result_for_task(task_name)
results.append(default_result)
@@ -415,7 +416,7 @@ class Prompt:
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
logger.error(f"构建任务{task_name}失败: {result!s}")
elif isinstance(result, dict):
context_data.update(result)
@@ -457,7 +458,7 @@ class Prompt:
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None:
"""构建S4U模式的聊天上下文"""
if not self.parameters.message_list_before_now_long:
return
@@ -472,7 +473,7 @@ class Prompt:
context_data["read_history_prompt"] = read_history_prompt
context_data["unread_history_prompt"] = unread_history_prompt
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None:
"""构建normal模式的聊天上下文"""
if not self.parameters.chat_talking_prompt_short:
return
@@ -481,8 +482,8 @@ class Prompt:
{self.parameters.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> Tuple[str, str]:
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> tuple[str, str]:
"""构建S4U风格的已读/未读历史消息prompt"""
try:
# 动态导入default_generator以避免循环导入
@@ -496,7 +497,7 @@ class Prompt:
except Exception as e:
logger.error(f"构建S4U历史消息prompt失败: {e}")
async def _build_expression_habits(self) -> Dict[str, Any]:
async def _build_expression_habits(self) -> dict[str, Any]:
"""构建表达习惯"""
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
if not use_expression:
@@ -537,7 +538,7 @@ class Prompt:
logger.error(f"构建表达习惯失败: {e}")
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> Dict[str, Any]:
async def _build_memory_block(self) -> dict[str, Any]:
"""构建记忆块"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
@@ -657,7 +658,7 @@ class Prompt:
logger.error(f"构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_memory_block_fast(self) -> Dict[str, Any]:
async def _build_memory_block_fast(self) -> dict[str, Any]:
"""快速构建记忆块(简化版本,用于未预构建时的后备方案)"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
@@ -681,7 +682,7 @@ class Prompt:
logger.warning(f"快速构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_relation_info(self) -> Dict[str, Any]:
async def _build_relation_info(self) -> dict[str, Any]:
"""构建关系信息"""
try:
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
@@ -690,7 +691,7 @@ class Prompt:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self) -> Dict[str, Any]:
async def _build_tool_info(self) -> dict[str, Any]:
"""构建工具信息"""
if not global_config.tool.enable_tool:
return {"tool_info_block": ""}
@@ -738,7 +739,7 @@ class Prompt:
logger.error(f"构建工具信息失败: {e}")
return {"tool_info_block": ""}
async def _build_knowledge_info(self) -> Dict[str, Any]:
async def _build_knowledge_info(self) -> dict[str, Any]:
"""构建知识信息"""
if not global_config.lpmm_knowledge.enable:
return {"knowledge_prompt": ""}
@@ -787,7 +788,7 @@ class Prompt:
logger.error(f"构建知识信息失败: {e}")
return {"knowledge_prompt": ""}
async def _build_cross_context(self) -> Dict[str, Any]:
async def _build_cross_context(self) -> dict[str, Any]:
"""构建跨群上下文"""
try:
cross_context = await Prompt.build_cross_context(
@@ -798,7 +799,7 @@ class Prompt:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
async def _format_with_context(self, context_data: dict[str, Any]) -> str:
"""使用上下文数据格式化模板"""
if self.parameters.prompt_mode == "s4u":
params = self._prepare_s4u_params(context_data)
@@ -809,7 +810,7 @@ class Prompt:
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备S4U模式的参数"""
return {
**context_data,
@@ -838,7 +839,7 @@ class Prompt:
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备Normal模式的参数"""
return {
**context_data,
@@ -866,7 +867,7 @@ class Prompt:
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备默认模式的参数"""
return {
"expression_habits_block": context_data.get("expression_habits_block", ""),
@@ -909,7 +910,7 @@ class Prompt:
result = self._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e
def __str__(self) -> str:
"""返回格式化后的结果或原始模板"""
@@ -926,7 +927,7 @@ class Prompt:
# =============================================================================
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
def parse_reply_target(target_message: str) -> tuple[str, str]:
"""
解析回复目标消息 - 统一实现
@@ -985,7 +986,7 @@ class Prompt:
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
"""
为超时的任务提供默认结果
@@ -1012,7 +1013,7 @@ class Prompt:
return {}
@staticmethod
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str:
"""
构建跨群聊上下文 - 统一实现
@@ -1075,7 +1076,7 @@ class Prompt:
# 工厂函数
def create_prompt(
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
) -> Prompt:
"""快速创建Prompt实例的工厂函数"""
if parameters is None:
@@ -1084,7 +1085,7 @@ def create_prompt(
async def create_prompt_async(
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
) -> Prompt:
"""异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs)

View File

@@ -1,11 +1,11 @@
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from typing import Any
from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save
from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
@@ -162,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
# 延迟300秒启动运行间隔300秒
super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300)
self.name_mapping: Dict[str, Tuple[str, float]] = {}
self.name_mapping: dict[str, tuple[str, float]] = {}
"""
联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间timestamp)}
注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新
@@ -182,7 +182,7 @@ class StatisticOutputTask(AsyncTask):
deploy_time = datetime(2000, 1, 1)
local_storage["deploy_time"] = now.timestamp()
self.stat_period: List[Tuple[str, timedelta, str]] = [
self.stat_period: list[tuple[str, timedelta, str]] = [
("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time"
("last_7_days", timedelta(days=7), "最近7天"),
("last_24_hours", timedelta(days=1), "最近24小时"),
@@ -193,7 +193,7 @@ class StatisticOutputTask(AsyncTask):
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
"""
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
def _statistic_console_output(self, stats: dict[str, Any], now: datetime):
"""
输出统计数据到控制台
:param stats: 统计数据
@@ -251,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
# -- 以下为统计数据收集方法 --
@staticmethod
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
async def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
"""
收集指定时间段的LLM请求统计数据
@@ -405,8 +405,8 @@ class StatisticOutputTask(AsyncTask):
@staticmethod
async def _collect_online_time_for_period(
collect_period: List[Tuple[str, datetime]], now: datetime
) -> Dict[str, Any]:
collect_period: list[tuple[str, datetime]], now: datetime
) -> dict[str, Any]:
"""
收集指定时间段的在线时间统计数据
@@ -464,7 +464,7 @@ class StatisticOutputTask(AsyncTask):
break
return stats
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
"""
收集指定时间段的消息统计数据
@@ -535,7 +535,7 @@ class StatisticOutputTask(AsyncTask):
break
return stats
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
"""
收集各时间段的统计数据
:param now: 基准当前时间
@@ -545,7 +545,7 @@ class StatisticOutputTask(AsyncTask):
if "last_full_statistics" in local_storage:
# 如果存在上次完整统计数据,则使用该数据进行增量统计
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
last_stat: dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
@@ -632,7 +632,7 @@ class StatisticOutputTask(AsyncTask):
# -- 以下为统计数据格式化方法 --
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
def _format_total_stat(stats: dict[str, Any]) -> str:
"""
格式化总统计数据
"""
@@ -648,7 +648,7 @@ class StatisticOutputTask(AsyncTask):
return "\n".join(output)
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
def _format_model_classified_stat(stats: dict[str, Any]) -> str:
"""
格式化按模型分类的统计数据
"""
@@ -674,7 +674,7 @@ class StatisticOutputTask(AsyncTask):
output.append("")
return "\n".join(output)
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
def _format_chat_stat(self, stats: dict[str, Any]) -> str:
"""
格式化聊天统计数据
"""
@@ -1019,7 +1019,7 @@ class StatisticOutputTask(AsyncTask):
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
"""生成图表数据 (异步)"""
now = datetime.now()
chart_data: Dict[str, Any] = {}
chart_data: dict[str, Any] = {}
time_ranges = [
("6h", 6, 10),
@@ -1035,16 +1035,16 @@ class StatisticOutputTask(AsyncTask):
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
start_time = now - timedelta(hours=hours)
time_points: List[datetime] = []
time_points: list[datetime] = []
current_time = start_time
while current_time <= now:
time_points.append(current_time)
current_time += timedelta(minutes=interval_minutes)
total_cost_data = [0.0] * len(time_points)
cost_by_model: Dict[str, List[float]] = {}
cost_by_module: Dict[str, List[float]] = {}
message_by_chat: Dict[str, List[int]] = {}
cost_by_model: dict[str, list[float]] = {}
cost_by_module: dict[str, list[float]] = {}
message_by_chat: dict[str, list[int]] = {}
time_labels = [t.strftime("%H:%M") for t in time_points]
interval_seconds = interval_minutes * 60

View File

@@ -1,8 +1,8 @@
import asyncio
from time import perf_counter
from collections.abc import Callable
from functools import wraps
from typing import Optional, Dict, Callable
from time import perf_counter
from rich.traceback import install
install(extra_lines=3)
@@ -75,12 +75,12 @@ class Timer:
3. 直接实例化:如果不调用 __enter__打印对象时将显示当前 perf_counter 的值
"""
__slots__ = ("name", "storage", "elapsed", "auto_unit", "start")
__slots__ = ("auto_unit", "elapsed", "name", "start", "storage")
def __init__(
self,
name: Optional[str] = None,
storage: Optional[Dict[str, float]] = None,
name: str | None = None,
storage: dict[str, float] | None = None,
auto_unit: bool = True,
do_type_check: bool = False,
):
@@ -103,7 +103,7 @@ class Timer:
if storage is not None and not isinstance(storage, dict):
raise TimerTypeError("storage", "Optional[dict]", type(storage))
def __call__(self, func: Optional[Callable] = None) -> Callable:
def __call__(self, func: Callable | None = None) -> Callable:
"""装饰器模式"""
if func is None:
return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f)

View File

@@ -2,15 +2,15 @@
错别字生成器 - 基于拼音和字频的中文错别字生成工具
"""
import orjson
import math
import os
import random
import time
import jieba
from collections import defaultdict
from pathlib import Path
import jieba
import orjson
from pypinyin import Style, pinyin
from src.common.logger import get_logger
@@ -51,7 +51,7 @@ class ChineseTypoGenerator:
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, "r", encoding="utf-8") as f:
with open(cache_file, encoding="utf-8") as f:
return orjson.loads(f.read())
# 使用内置的词频文件
@@ -59,7 +59,7 @@ class ChineseTypoGenerator:
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
# 读取jieba的词典文件
with open(dict_path, "r", encoding="utf-8") as f:
with open(dict_path, encoding="utf-8") as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
@@ -254,7 +254,7 @@ class ChineseTypoGenerator:
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
valid_words = {} # 改用字典存储词语及其频率
with open(dict_path, "r", encoding="utf-8") as f:
with open(dict_path, encoding="utf-8") as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:

View File

@@ -3,20 +3,21 @@ import random
import re
import string
import time
from collections import Counter
from typing import Any
import jieba
import numpy as np
from collections import Counter
from maim_message import UserInfo
from typing import Optional, Tuple, Dict, List, Any, Coroutine
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import Person
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from .typo_generator import ChineseTypoGenerator
logger = get_logger("chat_utils")
@@ -86,9 +87,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
if not is_mentioned:
# 判断是否被回复
if re.match(
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\)(.+?)\],说:", message.processed_plain_text
rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\)(.+?)\],说:", message.processed_plain_text
) or re.match(
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>(.+?)\],说:",
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:",
message.processed_plain_text,
):
is_mentioned = True
@@ -110,14 +111,14 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
async def get_embedding(text, request_type="embedding") -> list[float] | None:
"""获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
try:
embedding, _ = await llm.get_embedding(text)
except Exception as e:
logger.error(f"获取embedding失败: {str(e)}")
logger.error(f"获取embedding失败: {e!s}")
embedding = None
return embedding
@@ -622,7 +623,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
return time.strftime("%H:%M:%S", time.localtime(timestamp))
async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
"""
获取聊天类型(是否群聊)和私聊对象信息。
@@ -675,7 +676,6 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
if loop.is_running():
# 如果事件循环在运行,从其他线程提交并等待结果
try:
fut = asyncio.run_coroutine_threadsafe(
person_info_manager.get_value(person_id, "person_name"), loop
)
@@ -711,7 +711,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
return is_group_chat, chat_target_info
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
def assign_message_ids(messages: list[Any]) -> list[dict[str, Any]]:
"""
为消息列表中的每个消息分配唯一的简短随机ID

View File

@@ -1,29 +1,27 @@
import base64
import hashlib
import io
import os
import time
import hashlib
import uuid
import io
import numpy as np
from typing import Any
from typing import Optional, Tuple, Dict, Any
import numpy as np
from PIL import Image
from rich.traceback import install
from sqlalchemy import and_, select
from src.common.database.sqlalchemy_models import ImageDescriptions, Images, get_db_session
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import get_db_session
from sqlalchemy import select, and_
install(extra_lines=3)
logger = get_logger("chat_image")
def is_image_message(message: Dict[str, Any]) -> bool:
def is_image_message(message: dict[str, Any]) -> bool:
"""
判断消息是否为图片消息
@@ -69,7 +67,7 @@ class ImageManager:
os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
async def _get_description_from_db(image_hash: str, description_type: str) -> str | None:
"""从数据库获取图片描述
Args:
@@ -93,7 +91,7 @@ class ImageManager:
).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {e!s}")
return None
@staticmethod
@@ -136,7 +134,7 @@ class ImageManager:
await session.commit()
# 会在上下文管理器中自动调用
except Exception as e:
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {e!s}")
@staticmethod
async def get_emoji_tag(image_base64: str) -> str:
@@ -287,10 +285,10 @@ class ImageManager:
session.add(new_img)
await session.commit()
except Exception as e:
logger.error(f"保存到Images表失败: {str(e)}")
logger.error(f"保存到Images表失败: {e!s}")
except Exception as e:
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
logger.error(f"保存表情包文件或元数据失败: {e!s}")
else:
logger.debug("偷取表情包功能已关闭,跳过保存。")
@@ -300,7 +298,7 @@ class ImageManager:
return f"[表情包:{final_emotion}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
logger.error(f"获取表情包描述失败: {e!s}")
return "[表情包(处理失败)]"
async def get_image_description(self, image_base64: str) -> str:
@@ -391,11 +389,11 @@ class ImageManager:
logger.info(f"[VLM完成] 图片描述生成: {description}...")
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
logger.error(f"获取图片描述失败: {e!s}")
return "[图片(处理失败)]"
@staticmethod
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> str | None:
# sourcery skip: use-contextlib-suppress
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
@@ -512,10 +510,10 @@ class ImageManager:
logger.error("GIF转换失败: 内存不足可能是GIF太大或帧数太多")
return None # 内存不够啦
except Exception as e:
logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息
logger.error(f"GIF转换失败: {e!s}", exc_info=True) # 记录详细错误信息
return None # 其他错误也返回None
async def process_image(self, image_base64: str) -> Tuple[str, str]:
async def process_image(self, image_base64: str) -> tuple[str, str]:
# sourcery skip: hoist-if-from-if
"""处理图片并返回图片ID和描述
@@ -604,7 +602,7 @@ class ImageManager:
return image_id, f"[picid:{image_id}]"
except Exception as e:
logger.error(f"处理图片失败: {str(e)}")
logger.error(f"处理图片失败: {e!s}")
return "", "[图片]"
@@ -637,4 +635,4 @@ def image_path_to_base64(image_path: str) -> str:
if image_data := f.read():
return base64.b64encode(image_data).decode("utf-8")
else:
raise IOError(f"读取图片文件失败: {image_path}")
raise OSError(f"读取图片文件失败: {image_path}")

View File

@@ -1,35 +1,31 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""纯 inkfox 视频关键帧分析工具
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
- extract_keyframes_from_video
- get_system_info
功能:
- 关键帧提取 (base64, timestamp)
- 批量 / 逐帧 LLM 描述
- 自动模式 (<=3 帧批量,否则逐帧)
"""
视频分析器模块 - Rust优化版本
集成了Rust视频关键帧提取模块提供高性能的视频分析功能
支持SIMD优化、多线程处理和智能关键帧检测
"""
from __future__ import annotations
import os
import io
import asyncio
import base64
import tempfile
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any
import hashlib
import io
import os
import tempfile
import time
from pathlib import Path
import numpy as np
from PIL import Image
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import get_db_session, Videos
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Videos, get_db_session
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("utils_video")
# Rust模块可用性检测
@@ -205,7 +201,7 @@ class VideoAnalyzer:
hash_obj.update(video_data)
return hash_obj.hexdigest()
async def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
async def _check_video_exists(self, video_hash: str) -> Videos | None:
"""检查视频是否已经分析过"""
try:
async with get_db_session() as session:
@@ -222,8 +218,8 @@ class VideoAnalyzer:
return None
async def _store_video_result(
self, video_hash: str, description: str, metadata: Optional[Dict] = None
) -> Optional[Videos]:
self, video_hash: str, description: str, metadata: dict | None = None
) -> Videos | None:
"""存储视频分析结果到数据库"""
# 检查描述是否为错误信息,如果是则不保存
if description.startswith(""):
@@ -283,7 +279,7 @@ class VideoAnalyzer:
else:
logger.warning(f"无效的分析模式: {mode}")
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
async def extract_frames(self, video_path: str) -> list[tuple[str, float]]:
"""提取视频帧 - 智能选择最佳实现"""
# 检查是否应该使用Rust实现
if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe":
@@ -305,8 +301,8 @@ class VideoAnalyzer:
logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode}使用Python抽帧实现")
return await self._extract_frames_python_fallback(video_path)
# ---- 系统信息 ----
def _log_system(self) -> None:
async def _extract_frames_rust_advanced(self, video_path: str) -> list[tuple[str, float]]:
"""使用 Rust 高级接口的帧提取"""
try:
info = video.get_system_info() # type: ignore[attr-defined]
logger.info(
@@ -329,25 +325,174 @@ class VideoAnalyzer:
threads=self.threads,
verbose=False,
)
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
total_ms = getattr(result, "total_time_ms", 0)
frames: List[Tuple[str, float]] = []
for i, f in enumerate(files):
img = Image.open(f).convert("RGB")
if max(img.size) > self.max_image_size:
scale = self.max_image_size / max(img.size)
img = img.resize((int(img.width * scale), int(img.height * scale)), Image.Resampling.LANCZOS)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=self.frame_quality)
b64 = base64.b64encode(buf.getvalue()).decode()
ts = (i / max(1, len(files) - 1)) * (total_ms / 1000.0) if total_ms else float(i)
frames.append((b64, ts))
logger.info(f"检测到 {len(keyframe_indices)} 个关键帧")
# 3. 转换选定的关键帧为 base64
frames = []
frame_count = 0
for idx in keyframe_indices[: self.max_frames]:
if idx < len(frames_data):
try:
frame = frames_data[idx]
frame_data = frame.get_data()
# 将灰度数据转换为PIL图像
frame_array = np.frombuffer(frame_data, dtype=np.uint8).reshape((frame.height, frame.width))
pil_image = Image.fromarray(
frame_array,
mode="L", # 灰度模式
)
# 转换为RGB模式以便保存为JPEG
pil_image = pil_image.convert("RGB")
# 调整图像大小
if max(pil_image.size) > self.max_image_size:
ratio = self.max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为 base64
buffer = io.BytesIO()
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# 估算时间戳
estimated_timestamp = frame.frame_number * (1.0 / 30.0) # 假设30fps
frames.append((frame_base64, estimated_timestamp))
frame_count += 1
logger.debug(
f"处理关键帧 {frame_count}: 帧号 {frame.frame_number}, 时间 {estimated_timestamp:.2f}s"
)
except Exception as e:
logger.error(f"处理关键帧 {idx} 失败: {e}")
continue
logger.info(f"✅ Rust 高级提取完成: {len(frames)} 关键帧")
return frames
# ---- 批量分析 ----
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
except Exception as e:
logger.error(f"❌ Rust 高级帧提取失败: {e}")
# 回退到基础方法
logger.info("回退到基础 Rust 方法")
return await self._extract_frames_rust(video_path)
async def _extract_frames_rust(self, video_path: str) -> list[tuple[str, float]]:
"""使用 Rust 实现的帧提取"""
try:
logger.info("🔄 使用 Rust 模块提取关键帧...")
# 创建临时输出目录
with tempfile.TemporaryDirectory() as temp_dir:
# 使用便捷函数进行关键帧提取,使用配置参数
result = rust_video.extract_keyframes_from_video(
video_path=video_path,
output_dir=temp_dir,
threshold=self.rust_keyframe_threshold,
max_frames=self.max_frames * 2, # 提取更多帧以便筛选
max_save=self.max_frames,
ffmpeg_path=self.ffmpeg_path,
use_simd=self.rust_use_simd,
threads=self.rust_threads,
verbose=False, # 使用固定值,不需要配置
)
logger.info(
f"Rust 处理完成: 总帧数 {result.total_frames}, 关键帧 {result.keyframes_extracted}, 处理速度 {result.processing_fps:.1f} FPS"
)
# 转换保存的关键帧为 base64 格式
frames = []
temp_dir_path = Path(temp_dir)
# 获取所有保存的关键帧文件
keyframe_files = sorted(temp_dir_path.glob("keyframe_*.jpg"))
for i, keyframe_file in enumerate(keyframe_files):
if len(frames) >= self.max_frames:
break
try:
# 读取关键帧文件
with open(keyframe_file, "rb") as f:
image_data = f.read()
# 转换为 PIL 图像并压缩
pil_image = Image.open(io.BytesIO(image_data))
# 调整图像大小
if max(pil_image.size) > self.max_image_size:
ratio = self.max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为 base64
buffer = io.BytesIO()
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# 估算时间戳(基于帧索引和总时长)
if result.total_frames > 0:
# 假设关键帧在时间上均匀分布
estimated_timestamp = (i * result.total_time_ms / 1000.0) / result.keyframes_extracted
else:
estimated_timestamp = i * 1.0 # 默认每秒一帧
frames.append((frame_base64, estimated_timestamp))
logger.debug(f"处理关键帧 {i + 1}: 估算时间 {estimated_timestamp:.2f}s")
except Exception as e:
logger.error(f"处理关键帧 {keyframe_file.name} 失败: {e}")
continue
logger.info(f"✅ Rust 提取完成: {len(frames)} 关键帧")
return frames
except Exception as e:
logger.error(f"❌ Rust 帧提取失败: {e}")
raise e
async def _extract_frames_python_fallback(self, video_path: str) -> list[tuple[str, float]]:
"""Python降级抽帧实现 - 支持多种抽帧模式"""
try:
# 导入旧版本分析器
from .utils_video_legacy import get_legacy_video_analyzer
logger.info("🔄 使用Python降级抽帧实现...")
legacy_analyzer = get_legacy_video_analyzer()
# 同步配置参数
legacy_analyzer.max_frames = self.max_frames
legacy_analyzer.frame_quality = self.frame_quality
legacy_analyzer.max_image_size = self.max_image_size
legacy_analyzer.frame_extraction_mode = self.frame_extraction_mode
legacy_analyzer.frame_interval_seconds = self.frame_interval_seconds
legacy_analyzer.use_multiprocessing = self.use_multiprocessing
# 使用旧版本的抽帧功能
frames = await legacy_analyzer.extract_frames(video_path)
logger.info(f"✅ Python降级抽帧完成: {len(frames)}")
return frames
except Exception as e:
logger.error(f"❌ Python降级抽帧失败: {e}")
return []
async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
"""批量分析所有帧"""
logger.info(f"开始批量分析{len(frames)}")
if not frames:
return "❌ 没有可分析的帧"
# 构建提示词并格式化人格信息,要不然占位符的那个会爆炸
prompt = self.batch_analysis_prompt.format(
personality_core=self.personality_core, personality_side=self.personality_side
)
@@ -376,7 +521,7 @@ class VideoAnalyzer:
logger.error(f"❌ 视频识别失败: {e}")
raise e
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str:
"""使用多图片分析方法"""
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
@@ -412,53 +557,75 @@ class VideoAnalyzer:
temperature=None,
max_tokens=None,
)
return resp.content or "❌ 未获得响应"
# ---- 逐帧分析 ----
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
results: List[str] = []
for i, (b64, ts) in enumerate(frames):
prompt = f"分析第{i+1}" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
if question:
prompt += f"\n关注: {question}"
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
return api_response.content or "❌ 未获得响应内容"
async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
"""逐帧分析并汇总"""
logger.info(f"开始逐帧分析{len(frames)}")
frame_analyses = []
for i, (frame_base64, timestamp) in enumerate(frames):
try:
text, _ = await self.video_llm.generate_response_for_image(
prompt=prompt, image_base64=b64, image_format="jpeg"
)
results.append(f"{i+1}帧: {text}")
except Exception as e: # pragma: no cover
results.append(f"{i+1}帧: 失败 {e}")
if i < len(frames) - 1:
await asyncio.sleep(self.frame_analysis_delay)
summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results)
try:
final, _ = await self.video_llm.generate_response_for_image(
prompt=summary_prompt, image_base64=frames[-1][0], image_format="jpeg"
)
return final
except Exception: # pragma: no cover
return "\n".join(results)
logger.info("✅ 逐帧分析和汇总完成")
return summary
else:
return "❌ 没有可用于汇总的帧"
except Exception as e:
logger.error(f"❌ 汇总分析失败: {e}")
# 如果汇总失败,返回各帧分析结果
return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}"
# ---- 主入口 ----
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
if not os.path.exists(video_path):
return False, "❌ 文件不存在"
frames = await self.extract_keyframes(video_path)
if not frames:
return False, "❌ 未提取到关键帧"
mode = self.analysis_mode
if mode == "auto":
mode = "batch" if len(frames) <= 20 else "sequential"
text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question))
return True, text
async def analyze_video(self, video_path: str, user_question: str = None) -> tuple[bool, str]:
"""分析视频的主要方法
Returns:
Tuple[bool, str]: (是否成功, 分析结果或错误信息)
"""
if self.disabled:
error_msg = "❌ 视频分析功能已禁用:没有可用的视频处理实现"
logger.warning(error_msg)
return (False, error_msg)
try:
logger.info(f"开始分析视频: {os.path.basename(video_path)}")
# 提取帧
frames = await self.extract_frames(video_path)
if not frames:
error_msg = "❌ 无法从视频中提取有效帧"
return (False, error_msg)
# 根据模式选择分析方法
if self.analysis_mode == "auto":
# 智能选择少于等于3帧用批量否则用逐帧
mode = "batch" if len(frames) <= 3 else "sequential"
logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)")
else:
mode = self.analysis_mode
# 执行分析
if mode == "batch":
result = await self.analyze_frames_batch(frames, user_question)
else: # sequential
result = await self.analyze_frames_sequential(frames, user_question)
logger.info("✅ 视频分析完成")
return (True, result)
except Exception as e:
error_msg = f"❌ 视频分析失败: {e!s}"
logger.error(error_msg)
return (False, error_msg)
async def analyze_video_from_bytes(
self,
video_bytes: bytes,
filename: Optional[str] = None,
prompt: Optional[str] = None,
question: Optional[str] = None,
) -> Dict[str, str]:
self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None
) -> dict[str, str]:
"""从字节数据分析视频
Args:
@@ -568,34 +735,81 @@ class VideoAnalyzer:
return {"summary": result}
except Exception as e:
error_msg = f"❌ 从字节数据分析视频失败: {str(e)}"
error_msg = f"❌ 从字节数据分析视频失败: {e!s}"
logger.error(error_msg)
async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None:
# 不保存错误信息到数据库,允许后续重试
logger.info("💡 错误信息不保存到数据库,允许后续重试")
# 处理失败,通知等待者并清理资源
try:
if video_hash and video_event:
async with video_lock_manager:
if video_hash in video_events:
video_events[video_hash].set()
video_locks.pop(video_hash, None)
video_events.pop(video_hash, None)
except Exception as cleanup_e:
logger.error(f"❌ 清理锁资源失败: {cleanup_e}")
return {"summary": error_msg}
def is_supported_video(self, file_path: str) -> bool:
"""检查是否为支持的视频格式"""
supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"}
return Path(file_path).suffix.lower() in supported_formats
def get_processing_capabilities(self) -> dict[str, any]:
"""获取处理能力信息"""
if not RUST_VIDEO_AVAILABLE:
return {"error": "Rust视频处理模块不可用", "available": False, "reason": "rust_video模块未安装或加载失败"}
try:
async with get_db_session() as session: # type: ignore
stmt = insert(Videos).values( # type: ignore
video_id="",
video_hash=video_hash,
description=summary,
count=1,
timestamp=time.time(),
vlm_processed=True,
duration=None,
frame_count=None,
fps=None,
resolution=None,
file_size=file_size,
)
try:
await session.execute(stmt)
await session.commit()
logger.debug(f"视频缓存写入 success hash={video_hash}")
except sa_exc.IntegrityError: # 可能并发已写入
await session.rollback()
logger.debug(f"视频缓存已存在 hash={video_hash}")
except Exception: # pragma: no cover
logger.debug("视频缓存写入失败")
system_info = rust_video.get_system_info()
# 创建一个临时的extractor来获取CPU特性
extractor = rust_video.VideoKeyframeExtractor(threads=0, verbose=False)
cpu_features = extractor.get_cpu_features()
capabilities = {
"system": {
"threads": system_info.get("threads", 0),
"rust_version": system_info.get("version", "unknown"),
},
"cpu_features": cpu_features,
"recommended_settings": self._get_recommended_settings(cpu_features),
"analysis_modes": ["auto", "batch", "sequential"],
"supported_formats": [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"],
"available": True,
}
return capabilities
except Exception as e:
logger.error(f"获取处理能力信息失败: {e}")
return {"error": str(e), "available": False}
def _get_recommended_settings(self, cpu_features: dict[str, bool]) -> dict[str, any]:
"""根据CPU特性推荐最佳设置"""
settings = {
"use_simd": any(cpu_features.values()),
"block_size": 8192,
"threads": 0, # 自动检测
}
# 根据CPU特性调整设置
if cpu_features.get("avx2", False):
settings["block_size"] = 16384 # AVX2支持更大的块
settings["optimization_level"] = "avx2"
elif cpu_features.get("sse2", False):
settings["block_size"] = 8192
settings["optimization_level"] = "sse2"
else:
settings["use_simd"] = False
settings["block_size"] = 4096
settings["optimization_level"] = "scalar"
return settings
# ---- 外部接口 ----
@@ -613,7 +827,14 @@ def is_video_analysis_available() -> bool:
return True
def get_video_analysis_status() -> Dict[str, Any]:
def get_video_analysis_status() -> dict[str, any]:
"""获取视频分析功能的详细状态信息
Returns:
Dict[str, any]: 包含功能状态信息的字典
"""
# 检查OpenCV是否可用
opencv_available = False
try:
info = video.get_system_info() # type: ignore[attr-defined]
except Exception as e: # pragma: no cover

View File

@@ -1,25 +1,25 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
视频分析器模块 - 旧版本兼容模块
支持多种分析模式:批处理、逐帧、自动选择
包含Python原生的抽帧功能作为Rust模块的降级方案
"""
import os
import cv2
import asyncio
import base64
import io
import os
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any
import cv2
import numpy as np
from PIL import Image
from pathlib import Path
from typing import List, Tuple, Optional, Any
import io
from concurrent.futures import ThreadPoolExecutor
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("utils_video_legacy")
@@ -30,7 +30,7 @@ def _extract_frames_worker(
frame_quality: int,
max_image_size: int,
frame_extraction_mode: str,
frame_interval_seconds: Optional[float],
frame_interval_seconds: float | None,
) -> list[Any] | list[tuple[str, str]]:
"""线程池中提取视频帧的工作函数"""
frames = []
@@ -221,7 +221,7 @@ class LegacyVideoAnalyzer:
f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}"
)
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
async def extract_frames(self, video_path: str) -> list[tuple[str, float]]:
"""提取视频帧 - 支持多进程和单线程模式"""
# 先获取视频信息
cap = cv2.VideoCapture(video_path)
@@ -247,7 +247,7 @@ class LegacyVideoAnalyzer:
else:
return await self._extract_frames_fallback(video_path)
async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]:
async def _extract_frames_multiprocess(self, video_path: str) -> list[tuple[str, float]]:
"""线程池版本的帧提取"""
loop = asyncio.get_event_loop()
@@ -282,7 +282,7 @@ class LegacyVideoAnalyzer:
logger.info("🔄 降级到单线程模式...")
return await self._extract_frames_fallback(video_path)
async def _extract_frames_fallback(self, video_path: str) -> List[Tuple[str, float]]:
async def _extract_frames_fallback(self, video_path: str) -> list[tuple[str, float]]:
"""帧提取的降级方法 - 原始异步版本"""
frames = []
extracted_count = 0
@@ -389,7 +389,7 @@ class LegacyVideoAnalyzer:
logger.info(f"✅ 成功提取{len(frames)}")
return frames
async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
"""批量分析所有帧"""
logger.info(f"开始批量分析{len(frames)}")
@@ -441,7 +441,7 @@ class LegacyVideoAnalyzer:
logger.error(f"❌ 降级分析也失败: {fallback_e}")
raise
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str:
"""使用多图片分析方法"""
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
@@ -481,7 +481,7 @@ class LegacyVideoAnalyzer:
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
return api_response.content or "❌ 未获得响应内容"
async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
"""逐帧分析并汇总"""
logger.info(f"开始逐帧分析{len(frames)}")
@@ -567,7 +567,7 @@ class LegacyVideoAnalyzer:
return result
except Exception as e:
error_msg = f"❌ 视频分析失败: {str(e)}"
error_msg = f"❌ 视频分析失败: {e!s}"
logger.error(error_msg)
return error_msg

View File

@@ -1,8 +1,8 @@
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from rich.traceback import install
from src.common.logger import get_logger
from rich.traceback import install
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
install(extra_lines=3)
@@ -25,5 +25,5 @@ async def get_voice_text(voice_base64: str) -> str:
return f"[语音:{text}]"
except Exception as e:
logger.error(f"语音转文字失败: {str(e)}")
logger.error(f"语音转文字失败: {e!s}")
return "[语音]"