Files
Mofox-Core/src/chat/utils/utils.py
2025-08-13 23:17:28 +08:00

835 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import re
import string
import time
import jieba
import numpy as np
from collections import Counter
from maim_message import UserInfo
from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import Person
from .typo_generator import ChineseTypoGenerator
logger = get_logger("chat_utils")
def is_english_letter(char: str) -> bool:
"""检查字符是否为英文字母(忽略大小写)"""
return "a" <= char.lower() <= "z"
def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
logger.debug(f"result: {result}")
return result
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
if message.is_mentioned is not None:
return bool(message.is_mentioned), message.is_mentioned
if (
message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None
):
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
)
if global_config.bot.nickname in message.processed_plain_text:
is_mentioned = True
for alias_name in global_config.bot.alias_names:
if alias_name in message.processed_plain_text:
is_mentioned = True
# 判断是否被@
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", message.processed_plain_text):
is_at = True
is_mentioned = True
# print(f"message.processed_plain_text: {message.processed_plain_text}")
# print(f"is_mentioned: {is_mentioned}")
# print(f"is_at: {is_at}")
if is_at and global_config.chat.at_bot_inevitable_reply:
reply_probability = 1.0
logger.debug("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\)(.+?)\],说:", message.processed_plain_text
) or re.match(
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>(.+?)\],说:",
message.processed_plain_text,
):
is_mentioned = True
else:
# 判断内容中是否被提及
message_content = re.sub(r"@(.+?)(\d+)", "", message.processed_plain_text)
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content)
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content)
for keyword in keywords:
if keyword in message_content:
is_mentioned = True
for nickname in nicknames:
if nickname in message_content:
is_mentioned = True
if is_mentioned and global_config.chat.mentioned_bot_inevitable_reply:
reply_probability = 1.0
logger.debug("被提及回复概率设置为100%")
return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量"""
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
try:
embedding, _ = await llm.get_embedding(text)
except Exception as e:
logger.error(f"获取embedding失败: {str(e)}")
embedding = None
return embedding
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
if not recent_messages:
return []
who_chat_in_group = []
for msg_db_data in recent_messages:
user_info = UserInfo.from_dict(
{
"platform": msg_db_data["user_platform"],
"user_id": msg_db_data["user_id"],
"user_nickname": msg_db_data["user_nickname"],
"user_cardname": msg_db_data.get("user_cardname", ""),
}
)
if (
(user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append((user_info.platform, user_info.user_id, user_info.user_nickname))
return who_chat_in_group
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并
1. 识别分割点(, 。 ; 空格),但如果分割点左右都是英文字母则不分割。
2. 将文本分割成 (内容, 分隔符) 的元组。
3. 根据原始文本长度计算合并概率,概率性地合并相邻段落。
注意:此函数假定颜文字已在上层被保护。
Args:
text: 要分割的文本字符串 (假定颜文字已被保护)
Returns:
List[str]: 分割和合并后的句子列表
"""
# 预处理:处理多余的换行符
# 1. 将连续的换行符替换为单个换行符
text = re.sub(r"\n\s*\n+", "\n", text)
# 2. 处理换行符和其他分隔符的组合
text = re.sub(r"\n\s*([,。;\s])", r"\1", text)
text = re.sub(r"([,。;\s])\s*\n", r"\1", text)
# 处理两个汉字中间的换行符
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
len_text = len(text)
if len_text < 3:
return list(text) if random.random() < 0.01 else [text]
# 定义分隔符
separators = {"", ",", " ", "", ";"}
segments = []
current_segment = ""
# 1. 分割成 (内容, 分隔符) 元组
i = 0
while i < len(text):
char = text[i]
if char in separators:
# 检查分割条件:如果分隔符左右都是英文字母,则不分割
can_split = True
if 0 < i < len(text) - 1:
prev_char = text[i - 1]
next_char = text[i + 1]
# if is_english_letter(prev_char) and is_english_letter(next_char) and char == ' ': # 原计划只对空格应用此规则,现应用于所有分隔符
if is_english_letter(prev_char) and is_english_letter(next_char):
can_split = False
if can_split:
# 只有当当前段不为空时才添加
if current_segment:
segments.append((current_segment, char))
# 如果当前段为空,但分隔符是空格,则也添加一个空段(保留空格)
elif char == " ":
segments.append(("", char))
current_segment = ""
else:
# 不分割,将分隔符加入当前段
current_segment += char
else:
current_segment += char
i += 1
# 添加最后一个段(没有后续分隔符)
if current_segment:
segments.append((current_segment, ""))
# 过滤掉完全空的段(内容和分隔符都为空)
segments = [(content, sep) for content, sep in segments if content or sep]
# 如果分割后为空(例如,输入全是分隔符且不满足保留条件),恢复颜文字并返回
if not segments:
return [text] if text else [] # 如果原始文本非空,则返回原始文本(可能只包含未被分割的字符或颜文字占位符)
# 2. 概率合并
if len_text < 12:
split_strength = 0.2
elif len_text < 32:
split_strength = 0.6
else:
split_strength = 0.7
# 合并概率与分割强度相反
merge_probability = 1.0 - split_strength
merged_segments = []
idx = 0
while idx < len(segments):
current_content, current_sep = segments[idx]
# 检查是否可以与下一段合并
# 条件:不是最后一段,且随机数小于合并概率,且当前段有内容(避免合并空段)
if idx + 1 < len(segments) and random.random() < merge_probability and current_content:
next_content, next_sep = segments[idx + 1]
# 合并: (内容1 + 分隔符1 + 内容2, 分隔符2)
# 只有当下一段也有内容时才合并文本,否则只传递分隔符
if next_content:
merged_content = current_content + current_sep + next_content
merged_segments.append((merged_content, next_sep))
else: # 下一段内容为空,只保留当前内容和下一段的分隔符
merged_segments.append((current_content, next_sep))
idx += 2 # 跳过下一段,因为它已被合并
else:
# 不合并,直接添加当前段
merged_segments.append((current_content, current_sep))
idx += 1
# 提取最终的句子内容
final_sentences = [content for content, sep in merged_segments if content] # 只保留有内容的段
# 清理可能引入的空字符串和仅包含空白的字符串
final_sentences = [
s for s in final_sentences if s.strip()
] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串
logger.debug(f"分割并合并后的句子: {final_sentences}")
return final_sentences
def random_remove_punctuation(text: str) -> str:
"""随机处理标点符号,模拟人类打字习惯
Args:
text: 要处理的文本
Returns:
str: 处理后的文本
"""
result = ""
text_len = len(text)
for i, char in enumerate(text):
if char == "" and i == text_len - 1: # 结尾的句号
if random.random() > 0.1: # 90%概率删除结尾句号
continue
elif char == "":
rand = random.random()
if rand < 0.05: # 5%概率删除逗号
continue
elif rand < 0.25: # 20%概率把逗号变成空格
result += " "
continue
result += char
return result
def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese_typo: bool = True) -> list[str]:
if not global_config.response_post_process.enable_response_post_process:
return [text]
# 先保护颜文字
if global_config.response_splitter.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.debug(f"保护颜文字后的文本: {protected_text}")
else:
protected_text = text
kaomoji_mapping = {}
# 提取被 () 或 [] 或 ()包裹且包含中文的内容
pattern = re.compile(r"[(\[](?=.*[一-鿿]).*?[)\]]")
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
# 去除 () 和 [] 及其包裹的内容
cleaned_text = pattern.sub("", protected_text)
if cleaned_text == "":
return ["呃呃"]
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
# 对清理后的文本进行进一步处理
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo.tone_error_rate,
word_replace_rate=global_config.chinese_typo.word_replace_rate,
)
if global_config.response_splitter.enable and enable_splitter:
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else:
split_sentences = [cleaned_text]
sentences = []
for sentence in split_sentences:
if global_config.chinese_typo.enable and enable_chinese_typo:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text)
if typo_corrections:
sentences.append(typo_corrections)
else:
sentences.append(sentence)
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.bot.nickname}不知道哦"]
# if extracted_contents:
# for content in extracted_contents:
# sentences.append(content)
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
if global_config.response_splitter.enable_kaomoji_protection:
sentences = recover_kaomoji(sentences, kaomoji_mapping)
return sentences
def calculate_typing_time(
input_string: str,
thinking_start_time: float,
chinese_time: float = 0.3,
english_time: float = 0.15,
is_emoji: bool = False,
) -> float:
"""
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串
chinese_time (float): 中文字符的输入时间默认为0.2秒
english_time (float): 英文字符的输入时间默认为0.1秒
is_emoji (bool): 是否为emoji默认为False
特殊情况:
- 如果只有一个中文字符将使用3倍的中文输入时间
- 在所有输入结束后额外加上回车时间0.3秒
- 如果is_emoji为True将使用固定1秒的输入时间
"""
# # 将0-1的唤醒度映射到-1到1
# mood_arousal = mood_manager.current_mood.arousal
# # 映射到0.5到2倍的速度系数
# typing_speed_multiplier = 1.5**mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
# chinese_time *= 1 / typing_speed_multiplier
# english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
# 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
return chinese_time * 3 + 0.3 # 加上回车时间
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:
total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
if is_emoji:
total_time = 1
if time.time() - thinking_start_time > 10:
total_time = 1
# print(f"thinking_start_time:{thinking_start_time}")
# print(f"nowtime:{time.time()}")
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
# print(f"{total_time}")
return total_time # 加上回车时间
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
return Counter(words)
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
"""使用简单的余弦相似度计算文本相似度"""
# 将输入文本转换为词频向量
text_vector = text_to_vector(text)
# 计算每个主题的相似度
similarities = []
for topic in topics:
topic_vector = text_to_vector(topic)
# 获取所有唯一词
all_words = set(text_vector.keys()) | set(topic_vector.keys())
# 构建向量
v1 = [text_vector.get(word, 0) for word in all_words]
v2 = [topic_vector.get(word, 0) for word in all_words]
# 计算相似度
similarity = cosine_similarity(v1, v2)
similarities.append((topic, similarity))
# 按相似度降序排序并返回前k个
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
return f"{message[:max_length]}..." if len(message) > max_length else message
def protect_kaomoji(sentence):
""" "
识别并保护句子中的颜文字(含括号与无括号),将其替换为占位符,
并返回替换后的句子和占位符到颜文字的映射表。
Args:
sentence (str): 输入的原始句子
Returns:
tuple: (处理后的句子, {占位符: 颜文字})
"""
kaomoji_pattern = re.compile(
r"("
r"[(\[(【]" # 左括号
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
r"[^一-龥a-zA-Z0-9\s]" # 非中文、非英文、非数字、非空格字符(必须包含至少一个)
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
r"[)\])】" # 右括号
r"]"
r")"
r"|"
r"([▼▽・ᴥω・﹏^><≧≦ ̄`´∀ヮДд︿﹀へ。゚╥╯╰︶︹•⁄]{2,15})"
)
kaomoji_matches = kaomoji_pattern.findall(sentence)
placeholder_to_kaomoji = {}
for idx, match in enumerate(kaomoji_matches):
kaomoji = match[0] or match[1]
placeholder = f"__KAOMOJI_{idx}__"
sentence = sentence.replace(kaomoji, placeholder, 1)
placeholder_to_kaomoji[placeholder] = kaomoji
return sentence, placeholder_to_kaomoji
def recover_kaomoji(sentences, placeholder_to_kaomoji):
"""
根据映射表恢复句子中的颜文字。
Args:
sentences (list): 含有占位符的句子列表
placeholder_to_kaomoji (dict): 占位符到颜文字的映射表
Returns:
list: 恢复颜文字后的句子列表
"""
recovered_sentences = []
for sentence in sentences:
for placeholder, kaomoji in placeholder_to_kaomoji.items():
sentence = sentence.replace(placeholder, kaomoji)
recovered_sentences.append(sentence)
return recovered_sentences
def get_western_ratio(paragraph):
"""计算段落中字母数字字符的西文比例
原理:检查段落中字母数字字符的西文比例
通过is_english_letter函数判断每个字符是否为西文
只检查字母数字字符,忽略标点符号和空格等非字母数字字符
Args:
paragraph: 要检查的文本段落
Returns:
float: 西文字符比例(0.0-1.0)如果没有字母数字字符则返回0.0
"""
alnum_chars = [char for char in paragraph if char.isalnum()]
if not alnum_chars:
return 0.0
western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
return western_count / len(alnum_chars)
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
"""计算两个时间点之间的消息数量和文本总长度
Args:
start_time (float): 起始时间戳 (不包含)
end_time (float): 结束时间戳 (包含)
stream_id (str): 聊天流ID
Returns:
tuple[int, int]: (消息数量, 文本总长度)
"""
count = 0
total_length = 0
# 参数校验 (可选但推荐)
if start_time >= end_time:
# logger.debug(f"开始时间 {start_time} 大于或等于结束时间 {end_time},返回 0, 0")
return 0, 0
if not stream_id:
logger.error("stream_id 不能为空")
return 0, 0
# 使用message_repository中的count_messages和find_messages函数
# 构建查询条件
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
# 先获取消息数量
count = count_messages(filter_query)
# 获取消息内容计算总长度
messages = find_messages(message_filter=filter_query)
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)
return count, total_length
except Exception as e:
logger.error(f"计算消息数量时发生意外错误: {e}")
return 0, 0
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式
Args:
timestamp: 时间戳
mode: 转换模式,"normal"为标准格式,"relative"为相对时间格式
Returns:
str: 格式化后的时间字符串
"""
if mode == "normal":
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
elif mode == "normal_no_YMD":
return time.strftime("%H:%M:%S", time.localtime(timestamp))
elif mode == "relative":
now = time.time()
diff = now - timestamp
if diff < 20:
return "刚刚"
elif diff < 60:
return f"{int(diff)}秒前"
elif diff < 3600:
return f"{int(diff / 60)}分钟前"
elif diff < 86400:
return f"{int(diff / 3600)}小时前"
elif diff < 86400 * 2:
return f"{int(diff / 86400)}天前"
else:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
else: # mode = "lite" or unknown
# 只返回时分秒格式
return time.strftime("%H:%M:%S", time.localtime(timestamp))
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
"""
获取聊天类型(是否群聊)和私聊对象信息。
Args:
chat_id: 聊天流ID
Returns:
Tuple[bool, Optional[Dict]]:
- bool: 是否为群聊 (True 是群聊, False 是私聊或未知)
- Optional[Dict]: 如果是私聊,包含对方信息的字典;否则为 None。
字典包含: platform, user_id, user_nickname, person_id, person_name
"""
is_group_chat = False # Default to private/unknown
chat_target_info = None
try:
if chat_stream := get_chat_manager().get_stream(chat_id):
if chat_stream.group_info:
is_group_chat = True
chat_target_info = None # Explicitly None for group chat
elif chat_stream.user_info: # It's a private chat
is_group_chat = False
user_info = chat_stream.user_info
platform: str = chat_stream.platform
user_id: str = user_info.user_id # type: ignore
# Initialize target_info with basic info
target_info = {
"platform": platform,
"user_id": user_id,
"user_nickname": user_info.user_nickname,
"person_id": None,
"person_name": None,
}
# Try to fetch person info
try:
# Assume get_person_id is sync (as per original code), keep using to_thread
person = Person(platform=platform, user_id=user_id)
if not person.is_known:
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
return False, None
person_id = person.person_id
person_name = None
if person_id:
# get_value is async, so await it directly
person_name = person.person_name
target_info["person_id"] = person_id
target_info["person_name"] = person_name
except Exception as person_e:
logger.warning(
f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}"
)
chat_target_info = target_info
else:
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
except Exception as e:
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
# Keep defaults on error
return is_group_chat, chat_target_info
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
"""
为消息列表中的每个消息分配唯一的简短随机ID
Args:
messages: 消息列表
Returns:
包含 {'id': str, 'message': any} 格式的字典列表
"""
result = []
used_ids = set()
len_i = len(messages)
if len_i > 100:
a = 10
b = 99
else:
a = 1
b = 9
for i, message in enumerate(messages):
# 生成唯一的简短ID
while True:
# 使用索引+随机数生成简短ID
random_suffix = random.randint(a, b)
message_id = f"m{i+1}{random_suffix}"
if message_id not in used_ids:
used_ids.add(message_id)
break
result.append({
'id': message_id,
'message': message
})
return result
def assign_message_ids_flexible(
messages: list,
prefix: str = "msg",
id_length: int = 6,
use_timestamp: bool = False
) -> list:
"""
为消息列表中的每个消息分配唯一的简短随机ID增强版
Args:
messages: 消息列表
prefix: ID前缀默认为"msg"
id_length: ID的总长度不包括前缀默认为6
use_timestamp: 是否在ID中包含时间戳默认为False
Returns:
包含 {'id': str, 'message': any} 格式的字典列表
"""
result = []
used_ids = set()
for i, message in enumerate(messages):
# 生成唯一的ID
while True:
if use_timestamp:
# 使用时间戳的后几位 + 随机字符
timestamp_suffix = str(int(time.time() * 1000))[-3:]
remaining_length = id_length - 3
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
message_id = f"{prefix}{timestamp_suffix}{random_chars}"
else:
# 使用索引 + 随机字符
index_str = str(i + 1)
remaining_length = max(1, id_length - len(index_str))
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
message_id = f"{prefix}{index_str}{random_chars}"
if message_id not in used_ids:
used_ids.add(message_id)
break
result.append({
'id': message_id,
'message': message
})
return result
# 使用示例:
# messages = ["Hello", "World", "Test message"]
#
# # 基础版本
# result1 = assign_message_ids(messages)
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
#
# # 增强版本 - 自定义前缀和长度
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
#
# # 增强版本 - 使用时间戳
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]:
"""
统一的关键词解析函数,支持多种格式的关键词字符串解析
支持的格式:
1. 字符串列表格式:'["utils.py", "修改", "代码", "动作"]'
2. 斜杠分隔格式:'utils.py/修改/代码/动作'
3. 逗号分隔格式:'utils.py,修改,代码,动作'
4. 空格分隔格式:'utils.py 修改 代码 动作'
5. 已经是列表的情况:["utils.py", "修改", "代码", "动作"]
6. JSON格式字符串'{"keywords": ["utils.py", "修改", "代码", "动作"]}'
Args:
keywords_input: 关键词输入,可以是字符串或列表
Returns:
list[str]: 解析后的关键词列表,去除空白项
"""
if not keywords_input:
return []
# 如果已经是列表,直接处理
if isinstance(keywords_input, list):
return [str(k).strip() for k in keywords_input if str(k).strip()]
# 转换为字符串处理
keywords_str = str(keywords_input).strip()
if not keywords_str:
return []
try:
# 尝试作为JSON对象解析支持 {"keywords": [...]} 格式)
import json
json_data = json.loads(keywords_str)
if isinstance(json_data, dict) and "keywords" in json_data:
keywords_list = json_data["keywords"]
if isinstance(keywords_list, list):
return [str(k).strip() for k in keywords_list if str(k).strip()]
elif isinstance(json_data, list):
# 直接是JSON数组格式
return [str(k).strip() for k in json_data if str(k).strip()]
except (json.JSONDecodeError, ValueError):
pass
try:
# 尝试使用 ast.literal_eval 解析支持Python字面量格式
import ast
parsed = ast.literal_eval(keywords_str)
if isinstance(parsed, list):
return [str(k).strip() for k in parsed if str(k).strip()]
except (ValueError, SyntaxError):
pass
# 尝试不同的分隔符
separators = ['/', ',', ' ', '|', ';']
for separator in separators:
if separator in keywords_str:
keywords_list = [k.strip() for k in keywords_str.split(separator) if k.strip()]
if len(keywords_list) > 1: # 确保分割有效
return keywords_list
# 如果没有分隔符,返回单个关键词
return [keywords_str] if keywords_str else []