feat: 重构完成开始测试debug
This commit is contained in:
@@ -11,7 +11,9 @@ from nonebot import get_driver
|
||||
from ..models.utils_model import LLM_request
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
from .config import global_config
|
||||
from .message_cq import Message
|
||||
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
|
||||
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
|
||||
from .chat_stream import ChatStream
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
@@ -32,7 +34,7 @@ def db_message_to_str(message_dict: Dict) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: Message) -> bool:
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = [global_config.BOT_NICKNAME]
|
||||
for keyword in keywords:
|
||||
@@ -41,15 +43,6 @@ def is_mentioned_bot_in_message(message: Message) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_mentioned_bot_in_txt(message: str) -> bool:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = [global_config.BOT_NICKNAME]
|
||||
for keyword in keywords:
|
||||
if keyword in message:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(text):
|
||||
"""获取文本的embedding向量"""
|
||||
llm = LLM_request(model=global_config.embedding)
|
||||
@@ -84,10 +77,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
|
||||
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
chat_id = closest_record['chat_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_records = list(db.db.messages.find(
|
||||
{"time": {"$gt": closest_time}, "group_id": group_id}
|
||||
{"time": {"$gt": closest_time}, "chat_id": chat_id}
|
||||
).sort('time', 1).limit(length))
|
||||
|
||||
# 更新每条消息的memorized属性
|
||||
@@ -111,7 +104,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
return ''
|
||||
|
||||
|
||||
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
|
||||
"""从数据库获取群组最近的消息记录
|
||||
|
||||
Args:
|
||||
@@ -125,35 +118,28 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
|
||||
# 从数据库获取最近消息
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": group_id},
|
||||
# {
|
||||
# "time": 1,
|
||||
# "user_id": 1,
|
||||
# "user_nickname": 1,
|
||||
# "message_id": 1,
|
||||
# "raw_message": 1,
|
||||
# "processed_text": 1
|
||||
# }
|
||||
{"chat_id": chat_id},
|
||||
).sort("time", -1).limit(limit))
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
# 转换为 Message对象列表
|
||||
from .message_cq import Message
|
||||
message_objects = []
|
||||
for msg_data in recent_messages:
|
||||
try:
|
||||
chat_info=msg_data.get("chat_info",{})
|
||||
chat_stream=ChatStream.from_dict(chat_info)
|
||||
user_info=msg_data.get("user_info",{})
|
||||
user_info=UserInfo.from_dict(user_info)
|
||||
msg = Message(
|
||||
time=msg_data["time"],
|
||||
user_id=msg_data["user_id"],
|
||||
user_nickname=msg_data.get("user_nickname", ""),
|
||||
message_id=msg_data["message_id"],
|
||||
raw_message=msg_data["raw_message"],
|
||||
chat_stream=chat_stream,
|
||||
time=msg_data["time"],
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_data.get("processed_text", ""),
|
||||
group_id=group_id
|
||||
detailed_plain_text=msg_data.get("detailed_plain_text", "")
|
||||
)
|
||||
await msg.initialize()
|
||||
message_objects.append(msg)
|
||||
except KeyError:
|
||||
print("[WARNING] 数据库中存在无效的消息")
|
||||
@@ -164,13 +150,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
return message_objects
|
||||
|
||||
|
||||
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
|
||||
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": group_id},
|
||||
{"chat_id": chat_stream_id},
|
||||
{
|
||||
"time": 1, # 返回时间字段
|
||||
"user_id": 1, # 返回用户ID字段
|
||||
"user_nickname": 1, # 返回用户昵称字段
|
||||
"chat_id":1,
|
||||
"chat_info":1,
|
||||
"user_info": 1,
|
||||
"message_id": 1, # 返回消息ID字段
|
||||
"detailed_plain_text": 1 # 返回处理后的文本字段
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user