refactor(models):统一请求处理并优化响应处理 (refactor/unified_request)
对 `utils_model.py` 中的请求处理逻辑进行重构,创建统一的请求执行方法 `_execute_request`。该方法集中处理请求构建、重试逻辑和响应处理,替代了 `generate_response`、`generate_response_for_image` 和 `generate_response_async` 中的冗余代码。 关键变更: - 引入 `_execute_request` 作为 API 请求的单一入口 - 新增支持自定义重试策略和响应处理器 - 通过 `_build_payload` 简化图像和文本载荷构建 - 改进错误处理和日志记录 - 移除已弃用的同步方法 - 加入了`max_response_length`以兼容koboldcpp硬编码的默认值500 此次重构在保持现有功能的同时提高了代码可维护性,减少了重复代码
This commit is contained in:
@@ -32,16 +32,18 @@ def combine_messages(messages: List[Message]) -> str:
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
name = message.user_nickname or f"用户{message.user_id}"
|
||||
content = message.processed_plain_text or message.plain_text
|
||||
|
||||
|
||||
result += f"[{time_str}] {name}: {content}\n"
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def db_message_to_str (message_dict: Dict) -> str:
|
||||
|
||||
def db_message_to_str(message_dict: Dict) -> str:
|
||||
print(f"message_dict: {message_dict}")
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||
try:
|
||||
name="[(%s)%s]%s" % (message_dict['user_id'],message_dict.get("user_nickname", ""),message_dict.get("user_cardname", ""))
|
||||
name = "[(%s)%s]%s" % (
|
||||
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
|
||||
except:
|
||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||
content = message_dict.get("processed_plain_text", "")
|
||||
@@ -58,6 +60,7 @@ def is_mentioned_bot_in_message(message: Message) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_mentioned_bot_in_txt(message: str) -> bool:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = [global_config.BOT_NICKNAME]
|
||||
@@ -66,10 +69,13 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_embedding(text):
|
||||
|
||||
async def get_embedding(text):
|
||||
"""获取文本的embedding向量"""
|
||||
llm = LLM_request(model=global_config.embedding)
|
||||
return llm.get_embedding_sync(text)
|
||||
# return llm.get_embedding_sync(text)
|
||||
return await llm.get_embedding(text)
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
dot_product = np.dot(v1, v2)
|
||||
@@ -77,51 +83,54 @@ def cosine_similarity(v1, v2):
|
||||
norm2 = np.linalg.norm(v2)
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
char_count = Counter(text)
|
||||
total_chars = len(text)
|
||||
|
||||
|
||||
entropy = 0
|
||||
for count in char_count.values():
|
||||
probability = count / total_chars
|
||||
entropy -= probability * math.log2(probability)
|
||||
|
||||
|
||||
return entropy
|
||||
|
||||
|
||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
|
||||
chat_text = ''
|
||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
|
||||
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||
|
||||
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_records = list(db.db.messages.find(
|
||||
{"time": {"$gt": closest_time}, "group_id": group_id}
|
||||
).sort('time', 1).limit(length))
|
||||
|
||||
|
||||
# 更新每条消息的memorized属性
|
||||
for record in chat_records:
|
||||
# 检查当前记录的memorized值
|
||||
current_memorized = record.get('memorized', 0)
|
||||
if current_memorized > 3:
|
||||
if current_memorized > 3:
|
||||
# print(f"消息已读取3次,跳过")
|
||||
return ''
|
||||
|
||||
|
||||
# 更新memorized值
|
||||
db.db.messages.update_one(
|
||||
{"_id": record["_id"]},
|
||||
{"$set": {"memorized": current_memorized + 1}}
|
||||
)
|
||||
|
||||
|
||||
chat_text += record["detailed_plain_text"]
|
||||
|
||||
|
||||
return chat_text
|
||||
# print(f"消息已读取3次,跳过")
|
||||
return ''
|
||||
|
||||
|
||||
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
"""从数据库获取群组最近的消息记录
|
||||
|
||||
@@ -134,7 +143,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
list: Message对象列表,按时间正序排列
|
||||
"""
|
||||
|
||||
# 从数据库获取最近消息
|
||||
# 从数据库获取最近消息
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": group_id},
|
||||
# {
|
||||
@@ -149,7 +158,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
|
||||
# 转换为 Message对象列表
|
||||
from .message import Message
|
||||
message_objects = []
|
||||
@@ -168,12 +177,13 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
except KeyError:
|
||||
print("[WARNING] 数据库中存在无效的消息")
|
||||
continue
|
||||
|
||||
|
||||
# 按时间正序排列
|
||||
message_objects.reverse()
|
||||
return message_objects
|
||||
|
||||
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,combine = False):
|
||||
|
||||
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": group_id},
|
||||
{
|
||||
@@ -187,16 +197,16 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
|
||||
message_detailed_plain_text = ''
|
||||
message_detailed_plain_text_list = []
|
||||
|
||||
|
||||
# 反转消息列表,使最新的消息在最后
|
||||
recent_messages.reverse()
|
||||
|
||||
|
||||
if combine:
|
||||
for msg_db_data in recent_messages:
|
||||
message_detailed_plain_text+=str(msg_db_data["detailed_plain_text"])
|
||||
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
|
||||
return message_detailed_plain_text
|
||||
else:
|
||||
for msg_db_data in recent_messages:
|
||||
@@ -204,7 +214,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
|
||||
return message_detailed_plain_text_list
|
||||
|
||||
|
||||
|
||||
def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
||||
"""将文本分割成句子,但保持书名号中的内容完整
|
||||
Args:
|
||||
@@ -224,30 +233,30 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
||||
split_strength = 0.7
|
||||
else:
|
||||
split_strength = 0.9
|
||||
#先移除换行符
|
||||
# 先移除换行符
|
||||
# print(f"split_strength: {split_strength}")
|
||||
|
||||
|
||||
# print(f"处理前的文本: {text}")
|
||||
|
||||
|
||||
# 统一将英文逗号转换为中文逗号
|
||||
text = text.replace(',', ',')
|
||||
text = text.replace('\n', ' ')
|
||||
|
||||
|
||||
# print(f"处理前的文本: {text}")
|
||||
|
||||
|
||||
text_no_1 = ''
|
||||
for letter in text:
|
||||
# print(f"当前字符: {letter}")
|
||||
if letter in ['!','!','?','?']:
|
||||
if letter in ['!', '!', '?', '?']:
|
||||
# print(f"当前字符: {letter}, 随机数: {random.random()}")
|
||||
if random.random() < split_strength:
|
||||
letter = ''
|
||||
if letter in ['。','…']:
|
||||
if letter in ['。', '…']:
|
||||
# print(f"当前字符: {letter}, 随机数: {random.random()}")
|
||||
if random.random() < 1 - split_strength:
|
||||
letter = ''
|
||||
text_no_1 += letter
|
||||
|
||||
|
||||
# 对每个逗号单独判断是否分割
|
||||
sentences = [text_no_1]
|
||||
new_sentences = []
|
||||
@@ -276,15 +285,16 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
||||
sentences_done = []
|
||||
for sentence in sentences:
|
||||
sentence = sentence.rstrip(',,')
|
||||
if random.random() < split_strength*0.5:
|
||||
if random.random() < split_strength * 0.5:
|
||||
sentence = sentence.replace(',', '').replace(',', '')
|
||||
elif random.random() < split_strength:
|
||||
sentence = sentence.replace(',', ' ').replace(',', ' ')
|
||||
sentences_done.append(sentence)
|
||||
|
||||
|
||||
print(f"处理后的句子: {sentences_done}")
|
||||
return sentences_done
|
||||
|
||||
|
||||
# 常见的错别字映射
|
||||
TYPO_DICT = {
|
||||
'的': '地得',
|
||||
@@ -355,6 +365,7 @@ TYPO_DICT = {
|
||||
'嘻': '嘻西希'
|
||||
}
|
||||
|
||||
|
||||
def random_remove_punctuation(text: str) -> str:
|
||||
"""随机处理标点符号,模拟人类打字习惯
|
||||
|
||||
@@ -366,7 +377,7 @@ def random_remove_punctuation(text: str) -> str:
|
||||
"""
|
||||
result = ''
|
||||
text_len = len(text)
|
||||
|
||||
|
||||
for i, char in enumerate(text):
|
||||
if char == '。' and i == text_len - 1: # 结尾的句号
|
||||
if random.random() > 0.4: # 80%概率删除结尾句号
|
||||
@@ -381,6 +392,7 @@ def random_remove_punctuation(text: str) -> str:
|
||||
result += char
|
||||
return result
|
||||
|
||||
|
||||
def add_typos(text: str) -> str:
|
||||
TYPO_RATE = 0.02 # 控制错别字出现的概率(2%)
|
||||
result = ""
|
||||
@@ -393,20 +405,22 @@ def add_typos(text: str) -> str:
|
||||
result += char
|
||||
return result
|
||||
|
||||
|
||||
def process_llm_response(text: str) -> List[str]:
|
||||
# processed_response = process_text_with_typos(content)
|
||||
if len(text) > 300:
|
||||
print(f"回复过长 ({len(text)} 字符),返回默认回复")
|
||||
return ['懒得说']
|
||||
print(f"回复过长 ({len(text)} 字符),返回默认回复")
|
||||
return ['懒得说']
|
||||
# 处理长消息
|
||||
sentences = split_into_sentences_w_remove_punctuation(add_typos(text))
|
||||
# 检查分割后的消息数量是否过多(超过3条)
|
||||
if len(sentences) > 4:
|
||||
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||
return [f'{global_config.BOT_NICKNAME}不知道哦']
|
||||
|
||||
|
||||
return sentences
|
||||
|
||||
|
||||
def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float:
|
||||
"""
|
||||
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
||||
@@ -419,32 +433,10 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_
|
||||
if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符
|
||||
total_time += chinese_time
|
||||
else: # 其他字符(如英文)
|
||||
total_time += english_time
|
||||
total_time += english_time
|
||||
return total_time
|
||||
|
||||
|
||||
def find_similar_topics(message_txt: str, all_memory_topic: list, top_k: int = 5) -> list:
|
||||
"""使用重排序API找出与输入文本最相似的话题
|
||||
|
||||
Args:
|
||||
message_txt: 输入文本
|
||||
all_memory_topic: 所有记忆主题列表
|
||||
top_k: 返回最相似的话题数量
|
||||
|
||||
Returns:
|
||||
list: 最相似话题列表及其相似度分数
|
||||
"""
|
||||
|
||||
if not all_memory_topic:
|
||||
return []
|
||||
|
||||
try:
|
||||
llm = LLM_request(model=global_config.rerank)
|
||||
return llm.rerank_sync(message_txt, all_memory_topic, top_k)
|
||||
except Exception as e:
|
||||
print(f"重排序API调用出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
"""计算余弦相似度"""
|
||||
dot_product = np.dot(v1, v2)
|
||||
@@ -454,6 +446,7 @@ def cosine_similarity(v1, v2):
|
||||
return 0
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
def text_to_vector(text):
|
||||
"""将文本转换为词频向量"""
|
||||
# 分词
|
||||
@@ -462,11 +455,12 @@ def text_to_vector(text):
|
||||
word_freq = Counter(words)
|
||||
return word_freq
|
||||
|
||||
|
||||
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||
"""使用简单的余弦相似度计算文本相似度"""
|
||||
# 将输入文本转换为词频向量
|
||||
text_vector = text_to_vector(text)
|
||||
|
||||
|
||||
# 计算每个主题的相似度
|
||||
similarities = []
|
||||
for topic in topics:
|
||||
@@ -479,6 +473,6 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||
# 计算相似度
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
similarities.append((topic, similarity))
|
||||
|
||||
|
||||
# 按相似度降序排序并返回前k个
|
||||
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
|
||||
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
|
||||
|
||||
Reference in New Issue
Block a user