Merge remote-tracking branch 'upstream/debug' into tc_refractor

This commit is contained in:
Rikki
2025-03-11 06:01:54 +08:00
47 changed files with 3696 additions and 1392 deletions

View File

@@ -7,6 +7,7 @@ from typing import Dict, List
import jieba
import numpy as np
from nonebot import get_driver
from loguru import logger
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
@@ -21,16 +22,16 @@ config = driver.config
def db_message_to_str(message_dict: Dict) -> str:
print(f"message_dict: {message_dict}")
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = "[(%s)%s]%s" % (
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
except:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
print(f"result: {result}")
logger.debug(f"result: {result}")
return result
@@ -71,37 +72,43 @@ def calculate_information_content(text):
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns:
list: 消息记录字典列表,每个字典包含消息内容和时间信息
"""
chat_records = []
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4:
if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
chat_id = closest_record['chat_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "chat_id": chat_id}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
for record in chat_records:
# 检查当前记录的memorized值
for record in records:
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
# print(f"消息已读取3次跳过")
print("消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
chat_text += record["detailed_plain_text"]
return chat_text
# print(f"消息已读取3次跳过")
return ''
# 添加到记录列表中
chat_records.append({
'text': record["detailed_plain_text"],
'time': record["time"],
'group_id': record["group_id"]
})
return chat_records
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
@@ -142,7 +149,7 @@ async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
)
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
logger.warning("数据库中存在无效的消息")
continue
# 按时间正序排列
@@ -259,11 +266,10 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
sentence = sentence.replace('', ' ').replace(',', ' ')
sentences_done.append(sentence)
print(f"处理后的句子: {sentences_done}")
logger.info(f"处理后的句子: {sentences_done}")
return sentences_done
def random_remove_punctuation(text: str) -> str:
"""随机处理标点符号,模拟人类打字习惯
@@ -291,43 +297,70 @@ def random_remove_punctuation(text: str) -> str:
return result
def process_llm_response(text: str) -> List[str]:
# processed_response = process_text_with_typos(content)
if len(text) > 300:
print(f"回复过长 ({len(text)} 字符),返回默认回复")
if len(text) > 200:
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
return ['懒得说']
# 处理长消息
typo_generator = ChineseTypoGenerator(
error_rate=0.03,
min_freq=7,
tone_error_rate=0.2,
word_replace_rate=0.02
error_rate=global_config.chinese_typo_error_rate,
min_freq=global_config.chinese_typo_min_freq,
tone_error_rate=global_config.chinese_typo_tone_error_rate,
word_replace_rate=global_config.chinese_typo_word_replace_rate
)
typoed_text = typo_generator.create_typo_sentence(text)[0]
sentences = split_into_sentences_w_remove_punctuation(typoed_text)
split_sentences = split_into_sentences_w_remove_punctuation(text)
sentences = []
for sentence in split_sentences:
if global_config.chinese_typo_enable:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text)
if typo_corrections:
sentences.append(typo_corrections)
else:
sentences.append(sentence)
# 检查分割后的消息数量是否过多超过3条
if len(sentences) > 4:
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
if len(sentences) > 5:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f'{global_config.BOT_NICKNAME}不知道哦']
return sentences
def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float:
def calculate_typing_time(input_string: str, chinese_time: float = 0.4, english_time: float = 0.2) -> float:
"""
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串
chinese_time (float): 中文字符的输入时间默认为0.3
english_time (float): 英文字符的输入时间默认为0.15
chinese_time (float): 中文字符的输入时间默认为0.2
english_time (float): 英文字符的输入时间默认为0.1秒
特殊情况:
- 如果只有一个中文字符将使用3倍的中文输入时间
- 在所有输入结束后额外加上回车时间0.3秒
"""
mood_manager = MoodManager.get_instance()
# 将0-1的唤醒度映射到-1到1
mood_arousal = mood_manager.current_mood.arousal
# 映射到0.5到2倍的速度系数
typing_speed_multiplier = 1.5 ** mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
chinese_time *= 1 / typing_speed_multiplier
english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
chinese_chars = sum(1 for char in input_string if '\u4e00' <= char <= '\u9fff')
# 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
return chinese_time * 3 + 0.3 # 加上回车时间
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:
if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符
total_time += chinese_time
else: # 其他字符(如英文)
total_time += english_time
return total_time
return total_time + 0.3 # 加上回车时间
def cosine_similarity(v1, v2):