完成所有类型注解的修复

This commit is contained in:
UnCLAS-Prommer
2025-07-13 00:19:54 +08:00
parent d2ad6ea1d8
commit 7ef0bfb7c8
32 changed files with 358 additions and 434 deletions

View File

@@ -1,23 +1,21 @@
import random
import re
import time
from collections import Counter
import jieba
import numpy as np
from collections import Counter
from maim_message import UserInfo
from typing import Optional, Tuple, Dict
from src.common.logger import get_logger
# from src.mood.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...config.config import global_config
from ...common.message_repository import find_messages, count_messages
from typing import Optional, Tuple, Dict
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from .typo_generator import ChineseTypoGenerator
logger = get_logger("chat_utils")
@@ -31,11 +29,7 @@ def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = "[(%s)%s]%s" % (
message_dict["user_id"],
message_dict.get("user_nickname", ""),
message_dict.get("user_cardname", ""),
)
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
@@ -58,11 +52,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
and message.message_info.additional_config.get("is_mentioned") is not None
):
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned"))
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
logger.warning(e)
logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
)
@@ -135,20 +129,17 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c
if not recent_messages:
return []
message_detailed_plain_text = ""
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
for msg_db_data in recent_messages:
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text
else:
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
message_detailed_plain_text_list = []
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
@@ -204,10 +195,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
len_text = len(text)
if len_text < 3:
if random.random() < 0.01:
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
else:
return [text]
return list(text) if random.random() < 0.01 else [text]
# 定义分隔符
separators = {"", ",", " ", "", ";"}
@@ -352,10 +340,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"]
if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo.error_rate,
@@ -420,7 +407,7 @@ def calculate_typing_time(
# chinese_time *= 1 / typing_speed_multiplier
# english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff")
chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
# 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
@@ -429,11 +416,7 @@ def calculate_typing_time(
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:
if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符
total_time += chinese_time
else: # 其他字符(如英文)
total_time += english_time
total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
if is_emoji:
total_time = 1
@@ -453,18 +436,14 @@ def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
# 统计词频
word_freq = Counter(words)
return word_freq
return Counter(words)
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
@@ -491,9 +470,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
if len(message) > max_length:
return message[:max_length] + "..."
return message
return f"{message[:max_length]}..." if len(message) > max_length else message
def protect_kaomoji(sentence):
@@ -522,7 +499,7 @@ def protect_kaomoji(sentence):
placeholder_to_kaomoji = {}
for idx, match in enumerate(kaomoji_matches):
kaomoji = match[0] if match[0] else match[1]
kaomoji = match[0] or match[1]
placeholder = f"__KAOMOJI_{idx}__"
sentence = sentence.replace(kaomoji, placeholder, 1)
placeholder_to_kaomoji[placeholder] = kaomoji
@@ -563,7 +540,7 @@ def get_western_ratio(paragraph):
if not alnum_chars:
return 0.0
western_count = sum(1 for char in alnum_chars if is_english_letter(char))
western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
return western_count / len(alnum_chars)
@@ -610,6 +587,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式
Args:
@@ -621,7 +599,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
"""
if mode == "normal":
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
if mode == "normal_no_YMD":
elif mode == "normal_no_YMD":
return time.strftime("%H:%M:%S", time.localtime(timestamp))
elif mode == "relative":
now = time.time()
@@ -640,7 +618,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
else:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
else: # mode = "lite" or unknown
# 只返回时分秒格式,喵~
# 只返回时分秒格式
return time.strftime("%H:%M:%S", time.localtime(timestamp))
@@ -670,8 +648,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
elif chat_stream.user_info: # It's a private chat
is_group_chat = False
user_info = chat_stream.user_info
platform = chat_stream.platform
user_id = user_info.user_id
platform: str = chat_stream.platform # type: ignore
user_id: str = user_info.user_id # type: ignore
# Initialize target_info with basic info
target_info = {