初始化
This commit is contained in:
1145
src/chat/utils/chat_message_builder.py
Normal file
1145
src/chat/utils/chat_message_builder.py
Normal file
File diff suppressed because it is too large
Load Diff
282
src/chat/utils/prompt_builder.py
Normal file
282
src/chat/utils/prompt_builder.py
Normal file
@@ -0,0 +1,282 @@
|
||||
import re
|
||||
import asyncio
|
||||
import contextvars
|
||||
|
||||
from rich.traceback import install
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("prompt_build")
|
||||
|
||||
|
||||
class PromptContext:
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
# 使用contextvars创建协程上下文变量
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value)
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
# 保存当前上下文并设置新上下文
|
||||
if context_id is not None:
|
||||
try:
|
||||
# 添加超时保护,避免长时间等待锁
|
||||
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
|
||||
try:
|
||||
if context_id not in self._context_prompts:
|
||||
self._context_prompts[context_id] = {}
|
||||
finally:
|
||||
self._context_lock.release()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"获取上下文锁超时,context_id: {context_id}")
|
||||
# 超时时直接进入,不设置上下文
|
||||
context_id = None
|
||||
|
||||
# 保存当前协程的上下文值,不影响其他协程
|
||||
previous_context = self._current_context
|
||||
# 设置当前协程的新上下文
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
else:
|
||||
# 如果没有提供新上下文,保持当前上下文不变
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
# 恢复之前的上下文,添加异常保护
|
||||
if context_id is not None and token is not None:
|
||||
try:
|
||||
self._current_context_var.reset(token)
|
||||
except Exception as e:
|
||||
logger.warning(f"恢复上下文时出错: {e}")
|
||||
# 如果reset失败,尝试直接设置
|
||||
try:
|
||||
self._current_context = previous_context
|
||||
except Exception:
|
||||
pass # 静默忽略恢复失败
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
"""异步获取当前作用域中的提示模板"""
|
||||
async with self._context_lock:
|
||||
current_context = self._current_context
|
||||
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
|
||||
if (
|
||||
current_context
|
||||
and current_context in self._context_prompts
|
||||
and name in self._context_prompts[current_context]
|
||||
):
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self):
|
||||
self._prompts = {}
|
||||
self._counter = 0
|
||||
self._context = PromptContext()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
# 首先尝试从当前上下文获取
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
# 如果上下文中不存在,则使用全局提示模板
|
||||
async with self._lock:
|
||||
# logger.debug(f"从全局获取提示词: {name}")
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
self._counter += 1
|
||||
return f"prompt_{self._counter}"
|
||||
|
||||
def register(self, prompt: "Prompt") -> None:
|
||||
"""注册一个prompt"""
|
||||
if not prompt.name:
|
||||
prompt.name = self.generate_name(prompt.template)
|
||||
self._prompts[prompt.name] = prompt
|
||||
|
||||
def add_prompt(self, name: str, fstr: str) -> "Prompt":
|
||||
prompt = Prompt(fstr, name=name)
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
prompt = await self.get_prompt_async(name)
|
||||
return prompt.format(**kwargs)
|
||||
|
||||
|
||||
# 全局单例
|
||||
global_prompt_manager = PromptManager()
|
||||
|
||||
|
||||
class Prompt(str):
|
||||
# 临时标记,作为类常量
|
||||
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
|
||||
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" # type: ignore
|
||||
# 如果传入的是列表,将其转换为字符串
|
||||
if isinstance(template, list):
|
||||
template = "\n".join(str(item) for item in template)
|
||||
elif not isinstance(template, str):
|
||||
template = str(template)
|
||||
|
||||
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
|
||||
|
||||
@staticmethod
|
||||
def _restore_escaped_braces(template: str) -> str:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def __new__(cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
|
||||
# 如果传入的是元组,转换为列表
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
should_register = kwargs.pop("_should_register", True)
|
||||
|
||||
# 预处理模板中的转义花括号
|
||||
processed_fstr = cls._process_escaped_braces(fstr)
|
||||
|
||||
# 解析模板
|
||||
template_args = []
|
||||
result = re.findall(r"\{(.*?)}", processed_fstr)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
|
||||
# 如果提供了初始参数,立即格式化
|
||||
if kwargs or args:
|
||||
formatted = cls._format_template(fstr, args=args, kwargs=kwargs)
|
||||
obj = super().__new__(cls, formatted)
|
||||
else:
|
||||
obj = super().__new__(cls, "")
|
||||
|
||||
obj.template = fstr
|
||||
obj.name = name
|
||||
obj.args = template_args
|
||||
obj._args = args or []
|
||||
obj._kwargs = kwargs
|
||||
|
||||
# 修改自动注册逻辑
|
||||
if should_register and not global_prompt_manager._context._current_context:
|
||||
global_prompt_manager.register(obj)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls, fstr, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
|
||||
):
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = cls(fstr, name, args, **kwargs)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def _format_template(cls, template, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
||||
# 预处理模板中的转义花括号
|
||||
processed_template = cls._process_escaped_braces(template)
|
||||
|
||||
template_args = []
|
||||
result = re.findall(r"\{(.*?)}", processed_template)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
formatted_args = {}
|
||||
formatted_kwargs = {}
|
||||
|
||||
# 处理位置参数
|
||||
if args:
|
||||
# print(len(template_args), len(args), template_args, args)
|
||||
for i in range(len(args)):
|
||||
if i < len(template_args):
|
||||
arg = args[i]
|
||||
if isinstance(arg, Prompt):
|
||||
formatted_args[template_args[i]] = arg.format(**kwargs)
|
||||
else:
|
||||
formatted_args[template_args[i]] = arg
|
||||
else:
|
||||
logger.error(
|
||||
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
|
||||
)
|
||||
raise ValueError("格式化模板失败")
|
||||
|
||||
# 处理关键字参数
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, Prompt):
|
||||
remaining_kwargs = {k: v for k, v in kwargs.items() if k != key}
|
||||
formatted_kwargs[key] = value.format(**remaining_kwargs)
|
||||
else:
|
||||
formatted_kwargs[key] = value
|
||||
|
||||
try:
|
||||
# 先用位置参数格式化
|
||||
if args:
|
||||
processed_template = processed_template.format(**formatted_args)
|
||||
# 再用关键字参数格式化
|
||||
if kwargs:
|
||||
processed_template = processed_template.format(**formatted_kwargs)
|
||||
|
||||
# 将临时标记还原为实际的花括号
|
||||
result = cls._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(
|
||||
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
|
||||
) from e
|
||||
|
||||
def format(self, *args, **kwargs) -> "str":
|
||||
"""支持位置参数和关键字参数的格式化,使用"""
|
||||
ret = type(self)(
|
||||
self.template,
|
||||
self.name,
|
||||
args=list(args) if args else self._args,
|
||||
_should_register=False,
|
||||
**kwargs or self._kwargs,
|
||||
)
|
||||
# print(f"prompt build result: {ret} name: {ret.name} ")
|
||||
return str(ret)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__str__() if self._kwargs or self._args else self.template
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
1467
src/chat/utils/statistic.py
Normal file
1467
src/chat/utils/statistic.py
Normal file
File diff suppressed because it is too large
Load Diff
158
src/chat/utils/timer_calculator.py
Normal file
158
src/chat/utils/timer_calculator.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
|
||||
from time import perf_counter
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Callable
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
"""
|
||||
# 更好的计时器
|
||||
|
||||
使用形式:
|
||||
- 上下文
|
||||
- 装饰器
|
||||
- 直接实例化
|
||||
|
||||
使用场景:
|
||||
- 使用Timer:在需要测量代码执行时间时(如性能测试、计时器工具),Timer类是更可靠、高精度的选择。
|
||||
- 使用time.time()的场景:当需要记录实际时间点(如日志、时间戳)时使用,但避免用它测量时间间隔。
|
||||
|
||||
使用方式:
|
||||
|
||||
【装饰器】
|
||||
time_dict = {}
|
||||
@Timer("计数", time_dict)
|
||||
def func():
|
||||
pass
|
||||
print(time_dict)
|
||||
|
||||
【上下文_1】
|
||||
def func():
|
||||
with Timer() as t:
|
||||
pass
|
||||
print(t)
|
||||
print(t.human_readable)
|
||||
|
||||
【上下文_2】
|
||||
def func():
|
||||
time_dict = {}
|
||||
with Timer("计数", time_dict):
|
||||
pass
|
||||
print(time_dict)
|
||||
|
||||
【直接实例化】
|
||||
a = Timer()
|
||||
print(a) # 直接输出当前 perf_counter 值
|
||||
|
||||
参数:
|
||||
- name:计时器的名字,默认为 None
|
||||
- storage:计时器结果存储字典,默认为 None
|
||||
- auto_unit:自动选择单位(毫秒或秒),默认为 True(自动根据时间切换毫秒或秒)
|
||||
- do_type_check:是否进行类型检查,默认为 False(不进行类型检查)
|
||||
|
||||
属性:human_readable
|
||||
|
||||
自定义错误:TimerTypeError
|
||||
"""
|
||||
|
||||
|
||||
class TimerTypeError(TypeError):
|
||||
"""自定义类型错误"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, param, expected_type, actual_type):
|
||||
super().__init__(f"参数 '{param}' 类型错误,期望 {expected_type},实际得到 {actual_type.__name__}")
|
||||
|
||||
|
||||
class Timer:
|
||||
"""
|
||||
Timer 支持三种模式:
|
||||
1. 装饰器模式:用于测量函数/协程运行时间
|
||||
2. 上下文管理器模式:用于 with 语句块内部计时
|
||||
3. 直接实例化:如果不调用 __enter__,打印对象时将显示当前 perf_counter 的值
|
||||
"""
|
||||
|
||||
__slots__ = ("name", "storage", "elapsed", "auto_unit", "start")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
storage: Optional[Dict[str, float]] = None,
|
||||
auto_unit: bool = True,
|
||||
do_type_check: bool = False,
|
||||
):
|
||||
if do_type_check:
|
||||
self._validate_types(name, storage)
|
||||
|
||||
self.name = name
|
||||
self.storage = storage
|
||||
self.elapsed: float = None # type: ignore
|
||||
|
||||
self.auto_unit = auto_unit
|
||||
self.start: float = None # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _validate_types(name, storage):
|
||||
"""类型检查"""
|
||||
if name is not None and not isinstance(name, str):
|
||||
raise TimerTypeError("name", "Optional[str]", type(name))
|
||||
|
||||
if storage is not None and not isinstance(storage, dict):
|
||||
raise TimerTypeError("storage", "Optional[dict]", type(storage))
|
||||
|
||||
def __call__(self, func: Optional[Callable] = None) -> Callable:
|
||||
"""装饰器模式"""
|
||||
if func is None:
|
||||
return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
with self:
|
||||
return await func(*args, **kwargs)
|
||||
return None
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
return None
|
||||
|
||||
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
wrapper.__timer__ = self # 保留计时器引用 # type: ignore
|
||||
return wrapper
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
self.start = perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.elapsed = perf_counter() - self.start
|
||||
self._record_time()
|
||||
return False
|
||||
|
||||
def _record_time(self):
|
||||
"""记录时间"""
|
||||
if self.storage is not None and self.name:
|
||||
self.storage[self.name] = self.elapsed
|
||||
|
||||
@property
|
||||
def human_readable(self) -> str:
|
||||
"""人类可读时间格式"""
|
||||
if self.elapsed is None:
|
||||
return "未计时"
|
||||
|
||||
if self.auto_unit:
|
||||
return f"{self.elapsed * 1000:.2f}毫秒" if self.elapsed < 1 else f"{self.elapsed:.2f}秒"
|
||||
return f"{self.elapsed:.4f}秒"
|
||||
|
||||
def __str__(self):
|
||||
if self.start is not None:
|
||||
if self.elapsed is None:
|
||||
current_elapsed = perf_counter() - self.start
|
||||
return f"<Timer {self.name or '匿名'} [计时中: {current_elapsed:.4f}秒]>"
|
||||
return f"<Timer {self.name or '匿名'} [{self.human_readable}]>"
|
||||
return f"{perf_counter()}"
|
||||
477
src/chat/utils/typo_generator.py
Normal file
477
src/chat/utils/typo_generator.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
错别字生成器 - 基于拼音和字频的中文错别字生成工具
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import jieba
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from pypinyin import Style, pinyin
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("typo_gen")
|
||||
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
|
||||
"""
|
||||
初始化错别字生成器
|
||||
|
||||
参数:
|
||||
error_rate: 单字替换概率
|
||||
min_freq: 最小字频阈值
|
||||
tone_error_rate: 声调错误概率
|
||||
word_replace_rate: 整词替换概率
|
||||
max_freq_diff: 最大允许的频率差异
|
||||
"""
|
||||
self.error_rate = error_rate
|
||||
self.min_freq = min_freq
|
||||
self.tone_error_rate = tone_error_rate
|
||||
self.word_replace_rate = word_replace_rate
|
||||
self.max_freq_diff = max_freq_diff
|
||||
|
||||
# 加载数据
|
||||
# print("正在加载汉字数据库,请稍候...")
|
||||
# logger.info("正在加载汉字数据库,请稍候...")
|
||||
|
||||
self.pinyin_dict = self._create_pinyin_dict()
|
||||
self.char_frequency = self._load_or_create_char_frequency()
|
||||
|
||||
def _load_or_create_char_frequency(self):
|
||||
"""
|
||||
加载或创建汉字频率字典
|
||||
"""
|
||||
cache_file = Path("depends-data/char_frequency.json")
|
||||
|
||||
# 如果缓存文件存在,直接加载
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
# 使用内置的词频文件
|
||||
char_freq = defaultdict(int)
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
|
||||
# 读取jieba的词典文件
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
word, freq = line.strip().split()[:2]
|
||||
# 对词中的每个字进行频率累加
|
||||
for char in word:
|
||||
if self._is_chinese_char(char):
|
||||
char_freq[char] += int(freq)
|
||||
|
||||
# 归一化频率值
|
||||
max_freq = max(char_freq.values())
|
||||
normalized_freq = {char: freq / max_freq * 1000 for char, freq in char_freq.items()}
|
||||
|
||||
# 保存到缓存文件
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return normalized_freq
|
||||
|
||||
@staticmethod
|
||||
def _create_pinyin_dict():
|
||||
"""
|
||||
创建拼音到汉字的映射字典
|
||||
"""
|
||||
# 常用汉字范围
|
||||
chars = [chr(i) for i in range(0x4E00, 0x9FFF)]
|
||||
pinyin_dict = defaultdict(list)
|
||||
|
||||
# 为每个汉字建立拼音映射
|
||||
for char in chars:
|
||||
try:
|
||||
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||
pinyin_dict[py].append(char)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return pinyin_dict
|
||||
|
||||
@staticmethod
|
||||
def _is_chinese_char(char):
|
||||
"""
|
||||
判断是否为汉字
|
||||
"""
|
||||
try:
|
||||
return "\u4e00" <= char <= "\u9fff"
|
||||
except Exception as e:
|
||||
logger.debug(str(e))
|
||||
return False
|
||||
|
||||
def _get_pinyin(self, sentence):
|
||||
"""
|
||||
将中文句子拆分成单个汉字并获取其拼音
|
||||
"""
|
||||
# 将句子拆分成单个字符
|
||||
characters = list(sentence)
|
||||
|
||||
# 获取每个字符的拼音
|
||||
result = []
|
||||
for char in characters:
|
||||
# 跳过空格和非汉字字符
|
||||
if char.isspace() or not self._is_chinese_char(char):
|
||||
continue
|
||||
# 获取拼音(数字声调)
|
||||
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||
result.append((char, py))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_similar_tone_pinyin(py):
|
||||
"""
|
||||
获取相似声调的拼音
|
||||
"""
|
||||
# 检查拼音是否为空或无效
|
||||
if not py or len(py) < 1:
|
||||
return py
|
||||
|
||||
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||
if not py[-1].isdigit():
|
||||
# 为非数字结尾的拼音添加数字声调1
|
||||
return f"{py}1"
|
||||
|
||||
base = py[:-1] # 去掉声调
|
||||
tone = int(py[-1]) # 获取声调
|
||||
|
||||
# 处理轻声(通常用5表示)或无效声调
|
||||
if tone not in [1, 2, 3, 4]:
|
||||
return base + str(random.choice([1, 2, 3, 4]))
|
||||
|
||||
# 正常处理声调
|
||||
possible_tones = [1, 2, 3, 4]
|
||||
possible_tones.remove(tone) # 移除原声调
|
||||
new_tone = random.choice(possible_tones) # 随机选择一个新声调
|
||||
return base + str(new_tone)
|
||||
|
||||
def _calculate_replacement_probability(self, orig_freq, target_freq):
|
||||
"""
|
||||
根据频率差计算替换概率
|
||||
"""
|
||||
if target_freq > orig_freq:
|
||||
return 1.0 # 如果替换字频率更高,保持原有概率
|
||||
|
||||
freq_diff = orig_freq - target_freq
|
||||
if freq_diff > self.max_freq_diff:
|
||||
return 0.0 # 频率差太大,不替换
|
||||
|
||||
# 使用指数衰减函数计算概率
|
||||
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
||||
return math.exp(-3 * freq_diff / self.max_freq_diff)
|
||||
|
||||
def _get_similar_frequency_chars(self, char, py, num_candidates=5):
|
||||
"""
|
||||
获取与给定字频率相近的同音字,可能包含声调错误
|
||||
"""
|
||||
homophones = []
|
||||
|
||||
# 有一定概率使用错误声调
|
||||
if random.random() < self.tone_error_rate:
|
||||
wrong_tone_py = self._get_similar_tone_pinyin(py)
|
||||
homophones.extend(self.pinyin_dict[wrong_tone_py])
|
||||
|
||||
# 添加正确声调的同音字
|
||||
homophones.extend(self.pinyin_dict[py])
|
||||
|
||||
if not homophones:
|
||||
return None
|
||||
|
||||
# 获取原字的频率
|
||||
orig_freq = self.char_frequency.get(char, 0)
|
||||
|
||||
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
||||
freq_diff = [
|
||||
(h, self.char_frequency.get(h, 0))
|
||||
for h in homophones
|
||||
if h != char and self.char_frequency.get(h, 0) >= self.min_freq
|
||||
]
|
||||
|
||||
if not freq_diff:
|
||||
return None
|
||||
|
||||
# 计算每个候选字的替换概率
|
||||
candidates_with_prob = []
|
||||
for h, freq in freq_diff:
|
||||
prob = self._calculate_replacement_probability(orig_freq, freq)
|
||||
if prob > 0: # 只保留有效概率的候选字
|
||||
candidates_with_prob.append((h, prob))
|
||||
|
||||
if not candidates_with_prob:
|
||||
return None
|
||||
|
||||
# 根据概率排序
|
||||
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 返回概率最高的几个字
|
||||
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||
|
||||
@staticmethod
|
||||
def _get_word_pinyin(word):
|
||||
"""
|
||||
获取词语的拼音列表
|
||||
"""
|
||||
return [py[0] for py in pinyin(word, style=Style.TONE3)]
|
||||
|
||||
@staticmethod
|
||||
def _segment_sentence(sentence):
|
||||
"""
|
||||
使用jieba分词,返回词语列表
|
||||
"""
|
||||
return list(jieba.cut(sentence))
|
||||
|
||||
def _get_word_homophones(self, word):
|
||||
"""
|
||||
获取整个词的同音词,只返回高频的有意义词语
|
||||
"""
|
||||
if len(word) == 1:
|
||||
return []
|
||||
|
||||
# 获取词的拼音
|
||||
word_pinyin = self._get_word_pinyin(word)
|
||||
|
||||
# 遍历所有可能的同音字组合
|
||||
candidates = []
|
||||
for py in word_pinyin:
|
||||
chars = self.pinyin_dict.get(py, [])
|
||||
if not chars:
|
||||
return []
|
||||
candidates.append(chars)
|
||||
|
||||
# 生成所有可能的组合
|
||||
import itertools
|
||||
|
||||
all_combinations = itertools.product(*candidates)
|
||||
|
||||
# 获取jieba词典和词频信息
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
valid_words = {} # 改用字典存储词语及其频率
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
word_text = parts[0]
|
||||
word_freq = float(parts[1]) # 获取词频
|
||||
valid_words[word_text] = word_freq
|
||||
|
||||
# 获取原词的词频作为参考
|
||||
original_word_freq = valid_words.get(word, 0)
|
||||
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
|
||||
|
||||
# 过滤和计算频率
|
||||
homophones = []
|
||||
for combo in all_combinations:
|
||||
new_word = "".join(combo)
|
||||
if new_word != word and new_word in valid_words:
|
||||
new_word_freq = valid_words[new_word]
|
||||
# 只保留词频达到阈值的词
|
||||
if new_word_freq >= min_word_freq:
|
||||
# 计算词的平均字频(考虑字频和词频)
|
||||
char_avg_freq = sum(self.char_frequency.get(c, 0) for c in new_word) / len(new_word)
|
||||
# 综合评分:结合词频和字频
|
||||
combined_score = new_word_freq * 0.7 + char_avg_freq * 0.3
|
||||
if combined_score >= self.min_freq:
|
||||
homophones.append((new_word, combined_score))
|
||||
|
||||
# 按综合分数排序并限制返回数量
|
||||
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
|
||||
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
|
||||
|
||||
def create_typo_sentence(self, sentence):
|
||||
"""
|
||||
创建包含同音字错误的句子,支持词语级别和字级别的替换
|
||||
|
||||
参数:
|
||||
sentence: 输入的中文句子
|
||||
|
||||
返回:
|
||||
typo_sentence: 包含错别字的句子
|
||||
correction_suggestion: 随机选择的一个纠正建议,返回正确的字/词
|
||||
"""
|
||||
result = []
|
||||
typo_info = []
|
||||
word_typos = [] # 记录词语错误对(错词,正确词)
|
||||
char_typos = [] # 记录单字错误对(错字,正确字)
|
||||
current_pos = 0
|
||||
|
||||
# 分词
|
||||
words = self._segment_sentence(sentence)
|
||||
|
||||
for word in words:
|
||||
# 如果是标点符号或空格,直接添加
|
||||
if all(not self._is_chinese_char(c) for c in word):
|
||||
result.append(word)
|
||||
current_pos += len(word)
|
||||
continue
|
||||
|
||||
# 获取词语的拼音
|
||||
word_pinyin = self._get_word_pinyin(word)
|
||||
|
||||
# 尝试整词替换
|
||||
if len(word) > 1 and random.random() < self.word_replace_rate:
|
||||
word_homophones = self._get_word_homophones(word)
|
||||
if word_homophones:
|
||||
typo_word = random.choice(word_homophones)
|
||||
# 计算词的平均频率
|
||||
orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word)
|
||||
typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
|
||||
|
||||
# 添加到结果中
|
||||
result.append(typo_word)
|
||||
typo_info.append(
|
||||
(
|
||||
word,
|
||||
typo_word,
|
||||
" ".join(word_pinyin),
|
||||
" ".join(self._get_word_pinyin(typo_word)),
|
||||
orig_freq,
|
||||
typo_freq,
|
||||
)
|
||||
)
|
||||
word_typos.append((typo_word, word)) # 记录(错词,正确词)对
|
||||
current_pos += len(typo_word)
|
||||
continue
|
||||
|
||||
# 如果不进行整词替换,则进行单字替换
|
||||
if len(word) == 1:
|
||||
char = word
|
||||
py = word_pinyin[0]
|
||||
if random.random() < self.error_rate:
|
||||
similar_chars = self._get_similar_frequency_chars(char, py)
|
||||
if similar_chars:
|
||||
typo_char = random.choice(similar_chars)
|
||||
typo_freq = self.char_frequency.get(typo_char, 0)
|
||||
orig_freq = self.char_frequency.get(char, 0)
|
||||
replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq)
|
||||
if random.random() < replace_prob:
|
||||
result.append(typo_char)
|
||||
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||
char_typos.append((typo_char, char)) # 记录(错字,正确字)对
|
||||
current_pos += 1
|
||||
continue
|
||||
result.append(char)
|
||||
current_pos += 1
|
||||
else:
|
||||
# 处理多字词的单字替换
|
||||
word_result = []
|
||||
for _, (char, py) in enumerate(zip(word, word_pinyin, strict=False)):
|
||||
# 词中的字替换概率降低
|
||||
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
|
||||
|
||||
if random.random() < word_error_rate:
|
||||
similar_chars = self._get_similar_frequency_chars(char, py)
|
||||
if similar_chars:
|
||||
typo_char = random.choice(similar_chars)
|
||||
typo_freq = self.char_frequency.get(typo_char, 0)
|
||||
orig_freq = self.char_frequency.get(char, 0)
|
||||
replace_prob = self._calculate_replacement_probability(orig_freq, typo_freq)
|
||||
if random.random() < replace_prob:
|
||||
word_result.append(typo_char)
|
||||
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
|
||||
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
|
||||
char_typos.append((typo_char, char)) # 记录(错字,正确字)对
|
||||
continue
|
||||
word_result.append(char)
|
||||
result.append("".join(word_result))
|
||||
current_pos += len(word)
|
||||
|
||||
# 优先从词语错误中选择,如果没有则从单字错误中选择
|
||||
correction_suggestion = None
|
||||
# 50%概率返回纠正建议
|
||||
if random.random() < 0.5:
|
||||
if word_typos:
|
||||
wrong_word, correct_word = random.choice(word_typos)
|
||||
correction_suggestion = correct_word
|
||||
elif char_typos:
|
||||
wrong_char, correct_char = random.choice(char_typos)
|
||||
correction_suggestion = correct_char
|
||||
|
||||
return "".join(result), correction_suggestion
|
||||
|
||||
@staticmethod
|
||||
def format_typo_info(typo_info):
|
||||
"""
|
||||
格式化错别字信息
|
||||
|
||||
参数:
|
||||
typo_info: 错别字信息列表
|
||||
|
||||
返回:
|
||||
格式化后的错别字信息字符串
|
||||
"""
|
||||
if not typo_info:
|
||||
return "未生成错别字"
|
||||
|
||||
result = []
|
||||
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
||||
# 判断是否为词语替换
|
||||
is_word = " " in orig_py
|
||||
if is_word:
|
||||
error_type = "整词替换"
|
||||
else:
|
||||
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
||||
error_type = "声调错误" if tone_error else "同音字替换"
|
||||
|
||||
result.append(
|
||||
f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> "
|
||||
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]"
|
||||
)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def set_params(self, **kwargs):
|
||||
"""
|
||||
设置参数
|
||||
|
||||
可设置参数:
|
||||
error_rate: 单字替换概率
|
||||
min_freq: 最小字频阈值
|
||||
tone_error_rate: 声调错误概率
|
||||
word_replace_rate: 整词替换概率
|
||||
max_freq_diff: 最大允许的频率差异
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
print(f"参数 {key} 已设置为 {value}")
|
||||
else:
|
||||
print(f"警告: 参数 {key} 不存在")
|
||||
|
||||
|
||||
def main():
|
||||
# 创建错别字生成器实例
|
||||
typo_generator = ChineseTypoGenerator(error_rate=0.03, min_freq=7, tone_error_rate=0.02, word_replace_rate=0.3)
|
||||
|
||||
# 获取用户输入
|
||||
sentence = input("请输入中文句子:")
|
||||
|
||||
# 创建包含错别字的句子
|
||||
start_time = time.time()
|
||||
typo_sentence, correction_suggestion = typo_generator.create_typo_sentence(sentence)
|
||||
|
||||
# 打印结果
|
||||
print("\n原句:", sentence)
|
||||
print("错字版:", typo_sentence)
|
||||
|
||||
# 打印纠正建议
|
||||
if correction_suggestion:
|
||||
print("\n随机纠正建议:")
|
||||
print(f"应该改为:{correction_suggestion}")
|
||||
|
||||
# 计算并打印总耗时
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
print(f"\n总耗时:{total_time:.2f}秒")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
767
src/chat/utils/utils.py
Normal file
767
src/chat/utils/utils.py
Normal file
@@ -0,0 +1,767 @@
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import jieba
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from maim_message import UserInfo
|
||||
from typing import Optional, Tuple, Dict, List, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
logger = get_logger("chat_utils")
|
||||
|
||||
|
||||
def is_english_letter(char: str) -> bool:
|
||||
"""检查字符是否为英文字母(忽略大小写)"""
|
||||
return "a" <= char.lower() <= "z"
|
||||
|
||||
|
||||
def db_message_to_str(message_dict: dict) -> str:
|
||||
logger.debug(f"message_dict: {message_dict}")
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||
try:
|
||||
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
|
||||
except Exception:
|
||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||
content = message_dict.get("processed_plain_text", "")
|
||||
result = f"[{time_str}] {name}: {content}\n"
|
||||
logger.debug(f"result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = [global_config.bot.nickname]
|
||||
nicknames = global_config.bot.alias_names
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
if message.is_mentioned is not None:
|
||||
return bool(message.is_mentioned), message.is_mentioned
|
||||
if (
|
||||
message.message_info.additional_config is not None
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
):
|
||||
try:
|
||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
|
||||
is_mentioned = True
|
||||
return is_mentioned, reply_probability
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
logger.warning(
|
||||
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
||||
)
|
||||
|
||||
if global_config.bot.nickname in message.processed_plain_text:
|
||||
is_mentioned = True
|
||||
|
||||
for alias_name in global_config.bot.alias_names:
|
||||
if alias_name in message.processed_plain_text:
|
||||
is_mentioned = True
|
||||
|
||||
# 判断是否被@
|
||||
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", message.processed_plain_text):
|
||||
is_at = True
|
||||
is_mentioned = True
|
||||
|
||||
# print(f"message.processed_plain_text: {message.processed_plain_text}")
|
||||
# print(f"is_mentioned: {is_mentioned}")
|
||||
# print(f"is_at: {is_at}")
|
||||
|
||||
if is_at and global_config.chat.at_bot_inevitable_reply:
|
||||
reply_probability = 1.0
|
||||
logger.debug("被@,回复概率设置为100%")
|
||||
else:
|
||||
if not is_mentioned:
|
||||
# 判断是否被回复
|
||||
if re.match(
|
||||
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\):(.+?)\],说:", message.processed_plain_text
|
||||
) or re.match(
|
||||
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>:(.+?)\],说:",
|
||||
message.processed_plain_text,
|
||||
):
|
||||
is_mentioned = True
|
||||
else:
|
||||
# 判断内容中是否被提及
|
||||
message_content = re.sub(r"@(.+?)((\d+))", "", message.processed_plain_text)
|
||||
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
|
||||
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\):(.+?)\],说:", "", message_content)
|
||||
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>:(.+?)\],说:", "", message_content)
|
||||
for keyword in keywords:
|
||||
if keyword in message_content:
|
||||
is_mentioned = True
|
||||
for nickname in nicknames:
|
||||
if nickname in message_content:
|
||||
is_mentioned = True
|
||||
if is_mentioned and global_config.chat.mentioned_bot_inevitable_reply:
|
||||
reply_probability = 1.0
|
||||
logger.debug("被提及,回复概率设置为100%")
|
||||
return is_mentioned, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
try:
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
embedding = None
|
||||
return embedding
|
||||
|
||||
|
||||
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
||||
# 获取当前群聊记录内发言的人
|
||||
filter_query = {"chat_id": chat_stream_id}
|
||||
sort_order = [("time", -1)]
|
||||
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
who_chat_in_group = []
|
||||
for msg_db_data in recent_messages:
|
||||
user_info = UserInfo.from_dict(
|
||||
{
|
||||
"platform": msg_db_data["user_platform"],
|
||||
"user_id": msg_db_data["user_id"],
|
||||
"user_nickname": msg_db_data["user_nickname"],
|
||||
"user_cardname": msg_db_data.get("user_cardname", ""),
|
||||
}
|
||||
)
|
||||
if (
|
||||
(user_info.platform, user_info.user_id) != sender
|
||||
and user_info.user_id != global_config.bot.qq_account
|
||||
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||
and len(who_chat_in_group) < 5
|
||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||
who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
|
||||
|
||||
return who_chat_in_group
|
||||
|
||||
|
||||
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
"""将文本分割成句子,并根据概率合并
|
||||
1. 识别分割点(, , 。 ; 空格),但如果分割点左右都是英文字母则不分割。
|
||||
2. 将文本分割成 (内容, 分隔符) 的元组。
|
||||
3. 根据原始文本长度计算合并概率,概率性地合并相邻段落。
|
||||
注意:此函数假定颜文字已在上层被保护。
|
||||
Args:
|
||||
text: 要分割的文本字符串 (假定颜文字已被保护)
|
||||
Returns:
|
||||
List[str]: 分割和合并后的句子列表
|
||||
"""
|
||||
# 预处理:处理多余的换行符
|
||||
# 1. 将连续的换行符替换为单个换行符
|
||||
text = re.sub(r"\n\s*\n+", "\n", text)
|
||||
# 2. 处理换行符和其他分隔符的组合
|
||||
text = re.sub(r"\n\s*([,,。;\s])", r"\1", text)
|
||||
text = re.sub(r"([,,。;\s])\s*\n", r"\1", text)
|
||||
|
||||
# 处理两个汉字中间的换行符
|
||||
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
|
||||
|
||||
len_text = len(text)
|
||||
if len_text < 3:
|
||||
return list(text) if random.random() < 0.01 else [text]
|
||||
|
||||
# 定义分隔符
|
||||
separators = {",", ",", " ", "。", ";"}
|
||||
segments = []
|
||||
current_segment = ""
|
||||
|
||||
# 1. 分割成 (内容, 分隔符) 元组
|
||||
i = 0
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
if char in separators:
|
||||
# 检查分割条件:如果分隔符左右都是英文字母,则不分割
|
||||
can_split = True
|
||||
if 0 < i < len(text) - 1:
|
||||
prev_char = text[i - 1]
|
||||
next_char = text[i + 1]
|
||||
# if is_english_letter(prev_char) and is_english_letter(next_char) and char == ' ': # 原计划只对空格应用此规则,现应用于所有分隔符
|
||||
if is_english_letter(prev_char) and is_english_letter(next_char):
|
||||
can_split = False
|
||||
|
||||
if can_split:
|
||||
# 只有当当前段不为空时才添加
|
||||
if current_segment:
|
||||
segments.append((current_segment, char))
|
||||
# 如果当前段为空,但分隔符是空格,则也添加一个空段(保留空格)
|
||||
elif char == " ":
|
||||
segments.append(("", char))
|
||||
current_segment = ""
|
||||
else:
|
||||
# 不分割,将分隔符加入当前段
|
||||
current_segment += char
|
||||
else:
|
||||
current_segment += char
|
||||
i += 1
|
||||
|
||||
# 添加最后一个段(没有后续分隔符)
|
||||
if current_segment:
|
||||
segments.append((current_segment, ""))
|
||||
|
||||
# 过滤掉完全空的段(内容和分隔符都为空)
|
||||
segments = [(content, sep) for content, sep in segments if content or sep]
|
||||
|
||||
# 如果分割后为空(例如,输入全是分隔符且不满足保留条件),恢复颜文字并返回
|
||||
if not segments:
|
||||
return [text] if text else [] # 如果原始文本非空,则返回原始文本(可能只包含未被分割的字符或颜文字占位符)
|
||||
|
||||
# 2. 概率合并
|
||||
if len_text < 12:
|
||||
split_strength = 0.2
|
||||
elif len_text < 32:
|
||||
split_strength = 0.6
|
||||
else:
|
||||
split_strength = 0.7
|
||||
# 合并概率与分割强度相反
|
||||
merge_probability = 1.0 - split_strength
|
||||
|
||||
merged_segments = []
|
||||
idx = 0
|
||||
while idx < len(segments):
|
||||
current_content, current_sep = segments[idx]
|
||||
|
||||
# 检查是否可以与下一段合并
|
||||
# 条件:不是最后一段,且随机数小于合并概率,且当前段有内容(避免合并空段)
|
||||
if idx + 1 < len(segments) and random.random() < merge_probability and current_content:
|
||||
next_content, next_sep = segments[idx + 1]
|
||||
# 合并: (内容1 + 分隔符1 + 内容2, 分隔符2)
|
||||
# 只有当下一段也有内容时才合并文本,否则只传递分隔符
|
||||
if next_content:
|
||||
merged_content = current_content + current_sep + next_content
|
||||
merged_segments.append((merged_content, next_sep))
|
||||
else: # 下一段内容为空,只保留当前内容和下一段的分隔符
|
||||
merged_segments.append((current_content, next_sep))
|
||||
|
||||
idx += 2 # 跳过下一段,因为它已被合并
|
||||
else:
|
||||
# 不合并,直接添加当前段
|
||||
merged_segments.append((current_content, current_sep))
|
||||
idx += 1
|
||||
|
||||
# 提取最终的句子内容
|
||||
final_sentences = [content for content, sep in merged_segments if content] # 只保留有内容的段
|
||||
|
||||
# 清理可能引入的空字符串和仅包含空白的字符串
|
||||
final_sentences = [
|
||||
s for s in final_sentences if s.strip()
|
||||
] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串
|
||||
|
||||
logger.debug(f"分割并合并后的句子: {final_sentences}")
|
||||
return final_sentences
|
||||
|
||||
|
||||
def random_remove_punctuation(text: str) -> str:
|
||||
"""随机处理标点符号,模拟人类打字习惯
|
||||
|
||||
Args:
|
||||
text: 要处理的文本
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
result = ""
|
||||
text_len = len(text)
|
||||
|
||||
for i, char in enumerate(text):
|
||||
if char == "。" and i == text_len - 1: # 结尾的句号
|
||||
if random.random() > 0.1: # 90%概率删除结尾句号
|
||||
continue
|
||||
elif char == ",":
|
||||
rand = random.random()
|
||||
if rand < 0.05: # 5%概率删除逗号
|
||||
continue
|
||||
elif rand < 0.25: # 20%概率把逗号变成空格
|
||||
result += " "
|
||||
continue
|
||||
result += char
|
||||
return result
|
||||
|
||||
|
||||
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
|
||||
if not global_config.response_post_process.enable_response_post_process:
|
||||
return [text]
|
||||
|
||||
# 先保护颜文字
|
||||
if global_config.response_splitter.enable_kaomoji_protection:
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
||||
logger.debug(f"保护颜文字后的文本: {protected_text}")
|
||||
else:
|
||||
protected_text = text
|
||||
kaomoji_mapping = {}
|
||||
# 提取被 () 或 [] 或 ()包裹且包含中文的内容
|
||||
pattern = re.compile(r"[(\[(](?=.*[一-鿿]).*?[)\])]")
|
||||
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
|
||||
# 去除 () 和 [] 及其包裹的内容
|
||||
cleaned_text = pattern.sub("", protected_text)
|
||||
|
||||
if cleaned_text == "":
|
||||
return ["呃呃"]
|
||||
|
||||
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
|
||||
|
||||
# 对清理后的文本进行进一步处理
|
||||
max_length = global_config.response_splitter.max_length * 2
|
||||
max_sentence_num = global_config.response_splitter.max_sentence_num
|
||||
# 如果基本上是中文,则进行长度过滤
|
||||
if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
|
||||
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
||||
return ["懒得说"]
|
||||
|
||||
typo_generator = ChineseTypoGenerator(
|
||||
error_rate=global_config.chinese_typo.error_rate,
|
||||
min_freq=global_config.chinese_typo.min_freq,
|
||||
tone_error_rate=global_config.chinese_typo.tone_error_rate,
|
||||
word_replace_rate=global_config.chinese_typo.word_replace_rate,
|
||||
)
|
||||
|
||||
if global_config.response_splitter.enable and enable_splitter:
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
else:
|
||||
split_sentences = [cleaned_text]
|
||||
|
||||
sentences = []
|
||||
for sentence in split_sentences:
|
||||
if global_config.chinese_typo.enable and enable_chinese_typo:
|
||||
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
||||
sentences.append(typoed_text)
|
||||
if typo_corrections:
|
||||
sentences.append(typo_corrections)
|
||||
else:
|
||||
sentences.append(sentence)
|
||||
|
||||
if len(sentences) > max_sentence_num:
|
||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||
return [f"{global_config.bot.nickname}不知道哦"]
|
||||
|
||||
# if extracted_contents:
|
||||
# for content in extracted_contents:
|
||||
# sentences.append(content)
|
||||
|
||||
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
|
||||
if global_config.response_splitter.enable_kaomoji_protection:
|
||||
sentences = recover_kaomoji(sentences, kaomoji_mapping)
|
||||
|
||||
return sentences
|
||||
|
||||
|
||||
def calculate_typing_time(
|
||||
input_string: str,
|
||||
thinking_start_time: float,
|
||||
chinese_time: float = 0.3,
|
||||
english_time: float = 0.15,
|
||||
is_emoji: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
||||
input_string (str): 输入的字符串
|
||||
chinese_time (float): 中文字符的输入时间,默认为0.2秒
|
||||
english_time (float): 英文字符的输入时间,默认为0.1秒
|
||||
is_emoji (bool): 是否为emoji,默认为False
|
||||
|
||||
特殊情况:
|
||||
- 如果只有一个中文字符,将使用3倍的中文输入时间
|
||||
- 在所有输入结束后,额外加上回车时间0.3秒
|
||||
- 如果is_emoji为True,将使用固定1秒的输入时间
|
||||
"""
|
||||
# # 将0-1的唤醒度映射到-1到1
|
||||
# mood_arousal = mood_manager.current_mood.arousal
|
||||
# # 映射到0.5到2倍的速度系数
|
||||
# typing_speed_multiplier = 1.5**mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
|
||||
# chinese_time *= 1 / typing_speed_multiplier
|
||||
# english_time *= 1 / typing_speed_multiplier
|
||||
# 计算中文字符数
|
||||
chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
|
||||
|
||||
# 如果只有一个中文字符,使用3倍时间
|
||||
if chinese_chars == 1 and len(input_string.strip()) == 1:
|
||||
return chinese_time * 3 + 0.3 # 加上回车时间
|
||||
|
||||
# 正常计算所有字符的输入时间
|
||||
total_time = 0.0
|
||||
for char in input_string:
|
||||
total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
|
||||
if is_emoji:
|
||||
total_time = 1
|
||||
|
||||
if time.time() - thinking_start_time > 10:
|
||||
total_time = 1
|
||||
|
||||
# print(f"thinking_start_time:{thinking_start_time}")
|
||||
# print(f"nowtime:{time.time()}")
|
||||
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
|
||||
# print(f"{total_time}")
|
||||
|
||||
return total_time # 加上回车时间
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def text_to_vector(text):
|
||||
"""将文本转换为词频向量"""
|
||||
# 分词
|
||||
words = jieba.lcut(text)
|
||||
return Counter(words)
|
||||
|
||||
|
||||
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||
"""使用简单的余弦相似度计算文本相似度"""
|
||||
# 将输入文本转换为词频向量
|
||||
text_vector = text_to_vector(text)
|
||||
|
||||
# 计算每个主题的相似度
|
||||
similarities = []
|
||||
for topic in topics:
|
||||
topic_vector = text_to_vector(topic)
|
||||
# 获取所有唯一词
|
||||
all_words = set(text_vector.keys()) | set(topic_vector.keys())
|
||||
# 构建向量
|
||||
v1 = [text_vector.get(word, 0) for word in all_words]
|
||||
v2 = [topic_vector.get(word, 0) for word in all_words]
|
||||
# 计算相似度
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
similarities.append((topic, similarity))
|
||||
|
||||
# 按相似度降序排序并返回前k个
|
||||
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
|
||||
|
||||
|
||||
def truncate_message(message: str, max_length=20) -> str:
|
||||
"""截断消息,使其不超过指定长度"""
|
||||
return f"{message[:max_length]}..." if len(message) > max_length else message
|
||||
|
||||
|
||||
def protect_kaomoji(sentence):
|
||||
""" "
|
||||
识别并保护句子中的颜文字(含括号与无括号),将其替换为占位符,
|
||||
并返回替换后的句子和占位符到颜文字的映射表。
|
||||
Args:
|
||||
sentence (str): 输入的原始句子
|
||||
Returns:
|
||||
tuple: (处理后的句子, {占位符: 颜文字})
|
||||
"""
|
||||
kaomoji_pattern = re.compile(
|
||||
r"("
|
||||
r"[(\[(【]" # 左括号
|
||||
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
|
||||
r"[^一-龥a-zA-Z0-9\s]" # 非中文、非英文、非数字、非空格字符(必须包含至少一个)
|
||||
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
|
||||
r"[)\])】" # 右括号
|
||||
r"]"
|
||||
r")"
|
||||
r"|"
|
||||
r"([▼▽・ᴥω・﹏^><≧≦ ̄`´∀ヮДд︿﹀へ。゚╥╯╰︶︹•⁄]{2,15})"
|
||||
)
|
||||
|
||||
kaomoji_matches = kaomoji_pattern.findall(sentence)
|
||||
placeholder_to_kaomoji = {}
|
||||
|
||||
for idx, match in enumerate(kaomoji_matches):
|
||||
kaomoji = match[0] or match[1]
|
||||
placeholder = f"__KAOMOJI_{idx}__"
|
||||
sentence = sentence.replace(kaomoji, placeholder, 1)
|
||||
placeholder_to_kaomoji[placeholder] = kaomoji
|
||||
|
||||
return sentence, placeholder_to_kaomoji
|
||||
|
||||
|
||||
def recover_kaomoji(sentences, placeholder_to_kaomoji):
|
||||
"""
|
||||
根据映射表恢复句子中的颜文字。
|
||||
Args:
|
||||
sentences (list): 含有占位符的句子列表
|
||||
placeholder_to_kaomoji (dict): 占位符到颜文字的映射表
|
||||
Returns:
|
||||
list: 恢复颜文字后的句子列表
|
||||
"""
|
||||
recovered_sentences = []
|
||||
for sentence in sentences:
|
||||
for placeholder, kaomoji in placeholder_to_kaomoji.items():
|
||||
sentence = sentence.replace(placeholder, kaomoji)
|
||||
recovered_sentences.append(sentence)
|
||||
return recovered_sentences
|
||||
|
||||
|
||||
def get_western_ratio(paragraph):
|
||||
"""计算段落中字母数字字符的西文比例
|
||||
原理:检查段落中字母数字字符的西文比例
|
||||
通过is_english_letter函数判断每个字符是否为西文
|
||||
只检查字母数字字符,忽略标点符号和空格等非字母数字字符
|
||||
|
||||
Args:
|
||||
paragraph: 要检查的文本段落
|
||||
|
||||
Returns:
|
||||
float: 西文字符比例(0.0-1.0),如果没有字母数字字符则返回0.0
|
||||
"""
|
||||
alnum_chars = [char for char in paragraph if char.isalnum()]
|
||||
if not alnum_chars:
|
||||
return 0.0
|
||||
|
||||
western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
|
||||
return western_count / len(alnum_chars)
|
||||
|
||||
|
||||
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
|
||||
"""计算两个时间点之间的消息数量和文本总长度
|
||||
|
||||
Args:
|
||||
start_time (float): 起始时间戳 (不包含)
|
||||
end_time (float): 结束时间戳 (包含)
|
||||
stream_id (str): 聊天流ID
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (消息数量, 文本总长度)
|
||||
"""
|
||||
count = 0
|
||||
total_length = 0
|
||||
|
||||
# 参数校验 (可选但推荐)
|
||||
if start_time >= end_time:
|
||||
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
|
||||
return 0, 0
|
||||
if not stream_id:
|
||||
logger.error("stream_id 不能为空")
|
||||
return 0, 0
|
||||
|
||||
# 使用message_repository中的count_messages和find_messages函数
|
||||
|
||||
# 构建查询条件
|
||||
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
|
||||
|
||||
try:
|
||||
# 先获取消息数量
|
||||
count = count_messages(filter_query)
|
||||
|
||||
# 获取消息内容计算总长度
|
||||
messages = find_messages(message_filter=filter_query)
|
||||
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
|
||||
|
||||
return count, total_length
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算消息数量时发生意外错误: {e}")
|
||||
return 0, 0
|
||||
|
||||
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
||||
"""将时间戳转换为人类可读的时间格式
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
mode: 转换模式,"normal"为标准格式,"relative"为相对时间格式
|
||||
|
||||
Returns:
|
||||
str: 格式化后的时间字符串
|
||||
"""
|
||||
if mode == "normal":
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
|
||||
elif mode == "normal_no_YMD":
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
elif mode == "relative":
|
||||
now = time.time()
|
||||
diff = now - timestamp
|
||||
|
||||
if diff < 20:
|
||||
return "刚刚"
|
||||
elif diff < 60:
|
||||
return f"{int(diff)}秒前"
|
||||
elif diff < 3600:
|
||||
return f"{int(diff / 60)}分钟前"
|
||||
elif diff < 86400:
|
||||
return f"{int(diff / 3600)}小时前"
|
||||
elif diff < 86400 * 2:
|
||||
return f"{int(diff / 86400)}天前"
|
||||
else:
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
|
||||
else: # mode = "lite" or unknown
|
||||
# 只返回时分秒格式
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
"""
|
||||
获取聊天类型(是否群聊)和私聊对象信息。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[Dict]]:
|
||||
- bool: 是否为群聊 (True 是群聊, False 是私聊或未知)
|
||||
- Optional[Dict]: 如果是私聊,包含对方信息的字典;否则为 None。
|
||||
字典包含: platform, user_id, user_nickname, person_id, person_name
|
||||
"""
|
||||
is_group_chat = False # Default to private/unknown
|
||||
chat_target_info = None
|
||||
|
||||
try:
|
||||
if chat_stream := get_chat_manager().get_stream(chat_id):
|
||||
if chat_stream.group_info:
|
||||
is_group_chat = True
|
||||
chat_target_info = None # Explicitly None for group chat
|
||||
elif chat_stream.user_info: # It's a private chat
|
||||
is_group_chat = False
|
||||
user_info = chat_stream.user_info
|
||||
platform: str = chat_stream.platform
|
||||
user_id: str = user_info.user_id # type: ignore
|
||||
|
||||
# Initialize target_info with basic info
|
||||
target_info = {
|
||||
"platform": platform,
|
||||
"user_id": user_id,
|
||||
"user_nickname": user_info.user_nickname,
|
||||
"person_id": None,
|
||||
"person_name": None,
|
||||
}
|
||||
|
||||
# Try to fetch person info
|
||||
try:
|
||||
# Assume get_person_id is sync (as per original code), keep using to_thread
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_name = None
|
||||
if person_id:
|
||||
# get_value is async, so await it directly
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
||||
|
||||
target_info["person_id"] = person_id
|
||||
target_info["person_name"] = person_name
|
||||
except Exception as person_e:
|
||||
logger.warning(
|
||||
f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}"
|
||||
)
|
||||
|
||||
chat_target_info = target_info
|
||||
else:
|
||||
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
||||
# Keep defaults on error
|
||||
|
||||
return is_group_chat, chat_target_info
|
||||
|
||||
|
||||
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
包含 {'id': str, 'message': any} 格式的字典列表
|
||||
"""
|
||||
result = []
|
||||
used_ids = set()
|
||||
len_i = len(messages)
|
||||
if len_i > 100:
|
||||
a = 10
|
||||
b = 99
|
||||
else:
|
||||
a = 1
|
||||
b = 9
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# 生成唯一的简短ID
|
||||
while True:
|
||||
# 使用索引+随机数生成简短ID
|
||||
random_suffix = random.randint(a, b)
|
||||
message_id = f"m{i+1}{random_suffix}"
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
result.append({
|
||||
'id': message_id,
|
||||
'message': message
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def assign_message_ids_flexible(
|
||||
messages: list,
|
||||
prefix: str = "msg",
|
||||
id_length: int = 6,
|
||||
use_timestamp: bool = False
|
||||
) -> list:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID(增强版)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
prefix: ID前缀,默认为"msg"
|
||||
id_length: ID的总长度(不包括前缀),默认为6
|
||||
use_timestamp: 是否在ID中包含时间戳,默认为False
|
||||
|
||||
Returns:
|
||||
包含 {'id': str, 'message': any} 格式的字典列表
|
||||
"""
|
||||
result = []
|
||||
used_ids = set()
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# 生成唯一的ID
|
||||
while True:
|
||||
if use_timestamp:
|
||||
# 使用时间戳的后几位 + 随机字符
|
||||
timestamp_suffix = str(int(time.time() * 1000))[-3:]
|
||||
remaining_length = id_length - 3
|
||||
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
message_id = f"{prefix}{timestamp_suffix}{random_chars}"
|
||||
else:
|
||||
# 使用索引 + 随机字符
|
||||
index_str = str(i + 1)
|
||||
remaining_length = max(1, id_length - len(index_str))
|
||||
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
|
||||
message_id = f"{prefix}{index_str}{random_chars}"
|
||||
|
||||
if message_id not in used_ids:
|
||||
used_ids.add(message_id)
|
||||
break
|
||||
|
||||
result.append({
|
||||
'id': message_id,
|
||||
'message': message
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 使用示例:
|
||||
# messages = ["Hello", "World", "Test message"]
|
||||
#
|
||||
# # 基础版本
|
||||
# result1 = assign_message_ids(messages)
|
||||
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
|
||||
#
|
||||
# # 增强版本 - 自定义前缀和长度
|
||||
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
|
||||
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
|
||||
#
|
||||
# # 增强版本 - 使用时间戳
|
||||
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
|
||||
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
|
||||
659
src/chat/utils/utils_image.py
Normal file
659
src/chat/utils/utils_image.py
Normal file
@@ -0,0 +1,659 @@
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
import uuid
|
||||
import io
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
|
||||
|
||||
class ImageManager:
|
||||
_instance = None
|
||||
IMAGE_DIR = "data" # 图像存储根目录
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self._ensure_image_dir()
|
||||
|
||||
self._initialized = True
|
||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 使用SQLAlchemy创建表已在初始化时完成
|
||||
logger.debug("使用SQLAlchemy进行表管理")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def _ensure_image_dir(self):
|
||||
"""确保图像存储目录存在"""
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
|
||||
Args:
|
||||
image_hash: 图片哈希值
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
|
||||
Returns:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||
"""保存图片描述到数据库
|
||||
|
||||
Args:
|
||||
image_hash: 图片哈希值
|
||||
description: 描述文本
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
with get_db_session() as session:
|
||||
# 查找现有记录
|
||||
existing = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.description = description
|
||||
existing.timestamp = current_timestamp
|
||||
else:
|
||||
# 创建新记录
|
||||
new_desc = ImageDescriptions(
|
||||
image_description_hash=image_hash,
|
||||
type=description_type,
|
||||
description=description,
|
||||
timestamp=current_timestamp
|
||||
)
|
||||
session.add(new_desc)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
emoji_manager = get_emoji_manager()
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
||||
emotion_list = emoji.emotion
|
||||
tag_str = ",".join(emotion_list)
|
||||
return f"[表情包:{tag_str}]"
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,优先使用Emoji表中的缓存数据"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
emoji_manager = get_emoji_manager()
|
||||
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
|
||||
if cached_emoji_description:
|
||||
logger.info(f"[缓存命中] 使用已注册表情包描述: {cached_emoji_description[:50]}...")
|
||||
return cached_emoji_description
|
||||
except Exception as e:
|
||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||
|
||||
# 查询ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# === 二步走识别流程 ===
|
||||
|
||||
# 第一步:VLM视觉分析 - 生成详细描述
|
||||
if image_format in ["gif", "GIF"]:
|
||||
image_base64_processed = self.transform_gif(image_base64)
|
||||
if image_base64_processed is None:
|
||||
logger.warning("GIF转换失败,无法获取描述")
|
||||
return "[表情包(GIF处理失败)]"
|
||||
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(
|
||||
vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
|
||||
)
|
||||
else:
|
||||
vlm_prompt = (
|
||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
)
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(
|
||||
vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if detailed_description is None:
|
||||
logger.warning("VLM未能生成表情包详细描述")
|
||||
return "[表情包(VLM描述生成失败)]"
|
||||
|
||||
# 第二步:LLM情感分析 - 基于详细描述生成简短的情感标签
|
||||
emotion_prompt = f"""
|
||||
请你基于这个表情包的详细描述,提取出最核心的情感含义,用1-2个词概括。
|
||||
详细描述:'{detailed_description}'
|
||||
|
||||
要求:
|
||||
1. 只输出1-2个最核心的情感词汇
|
||||
2. 从互联网梗、meme的角度理解
|
||||
3. 输出简短精准,不要解释
|
||||
4. 如果有多个词用逗号分隔
|
||||
"""
|
||||
|
||||
# 使用较低温度确保输出稳定
|
||||
emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
|
||||
emotion_result, _ = await emotion_llm.generate_response_async(
|
||||
emotion_prompt, temperature=0.3, max_tokens=50
|
||||
)
|
||||
|
||||
if emotion_result is None:
|
||||
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
|
||||
# 降级处理:从详细描述中提取关键词
|
||||
import jieba
|
||||
|
||||
words = list(jieba.cut(detailed_description))
|
||||
emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
|
||||
|
||||
# 处理情感结果,取前1-2个最重要的标签
|
||||
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||
final_emotion = emotions[0] if emotions else "表情"
|
||||
|
||||
# 如果有第二个情感且不重复,也包含进来
|
||||
if len(emotions) > 1 and emotions[1] != emotions[0]:
|
||||
final_emotion = f"{emotions[0]},{emotions[1]}"
|
||||
|
||||
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# 保存表情包文件和元数据(用于可能的后续分析)
|
||||
logger.debug(f"保存表情包: {image_hash}")
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
file_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
existing_img = session.execute(select(Images).where(
|
||||
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
|
||||
)).scalar()
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
existing_img.description = detailed_description # 保存详细描述
|
||||
existing_img.timestamp = current_timestamp
|
||||
else:
|
||||
new_img = Images(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_img)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存最终的情感标签到缓存 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, final_emotion, "emoji")
|
||||
|
||||
return f"[表情包:{final_emotion}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
return "[表情包(处理失败)]"
|
||||
|
||||
async def get_image_description(self, image_base64: str) -> str:
|
||||
"""获取普通图片描述,优先使用Images表中的缓存数据"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片(描述生成失败)]"
|
||||
|
||||
# 保存图片和描述
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库,补充缺失字段
|
||||
if existing_image:
|
||||
existing_image.path = file_path
|
||||
existing_image.description = description
|
||||
existing_image.timestamp = current_timestamp
|
||||
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
session.commit()
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
new_img = Images(
|
||||
image_id=str(uuid.uuid4()),
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
vlm_processed=True,
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM完成] 图片描述生成: {description[:50]}...")
|
||||
return f"[图片:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {str(e)}")
|
||||
return "[图片(处理失败)]"
|
||||
|
||||
@staticmethod
|
||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
|
||||
|
||||
Args:
|
||||
gif_base64: GIF的base64编码字符串
|
||||
similarity_threshold: 判定帧相似的阈值 (MSE),越小表示要求差异越大才算不同帧,默认1000.0
|
||||
max_frames: 最大抽取的帧数,默认15
|
||||
|
||||
Returns:
|
||||
Optional[str]: 拼接后的JPG图像的base64编码字符串, 或者在失败时返回None
|
||||
"""
|
||||
try:
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(gif_base64, str):
|
||||
gif_base64 = gif_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
# 解码base64
|
||||
gif_data = base64.b64decode(gif_base64)
|
||||
gif = Image.open(io.BytesIO(gif_data))
|
||||
|
||||
# 收集所有帧
|
||||
all_frames = []
|
||||
try:
|
||||
while True:
|
||||
gif.seek(len(all_frames))
|
||||
# 确保是RGB格式方便比较
|
||||
frame = gif.convert("RGB")
|
||||
all_frames.append(frame.copy())
|
||||
except EOFError:
|
||||
pass # 读完啦
|
||||
|
||||
if not all_frames:
|
||||
logger.warning("GIF中没有找到任何帧")
|
||||
return None # 空的GIF直接返回None
|
||||
|
||||
# --- 新的帧选择逻辑 ---
|
||||
selected_frames = []
|
||||
last_selected_frame_np = None
|
||||
|
||||
for i, current_frame in enumerate(all_frames):
|
||||
current_frame_np = np.array(current_frame)
|
||||
|
||||
# 第一帧总是要选的
|
||||
if i == 0:
|
||||
selected_frames.append(current_frame)
|
||||
last_selected_frame_np = current_frame_np
|
||||
continue
|
||||
|
||||
# 计算和上一张选中帧的差异(均方误差 MSE)
|
||||
if last_selected_frame_np is not None:
|
||||
mse = np.mean((current_frame_np - last_selected_frame_np) ** 2)
|
||||
# logger.debug(f"帧 {i} 与上一选中帧的 MSE: {mse}") # 可以取消注释来看差异值
|
||||
|
||||
# 如果差异够大,就选它!
|
||||
if mse > similarity_threshold:
|
||||
selected_frames.append(current_frame)
|
||||
last_selected_frame_np = current_frame_np
|
||||
# 检查是不是选够了
|
||||
if len(selected_frames) >= max_frames:
|
||||
# logger.debug(f"已选够 {max_frames} 帧,停止选择。")
|
||||
break
|
||||
# 如果差异不大就跳过这一帧啦
|
||||
|
||||
# --- 帧选择逻辑结束 ---
|
||||
|
||||
# 如果选择后连一帧都没有(比如GIF只有一帧且后续处理失败?)或者原始GIF就没帧,也返回None
|
||||
if not selected_frames:
|
||||
logger.warning("处理后没有选中任何帧")
|
||||
return None
|
||||
|
||||
# logger.debug(f"总帧数: {len(all_frames)}, 选中帧数: {len(selected_frames)}")
|
||||
|
||||
# 获取选中的第一帧的尺寸(假设所有帧尺寸一致)
|
||||
frame_width, frame_height = selected_frames[0].size
|
||||
|
||||
# 计算目标尺寸,保持宽高比
|
||||
target_height = 200 # 固定高度
|
||||
# 防止除以零
|
||||
if frame_height == 0:
|
||||
logger.error("帧高度为0,无法计算缩放尺寸")
|
||||
return None
|
||||
target_width = int((target_height / frame_height) * frame_width)
|
||||
# 宽度也不能是0
|
||||
if target_width == 0:
|
||||
logger.warning(f"计算出的目标宽度为0 (原始尺寸 {frame_width}x{frame_height}),调整为1")
|
||||
target_width = 1
|
||||
|
||||
# 调整所有选中帧的大小
|
||||
resized_frames = [
|
||||
frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
|
||||
]
|
||||
|
||||
# 创建拼接图像
|
||||
total_width = target_width * len(resized_frames)
|
||||
# 防止总宽度为0
|
||||
if total_width == 0 and resized_frames:
|
||||
logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小")
|
||||
# 至少给点宽度吧
|
||||
total_width = len(resized_frames)
|
||||
elif total_width == 0:
|
||||
logger.error("计算出的总宽度为0且无选中帧")
|
||||
return None
|
||||
|
||||
combined_image = Image.new("RGB", (total_width, target_height))
|
||||
|
||||
# 水平拼接图像
|
||||
for idx, frame in enumerate(resized_frames):
|
||||
combined_image.paste(frame, (idx * target_width, 0))
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
except MemoryError:
|
||||
logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多")
|
||||
return None # 内存不够啦
|
||||
except Exception as e:
|
||||
logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息
|
||||
return None # 其他错误也返回None
|
||||
|
||||
async def process_image(self, image_base64: str) -> Tuple[str, str]:
|
||||
# sourcery skip: hoist-if-from-if
|
||||
"""处理图片并返回图片ID和描述
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (图片ID, 描述)
|
||||
"""
|
||||
try:
|
||||
# 生成图片ID
|
||||
# 计算图片哈希
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
if existing_image:
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
|
||||
existing_image.count += 1
|
||||
session.commit()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
# 保存新图片
|
||||
current_timestamp = time.time()
|
||||
image_dir = os.path.join(self.IMAGE_DIR, "images")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
filename = f"{image_id}.png"
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
new_img = Images(
|
||||
image_id=image_id,
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
timestamp=current_timestamp,
|
||||
vlm_processed=False,
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
|
||||
# 启动异步VLM处理
|
||||
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
|
||||
|
||||
return image_id, f"[picid:{image_id}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {str(e)}")
|
||||
return "", "[图片]"
|
||||
|
||||
async def _process_image_with_vlm(self, image_id: str, image_base64: str) -> None:
|
||||
"""使用VLM处理图片并更新数据库
|
||||
|
||||
Args:
|
||||
image_id: 图片ID
|
||||
image_base64: 图片的base64编码
|
||||
"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
# 获取当前图片记录
|
||||
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = session.execute(select(Images).where(
|
||||
and_(
|
||||
Images.emoji_hash == image_hash,
|
||||
Images.description.isnot(None),
|
||||
Images.description != "",
|
||||
Images.id != image.id
|
||||
)
|
||||
)).scalar()
|
||||
if existing_with_description:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
image.vlm_processed = True
|
||||
session.commit()
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VLM处理图片失败: {str(e)}")
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
image_manager = None
|
||||
|
||||
|
||||
def get_image_manager() -> ImageManager:
|
||||
"""获取全局图片管理器单例"""
|
||||
global image_manager
|
||||
if image_manager is None:
|
||||
image_manager = ImageManager()
|
||||
return image_manager
|
||||
|
||||
|
||||
def image_path_to_base64(image_path: str) -> str:
|
||||
"""将图片路径转换为base64编码
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
Returns:
|
||||
str: base64编码的图片数据
|
||||
Raises:
|
||||
FileNotFoundError: 当图片文件不存在时
|
||||
IOError: 当读取图片文件失败时
|
||||
"""
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
if image_data := f.read():
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
else:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
29
src/chat/utils/utils_voice.py
Normal file
29
src/chat/utils/utils_voice.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_voice")
|
||||
|
||||
|
||||
async def get_voice_text(voice_base64: str) -> str:
|
||||
"""获取音频文件转录文本"""
|
||||
if not global_config.voice.enable_asr:
|
||||
logger.warning("语音识别未启用,无法处理语音消息")
|
||||
return "[语音]"
|
||||
try:
|
||||
_llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio")
|
||||
text = await _llm.generate_response_for_voice(voice_base64)
|
||||
if text is None:
|
||||
logger.warning("未能生成语音文本")
|
||||
return "[语音(文本生成失败)]"
|
||||
|
||||
logger.debug(f"描述是{text}")
|
||||
|
||||
return f"[语音:{text}]"
|
||||
except Exception as e:
|
||||
logger.error(f"语音转文字失败: {str(e)}")
|
||||
return "[语音]"
|
||||
Reference in New Issue
Block a user