re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -1,18 +1,19 @@
|
||||
import time # 导入 time 模块以获取当前时间
|
||||
import random
|
||||
import re
|
||||
import time # 导入 time 模块以获取当前时间
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select, and_
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
||||
@@ -22,7 +23,7 @@ install(extra_lines=3)
|
||||
def replace_user_references_sync(
|
||||
content: str,
|
||||
platform: str,
|
||||
name_resolver: Optional[Callable[[str, str], str]] = None,
|
||||
name_resolver: Callable[[str, str], str] | None = None,
|
||||
replace_bot_name: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -98,7 +99,7 @@ def replace_user_references_sync(
|
||||
async def replace_user_references_async(
|
||||
content: str,
|
||||
platform: str,
|
||||
name_resolver: Optional[Callable[[str, str], Any]] = None,
|
||||
name_resolver: Callable[[str, str], Any] | None = None,
|
||||
replace_bot_name: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -171,7 +172,7 @@ async def replace_user_references_async(
|
||||
|
||||
async def get_raw_msg_by_timestamp(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -191,7 +192,7 @@ async def get_raw_msg_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -217,7 +218,7 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -236,10 +237,10 @@ async def get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
person_ids: List[str],
|
||||
person_ids: list[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -260,7 +261,7 @@ async def get_actions_by_timestamp_with_chat(
|
||||
timestamp_end: float = time.time(),
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -369,7 +370,7 @@ async def get_actions_by_timestamp_with_chat(
|
||||
|
||||
async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
@@ -420,7 +421,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
|
||||
async def get_raw_msg_by_timestamp_random(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
||||
"""
|
||||
@@ -438,7 +439,7 @@ async def get_raw_msg_by_timestamp_random(
|
||||
|
||||
async def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -449,7 +450,7 @@ async def get_raw_msg_by_timestamp_with_users(
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -460,7 +461,7 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id: str, timestamp: float, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -471,7 +472,7 @@ async def get_raw_msg_before_timestamp_with_chat(
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -480,9 +481,7 @@ async def get_raw_msg_before_timestamp_with_users(
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def num_new_messages_since(
|
||||
chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None
|
||||
) -> int:
|
||||
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float | None = None) -> int:
|
||||
"""
|
||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||
@@ -514,17 +513,16 @@ async def num_new_messages_since_with_users(
|
||||
|
||||
|
||||
async def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||
pic_id_mapping: dict[str, str] | None = None,
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
message_id_list: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[str, list[tuple[float, str, str]], dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
|
||||
@@ -543,7 +541,7 @@ async def _build_readable_messages_internal(
|
||||
if not messages:
|
||||
return "", [], pic_id_mapping or {}, pic_counter
|
||||
|
||||
message_details_raw: List[Tuple[float, str, str, bool]] = []
|
||||
message_details_raw: list[tuple[float, str, str, bool]] = []
|
||||
|
||||
# 使用传入的映射字典,如果没有则创建新的
|
||||
if pic_id_mapping is None:
|
||||
@@ -669,7 +667,7 @@ async def _build_readable_messages_internal(
|
||||
message_details_with_flags.append((timestamp, name, content, is_action))
|
||||
|
||||
# 应用截断逻辑 (如果 truncate 为 True)
|
||||
message_details: List[Tuple[float, str, str, bool]] = []
|
||||
message_details: list[tuple[float, str, str, bool]] = []
|
||||
n_messages = len(message_details_with_flags)
|
||||
if truncate and n_messages > 0:
|
||||
for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags):
|
||||
@@ -811,7 +809,7 @@ async def _build_readable_messages_internal(
|
||||
)
|
||||
|
||||
|
||||
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
@@ -849,7 +847,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
return "\n".join(mapping_lines)
|
||||
|
||||
|
||||
def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
def build_readable_actions(actions: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
将动作列表转换为可读的文本格式。
|
||||
格式: 在()分钟前,你使用了(action_name),具体内容是:(action_prompt_display)
|
||||
@@ -924,12 +922,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
|
||||
|
||||
async def build_readable_messages_with_list(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
) -> tuple[str, list[tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
@@ -945,7 +943,7 @@ async def build_readable_messages_with_list(
|
||||
|
||||
|
||||
async def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
@@ -953,7 +951,7 @@ async def build_readable_messages_with_id(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
@@ -982,7 +980,7 @@ async def build_readable_messages_with_id(
|
||||
|
||||
|
||||
async def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
@@ -990,7 +988,7 @@ async def build_readable_messages(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = True,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
message_id_list: list[dict[str, Any]] | None = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
@@ -1150,7 +1148,7 @@ async def build_readable_messages(
|
||||
return "".join(result_parts)
|
||||
|
||||
|
||||
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。
|
||||
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段,将bbb映射为匿名占位符。
|
||||
@@ -1263,7 +1261,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
return formatted_string
|
||||
|
||||
|
||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统相关的映射表和工具函数
|
||||
提供记忆类型、置信度、重要性等的中文标签映射
|
||||
|
||||
@@ -3,19 +3,20 @@
|
||||
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
|
||||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
import contextvars
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Literal, Tuple
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -50,11 +51,11 @@ class PromptParameters:
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
@@ -77,12 +78,12 @@ class PromptParameters:
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
@@ -98,22 +99,22 @@ class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._context_prompts: dict[str, dict[str, "Prompt"]] = {}
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
def _current_context(self) -> str | None:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
def _current_context(self, value: str | None):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
async def async_scope(self, context_id: str | None = None):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
if context_id is not None:
|
||||
try:
|
||||
@@ -159,7 +160,7 @@ class PromptContext:
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
@@ -177,7 +178,7 @@ class PromptManager:
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
async def async_message_scope(self, message_id: str | None = None):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
@@ -240,8 +241,8 @@ class Prompt:
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
name: str | None = None,
|
||||
parameters: PromptParameters | None = None,
|
||||
should_register: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -281,7 +282,7 @@ class Prompt:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def _parse_template_args(self, template: str) -> List[str]:
|
||||
def _parse_template_args(self, template: str) -> list[str]:
|
||||
"""解析模板参数"""
|
||||
template_args = []
|
||||
processed_template = self._process_escaped_braces(template)
|
||||
@@ -325,7 +326,7 @@ class Prompt:
|
||||
logger.error(f"构建Prompt失败: {e}")
|
||||
raise RuntimeError(f"构建Prompt失败: {e}") from e
|
||||
|
||||
async def _build_context_data(self) -> Dict[str, Any]:
|
||||
async def _build_context_data(self) -> dict[str, Any]:
|
||||
"""构建智能上下文数据"""
|
||||
# 并行执行所有构建任务
|
||||
start_time = time.time()
|
||||
@@ -405,7 +406,7 @@ class Prompt:
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
except Exception as e:
|
||||
logger.error(f"构建任务{task_name}失败: {str(e)}")
|
||||
logger.error(f"构建任务{task_name}失败: {e!s}")
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
|
||||
@@ -415,7 +416,7 @@ class Prompt:
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"构建任务{task_name}失败: {str(result)}")
|
||||
logger.error(f"构建任务{task_name}失败: {result!s}")
|
||||
elif isinstance(result, dict):
|
||||
context_data.update(result)
|
||||
|
||||
@@ -457,7 +458,7 @@ class Prompt:
|
||||
|
||||
return context_data
|
||||
|
||||
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None:
|
||||
"""构建S4U模式的聊天上下文"""
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
@@ -472,7 +473,7 @@ class Prompt:
|
||||
context_data["read_history_prompt"] = read_history_prompt
|
||||
context_data["unread_history_prompt"] = unread_history_prompt
|
||||
|
||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None:
|
||||
"""构建normal模式的聊天上下文"""
|
||||
if not self.parameters.chat_talking_prompt_short:
|
||||
return
|
||||
@@ -481,8 +482,8 @@ class Prompt:
|
||||
{self.parameters.chat_talking_prompt_short}"""
|
||||
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""构建S4U风格的已读/未读历史消息prompt"""
|
||||
try:
|
||||
# 动态导入default_generator以避免循环导入
|
||||
@@ -496,7 +497,7 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
|
||||
async def _build_expression_habits(self) -> Dict[str, Any]:
|
||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
|
||||
if not use_expression:
|
||||
@@ -537,7 +538,7 @@ class Prompt:
|
||||
logger.error(f"构建表达习惯失败: {e}")
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
async def _build_memory_block(self) -> Dict[str, Any]:
|
||||
async def _build_memory_block(self) -> dict[str, Any]:
|
||||
"""构建记忆块"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
@@ -657,7 +658,7 @@ class Prompt:
|
||||
logger.error(f"构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_memory_block_fast(self) -> Dict[str, Any]:
|
||||
async def _build_memory_block_fast(self) -> dict[str, Any]:
|
||||
"""快速构建记忆块(简化版本,用于未预构建时的后备方案)"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
@@ -681,7 +682,7 @@ class Prompt:
|
||||
logger.warning(f"快速构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
async def _build_relation_info(self) -> dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
@@ -690,7 +691,7 @@ class Prompt:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return {"relation_info_block": ""}
|
||||
|
||||
async def _build_tool_info(self) -> Dict[str, Any]:
|
||||
async def _build_tool_info(self) -> dict[str, Any]:
|
||||
"""构建工具信息"""
|
||||
if not global_config.tool.enable_tool:
|
||||
return {"tool_info_block": ""}
|
||||
@@ -738,7 +739,7 @@ class Prompt:
|
||||
logger.error(f"构建工具信息失败: {e}")
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
async def _build_knowledge_info(self) -> Dict[str, Any]:
|
||||
async def _build_knowledge_info(self) -> dict[str, Any]:
|
||||
"""构建知识信息"""
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
return {"knowledge_prompt": ""}
|
||||
@@ -787,7 +788,7 @@ class Prompt:
|
||||
logger.error(f"构建知识信息失败: {e}")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
async def _build_cross_context(self) -> dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
cross_context = await Prompt.build_cross_context(
|
||||
@@ -798,7 +799,7 @@ class Prompt:
|
||||
logger.error(f"构建跨群上下文失败: {e}")
|
||||
return {"cross_context_block": ""}
|
||||
|
||||
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
|
||||
async def _format_with_context(self, context_data: dict[str, Any]) -> str:
|
||||
"""使用上下文数据格式化模板"""
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
params = self._prepare_s4u_params(context_data)
|
||||
@@ -809,7 +810,7 @@ class Prompt:
|
||||
|
||||
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
|
||||
|
||||
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备S4U模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
@@ -838,7 +839,7 @@ class Prompt:
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备Normal模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
@@ -866,7 +867,7 @@ class Prompt:
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备默认模式的参数"""
|
||||
return {
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
@@ -909,7 +910,7 @@ class Prompt:
|
||||
result = self._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回格式化后的结果或原始模板"""
|
||||
@@ -926,7 +927,7 @@ class Prompt:
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
def parse_reply_target(target_message: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
@@ -985,7 +986,7 @@ class Prompt:
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
|
||||
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
为超时的任务提供默认结果
|
||||
|
||||
@@ -1012,7 +1013,7 @@ class Prompt:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
@@ -1075,7 +1076,7 @@ class Prompt:
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""快速创建Prompt实例的工厂函数"""
|
||||
if parameters is None:
|
||||
@@ -1084,7 +1085,7 @@ def create_prompt(
|
||||
|
||||
|
||||
async def create_prompt_async(
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save
|
||||
from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
@@ -162,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 延迟300秒启动,运行间隔300秒
|
||||
super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300)
|
||||
|
||||
self.name_mapping: Dict[str, Tuple[str, float]] = {}
|
||||
self.name_mapping: dict[str, tuple[str, float]] = {}
|
||||
"""
|
||||
联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间(timestamp))}
|
||||
注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新
|
||||
@@ -182,7 +182,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
deploy_time = datetime(2000, 1, 1)
|
||||
local_storage["deploy_time"] = now.timestamp()
|
||||
|
||||
self.stat_period: List[Tuple[str, timedelta, str]] = [
|
||||
self.stat_period: list[tuple[str, timedelta, str]] = [
|
||||
("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time"
|
||||
("last_7_days", timedelta(days=7), "最近7天"),
|
||||
("last_24_hours", timedelta(days=1), "最近24小时"),
|
||||
@@ -193,7 +193,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
|
||||
"""
|
||||
|
||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
||||
def _statistic_console_output(self, stats: dict[str, Any], now: datetime):
|
||||
"""
|
||||
输出统计数据到控制台
|
||||
:param stats: 统计数据
|
||||
@@ -251,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# -- 以下为统计数据收集方法 --
|
||||
|
||||
@staticmethod
|
||||
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
async def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的LLM请求统计数据
|
||||
|
||||
@@ -405,8 +405,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
async def _collect_online_time_for_period(
|
||||
collect_period: List[Tuple[str, datetime]], now: datetime
|
||||
) -> Dict[str, Any]:
|
||||
collect_period: list[tuple[str, datetime]], now: datetime
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的在线时间统计数据
|
||||
|
||||
@@ -464,7 +464,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的消息统计数据
|
||||
|
||||
@@ -535,7 +535,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
:param now: 基准当前时间
|
||||
@@ -545,7 +545,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if "last_full_statistics" in local_storage:
|
||||
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
||||
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
last_stat: dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
|
||||
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
|
||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||
@@ -632,7 +632,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# -- 以下为统计数据格式化方法 --
|
||||
|
||||
@staticmethod
|
||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
||||
def _format_total_stat(stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化总统计数据
|
||||
"""
|
||||
@@ -648,7 +648,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
return "\n".join(output)
|
||||
|
||||
@staticmethod
|
||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
def _format_model_classified_stat(stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化按模型分类的统计数据
|
||||
"""
|
||||
@@ -674,7 +674,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
|
||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||
def _format_chat_stat(self, stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化聊天统计数据
|
||||
"""
|
||||
@@ -1019,7 +1019,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
"""生成图表数据 (异步)"""
|
||||
now = datetime.now()
|
||||
chart_data: Dict[str, Any] = {}
|
||||
chart_data: dict[str, Any] = {}
|
||||
|
||||
time_ranges = [
|
||||
("6h", 6, 10),
|
||||
@@ -1035,16 +1035,16 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||
start_time = now - timedelta(hours=hours)
|
||||
time_points: List[datetime] = []
|
||||
time_points: list[datetime] = []
|
||||
current_time = start_time
|
||||
while current_time <= now:
|
||||
time_points.append(current_time)
|
||||
current_time += timedelta(minutes=interval_minutes)
|
||||
|
||||
total_cost_data = [0.0] * len(time_points)
|
||||
cost_by_model: Dict[str, List[float]] = {}
|
||||
cost_by_module: Dict[str, List[float]] = {}
|
||||
message_by_chat: Dict[str, List[int]] = {}
|
||||
cost_by_model: dict[str, list[float]] = {}
|
||||
cost_by_module: dict[str, list[float]] = {}
|
||||
message_by_chat: dict[str, list[int]] = {}
|
||||
time_labels = [t.strftime("%H:%M") for t in time_points]
|
||||
interval_seconds = interval_minutes * 60
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
|
||||
from time import perf_counter
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Callable
|
||||
from time import perf_counter
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -75,12 +75,12 @@ class Timer:
|
||||
3. 直接实例化:如果不调用 __enter__,打印对象时将显示当前 perf_counter 的值
|
||||
"""
|
||||
|
||||
__slots__ = ("name", "storage", "elapsed", "auto_unit", "start")
|
||||
__slots__ = ("auto_unit", "elapsed", "name", "start", "storage")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
storage: Optional[Dict[str, float]] = None,
|
||||
name: str | None = None,
|
||||
storage: dict[str, float] | None = None,
|
||||
auto_unit: bool = True,
|
||||
do_type_check: bool = False,
|
||||
):
|
||||
@@ -103,7 +103,7 @@ class Timer:
|
||||
if storage is not None and not isinstance(storage, dict):
|
||||
raise TimerTypeError("storage", "Optional[dict]", type(storage))
|
||||
|
||||
def __call__(self, func: Optional[Callable] = None) -> Callable:
|
||||
def __call__(self, func: Callable | None = None) -> Callable:
|
||||
"""装饰器模式"""
|
||||
if func is None:
|
||||
return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f)
|
||||
|
||||
@@ -2,15 +2,15 @@
|
||||
错别字生成器 - 基于拼音和字频的中文错别字生成工具
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import jieba
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import jieba
|
||||
import orjson
|
||||
from pypinyin import Style, pinyin
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -51,7 +51,7 @@ class ChineseTypoGenerator:
|
||||
|
||||
# 如果缓存文件存在,直接加载
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
|
||||
# 使用内置的词频文件
|
||||
@@ -59,7 +59,7 @@ class ChineseTypoGenerator:
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
|
||||
# 读取jieba的词典文件
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
with open(dict_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
word, freq = line.strip().split()[:2]
|
||||
# 对词中的每个字进行频率累加
|
||||
@@ -254,7 +254,7 @@ class ChineseTypoGenerator:
|
||||
# 获取jieba词典和词频信息
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
valid_words = {} # 改用字典存储词语及其频率
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
with open(dict_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
|
||||
@@ -3,20 +3,21 @@ import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from collections import Counter
|
||||
from typing import Any
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from maim_message import UserInfo
|
||||
from typing import Optional, Tuple, Dict, List, Any, Coroutine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import Person
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
logger = get_logger("chat_utils")
|
||||
@@ -86,9 +87,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
if not is_mentioned:
|
||||
# 判断是否被回复
|
||||
if re.match(
|
||||
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\):(.+?)\],说:", message.processed_plain_text
|
||||
rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\):(.+?)\],说:", message.processed_plain_text
|
||||
) or re.match(
|
||||
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>:(.+?)\],说:",
|
||||
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>:(.+?)\],说:",
|
||||
message.processed_plain_text,
|
||||
):
|
||||
is_mentioned = True
|
||||
@@ -110,14 +111,14 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
return is_mentioned, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
async def get_embedding(text, request_type="embedding") -> list[float] | None:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
try:
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
logger.error(f"获取embedding失败: {e!s}")
|
||||
embedding = None
|
||||
return embedding
|
||||
|
||||
@@ -622,7 +623,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
|
||||
"""
|
||||
获取聊天类型(是否群聊)和私聊对象信息。
|
||||
|
||||
@@ -675,7 +676,6 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
|
||||
if loop.is_running():
|
||||
# 如果事件循环在运行,从其他线程提交并等待结果
|
||||
try:
|
||||
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
person_info_manager.get_value(person_id, "person_name"), loop
|
||||
)
|
||||
@@ -711,7 +711,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
|
||||
return is_group_chat, chat_target_info
|
||||
|
||||
|
||||
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
def assign_message_ids(messages: list[Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID
|
||||
|
||||
|
||||
@@ -1,29 +1,27 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
import uuid
|
||||
import io
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.common.database.sqlalchemy_models import ImageDescriptions, Images, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
|
||||
|
||||
def is_image_message(message: Dict[str, Any]) -> bool:
|
||||
def is_image_message(message: dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断消息是否为图片消息
|
||||
|
||||
@@ -69,7 +67,7 @@ class ImageManager:
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
async def _get_description_from_db(image_hash: str, description_type: str) -> str | None:
|
||||
"""从数据库获取图片描述
|
||||
|
||||
Args:
|
||||
@@ -93,7 +91,7 @@ class ImageManager:
|
||||
).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {e!s}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -136,7 +134,7 @@ class ImageManager:
|
||||
await session.commit()
|
||||
# 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {e!s}")
|
||||
|
||||
@staticmethod
|
||||
async def get_emoji_tag(image_base64: str) -> str:
|
||||
@@ -287,10 +285,10 @@ class ImageManager:
|
||||
session.add(new_img)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
logger.error(f"保存到Images表失败: {e!s}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
logger.error(f"保存表情包文件或元数据失败: {e!s}")
|
||||
else:
|
||||
logger.debug("偷取表情包功能已关闭,跳过保存。")
|
||||
|
||||
@@ -300,7 +298,7 @@ class ImageManager:
|
||||
return f"[表情包:{final_emotion}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
logger.error(f"获取表情包描述失败: {e!s}")
|
||||
return "[表情包(处理失败)]"
|
||||
|
||||
async def get_image_description(self, image_base64: str) -> str:
|
||||
@@ -391,11 +389,11 @@ class ImageManager:
|
||||
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
||||
return f"[图片:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {str(e)}")
|
||||
logger.error(f"获取图片描述失败: {e!s}")
|
||||
return "[图片(处理失败)]"
|
||||
|
||||
@staticmethod
|
||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
|
||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> str | None:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
|
||||
|
||||
@@ -512,10 +510,10 @@ class ImageManager:
|
||||
logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多")
|
||||
return None # 内存不够啦
|
||||
except Exception as e:
|
||||
logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息
|
||||
logger.error(f"GIF转换失败: {e!s}", exc_info=True) # 记录详细错误信息
|
||||
return None # 其他错误也返回None
|
||||
|
||||
async def process_image(self, image_base64: str) -> Tuple[str, str]:
|
||||
async def process_image(self, image_base64: str) -> tuple[str, str]:
|
||||
# sourcery skip: hoist-if-from-if
|
||||
"""处理图片并返回图片ID和描述
|
||||
|
||||
@@ -604,7 +602,7 @@ class ImageManager:
|
||||
return image_id, f"[picid:{image_id}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {str(e)}")
|
||||
logger.error(f"处理图片失败: {e!s}")
|
||||
return "", "[图片]"
|
||||
|
||||
|
||||
@@ -637,4 +635,4 @@ def image_path_to_base64(image_path: str) -> str:
|
||||
if image_data := f.read():
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
else:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
raise OSError(f"读取图片文件失败: {image_path}")
|
||||
|
||||
@@ -1,35 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""纯 inkfox 视频关键帧分析工具
|
||||
|
||||
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
|
||||
- extract_keyframes_from_video
|
||||
- get_system_info
|
||||
|
||||
功能:
|
||||
- 关键帧提取 (base64, timestamp)
|
||||
- 批量 / 逐帧 LLM 描述
|
||||
- 自动模式 (<=3 帧批量,否则逐帧)
|
||||
"""
|
||||
视频分析器模块 - Rust优化版本
|
||||
集成了Rust视频关键帧提取模块,提供高性能的视频分析功能
|
||||
支持SIMD优化、多线程处理和智能关键帧检测
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import get_db_session, Videos
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_models import Videos, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("utils_video")
|
||||
|
||||
# Rust模块可用性检测
|
||||
@@ -205,7 +201,7 @@ class VideoAnalyzer:
|
||||
hash_obj.update(video_data)
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
async def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
|
||||
async def _check_video_exists(self, video_hash: str) -> Videos | None:
|
||||
"""检查视频是否已经分析过"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -222,8 +218,8 @@ class VideoAnalyzer:
|
||||
return None
|
||||
|
||||
async def _store_video_result(
|
||||
self, video_hash: str, description: str, metadata: Optional[Dict] = None
|
||||
) -> Optional[Videos]:
|
||||
self, video_hash: str, description: str, metadata: dict | None = None
|
||||
) -> Videos | None:
|
||||
"""存储视频分析结果到数据库"""
|
||||
# 检查描述是否为错误信息,如果是则不保存
|
||||
if description.startswith("❌"):
|
||||
@@ -283,7 +279,7 @@ class VideoAnalyzer:
|
||||
else:
|
||||
logger.warning(f"无效的分析模式: {mode}")
|
||||
|
||||
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
async def extract_frames(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""提取视频帧 - 智能选择最佳实现"""
|
||||
# 检查是否应该使用Rust实现
|
||||
if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe":
|
||||
@@ -305,8 +301,8 @@ class VideoAnalyzer:
|
||||
logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode},使用Python抽帧实现")
|
||||
return await self._extract_frames_python_fallback(video_path)
|
||||
|
||||
# ---- 系统信息 ----
|
||||
def _log_system(self) -> None:
|
||||
async def _extract_frames_rust_advanced(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""使用 Rust 高级接口的帧提取"""
|
||||
try:
|
||||
info = video.get_system_info() # type: ignore[attr-defined]
|
||||
logger.info(
|
||||
@@ -329,25 +325,174 @@ class VideoAnalyzer:
|
||||
threads=self.threads,
|
||||
verbose=False,
|
||||
)
|
||||
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
|
||||
total_ms = getattr(result, "total_time_ms", 0)
|
||||
frames: List[Tuple[str, float]] = []
|
||||
for i, f in enumerate(files):
|
||||
img = Image.open(f).convert("RGB")
|
||||
if max(img.size) > self.max_image_size:
|
||||
scale = self.max_image_size / max(img.size)
|
||||
img = img.resize((int(img.width * scale), int(img.height * scale)), Image.Resampling.LANCZOS)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=self.frame_quality)
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
ts = (i / max(1, len(files) - 1)) * (total_ms / 1000.0) if total_ms else float(i)
|
||||
frames.append((b64, ts))
|
||||
|
||||
logger.info(f"检测到 {len(keyframe_indices)} 个关键帧")
|
||||
|
||||
# 3. 转换选定的关键帧为 base64
|
||||
frames = []
|
||||
frame_count = 0
|
||||
|
||||
for idx in keyframe_indices[: self.max_frames]:
|
||||
if idx < len(frames_data):
|
||||
try:
|
||||
frame = frames_data[idx]
|
||||
frame_data = frame.get_data()
|
||||
|
||||
# 将灰度数据转换为PIL图像
|
||||
frame_array = np.frombuffer(frame_data, dtype=np.uint8).reshape((frame.height, frame.width))
|
||||
pil_image = Image.fromarray(
|
||||
frame_array,
|
||||
mode="L", # 灰度模式
|
||||
)
|
||||
|
||||
# 转换为RGB模式以便保存为JPEG
|
||||
pil_image = pil_image.convert("RGB")
|
||||
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > self.max_image_size:
|
||||
ratio = self.max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为 base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# 估算时间戳
|
||||
estimated_timestamp = frame.frame_number * (1.0 / 30.0) # 假设30fps
|
||||
|
||||
frames.append((frame_base64, estimated_timestamp))
|
||||
frame_count += 1
|
||||
|
||||
logger.debug(
|
||||
f"处理关键帧 {frame_count}: 帧号 {frame.frame_number}, 时间 {estimated_timestamp:.2f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理关键帧 {idx} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"✅ Rust 高级提取完成: {len(frames)} 关键帧")
|
||||
return frames
|
||||
|
||||
# ---- 批量分析 ----
|
||||
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.utils_model import RequestType
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Rust 高级帧提取失败: {e}")
|
||||
# 回退到基础方法
|
||||
logger.info("回退到基础 Rust 方法")
|
||||
return await self._extract_frames_rust(video_path)
|
||||
|
||||
async def _extract_frames_rust(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""使用 Rust 实现的帧提取"""
|
||||
try:
|
||||
logger.info("🔄 使用 Rust 模块提取关键帧...")
|
||||
|
||||
# 创建临时输出目录
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# 使用便捷函数进行关键帧提取,使用配置参数
|
||||
result = rust_video.extract_keyframes_from_video(
|
||||
video_path=video_path,
|
||||
output_dir=temp_dir,
|
||||
threshold=self.rust_keyframe_threshold,
|
||||
max_frames=self.max_frames * 2, # 提取更多帧以便筛选
|
||||
max_save=self.max_frames,
|
||||
ffmpeg_path=self.ffmpeg_path,
|
||||
use_simd=self.rust_use_simd,
|
||||
threads=self.rust_threads,
|
||||
verbose=False, # 使用固定值,不需要配置
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Rust 处理完成: 总帧数 {result.total_frames}, 关键帧 {result.keyframes_extracted}, 处理速度 {result.processing_fps:.1f} FPS"
|
||||
)
|
||||
|
||||
# 转换保存的关键帧为 base64 格式
|
||||
frames = []
|
||||
temp_dir_path = Path(temp_dir)
|
||||
|
||||
# 获取所有保存的关键帧文件
|
||||
keyframe_files = sorted(temp_dir_path.glob("keyframe_*.jpg"))
|
||||
|
||||
for i, keyframe_file in enumerate(keyframe_files):
|
||||
if len(frames) >= self.max_frames:
|
||||
break
|
||||
|
||||
try:
|
||||
# 读取关键帧文件
|
||||
with open(keyframe_file, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
# 转换为 PIL 图像并压缩
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > self.max_image_size:
|
||||
ratio = self.max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为 base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# 估算时间戳(基于帧索引和总时长)
|
||||
if result.total_frames > 0:
|
||||
# 假设关键帧在时间上均匀分布
|
||||
estimated_timestamp = (i * result.total_time_ms / 1000.0) / result.keyframes_extracted
|
||||
else:
|
||||
estimated_timestamp = i * 1.0 # 默认每秒一帧
|
||||
|
||||
frames.append((frame_base64, estimated_timestamp))
|
||||
|
||||
logger.debug(f"处理关键帧 {i + 1}: 估算时间 {estimated_timestamp:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理关键帧 {keyframe_file.name} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"✅ Rust 提取完成: {len(frames)} 关键帧")
|
||||
return frames
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Rust 帧提取失败: {e}")
|
||||
raise e
|
||||
|
||||
async def _extract_frames_python_fallback(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""Python降级抽帧实现 - 支持多种抽帧模式"""
|
||||
try:
|
||||
# 导入旧版本分析器
|
||||
from .utils_video_legacy import get_legacy_video_analyzer
|
||||
|
||||
logger.info("🔄 使用Python降级抽帧实现...")
|
||||
legacy_analyzer = get_legacy_video_analyzer()
|
||||
|
||||
# 同步配置参数
|
||||
legacy_analyzer.max_frames = self.max_frames
|
||||
legacy_analyzer.frame_quality = self.frame_quality
|
||||
legacy_analyzer.max_image_size = self.max_image_size
|
||||
legacy_analyzer.frame_extraction_mode = self.frame_extraction_mode
|
||||
legacy_analyzer.frame_interval_seconds = self.frame_interval_seconds
|
||||
legacy_analyzer.use_multiprocessing = self.use_multiprocessing
|
||||
|
||||
# 使用旧版本的抽帧功能
|
||||
frames = await legacy_analyzer.extract_frames(video_path)
|
||||
|
||||
logger.info(f"✅ Python降级抽帧完成: {len(frames)} 帧")
|
||||
return frames
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Python降级抽帧失败: {e}")
|
||||
return []
|
||||
|
||||
async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
|
||||
"""批量分析所有帧"""
|
||||
logger.info(f"开始批量分析{len(frames)}帧")
|
||||
|
||||
if not frames:
|
||||
return "❌ 没有可分析的帧"
|
||||
|
||||
# 构建提示词并格式化人格信息,要不然占位符的那个会爆炸
|
||||
prompt = self.batch_analysis_prompt.format(
|
||||
personality_core=self.personality_core, personality_side=self.personality_side
|
||||
)
|
||||
@@ -376,7 +521,7 @@ class VideoAnalyzer:
|
||||
logger.error(f"❌ 视频识别失败: {e}")
|
||||
raise e
|
||||
|
||||
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
|
||||
async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str:
|
||||
"""使用多图片分析方法"""
|
||||
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
|
||||
|
||||
@@ -412,53 +557,75 @@ class VideoAnalyzer:
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
)
|
||||
return resp.content or "❌ 未获得响应"
|
||||
|
||||
# ---- 逐帧分析 ----
|
||||
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
||||
results: List[str] = []
|
||||
for i, (b64, ts) in enumerate(frames):
|
||||
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||
if question:
|
||||
prompt += f"\n关注: {question}"
|
||||
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
|
||||
return api_response.content or "❌ 未获得响应内容"
|
||||
|
||||
async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
|
||||
"""逐帧分析并汇总"""
|
||||
logger.info(f"开始逐帧分析{len(frames)}帧")
|
||||
|
||||
frame_analyses = []
|
||||
|
||||
for i, (frame_base64, timestamp) in enumerate(frames):
|
||||
try:
|
||||
text, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=prompt, image_base64=b64, image_format="jpeg"
|
||||
)
|
||||
results.append(f"第{i+1}帧: {text}")
|
||||
except Exception as e: # pragma: no cover
|
||||
results.append(f"第{i+1}帧: 失败 {e}")
|
||||
if i < len(frames) - 1:
|
||||
await asyncio.sleep(self.frame_analysis_delay)
|
||||
summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results)
|
||||
try:
|
||||
final, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=summary_prompt, image_base64=frames[-1][0], image_format="jpeg"
|
||||
)
|
||||
return final
|
||||
except Exception: # pragma: no cover
|
||||
return "\n".join(results)
|
||||
logger.info("✅ 逐帧分析和汇总完成")
|
||||
return summary
|
||||
else:
|
||||
return "❌ 没有可用于汇总的帧"
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 汇总分析失败: {e}")
|
||||
# 如果汇总失败,返回各帧分析结果
|
||||
return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}"
|
||||
|
||||
# ---- 主入口 ----
|
||||
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
|
||||
if not os.path.exists(video_path):
|
||||
return False, "❌ 文件不存在"
|
||||
frames = await self.extract_keyframes(video_path)
|
||||
if not frames:
|
||||
return False, "❌ 未提取到关键帧"
|
||||
mode = self.analysis_mode
|
||||
if mode == "auto":
|
||||
mode = "batch" if len(frames) <= 20 else "sequential"
|
||||
text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question))
|
||||
return True, text
|
||||
async def analyze_video(self, video_path: str, user_question: str = None) -> tuple[bool, str]:
|
||||
"""分析视频的主要方法
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 分析结果或错误信息)
|
||||
"""
|
||||
if self.disabled:
|
||||
error_msg = "❌ 视频分析功能已禁用:没有可用的视频处理实现"
|
||||
logger.warning(error_msg)
|
||||
return (False, error_msg)
|
||||
|
||||
try:
|
||||
logger.info(f"开始分析视频: {os.path.basename(video_path)}")
|
||||
|
||||
# 提取帧
|
||||
frames = await self.extract_frames(video_path)
|
||||
if not frames:
|
||||
error_msg = "❌ 无法从视频中提取有效帧"
|
||||
return (False, error_msg)
|
||||
|
||||
# 根据模式选择分析方法
|
||||
if self.analysis_mode == "auto":
|
||||
# 智能选择:少于等于3帧用批量,否则用逐帧
|
||||
mode = "batch" if len(frames) <= 3 else "sequential"
|
||||
logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)")
|
||||
else:
|
||||
mode = self.analysis_mode
|
||||
|
||||
# 执行分析
|
||||
if mode == "batch":
|
||||
result = await self.analyze_frames_batch(frames, user_question)
|
||||
else: # sequential
|
||||
result = await self.analyze_frames_sequential(frames, user_question)
|
||||
|
||||
logger.info("✅ 视频分析完成")
|
||||
return (True, result)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ 视频分析失败: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return (False, error_msg)
|
||||
|
||||
async def analyze_video_from_bytes(
|
||||
self,
|
||||
video_bytes: bytes,
|
||||
filename: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
question: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None
|
||||
) -> dict[str, str]:
|
||||
"""从字节数据分析视频
|
||||
|
||||
Args:
|
||||
@@ -568,34 +735,81 @@ class VideoAnalyzer:
|
||||
return {"summary": result}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ 从字节数据分析视频失败: {str(e)}"
|
||||
error_msg = f"❌ 从字节数据分析视频失败: {e!s}"
|
||||
logger.error(error_msg)
|
||||
|
||||
async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None:
|
||||
# 不保存错误信息到数据库,允许后续重试
|
||||
logger.info("💡 错误信息不保存到数据库,允许后续重试")
|
||||
|
||||
# 处理失败,通知等待者并清理资源
|
||||
try:
|
||||
if video_hash and video_event:
|
||||
async with video_lock_manager:
|
||||
if video_hash in video_events:
|
||||
video_events[video_hash].set()
|
||||
video_locks.pop(video_hash, None)
|
||||
video_events.pop(video_hash, None)
|
||||
except Exception as cleanup_e:
|
||||
logger.error(f"❌ 清理锁资源失败: {cleanup_e}")
|
||||
|
||||
return {"summary": error_msg}
|
||||
|
||||
def is_supported_video(self, file_path: str) -> bool:
|
||||
"""检查是否为支持的视频格式"""
|
||||
supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"}
|
||||
return Path(file_path).suffix.lower() in supported_formats
|
||||
|
||||
def get_processing_capabilities(self) -> dict[str, any]:
|
||||
"""获取处理能力信息"""
|
||||
if not RUST_VIDEO_AVAILABLE:
|
||||
return {"error": "Rust视频处理模块不可用", "available": False, "reason": "rust_video模块未安装或加载失败"}
|
||||
|
||||
try:
|
||||
async with get_db_session() as session: # type: ignore
|
||||
stmt = insert(Videos).values( # type: ignore
|
||||
video_id="",
|
||||
video_hash=video_hash,
|
||||
description=summary,
|
||||
count=1,
|
||||
timestamp=time.time(),
|
||||
vlm_processed=True,
|
||||
duration=None,
|
||||
frame_count=None,
|
||||
fps=None,
|
||||
resolution=None,
|
||||
file_size=file_size,
|
||||
)
|
||||
try:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
logger.debug(f"视频缓存写入 success hash={video_hash}")
|
||||
except sa_exc.IntegrityError: # 可能并发已写入
|
||||
await session.rollback()
|
||||
logger.debug(f"视频缓存已存在 hash={video_hash}")
|
||||
except Exception: # pragma: no cover
|
||||
logger.debug("视频缓存写入失败")
|
||||
system_info = rust_video.get_system_info()
|
||||
|
||||
# 创建一个临时的extractor来获取CPU特性
|
||||
extractor = rust_video.VideoKeyframeExtractor(threads=0, verbose=False)
|
||||
cpu_features = extractor.get_cpu_features()
|
||||
|
||||
capabilities = {
|
||||
"system": {
|
||||
"threads": system_info.get("threads", 0),
|
||||
"rust_version": system_info.get("version", "unknown"),
|
||||
},
|
||||
"cpu_features": cpu_features,
|
||||
"recommended_settings": self._get_recommended_settings(cpu_features),
|
||||
"analysis_modes": ["auto", "batch", "sequential"],
|
||||
"supported_formats": [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"],
|
||||
"available": True,
|
||||
}
|
||||
|
||||
return capabilities
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取处理能力信息失败: {e}")
|
||||
return {"error": str(e), "available": False}
|
||||
|
||||
def _get_recommended_settings(self, cpu_features: dict[str, bool]) -> dict[str, any]:
|
||||
"""根据CPU特性推荐最佳设置"""
|
||||
settings = {
|
||||
"use_simd": any(cpu_features.values()),
|
||||
"block_size": 8192,
|
||||
"threads": 0, # 自动检测
|
||||
}
|
||||
|
||||
# 根据CPU特性调整设置
|
||||
if cpu_features.get("avx2", False):
|
||||
settings["block_size"] = 16384 # AVX2支持更大的块
|
||||
settings["optimization_level"] = "avx2"
|
||||
elif cpu_features.get("sse2", False):
|
||||
settings["block_size"] = 8192
|
||||
settings["optimization_level"] = "sse2"
|
||||
else:
|
||||
settings["use_simd"] = False
|
||||
settings["block_size"] = 4096
|
||||
settings["optimization_level"] = "scalar"
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
# ---- 外部接口 ----
|
||||
@@ -613,7 +827,14 @@ def is_video_analysis_available() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_video_analysis_status() -> Dict[str, Any]:
|
||||
def get_video_analysis_status() -> dict[str, any]:
|
||||
"""获取视频分析功能的详细状态信息
|
||||
|
||||
Returns:
|
||||
Dict[str, any]: 包含功能状态信息的字典
|
||||
"""
|
||||
# 检查OpenCV是否可用
|
||||
opencv_available = False
|
||||
try:
|
||||
info = video.get_system_info() # type: ignore[attr-defined]
|
||||
except Exception as e: # pragma: no cover
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
视频分析器模块 - 旧版本兼容模块
|
||||
支持多种分析模式:批处理、逐帧、自动选择
|
||||
包含Python原生的抽帧功能,作为Rust模块的降级方案
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Any
|
||||
import io
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("utils_video_legacy")
|
||||
|
||||
@@ -30,7 +30,7 @@ def _extract_frames_worker(
|
||||
frame_quality: int,
|
||||
max_image_size: int,
|
||||
frame_extraction_mode: str,
|
||||
frame_interval_seconds: Optional[float],
|
||||
frame_interval_seconds: float | None,
|
||||
) -> list[Any] | list[tuple[str, str]]:
|
||||
"""线程池中提取视频帧的工作函数"""
|
||||
frames = []
|
||||
@@ -221,7 +221,7 @@ class LegacyVideoAnalyzer:
|
||||
f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}"
|
||||
)
|
||||
|
||||
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
async def extract_frames(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""提取视频帧 - 支持多进程和单线程模式"""
|
||||
# 先获取视频信息
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
@@ -247,7 +247,7 @@ class LegacyVideoAnalyzer:
|
||||
else:
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
async def _extract_frames_multiprocess(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""线程池版本的帧提取"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -282,7 +282,7 @@ class LegacyVideoAnalyzer:
|
||||
logger.info("🔄 降级到单线程模式...")
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
async def _extract_frames_fallback(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
async def _extract_frames_fallback(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""帧提取的降级方法 - 原始异步版本"""
|
||||
frames = []
|
||||
extracted_count = 0
|
||||
@@ -389,7 +389,7 @@ class LegacyVideoAnalyzer:
|
||||
logger.info(f"✅ 成功提取{len(frames)}帧")
|
||||
return frames
|
||||
|
||||
async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
|
||||
async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
|
||||
"""批量分析所有帧"""
|
||||
logger.info(f"开始批量分析{len(frames)}帧")
|
||||
|
||||
@@ -441,7 +441,7 @@ class LegacyVideoAnalyzer:
|
||||
logger.error(f"❌ 降级分析也失败: {fallback_e}")
|
||||
raise
|
||||
|
||||
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
|
||||
async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str:
|
||||
"""使用多图片分析方法"""
|
||||
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
|
||||
|
||||
@@ -481,7 +481,7 @@ class LegacyVideoAnalyzer:
|
||||
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
|
||||
return api_response.content or "❌ 未获得响应内容"
|
||||
|
||||
async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
|
||||
async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
|
||||
"""逐帧分析并汇总"""
|
||||
logger.info(f"开始逐帧分析{len(frames)}帧")
|
||||
|
||||
@@ -567,7 +567,7 @@ class LegacyVideoAnalyzer:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ 视频分析失败: {str(e)}"
|
||||
error_msg = f"❌ 视频分析失败: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -25,5 +25,5 @@ async def get_voice_text(voice_base64: str) -> str:
|
||||
|
||||
return f"[语音:{text}]"
|
||||
except Exception as e:
|
||||
logger.error(f"语音转文字失败: {str(e)}")
|
||||
logger.error(f"语音转文字失败: {e!s}")
|
||||
return "[语音]"
|
||||
|
||||
Reference in New Issue
Block a user