Merge branch 'refractor' of https://github.com/tcmofashi/MaiMBot into refractor
This commit is contained in:
@@ -11,18 +11,25 @@ from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
|
|||||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||||
from .llm_generator import ResponseGenerator
|
from .llm_generator import ResponseGenerator
|
||||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||||
|
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||||
from .message_cq import (
|
from .message_cq import (
|
||||||
MessageRecvCQ,
|
MessageRecvCQ,
|
||||||
MessageSendCQ,
|
MessageSendCQ,
|
||||||
)
|
)
|
||||||
|
from .chat_stream import chat_manager
|
||||||
|
MessageRecvCQ,
|
||||||
|
MessageSendCQ,
|
||||||
|
)
|
||||||
from .chat_stream import chat_manager
|
from .chat_stream import chat_manager
|
||||||
from .message_sender import message_manager # 导入新的消息管理器
|
from .message_sender import message_manager # 导入新的消息管理器
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
||||||
from .utils_image import image_path_to_base64
|
from .utils_image import image_path_to_base64
|
||||||
|
from .utils_image import image_path_to_base64
|
||||||
from .willing_manager import willing_manager # 导入意愿管理器
|
from .willing_manager import willing_manager # 导入意愿管理器
|
||||||
from .message_base import UserInfo, GroupInfo, Seg
|
from .message_base import UserInfo, GroupInfo, Seg
|
||||||
|
from .message_base import UserInfo, GroupInfo, Seg
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -47,6 +54,7 @@ class ChatBot:
|
|||||||
self.bot = bot # 更新 bot 实例
|
self.bot = bot # 更新 bot 实例
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||||
@@ -54,11 +62,15 @@ class ChatBot:
|
|||||||
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
||||||
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
||||||
|
|
||||||
|
message_cq=MessageRecvCQ(
|
||||||
message_cq=MessageRecvCQ(
|
message_cq=MessageRecvCQ(
|
||||||
message_id=event.message_id,
|
message_id=event.message_id,
|
||||||
user_id=event.user_id,
|
user_id=event.user_id,
|
||||||
raw_message=str(event.original_message),
|
raw_message=str(event.original_message),
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
|
user_id=event.user_id,
|
||||||
|
raw_message=str(event.original_message),
|
||||||
|
group_id=event.group_id,
|
||||||
reply_message=event.reply,
|
reply_message=event.reply,
|
||||||
platform='qq'
|
platform='qq'
|
||||||
)
|
)
|
||||||
@@ -88,12 +100,15 @@ class ChatBot:
|
|||||||
return
|
return
|
||||||
# 过滤词
|
# 过滤词
|
||||||
for word in global_config.ban_words:
|
for word in global_config.ban_words:
|
||||||
|
if word in message.processed_plain_text:
|
||||||
|
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||||
if word in message.processed_plain_text:
|
if word in message.processed_plain_text:
|
||||||
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||||
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
||||||
return
|
return
|
||||||
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
||||||
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -124,6 +139,11 @@ class ChatBot:
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
|
bot_user_info=UserInfo(
|
||||||
|
user_id=global_config.BOT_QQ,
|
||||||
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
|
platform=messageinfo.platform
|
||||||
|
)
|
||||||
bot_user_info=UserInfo(
|
bot_user_info=UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
@@ -136,6 +156,11 @@ class ChatBot:
|
|||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
reply=message
|
reply=message
|
||||||
)
|
)
|
||||||
|
thinking_message = MessageThinking.from_chat_stream(
|
||||||
|
chat_stream=chat,
|
||||||
|
message_id=think_id,
|
||||||
|
reply=message
|
||||||
|
)
|
||||||
|
|
||||||
message_manager.add_message(thinking_message)
|
message_manager.add_message(thinking_message)
|
||||||
|
|
||||||
@@ -146,6 +171,7 @@ class ChatBot:
|
|||||||
response,raw_content = await self.gpt.generate_response(message)
|
response,raw_content = await self.gpt.generate_response(message)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
|
container = message_manager.get_container(chat.stream_id)
|
||||||
container = message_manager.get_container(chat.stream_id)
|
container = message_manager.get_container(chat.stream_id)
|
||||||
thinking_message = None
|
thinking_message = None
|
||||||
# 找到message,删除
|
# 找到message,删除
|
||||||
@@ -163,6 +189,7 @@ class ChatBot:
|
|||||||
#记录开始思考的时间,避免从思考到回复的时间太久
|
#记录开始思考的时间,避免从思考到回复的时间太久
|
||||||
thinking_start_time = thinking_message.thinking_start_time
|
thinking_start_time = thinking_message.thinking_start_time
|
||||||
message_set = MessageSet(chat, think_id)
|
message_set = MessageSet(chat, think_id)
|
||||||
|
message_set = MessageSet(chat, think_id)
|
||||||
#计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
#计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
||||||
accu_typing_time = 0
|
accu_typing_time = 0
|
||||||
|
|
||||||
@@ -174,6 +201,8 @@ class ChatBot:
|
|||||||
accu_typing_time += typing_time
|
accu_typing_time += typing_time
|
||||||
timepoint = tinking_time_point + accu_typing_time
|
timepoint = tinking_time_point + accu_typing_time
|
||||||
|
|
||||||
|
message_segment = Seg(type='text', data=msg)
|
||||||
|
bot_message = MessageSending(
|
||||||
message_segment = Seg(type='text', data=msg)
|
message_segment = Seg(type='text', data=msg)
|
||||||
bot_message = MessageSending(
|
bot_message = MessageSending(
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
@@ -182,6 +211,12 @@ class ChatBot:
|
|||||||
reply=message,
|
reply=message,
|
||||||
is_head=not mark_head,
|
is_head=not mark_head,
|
||||||
is_emoji=False
|
is_emoji=False
|
||||||
|
)
|
||||||
|
chat_stream=chat,
|
||||||
|
message_segment=message_segment,
|
||||||
|
reply=message,
|
||||||
|
is_head=not mark_head,
|
||||||
|
is_emoji=False
|
||||||
)
|
)
|
||||||
if not mark_head:
|
if not mark_head:
|
||||||
mark_head = True
|
mark_head = True
|
||||||
@@ -200,6 +235,7 @@ class ChatBot:
|
|||||||
if emoji_raw != None:
|
if emoji_raw != None:
|
||||||
emoji_path,discription = emoji_raw
|
emoji_path,discription = emoji_raw
|
||||||
|
|
||||||
|
emoji_cq = image_path_to_base64(emoji_path)
|
||||||
emoji_cq = image_path_to_base64(emoji_path)
|
emoji_cq = image_path_to_base64(emoji_path)
|
||||||
|
|
||||||
if random() < 0.5:
|
if random() < 0.5:
|
||||||
@@ -207,6 +243,15 @@ class ChatBot:
|
|||||||
else:
|
else:
|
||||||
bot_response_time = bot_response_time + 1
|
bot_response_time = bot_response_time + 1
|
||||||
|
|
||||||
|
message_segment = Seg(type='emoji', data=emoji_cq)
|
||||||
|
bot_message = MessageSending(
|
||||||
|
message_id=think_id,
|
||||||
|
chat_stream=chat,
|
||||||
|
message_segment=message_segment,
|
||||||
|
reply=message,
|
||||||
|
is_head=False,
|
||||||
|
is_emoji=True
|
||||||
|
)
|
||||||
message_segment = Seg(type='emoji', data=emoji_cq)
|
message_segment = Seg(type='emoji', data=emoji_cq)
|
||||||
bot_message = MessageSending(
|
bot_message = MessageSending(
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
@@ -237,6 +282,12 @@ class ChatBot:
|
|||||||
user_info=userinfo,
|
user_info=userinfo,
|
||||||
group_info=groupinfo
|
group_info=groupinfo
|
||||||
)
|
)
|
||||||
|
|
||||||
|
willing_manager.change_reply_willing_after_sent(
|
||||||
|
platform=messageinfo.platform,
|
||||||
|
user_info=userinfo,
|
||||||
|
group_info=groupinfo
|
||||||
|
)
|
||||||
|
|
||||||
# 创建全局ChatBot实例
|
# 创建全局ChatBot实例
|
||||||
chat_bot = ChatBot()
|
chat_bot = ChatBot()
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .message_base import UserInfo
|
from .message_base import UserInfo
|
||||||
@@ -15,6 +16,7 @@ class Impression:
|
|||||||
class Relationship:
|
class Relationship:
|
||||||
user_id: int = None
|
user_id: int = None
|
||||||
platform: str = None
|
platform: str = None
|
||||||
|
platform: str = None
|
||||||
gender: str = None
|
gender: str = None
|
||||||
age: int = None
|
age: int = None
|
||||||
nickname: str = None
|
nickname: str = None
|
||||||
@@ -33,6 +35,7 @@ class Relationship:
|
|||||||
class RelationshipManager:
|
class RelationshipManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||||
|
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||||
|
|
||||||
async def update_relationship(self,
|
async def update_relationship(self,
|
||||||
chat_stream:ChatStream,
|
chat_stream:ChatStream,
|
||||||
@@ -63,16 +66,23 @@ class RelationshipManager:
|
|||||||
|
|
||||||
# 检查是否在内存中已存在
|
# 检查是否在内存中已存在
|
||||||
relationship = self.relationships.get(key)
|
relationship = self.relationships.get(key)
|
||||||
|
relationship = self.relationships.get(key)
|
||||||
if relationship:
|
if relationship:
|
||||||
# 如果存在,更新现有对象
|
# 如果存在,更新现有对象
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
for k, value in data.items():
|
for k, value in data.items():
|
||||||
if hasattr(relationship, k) and value is not None:
|
if hasattr(relationship, k) and value is not None:
|
||||||
setattr(relationship, k, value)
|
setattr(relationship, k, value)
|
||||||
|
for k, value in data.items():
|
||||||
|
if hasattr(relationship, k) and value is not None:
|
||||||
|
setattr(relationship, k, value)
|
||||||
else:
|
else:
|
||||||
for k, value in kwargs.items():
|
for k, value in kwargs.items():
|
||||||
if hasattr(relationship, k) and value is not None:
|
if hasattr(relationship, k) and value is not None:
|
||||||
setattr(relationship, k, value)
|
setattr(relationship, k, value)
|
||||||
|
for k, value in kwargs.items():
|
||||||
|
if hasattr(relationship, k) and value is not None:
|
||||||
|
setattr(relationship, k, value)
|
||||||
else:
|
else:
|
||||||
# 如果不存在,创建新对象
|
# 如果不存在,创建新对象
|
||||||
if user_info is not None:
|
if user_info is not None:
|
||||||
@@ -85,6 +95,16 @@ class RelationshipManager:
|
|||||||
kwargs['user_id'] = user_id
|
kwargs['user_id'] = user_id
|
||||||
relationship = Relationship(**kwargs)
|
relationship = Relationship(**kwargs)
|
||||||
self.relationships[key] = relationship
|
self.relationships[key] = relationship
|
||||||
|
if user_info is not None:
|
||||||
|
relationship = Relationship(user_info=user_info, **kwargs)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
data['platform'] = platform
|
||||||
|
relationship = Relationship(user_id=user_id, data=data)
|
||||||
|
else:
|
||||||
|
kwargs['platform'] = platform
|
||||||
|
kwargs['user_id'] = user_id
|
||||||
|
relationship = Relationship(**kwargs)
|
||||||
|
self.relationships[key] = relationship
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
await self.storage_relationship(relationship)
|
await self.storage_relationship(relationship)
|
||||||
@@ -92,6 +112,33 @@ class RelationshipManager:
|
|||||||
|
|
||||||
return relationship
|
return relationship
|
||||||
|
|
||||||
|
async def update_relationship_value(self,
|
||||||
|
user_id: int = None,
|
||||||
|
platform: str = None,
|
||||||
|
user_info: UserInfo = None,
|
||||||
|
**kwargs) -> Optional[Relationship]:
|
||||||
|
"""更新关系值
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||||
|
platform: 平台(可选,如果提供user_info则不需要)
|
||||||
|
user_info: 用户信息对象(可选)
|
||||||
|
**kwargs: 其他参数
|
||||||
|
Returns:
|
||||||
|
Relationship: 关系对象
|
||||||
|
"""
|
||||||
|
# 确定user_id和platform
|
||||||
|
if user_info is not None:
|
||||||
|
user_id = user_info.user_id
|
||||||
|
platform = user_info.platform or 'qq'
|
||||||
|
else:
|
||||||
|
platform = platform or 'qq'
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise ValueError("必须提供user_id或user_info")
|
||||||
|
|
||||||
|
# 使用(user_id, platform)作为键
|
||||||
|
key = (user_id, platform)
|
||||||
|
|
||||||
async def update_relationship_value(self,
|
async def update_relationship_value(self,
|
||||||
user_id: int = None,
|
user_id: int = None,
|
||||||
platform: str = None,
|
platform: str = None,
|
||||||
@@ -121,7 +168,10 @@ class RelationshipManager:
|
|||||||
|
|
||||||
# 检查是否在内存中已存在
|
# 检查是否在内存中已存在
|
||||||
relationship = self.relationships.get(key)
|
relationship = self.relationships.get(key)
|
||||||
|
relationship = self.relationships.get(key)
|
||||||
if relationship:
|
if relationship:
|
||||||
|
for k, value in kwargs.items():
|
||||||
|
if k == 'relationship_value':
|
||||||
for k, value in kwargs.items():
|
for k, value in kwargs.items():
|
||||||
if k == 'relationship_value':
|
if k == 'relationship_value':
|
||||||
relationship.relationship_value += value
|
relationship.relationship_value += value
|
||||||
@@ -129,12 +179,41 @@ class RelationshipManager:
|
|||||||
relationship.saved = True
|
relationship.saved = True
|
||||||
return relationship
|
return relationship
|
||||||
else:
|
else:
|
||||||
|
# 如果不存在且提供了user_info,则创建新的关系
|
||||||
|
if user_info is not None:
|
||||||
|
return await self.update_relationship(user_info=user_info, **kwargs)
|
||||||
|
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||||
# 如果不存在且提供了user_info,则创建新的关系
|
# 如果不存在且提供了user_info,则创建新的关系
|
||||||
if user_info is not None:
|
if user_info is not None:
|
||||||
return await self.update_relationship(user_info=user_info, **kwargs)
|
return await self.update_relationship(user_info=user_info, **kwargs)
|
||||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_relationship(self,
|
||||||
|
user_id: int = None,
|
||||||
|
platform: str = None,
|
||||||
|
user_info: UserInfo = None) -> Optional[Relationship]:
|
||||||
|
"""获取用户关系对象
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||||
|
platform: 平台(可选,如果提供user_info则不需要)
|
||||||
|
user_info: 用户信息对象(可选)
|
||||||
|
Returns:
|
||||||
|
Relationship: 关系对象
|
||||||
|
"""
|
||||||
|
# 确定user_id和platform
|
||||||
|
if user_info is not None:
|
||||||
|
user_id = user_info.user_id
|
||||||
|
platform = user_info.platform or 'qq'
|
||||||
|
else:
|
||||||
|
platform = platform or 'qq'
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise ValueError("必须提供user_id或user_info")
|
||||||
|
|
||||||
|
key = (user_id, platform)
|
||||||
|
if key in self.relationships:
|
||||||
|
return self.relationships[key]
|
||||||
def get_relationship(self,
|
def get_relationship(self,
|
||||||
user_id: int = None,
|
user_id: int = None,
|
||||||
platform: str = None,
|
platform: str = None,
|
||||||
@@ -169,10 +248,18 @@ class RelationshipManager:
|
|||||||
if 'platform' not in data:
|
if 'platform' not in data:
|
||||||
data['platform'] = 'qq'
|
data['platform'] = 'qq'
|
||||||
|
|
||||||
|
rela = Relationship(data=data)
|
||||||
|
"""从数据库加载或创建新的关系对象"""
|
||||||
|
# 确保data中有platform字段,如果没有则默认为'qq'
|
||||||
|
if 'platform' not in data:
|
||||||
|
data['platform'] = 'qq'
|
||||||
|
|
||||||
rela = Relationship(data=data)
|
rela = Relationship(data=data)
|
||||||
rela.saved = True
|
rela.saved = True
|
||||||
key = (rela.user_id, rela.platform)
|
key = (rela.user_id, rela.platform)
|
||||||
self.relationships[key] = rela
|
self.relationships[key] = rela
|
||||||
|
key = (rela.user_id, rela.platform)
|
||||||
|
self.relationships[key] = rela
|
||||||
return rela
|
return rela
|
||||||
|
|
||||||
async def load_all_relationships(self):
|
async def load_all_relationships(self):
|
||||||
@@ -190,6 +277,7 @@ class RelationshipManager:
|
|||||||
# 依次加载每条记录
|
# 依次加载每条记录
|
||||||
for data in all_relationships:
|
for data in all_relationships:
|
||||||
await self.load_relationship(data)
|
await self.load_relationship(data)
|
||||||
|
await self.load_relationship(data)
|
||||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -200,15 +288,19 @@ class RelationshipManager:
|
|||||||
async def _save_all_relationships(self):
|
async def _save_all_relationships(self):
|
||||||
"""将所有关系数据保存到数据库"""
|
"""将所有关系数据保存到数据库"""
|
||||||
# 保存所有关系数据
|
# 保存所有关系数据
|
||||||
|
for (userid, platform), relationship in self.relationships.items():
|
||||||
for (userid, platform), relationship in self.relationships.items():
|
for (userid, platform), relationship in self.relationships.items():
|
||||||
if not relationship.saved:
|
if not relationship.saved:
|
||||||
relationship.saved = True
|
relationship.saved = True
|
||||||
await self.storage_relationship(relationship)
|
await self.storage_relationship(relationship)
|
||||||
|
|
||||||
|
async def storage_relationship(self, relationship: Relationship):
|
||||||
|
"""将关系记录存储到数据库中"""
|
||||||
async def storage_relationship(self, relationship: Relationship):
|
async def storage_relationship(self, relationship: Relationship):
|
||||||
"""将关系记录存储到数据库中"""
|
"""将关系记录存储到数据库中"""
|
||||||
user_id = relationship.user_id
|
user_id = relationship.user_id
|
||||||
platform = relationship.platform
|
platform = relationship.platform
|
||||||
|
platform = relationship.platform
|
||||||
nickname = relationship.nickname
|
nickname = relationship.nickname
|
||||||
relationship_value = relationship.relationship_value
|
relationship_value = relationship.relationship_value
|
||||||
gender = relationship.gender
|
gender = relationship.gender
|
||||||
@@ -217,8 +309,10 @@ class RelationshipManager:
|
|||||||
|
|
||||||
db = Database.get_instance()
|
db = Database.get_instance()
|
||||||
db.db.relationships.update_one(
|
db.db.relationships.update_one(
|
||||||
|
{'user_id': user_id, 'platform': platform},
|
||||||
{'user_id': user_id, 'platform': platform},
|
{'user_id': user_id, 'platform': platform},
|
||||||
{'$set': {
|
{'$set': {
|
||||||
|
'platform': platform,
|
||||||
'platform': platform,
|
'platform': platform,
|
||||||
'nickname': nickname,
|
'nickname': nickname,
|
||||||
'relationship_value': relationship_value,
|
'relationship_value': relationship_value,
|
||||||
@@ -229,6 +323,28 @@ class RelationshipManager:
|
|||||||
upsert=True
|
upsert=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_name(self,
|
||||||
|
user_id: int = None,
|
||||||
|
platform: str = None,
|
||||||
|
user_info: UserInfo = None) -> str:
|
||||||
|
"""获取用户昵称
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||||
|
platform: 平台(可选,如果提供user_info则不需要)
|
||||||
|
user_info: 用户信息对象(可选)
|
||||||
|
Returns:
|
||||||
|
str: 用户昵称
|
||||||
|
"""
|
||||||
|
# 确定user_id和platform
|
||||||
|
if user_info is not None:
|
||||||
|
user_id = user_info.user_id
|
||||||
|
platform = user_info.platform or 'qq'
|
||||||
|
else:
|
||||||
|
platform = platform or 'qq'
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise ValueError("必须提供user_id或user_info")
|
||||||
|
|
||||||
def get_name(self,
|
def get_name(self,
|
||||||
user_id: int = None,
|
user_id: int = None,
|
||||||
platform: str = None,
|
platform: str = None,
|
||||||
@@ -254,6 +370,11 @@ class RelationshipManager:
|
|||||||
# 确保user_id是整数类型
|
# 确保user_id是整数类型
|
||||||
user_id = int(user_id)
|
user_id = int(user_id)
|
||||||
key = (user_id, platform)
|
key = (user_id, platform)
|
||||||
|
if key in self.relationships:
|
||||||
|
return self.relationships[key].nickname
|
||||||
|
elif user_info is not None:
|
||||||
|
return user_info.user_nickname or user_info.user_cardname or "某人"
|
||||||
|
key = (user_id, platform)
|
||||||
if key in self.relationships:
|
if key in self.relationships:
|
||||||
return self.relationships[key].nickname
|
return self.relationships[key].nickname
|
||||||
elif user_info is not None:
|
elif user_info is not None:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .message_base import MessageBase
|
from .message_base import MessageBase
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ import asyncio
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .message_base import UserInfo, GroupInfo
|
from .message_base import UserInfo, GroupInfo
|
||||||
from .chat_stream import chat_manager,ChatStream
|
from .chat_stream import chat_manager,ChatStream
|
||||||
@@ -9,6 +12,7 @@ from .chat_stream import chat_manager,ChatStream
|
|||||||
|
|
||||||
class WillingManager:
|
class WillingManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||||
self._decay_task = None
|
self._decay_task = None
|
||||||
self._started = False
|
self._started = False
|
||||||
@@ -19,6 +23,8 @@ class WillingManager:
|
|||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
for chat_id in self.chat_reply_willing:
|
for chat_id in self.chat_reply_willing:
|
||||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||||
|
for chat_id in self.chat_reply_willing:
|
||||||
|
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||||
|
|
||||||
def get_willing(self,chat_stream:ChatStream) -> float:
|
def get_willing(self,chat_stream:ChatStream) -> float:
|
||||||
"""获取指定聊天流的回复意愿"""
|
"""获取指定聊天流的回复意愿"""
|
||||||
@@ -27,6 +33,9 @@ class WillingManager:
|
|||||||
return self.chat_reply_willing.get(stream.stream_id, 0)
|
return self.chat_reply_willing.get(stream.stream_id, 0)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def set_willing(self, chat_id: str, willing: float):
|
||||||
|
"""设置指定聊天流的回复意愿"""
|
||||||
|
self.chat_reply_willing[chat_id] = willing
|
||||||
def set_willing(self, chat_id: str, willing: float):
|
def set_willing(self, chat_id: str, willing: float):
|
||||||
"""设置指定聊天流的回复意愿"""
|
"""设置指定聊天流的回复意愿"""
|
||||||
self.chat_reply_willing[chat_id] = willing
|
self.chat_reply_willing[chat_id] = willing
|
||||||
@@ -81,6 +90,7 @@ class WillingManager:
|
|||||||
if reply_probability < 0:
|
if reply_probability < 0:
|
||||||
reply_probability = 0
|
reply_probability = 0
|
||||||
|
|
||||||
|
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||||
return reply_probability
|
return reply_probability
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user