feat: 重构完成开始测试debug

This commit is contained in:
tcmofashi
2025-03-11 01:15:32 +08:00
parent 20b8778e2b
commit 7899e67cb2
13 changed files with 486 additions and 572 deletions

View File

@@ -1,6 +1,5 @@
import time
from random import random
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
@@ -11,25 +10,18 @@ from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message_cq import (
MessageRecvCQ,
MessageSendCQ,
)
from .chat_stream import chat_manager
MessageRecvCQ,
MessageSendCQ,
)
from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager
from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
from .utils_image import image_path_to_base64
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg
from .message_base import UserInfo, GroupInfo, Seg
class ChatBot:
def __init__(self):
@@ -53,24 +45,21 @@ class ChatBot:
self.bot = bot # 更新 bot 实例
group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
message_cq=MessageRecvCQ(
# 白名单设定由nontbot侧完成
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
if event.user_id in global_config.ban_user_id:
return
message_cq=MessageRecvCQ(
message_id=event.message_id,
user_id=event.user_id,
raw_message=str(event.original_message),
group_id=event.group_id,
user_id=event.user_id,
raw_message=str(event.original_message),
group_id=event.group_id,
reply_message=event.reply,
platform='qq'
)
@@ -78,37 +67,26 @@ class ChatBot:
# 进入maimbot
message=MessageRecv(**message_json)
await message.process()
groupinfo=message.message_info.group_info
userinfo=message.message_info.user_info
messageinfo=message.message_info
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
# 消息过滤涉及到config有待更新
if groupinfo:
if groupinfo.group_id not in global_config.talk_allowed_groups:
return
else:
if userinfo:
if userinfo.user_id in []:
pass
else:
return
else:
return
if userinfo.user_id in global_config.ban_user_id:
return
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
await relationship_manager.update_relationship(chat_stream=chat,)
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5)
await message.process()
# 过滤词
for word in global_config.ban_words:
if word in message.processed_plain_text:
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
if word in message.processed_plain_text:
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered")
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
@@ -130,20 +108,13 @@ class ChatBot:
is_emoji=message.is_emoji,
interested_rate=interested_rate
)
current_willing = willing_manager.get_willing(
chat_stream=chat
)
current_willing = willing_manager.get_willing(chat_stream=chat)
print(f"\033[1;32m[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
response = None
if random() < reply_probability:
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=messageinfo.platform
)
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
@@ -151,22 +122,16 @@ class ChatBot:
)
tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point)
thinking_message = MessageThinking.from_chat_stream(
chat_stream=chat,
thinking_message = MessageThinking(
message_id=think_id,
reply=message
)
thinking_message = MessageThinking.from_chat_stream(
chat_stream=chat,
message_id=think_id,
bot_user_info=bot_user_info,
reply=message
)
message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent(
chat_stream=chat
)
willing_manager.change_reply_willing_sent(chat)
response,raw_content = await self.gpt.generate_response(message)
@@ -201,18 +166,11 @@ class ChatBot:
accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time
message_segment = Seg(type='text', data=msg)
bot_message = MessageSending(
message_segment = Seg(type='text', data=msg)
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
message_segment=message_segment,
reply=message,
is_head=not mark_head,
is_emoji=False
)
chat_stream=chat,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=message,
is_head=not mark_head,
@@ -235,7 +193,6 @@ class ChatBot:
if emoji_raw != None:
emoji_path,discription = emoji_raw
emoji_cq = image_path_to_base64(emoji_path)
emoji_cq = image_path_to_base64(emoji_path)
if random() < 0.5:
@@ -247,15 +204,7 @@ class ChatBot:
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
message_segment=message_segment,
reply=message,
is_head=False,
is_emoji=True
)
message_segment = Seg(type='emoji', data=emoji_cq)
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=message,
is_head=False,
@@ -273,20 +222,12 @@ class ChatBot:
'fearful': -0.7,
'neutral': 0.1
}
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value=valuedict[emotion[0]])
# 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
willing_manager.change_reply_willing_after_sent(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo
)
willing_manager.change_reply_willing_after_sent(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo
chat_stream=chat
)
# 创建全局ChatBot实例

View File

@@ -1,55 +1,67 @@
import time
import asyncio
from typing import Optional, Dict, Tuple
import hashlib
import time
from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from .message_base import UserInfo, GroupInfo
from .message_base import GroupInfo, UserInfo
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None):
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = data.get('create_time', int(time.time())) if data else int(time.time())
self.last_active_time = data.get('last_active_time', self.create_time) if data else self.create_time
self.create_time = (
data.get("create_time", int(time.time())) if data else int(time.time())
)
self.last_active_time = (
data.get("last_active_time", self.create_time) if data else self.create_time
)
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
'stream_id': self.stream_id,
'platform': self.platform,
'user_info': self.user_info.to_dict() if self.user_info else None,
'group_info': self.group_info.to_dict() if self.group_info else None,
'create_time': self.create_time,
'last_active_time': self.last_active_time
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
"group_info": self.group_info.to_dict() if self.group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
return result
@classmethod
def from_dict(cls, data: dict) -> 'ChatStream':
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = UserInfo(**data.get('user_info', {})) if data.get('user_info') else None
group_info = GroupInfo(**data.get('group_info', {})) if data.get('group_info') else None
user_info = (
UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
)
group_info = (
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
)
return cls(
stream_id=data['stream_id'],
platform=data['platform'],
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
group_info=group_info,
data=data
data=data,
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
@@ -58,14 +70,15 @@ class ChatStream:
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
@@ -76,7 +89,7 @@ class ChatManager:
asyncio.create_task(self._initialize())
# 启动自动保存任务
asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
@@ -84,7 +97,7 @@ class ChatManager:
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
@@ -94,49 +107,48 @@ class ChatManager:
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if 'chat_streams' not in self.db.db.list_collection_names():
self.db.db.create_collection('chat_streams')
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([('stream_id', 1)], unique=True)
self.db.db.chat_streams.create_index([
('platform', 1),
('user_info.user_id', 1),
('group_info.group_id', 1)
])
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
def _generate_stream_id(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID"""
# 组合关键信息
components = [
platform,
str(user_info.user_id),
str(group_info.group_id) if group_info else 'private'
str(group_info.group_id) if group_info else "private",
]
# 使用MD5生成唯一ID
key = '_'.join(components)
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(self,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None) -> ChatStream:
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
@@ -146,9 +158,9 @@ class ChatManager:
stream.group_info = group_info
stream.update_active_time()
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({'stream_id': stream_id})
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
@@ -162,41 +174,38 @@ class ChatManager:
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
group_info=group_info,
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(self,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None) -> Optional[ChatStream]:
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
{'stream_id': stream.stream_id},
{'$set': stream.to_dict()},
upsert=True
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})

View File

@@ -3,23 +3,22 @@ import html
import os
import time
from dataclasses import dataclass
from typing import Dict, Optional, List, Union
from typing import Dict, List, Optional, Union
import requests
# 解析各种CQ码
# 包含CQ码类
import urllib3
from loguru import logger
from nonebot import get_driver
from urllib3.util import create_urllib3_context
from loguru import logger
from ..models.utils_model import LLM_request
from .config import global_config
from .mapper import emojimapper
from .utils_image import image_manager
from .utils_user import get_user_nickname
from .message_base import Seg
from .utils_user import get_user_nickname
driver = get_driver()
config = driver.config
@@ -37,21 +36,25 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections, maxsize=maxsize,
block=block, ssl_context=self.ssl_context)
num_pools=connections,
maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
@dataclass
class CQCode:
"""
CQ码数据类用于存储和处理CQ码
属性:
type: CQ码类型'image', 'at', 'face'等)
params: CQ码的参数字典
raw_code: 原始CQ码字符串
translated_segments: 经过处理后的Seg对象列表
"""
type: str
params: Dict[str, str]
group_id: int
@@ -65,77 +68,52 @@ class CQCode:
def __post_init__(self):
"""初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
self._llm = LLM_request(
model=global_config.vlm, temperature=0.4, max_tokens=300
)
def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == 'text':
if self.type == "text":
self.translated_segments = Seg(
type='text',
data=self.params.get('text', '')
type="text", data=self.params.get("text", "")
)
elif self.type == 'image':
elif self.type == "image":
base64_data = self.translate_image()
if base64_data:
if self.params.get('sub_type') == '0':
self.translated_segments = Seg(
type='image',
data=base64_data
)
if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data)
else:
self.translated_segments = Seg(
type='emoji',
data=base64_data
)
self.translated_segments = Seg(type="emoji", data=base64_data)
else:
self.translated_segments = Seg(
type='text',
data='[图片]'
)
elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', ''))
self.translated_segments = Seg(type="text", data="[图片]")
elif self.type == "at":
user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg(
type='text',
data=f"[@{user_nickname or '某人'}]"
type="text", data=f"[@{user_nickname or '某人'}]"
)
elif self.type == 'reply':
elif self.type == "reply":
reply_segments = self.translate_reply()
if reply_segments:
self.translated_segments = Seg(
type='seglist',
data=reply_segments
)
self.translated_segments = Seg(type="seglist", data=reply_segments)
else:
self.translated_segments = Seg(
type='text',
data='[回复某人消息]'
)
elif self.type == 'face':
face_id = self.params.get('id', '')
self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face":
face_id = self.params.get("id", "")
self.translated_segments = Seg(
type='text',
data=f"[{emojimapper.get(int(face_id), '表情')}]"
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == 'forward':
elif self.type == "forward":
forward_segments = self.translate_forward()
if forward_segments:
self.translated_segments = Seg(
type='seglist',
data=forward_segments
)
self.translated_segments = Seg(type="seglist", data=forward_segments)
else:
self.translated_segments = Seg(
type='text',
data='[转发消息]'
)
self.translated_segments = Seg(type="text", data="[转发消息]")
else:
self.translated_segments = Seg(
type='text',
data=f"[{self.type}]"
)
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
def get_img(self):
'''
"""
headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8',
@@ -144,18 +122,18 @@ class CQCode:
'Cache-Control': 'no-cache',
'Pragma': 'no-cache'
}
'''
"""
# 腾讯专用请求头配置
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36',
'Accept': 'text/html, application/xhtml xml, */*',
'Accept-Encoding': 'gbk, GB2312',
'Accept-Language': 'zh-cn',
'Content-Type': 'application/x-www-form-urlencoded',
'Cache-Control': 'no-cache'
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
"Accept": "text/html, application/xhtml xml, */*",
"Accept-Encoding": "gbk, GB2312",
"Accept-Language": "zh-cn",
"Content-Type": "application/x-www-form-urlencoded",
"Cache-Control": "no-cache",
}
url = html.unescape(self.params['url'])
if not url.startswith(('http://', 'https://')):
url = html.unescape(self.params["url"])
if not url.startswith(("http://", "https://")):
return None
# 创建专用会话
@@ -171,30 +149,30 @@ class CQCode:
headers=headers,
timeout=15,
allow_redirects=True,
stream=True # 流式传输避免大内存问题
stream=True, # 流式传输避免大内存问题
)
# 腾讯服务器特殊状态码处理
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url:
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
return None
if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
# 验证内容类型
content_type = response.headers.get('Content-Type', '')
if not content_type.startswith('image/'):
content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}")
# 转换为Base64
image_base64 = base64.b64encode(response.content).decode('utf-8')
image_base64 = base64.b64encode(response.content).decode("utf-8")
self.image_base64 = image_base64
return image_base64
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
if retry == max_retries - 1:
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}")
time.sleep(1.5 ** retry) # 指数退避
time.sleep(1.5**retry) # 指数退避
except Exception as e:
print(f"\033[1;33m[未知错误]\033[0m {str(e)}")
@@ -202,21 +180,21 @@ class CQCode:
return None
def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串"""
if 'url' not in self.params:
if "url" not in self.params:
return None
return self.get_img()
def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表"""
try:
if 'content' not in self.params:
if "content" not in self.params:
return None
content = self.unescape(self.params['content'])
content = self.unescape(self.params["content"])
import ast
try:
messages = ast.literal_eval(content)
except ValueError as e:
@@ -225,46 +203,52 @@ class CQCode:
formatted_segments = []
for msg in messages:
sender = msg.get('sender', {})
nickname = sender.get('card') or sender.get('nickname', '未知用户')
raw_message = msg.get('raw_message', '')
message_array = msg.get('message', [])
sender = msg.get("sender", {})
nickname = sender.get("card") or sender.get("nickname", "未知用户")
raw_message = msg.get("raw_message", "")
message_array = msg.get("message", [])
if message_array and isinstance(message_array, list):
for message_part in message_array:
if message_part.get('type') == 'forward':
content_seg = Seg(type='text', data='[转发消息]')
if message_part.get("type") == "forward":
content_seg = Seg(type="text", data="[转发消息]")
break
else:
if raw_message:
from .message_cq import MessageRecvCQ
message_obj = MessageRecvCQ(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
user_id=msg.get("user_id", 0),
message_id=msg.get("message_id", 0),
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
group_id=msg.get("group_id", 0),
)
content_seg = Seg(
type="seglist", data=message_obj.message_segments
)
content_seg = Seg(type='seglist', data=message_obj.message_segments)
else:
content_seg = Seg(type='text', data='[空消息]')
content_seg = Seg(type="text", data="[空消息]")
else:
if raw_message:
from .message_cq import MessageRecvCQ
message_obj = MessageRecvCQ(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
user_id=msg.get("user_id", 0),
message_id=msg.get("message_id", 0),
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
group_id=msg.get("group_id", 0),
)
content_seg = Seg(
type="seglist", data=message_obj.message_segments
)
content_seg = Seg(type='seglist', data=message_obj.message_segments)
else:
content_seg = Seg(type='text', data='[空消息]')
content_seg = Seg(type="text", data="[空消息]")
formatted_segments.append(Seg(type='text', data=f"{nickname}: "))
formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
formatted_segments.append(content_seg)
formatted_segments.append(Seg(type='text', data='\n'))
formatted_segments.append(Seg(type="text", data="\n"))
return formatted_segments
@@ -275,6 +259,7 @@ class CQCode:
def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
if self.reply_message is None:
return None
@@ -283,17 +268,26 @@ class CQCode:
user_id=self.reply_message.sender.user_id,
message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message),
group_id=self.group_id
group_id=self.group_id,
)
segments = []
if message_obj.user_id == global_config.BOT_QQ:
segments.append(Seg(type='text', data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
segments.append(
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
else:
segments.append(Seg(type='text', data=f"[回复 {self.reply_message.sender.nickname} 的消息: "))
segments.append(Seg(type='seglist', data=message_obj.message_segments))
segments.append(Seg(type='text', data="]"))
segments.append(
Seg(
type="text",
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
)
)
segments.append(Seg(type="seglist", data=message_obj.message_segments))
segments.append(Seg(type="text", data="]"))
return segments
else:
return None
@@ -301,12 +295,12 @@ class CQCode:
@staticmethod
def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符"""
return text.replace('&#44;', ',') \
.replace('&#91;', '[') \
.replace('&#93;', ']') \
.replace('&amp;', '&')
return (
text.replace("&#44;", ",")
.replace("&#91;", "[")
.replace("&#93;", "]")
.replace("&amp;", "&")
)
class CQCode_tool:
@@ -314,29 +308,25 @@ class CQCode_tool:
def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
"""
将CQ码字典转换为CQCode对象
Args:
cq_code: CQ码字典
reply: 回复消息的字典(可选)
Returns:
CQCode对象
"""
# 处理字典形式的CQ码
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
cq_type = cq_code.get('type', 'text')
cq_type = cq_code.get("type", "text")
params = {}
if cq_type == 'text':
params['text'] = cq_code.get('data', {}).get('text', '')
if cq_type == "text":
params["text"] = cq_code.get("data", {}).get("text", "")
else:
params = cq_code.get('data', {})
params = cq_code.get("data", {})
instance = CQCode(
type=cq_type,
params=params,
group_id=0,
user_id=0,
reply_message=reply
type=cq_type, params=params, group_id=0, user_id=0, reply_message=reply
)
# 进行翻译处理
@@ -353,7 +343,7 @@ class CQCode_tool:
回复CQ码字符串
"""
return f"[CQ:reply,id={message_id}]"
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
@@ -366,13 +356,15 @@ class CQCode_tool:
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
escaped_path = (
abs_path.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@staticmethod
def create_emoji_cq_base64(base64_data: str) -> str:
"""
@@ -383,15 +375,14 @@ class CQCode_tool:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = base64_data.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
cq_code_tool = CQCode_tool()

View File

@@ -1,11 +1,11 @@
import asyncio
import base64
import hashlib
import os
import random
import time
import traceback
from typing import Optional, Tuple
import base64
import hashlib
from loguru import logger
from nonebot import get_driver
@@ -13,9 +13,8 @@ from nonebot import get_driver
from ...common.database import Database
from ..chat.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import image_path_to_base64
from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request
from ..chat.utils_image import ImageManager
driver = get_driver()
config = driver.config
@@ -78,7 +77,6 @@ class EmojiManager:
if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('tags', 1)])
self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str):

View File

@@ -7,7 +7,7 @@ from nonebot import get_driver
from ...common.database import Database
from ..models.utils_model import LLM_request
from .config import global_config
from .message_cq import Message
from .message import MessageRecv, MessageThinking, MessageSending
from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager
from .utils import process_llm_response
@@ -18,58 +18,88 @@ config = driver.config
class ResponseGenerator:
def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000)
self.model_r1 = LLM_request(
model=global_config.llm_reasoning,
temperature=0.7,
max_tokens=1000,
stream=True,
)
self.model_v3 = LLM_request(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
)
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance()
self.current_model_type = 'r1' # 默认使用 R1
self.current_model_type = "r1" # 默认使用 R1
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
async def generate_response(
self, message: MessageThinking
) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = 'r1'
self.current_model_type = "r1"
current_model = self.model_r1
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
self.current_model_type = 'v3'
elif (
rand
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
):
self.current_model_type = "v3"
current_model = self.model_v3
else:
self.current_model_type = 'r1_distill'
self.current_model_type = "r1_distill"
current_model = self.model_r1_distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
model_response = await self._generate_response_with_model(message, current_model)
raw_content=model_response
print(
f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++"
)
model_response = await self._generate_response_with_model(
message, current_model
)
raw_content = model_response
if model_response:
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
print(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
model_response = await self._process_response(model_response)
if model_response:
return model_response, raw_content
return None, raw_content
return model_response ,raw_content
return None,raw_content
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
async def _generate_response_with_model(
self, message: MessageThinking, model: LLM_request
) -> Optional[str]:
"""使用指定的模型生成回复"""
sender_name = message.user_nickname or f"用户{message.user_id}"
if message.user_cardname:
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
sender_name = (
message.chat_stream.user_info.user_nickname
or f"用户{message.chat_stream.user_info.user_id}"
)
if message.chat_stream.user_info.user_cardname:
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
# 获取关系值
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
relationship_value = (
relationship_manager.get_relationship(
message.chat_stream
).relationship_value
if relationship_manager.get_relationship(message.chat_stream)
else 0.0
)
if relationship_value != 0.0:
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
pass
# 构建prompt
prompt, prompt_check = await prompt_builder._build_prompt(
message_txt=message.processed_plain_text,
sender_name=sender_name,
relationship_value=relationship_value,
group_id=message.group_id
stream_id=message.chat_stream.stream_id,
)
# 读空气模块 简化逻辑,先停用
@@ -95,7 +125,7 @@ class ResponseGenerator:
except Exception as e:
print(f"生成回复时出错: {e}")
return None
# 保存到数据库
self._save_to_db(
message=message,
@@ -107,54 +137,71 @@ class ResponseGenerator:
reasoning_content=reasoning_content,
# reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
)
return content
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
content: str, reasoning_content: str,):
def _save_to_db(
self,
message: Message,
sender_name: str,
prompt: str,
prompt_check: str,
content: str,
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one({
'time': time.time(),
'group_id': message.group_id,
'user': sender_name,
'message': message.processed_plain_text,
'model': self.current_model_type,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
'reasoning': reasoning_content,
'response': content,
'prompt': prompt,
'prompt_check': prompt_check
})
self.db.db.reasoning_logs.insert_one(
{
"time": time.time(),
"group_id": message.group_id,
"user": sender_name,
"message": message.processed_plain_text,
"model": self.current_model_type,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
"reasoning": reasoning_content,
"response": content,
"prompt": prompt,
"prompt_check": prompt_check,
}
)
async def _get_emotion_tags(self, content: str) -> List[str]:
"""提取情感标签"""
try:
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
prompt = f"""请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
只输出标签就好,不要输出其他内容:
内容:{content}
输出:
'''
"""
content, _ = await self.model_v25.generate_response(prompt)
content=content.strip()
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
content = content.strip()
if content in [
"happy",
"angry",
"sad",
"surprised",
"disgusted",
"fearful",
"neutral",
]:
return [content]
else:
return ["neutral"]
except Exception as e:
print(f"获取情感标签时出错: {e}")
return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
"""处理响应内容,返回处理后的内容和情感标签"""
if not content:
return None, []
processed_response = process_llm_response(content)
return processed_response

View File

@@ -5,7 +5,6 @@ from typing import Dict, ForwardRef, List, Optional, Union
import urllib3
from loguru import logger
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream
@@ -108,25 +107,32 @@ class MessageRecv(MessageBase):
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageProcessBase(MessageBase):
"""消息处理基类,用于处理中和发送中的消息"""
@dataclass
class Message(MessageBase):
chat_stream: ChatStream=None
reply: Optional['Message'] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
def __init__(
self,
message_id: str,
time: int,
chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None
reply: Optional['MessageRecv'] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "",
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=int(time.time()),
time=time,
group_info=chat_stream.group_info,
user_info=chat_stream.user_info
user_info=user_info
)
# 调用父类初始化
@@ -136,17 +142,41 @@ class MessageProcessBase(MessageBase):
raw_message=None
)
# 处理状态相关属性
self.thinking_start_time = int(time.time())
self.thinking_time = 0
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = ""
self.detailed_plain_text = ""
self.processed_plain_text = detailed_plain_text
self.detailed_plain_text = processed_plain_text
# 回复消息
self.reply = reply
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
time=int(time.time()),
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply
)
# 处理状态相关属性
self.thinking_start_time = int(time.time())
self.thinking_time = 0
def update_thinking_time(self) -> float:
"""更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
@@ -224,12 +254,14 @@ class MessageThinking(MessageProcessBase):
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段
reply=reply
)
@@ -237,15 +269,6 @@ class MessageThinking(MessageProcessBase):
# 思考状态特有属性
self.interrupt = False
@classmethod
def from_chat_stream(cls, chat_stream: ChatStream, message_id: str, reply: Optional['MessageRecv'] = None) -> 'MessageThinking':
"""从聊天流创建思考状态消息"""
return cls(
message_id=message_id,
chat_stream=chat_stream,
reply=reply
)
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
@@ -254,6 +277,7 @@ class MessageSending(MessageProcessBase):
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Seg,
reply: Optional['MessageRecv'] = None,
is_head: bool = False,
@@ -263,6 +287,7 @@ class MessageSending(MessageProcessBase):
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply
)
@@ -296,10 +321,16 @@ class MessageSending(MessageProcessBase):
message_id=thinking.message_info.message_id,
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji
)
def to_dict(self):
ret= super().to_dict()
ret['mesage_info']['user_info']=self.chat_stream.user_info.to_dict()
return ret
@dataclass
class MessageSet:

View File

@@ -78,6 +78,21 @@ class GroupInfo:
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
def from_dict(cls, data: Dict) -> 'GroupInfo':
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
group_id=data.get('group_id'),
group_name=data.get('group_name',None)
)
@dataclass
class UserInfo:
@@ -90,6 +105,22 @@ class UserInfo:
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
def from_dict(cls, data: Dict) -> 'UserInfo':
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
user_id=data.get('user_id'),
user_nickname=data.get('user_nickname',None),
user_cardname=data.get('user_cardname',None)
)
@dataclass
class BaseMessageInfo:
@@ -147,7 +178,7 @@ class MessageBase:
"""
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message')
raw_message = data.get('raw_message',None)
return cls(
message_info=message_info,
message_segment=message_segment,

View File

@@ -139,26 +139,23 @@ class MessageSendCQ(MessageCQ):
def __init__(
self,
message_id: int,
user_id: int,
message_segment: Seg,
group_id: Optional[int] = None,
reply_to_message_id: Optional[int] = None,
platform: str = "qq"
data: Dict
):
# 调用父类初始化
super().__init__(message_id, user_id, group_id, platform)
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
super().__init__(
message_info.message_id,
message_info.user_info.user_id,
message_info.group_info.group_id if message_info.group_info else None,
message_info.platform)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message(reply_to_message_id)
self.raw_message = self._generate_raw_message()
def _generate_raw_message(self, reply_to_message_id: Optional[int] = None) -> str:
def _generate_raw_message(self, ) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 添加回复消息
if reply_to_message_id:
segments.append(cq_code_tool.create_reply_cq(reply_to_message_id))
# 处理消息段
if self.message_segment.type == 'seglist':

View File

@@ -29,13 +29,12 @@ class Message_Sender:
) -> None:
"""发送消息"""
if isinstance(message, MessageSending):
message_json = message.to_dict()
message_send=MessageSendCQ(
message_id=message.message_id,
user_id=message.message_info.user_info.user_id,
message_segment=message.message_segment,
reply=message.reply
data=message_json
)
if message.message_info.group_info:
if message_send.message_info.group_info:
try:
await self._current_bot.send_group_msg(
group_id=message.message_info.group_info.group_id,

View File

@@ -8,6 +8,7 @@ from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule
from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import ChatStream, chat_manager
class PromptBuilder:
@@ -22,7 +23,7 @@ class PromptBuilder:
message_txt: str,
sender_name: str = "某人",
relationship_value: float = 0.0,
group_id: Optional[int] = None) -> tuple[str, str]:
stream_id: Optional[int] = None) -> tuple[str, str]:
"""构建prompt
Args:
@@ -72,11 +73,17 @@ class PromptBuilder:
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文
chat_in_group=True
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id)
if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
else:
chat_in_group=False
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
@@ -112,8 +119,10 @@ class PromptBuilder:
#激活prompt构建
activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
if chat_in_group:
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
else:
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
#检测机器人相关词汇
bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
@@ -129,16 +138,20 @@ class PromptBuilder:
probability_3 = global_config.PERSONALITY_3
prompt_personality = ''
personality_choice = random.random()
if chat_in_group:
prompt_in_group=f"你正在浏览{chat_stream.platform}"
else:
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊"
if personality_choice < probability_1: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}{prompt_in_group},{promt_info_prompt},
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt},
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}{prompt_in_group}{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
请你表达自己的见解和观点。可以有个性。'''
else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt},
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}{prompt_in_group}{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
请你表达自己的见解和观点。可以有个性。'''

View File

@@ -16,17 +16,16 @@ class Impression:
class Relationship:
user_id: int = None
platform: str = None
platform: str = None
gender: str = None
age: int = None
nickname: str = None
relationship_value: float = None
saved = False
def __init__(self, chat:ChatStream,data:dict):
self.user_id=chat.user_info.user_id
self.platform=chat.platform
self.nickname=chat.user_info.user_nickname
def __init__(self, chat:ChatStream=None,data:dict=None):
self.user_id=chat.user_info.user_id if chat.user_info else data.get('user_id',0)
self.platform=chat.platform if chat.user_info else data.get('platform','')
self.nickname=chat.user_info.user_nickname if chat.user_info else data.get('nickname','')
self.relationship_value=data.get('relationship_value',0)
self.age=data.get('age',0)
self.gender=data.get('gender','')
@@ -35,7 +34,6 @@ class Relationship:
class RelationshipManager:
def __init__(self):
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self,
chat_stream:ChatStream,
@@ -43,9 +41,7 @@ class RelationshipManager:
**kwargs) -> Optional[Relationship]:
"""更新或创建关系
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
chat_stream: 聊天流对象
data: 字典格式的数据(可选)
**kwargs: 其他参数
Returns:
@@ -66,44 +62,18 @@ class RelationshipManager:
# 检查是否在内存中已存在
relationship = self.relationships.get(key)
relationship = self.relationships.get(key)
if relationship:
# 如果存在,更新现有对象
if isinstance(data, dict):
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else:
for k, value in kwargs.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
for k, value in kwargs.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else:
# 如果不存在,创建新对象
if user_info is not None:
relationship = Relationship(user_info=user_info, **kwargs)
elif isinstance(data, dict):
data['platform'] = platform
relationship = Relationship(user_id=user_id, data=data)
if chat_stream.user_info is not None:
relationship = Relationship(chat=chat_stream, **kwargs)
else:
kwargs['platform'] = platform
kwargs['user_id'] = user_id
relationship = Relationship(**kwargs)
self.relationships[key] = relationship
if user_info is not None:
relationship = Relationship(user_info=user_info, **kwargs)
elif isinstance(data, dict):
data['platform'] = platform
relationship = Relationship(user_id=user_id, data=data)
else:
kwargs['platform'] = platform
kwargs['user_id'] = user_id
relationship = Relationship(**kwargs)
raise ValueError("必须提供user_id或user_info")
self.relationships[key] = relationship
# 保存到数据库
@@ -113,36 +83,7 @@ class RelationshipManager:
return relationship
async def update_relationship_value(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
async def update_relationship_value(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None,
chat_stream:ChatStream,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
@@ -154,6 +95,7 @@ class RelationshipManager:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
@@ -168,10 +110,7 @@ class RelationshipManager:
# 检查是否在内存中已存在
relationship = self.relationships.get(key)
relationship = self.relationships.get(key)
if relationship:
for k, value in kwargs.items():
if k == 'relationship_value':
for k, value in kwargs.items():
if k == 'relationship_value':
relationship.relationship_value += value
@@ -181,43 +120,12 @@ class RelationshipManager:
else:
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs)
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
return None
def get_relationship(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
def get_relationship(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> Optional[Relationship]:
chat_stream:ChatStream) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
@@ -227,6 +135,8 @@ class RelationshipManager:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
platform = chat_stream.user_info.platform or 'qq'
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
@@ -248,18 +158,10 @@ class RelationshipManager:
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
"""从数据库加载或创建新的关系对象"""
# 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
rela.saved = True
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela
async def load_all_relationships(self):
@@ -277,7 +179,6 @@ class RelationshipManager:
# 依次加载每条记录
for data in all_relationships:
await self.load_relationship(data)
await self.load_relationship(data)
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
while True:
@@ -288,19 +189,15 @@ class RelationshipManager:
async def _save_all_relationships(self):
"""将所有关系数据保存到数据库"""
# 保存所有关系数据
for (userid, platform), relationship in self.relationships.items():
for (userid, platform), relationship in self.relationships.items():
if not relationship.saved:
relationship.saved = True
await self.storage_relationship(relationship)
async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中"""
async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中"""
user_id = relationship.user_id
platform = relationship.platform
platform = relationship.platform
nickname = relationship.nickname
relationship_value = relationship.relationship_value
gender = relationship.gender
@@ -309,10 +206,8 @@ class RelationshipManager:
db = Database.get_instance()
db.db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'user_id': user_id, 'platform': platform},
{'$set': {
'platform': platform,
'platform': platform,
'nickname': nickname,
'relationship_value': relationship_value,
@@ -323,27 +218,6 @@ class RelationshipManager:
upsert=True
)
def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
def get_name(self,
user_id: int = None,
@@ -370,11 +244,6 @@ class RelationshipManager:
# 确保user_id是整数类型
user_id = int(user_id)
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None:

View File

@@ -18,8 +18,9 @@ class MessageStorage:
"time": message.message_info.time,
"chat_id":chat_stream.stream_id,
"chat_info": chat_stream.to_dict(),
"detailed_plain_text": message.detailed_plain_text,
"user_info": message.message_info.user_info.to_dict(),
"processed_plain_text": message.processed_plain_text,
"detailed_plain_text": message.detailed_plain_text,
"topic": topic,
}
self.db.db.messages.insert_one(message_data)

View File

@@ -11,7 +11,9 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config
from .message_cq import Message
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
from .chat_stream import ChatStream
driver = get_driver()
config = driver.config
@@ -32,7 +34,7 @@ def db_message_to_str(message_dict: Dict) -> str:
return result
def is_mentioned_bot_in_message(message: Message) -> bool:
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
for keyword in keywords:
@@ -41,15 +43,6 @@ def is_mentioned_bot_in_message(message: Message) -> bool:
return False
def is_mentioned_bot_in_txt(message: str) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
for keyword in keywords:
if keyword in message:
return True
return False
async def get_embedding(text):
"""获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding)
@@ -84,10 +77,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
chat_id = closest_record['chat_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
{"time": {"$gt": closest_time}, "chat_id": chat_id}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
@@ -111,7 +104,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return ''
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
@@ -125,35 +118,28 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
{"chat_id": chat_id},
).sort("time", -1).limit(limit))
if not recent_messages:
return []
# 转换为 Message对象列表
from .message_cq import Message
message_objects = []
for msg_data in recent_messages:
try:
chat_info=msg_data.get("chat_info",{})
chat_stream=ChatStream.from_dict(chat_info)
user_info=msg_data.get("user_info",{})
user_info=UserInfo.from_dict(user_info)
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
chat_stream=chat_stream,
time=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
detailed_plain_text=msg_data.get("detailed_plain_text", "")
)
await msg.initialize()
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
@@ -164,13 +150,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return message_objects
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{"chat_id": chat_stream_id},
{
"time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段
"user_nickname": 1, # 返回用户昵称字段
"chat_id":1,
"chat_info":1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段
}