feat: 重构完成开始测试debug

This commit is contained in:
tcmofashi
2025-03-11 01:15:32 +08:00
parent 20b8778e2b
commit 7899e67cb2
13 changed files with 486 additions and 572 deletions

View File

@@ -11,7 +11,9 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config
from .message_cq import Message
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
from .chat_stream import ChatStream
driver = get_driver()
config = driver.config
@@ -32,7 +34,7 @@ def db_message_to_str(message_dict: Dict) -> str:
return result
def is_mentioned_bot_in_message(message: Message) -> bool:
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
for keyword in keywords:
@@ -41,15 +43,6 @@ def is_mentioned_bot_in_message(message: Message) -> bool:
return False
def is_mentioned_bot_in_txt(message: str) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
for keyword in keywords:
if keyword in message:
return True
return False
async def get_embedding(text):
"""获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding)
@@ -84,10 +77,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
chat_id = closest_record['chat_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
{"time": {"$gt": closest_time}, "chat_id": chat_id}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
@@ -111,7 +104,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return ''
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
@@ -125,35 +118,28 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
{"chat_id": chat_id},
).sort("time", -1).limit(limit))
if not recent_messages:
return []
# 转换为 Message对象列表
from .message_cq import Message
message_objects = []
for msg_data in recent_messages:
try:
chat_info=msg_data.get("chat_info",{})
chat_stream=ChatStream.from_dict(chat_info)
user_info=msg_data.get("user_info",{})
user_info=UserInfo.from_dict(user_info)
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
chat_stream=chat_stream,
time=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
detailed_plain_text=msg_data.get("detailed_plain_text", "")
)
await msg.initialize()
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
@@ -164,13 +150,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return message_objects
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{"chat_id": chat_stream_id},
{
"time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段
"user_nickname": 1, # 返回用户昵称字段
"chat_id":1,
"chat_info":1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段
}