Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -20,3 +20,4 @@
|
|||||||
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
||||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||||
|
4. 现在增加了参数类型检查,完善了对应注释
|
||||||
@@ -47,7 +47,7 @@ class MaiEmoji:
|
|||||||
self.embedding = []
|
self.embedding = []
|
||||||
self.hash = "" # 初始为空,在创建实例时会计算
|
self.hash = "" # 初始为空,在创建实例时会计算
|
||||||
self.description = ""
|
self.description = ""
|
||||||
self.emotion = []
|
self.emotion: List[str] = []
|
||||||
self.usage_count = 0
|
self.usage_count = 0
|
||||||
self.last_used_time = time.time()
|
self.last_used_time = time.time()
|
||||||
self.register_time = time.time()
|
self.register_time = time.time()
|
||||||
|
|||||||
@@ -243,6 +243,8 @@ class HeartFChatting:
|
|||||||
loop_start_time = time.time()
|
loop_start_time = time.time()
|
||||||
await self.relationship_builder.build_relation()
|
await self.relationship_builder.build_relation()
|
||||||
|
|
||||||
|
available_actions = {}
|
||||||
|
|
||||||
# 第一步:动作修改
|
# 第一步:动作修改
|
||||||
with Timer("动作修改", cycle_timers):
|
with Timer("动作修改", cycle_timers):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -38,7 +38,9 @@ class HeartFCSender:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
|
|
||||||
async def send_message(self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True):
|
async def send_message(
|
||||||
|
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
处理、发送并存储一条消息。
|
处理、发送并存储一条消息。
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ class ActionPlanner:
|
|||||||
|
|
||||||
self.last_obs_time_mark = 0.0
|
self.last_obs_time_mark = 0.0
|
||||||
|
|
||||||
async def plan(self, mode: ChatMode = ChatMode.FOCUS) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension
|
async def plan(
|
||||||
|
self, mode: ChatMode = ChatMode.FOCUS
|
||||||
|
) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -508,7 +508,7 @@ class DefaultReplyer:
|
|||||||
# 构建背景对话 prompt
|
# 构建背景对话 prompt
|
||||||
background_dialogue_prompt = ""
|
background_dialogue_prompt = ""
|
||||||
if background_dialogue_list:
|
if background_dialogue_list:
|
||||||
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size*0.6):]
|
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :]
|
||||||
background_dialogue_prompt_str = build_readable_messages(
|
background_dialogue_prompt_str = build_readable_messages(
|
||||||
latest_25_msgs,
|
latest_25_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
@@ -521,7 +521,7 @@ class DefaultReplyer:
|
|||||||
# 构建核心对话 prompt
|
# 构建核心对话 prompt
|
||||||
core_dialogue_prompt = ""
|
core_dialogue_prompt = ""
|
||||||
if core_dialogue_list:
|
if core_dialogue_list:
|
||||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size*2):] # 限制消息数量
|
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||||
|
|
||||||
core_dialogue_prompt_str = build_readable_messages(
|
core_dialogue_prompt_str = build_readable_messages(
|
||||||
core_dialogue_list,
|
core_dialogue_list,
|
||||||
@@ -586,7 +586,6 @@ class DefaultReplyer:
|
|||||||
limit=global_config.chat.max_context_size * 2,
|
limit=global_config.chat.max_context_size * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
@@ -713,8 +712,6 @@ class DefaultReplyer:
|
|||||||
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
||||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 根据配置选择使用哪种 prompt 构建模式
|
# 根据配置选择使用哪种 prompt 构建模式
|
||||||
if global_config.chat.use_s4u_prompt_mode and person_id:
|
if global_config.chat.use_s4u_prompt_mode and person_id:
|
||||||
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
||||||
@@ -726,7 +723,6 @@ class DefaultReplyer:
|
|||||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||||
target_user_id = ""
|
target_user_id = ""
|
||||||
|
|
||||||
|
|
||||||
# 构建分离的对话 prompt
|
# 构建分离的对话 prompt
|
||||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||||
message_list_before_now_long, target_user_id
|
message_list_before_now_long, target_user_id
|
||||||
|
|||||||
@@ -106,7 +106,6 @@ class ChatConfig(ConfigBase):
|
|||||||
focus_value: float = 1.0
|
focus_value: float = 1.0
|
||||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||||
|
|
||||||
|
|
||||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||||
"""
|
"""
|
||||||
根据当前时间和聊天流获取对应的 talk_frequency
|
根据当前时间和聊天流获取对应的 talk_frequency
|
||||||
@@ -246,6 +245,7 @@ class ChatConfig(ConfigBase):
|
|||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageReceiveConfig(ConfigBase):
|
class MessageReceiveConfig(ConfigBase):
|
||||||
"""消息接收配置类"""
|
"""消息接收配置类"""
|
||||||
@@ -274,8 +274,6 @@ class NormalChatConfig(ConfigBase):
|
|||||||
"""@bot 必然回复"""
|
"""@bot 必然回复"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExpressionConfig(ConfigBase):
|
class ExpressionConfig(ConfigBase):
|
||||||
"""表达配置类"""
|
"""表达配置类"""
|
||||||
|
|||||||
@@ -63,10 +63,10 @@ class Individuality:
|
|||||||
personality_side: 人格侧面描述
|
personality_side: 人格侧面描述
|
||||||
identity: 身份细节描述
|
identity: 身份细节描述
|
||||||
"""
|
"""
|
||||||
bot_nickname=global_config.bot.nickname
|
bot_nickname = global_config.bot.nickname
|
||||||
personality_core=global_config.personality.personality_core
|
personality_core = global_config.personality.personality_core
|
||||||
personality_side=global_config.personality.personality_side
|
personality_side = global_config.personality.personality_side
|
||||||
identity=global_config.personality.identity
|
identity = global_config.personality.identity
|
||||||
|
|
||||||
logger.info("正在初始化个体特征")
|
logger.info("正在初始化个体特征")
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
@@ -168,7 +168,6 @@ class Individuality:
|
|||||||
else:
|
else:
|
||||||
logger.error("人设构建失败")
|
logger.error("人设构建失败")
|
||||||
|
|
||||||
|
|
||||||
async def get_personality_block(self) -> str:
|
async def get_personality_block(self) -> str:
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||||
@@ -200,7 +199,6 @@ class Individuality:
|
|||||||
|
|
||||||
return identity_block
|
return identity_block
|
||||||
|
|
||||||
|
|
||||||
def _get_config_hash(
|
def _get_config_hash(
|
||||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
@@ -295,7 +293,6 @@ class Individuality:
|
|||||||
except IOError as e:
|
except IOError as e:
|
||||||
logger.error(f"保存meta_info文件失败: {e}")
|
logger.error(f"保存meta_info文件失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _create_personality(self, personality_core: str, personality_side: str) -> str:
|
async def _create_personality(self, personality_core: str, personality_side: str) -> str:
|
||||||
# sourcery skip: merge-list-append, move-assign
|
# sourcery skip: merge-list-append, move-assign
|
||||||
"""使用LLM创建压缩版本的impression
|
"""使用LLM创建压缩版本的impression
|
||||||
|
|||||||
@@ -42,7 +42,15 @@ class Personality:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, bot_nickname: str, personality_core: str, personality_side: str, identity: List[str] = None, compress_personality: bool = True, compress_identity: bool = True) -> "Personality":
|
def initialize(
|
||||||
|
cls,
|
||||||
|
bot_nickname: str,
|
||||||
|
personality_core: str,
|
||||||
|
personality_side: str,
|
||||||
|
identity: List[str] = None,
|
||||||
|
compress_personality: bool = True,
|
||||||
|
compress_identity: bool = True,
|
||||||
|
) -> "Personality":
|
||||||
"""初始化人格特质
|
"""初始化人格特质
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class ContextMessage:
|
|||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
"content": self.content,
|
"content": self.content,
|
||||||
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
||||||
"group_name": self.group_name
|
"group_name": self.group_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -66,20 +66,20 @@ class ContextWebManager:
|
|||||||
self.app = web.Application()
|
self.app = web.Application()
|
||||||
|
|
||||||
# 设置CORS
|
# 设置CORS
|
||||||
cors = aiohttp_cors.setup(self.app, defaults={
|
cors = aiohttp_cors.setup(
|
||||||
"*": aiohttp_cors.ResourceOptions(
|
self.app,
|
||||||
allow_credentials=True,
|
defaults={
|
||||||
expose_headers="*",
|
"*": aiohttp_cors.ResourceOptions(
|
||||||
allow_headers="*",
|
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
|
||||||
allow_methods="*"
|
)
|
||||||
)
|
},
|
||||||
})
|
)
|
||||||
|
|
||||||
# 添加路由
|
# 添加路由
|
||||||
self.app.router.add_get('/', self.index_handler)
|
self.app.router.add_get("/", self.index_handler)
|
||||||
self.app.router.add_get('/ws', self.websocket_handler)
|
self.app.router.add_get("/ws", self.websocket_handler)
|
||||||
self.app.router.add_get('/api/contexts', self.get_contexts_handler)
|
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
|
||||||
self.app.router.add_get('/debug', self.debug_handler)
|
self.app.router.add_get("/debug", self.debug_handler)
|
||||||
|
|
||||||
# 为所有路由添加CORS
|
# 为所有路由添加CORS
|
||||||
for route in list(self.app.router.routes()):
|
for route in list(self.app.router.routes()):
|
||||||
@@ -88,7 +88,7 @@ class ContextWebManager:
|
|||||||
self.runner = web.AppRunner(self.app)
|
self.runner = web.AppRunner(self.app)
|
||||||
await self.runner.setup()
|
await self.runner.setup()
|
||||||
|
|
||||||
self.site = web.TCPSite(self.runner, 'localhost', self.port)
|
self.site = web.TCPSite(self.runner, "localhost", self.port)
|
||||||
await self.site.start()
|
await self.site.start()
|
||||||
|
|
||||||
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
|
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
|
||||||
@@ -118,7 +118,8 @@ class ContextWebManager:
|
|||||||
|
|
||||||
async def index_handler(self, request):
|
async def index_handler(self, request):
|
||||||
"""主页处理器"""
|
"""主页处理器"""
|
||||||
html_content = '''
|
html_content = (
|
||||||
|
"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
@@ -231,7 +232,9 @@ class ContextWebManager:
|
|||||||
|
|
||||||
function connectWebSocket() {
|
function connectWebSocket() {
|
||||||
console.log('正在连接WebSocket...');
|
console.log('正在连接WebSocket...');
|
||||||
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws');
|
ws = new WebSocket('ws://localhost:"""
|
||||||
|
+ str(self.port)
|
||||||
|
+ """/ws');
|
||||||
|
|
||||||
ws.onopen = function() {
|
ws.onopen = function() {
|
||||||
console.log('WebSocket连接已建立');
|
console.log('WebSocket连接已建立');
|
||||||
@@ -402,8 +405,9 @@ class ContextWebManager:
|
|||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
'''
|
"""
|
||||||
return web.Response(text=html_content, content_type='text/html')
|
)
|
||||||
|
return web.Response(text=html_content, content_type="text/html")
|
||||||
|
|
||||||
async def websocket_handler(self, request):
|
async def websocket_handler(self, request):
|
||||||
"""WebSocket处理器"""
|
"""WebSocket处理器"""
|
||||||
@@ -418,7 +422,7 @@ class ContextWebManager:
|
|||||||
|
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
if msg.type == WSMsgType.ERROR:
|
if msg.type == WSMsgType.ERROR:
|
||||||
logger.error(f'WebSocket错误: {ws.exception()}')
|
logger.error(f"WebSocket错误: {ws.exception()}")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 清理断开的连接
|
# 清理断开的连接
|
||||||
@@ -438,7 +442,7 @@ class ContextWebManager:
|
|||||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||||
|
|
||||||
# 转换为字典格式
|
# 转换为字典格式
|
||||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||||
|
|
||||||
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
|
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
|
||||||
return web.json_response({"contexts": contexts_data})
|
return web.json_response({"contexts": contexts_data})
|
||||||
@@ -461,14 +465,14 @@ class ContextWebManager:
|
|||||||
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
||||||
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
|
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
|
||||||
|
|
||||||
chats_html += f'''
|
chats_html += f"""
|
||||||
<div class="chat">
|
<div class="chat">
|
||||||
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
|
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
|
||||||
{messages_html}
|
{messages_html}
|
||||||
</div>
|
</div>
|
||||||
'''
|
"""
|
||||||
|
|
||||||
html_content = f'''
|
html_content = f"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
@@ -510,9 +514,9 @@ class ContextWebManager:
|
|||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
'''
|
"""
|
||||||
|
|
||||||
return web.Response(text=html_content, content_type='text/html')
|
return web.Response(text=html_content, content_type="text/html")
|
||||||
|
|
||||||
async def add_message(self, chat_id: str, message: MessageRecv):
|
async def add_message(self, chat_id: str, message: MessageRecv):
|
||||||
"""添加新消息到上下文"""
|
"""添加新消息到上下文"""
|
||||||
@@ -526,14 +530,18 @@ class ContextWebManager:
|
|||||||
# 统计当前总消息数
|
# 统计当前总消息数
|
||||||
total_messages = sum(len(contexts) for contexts in self.contexts.values())
|
total_messages = sum(len(contexts) for contexts in self.contexts.values())
|
||||||
|
|
||||||
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}")
|
logger.info(
|
||||||
|
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
|
||||||
|
)
|
||||||
|
|
||||||
# 调试:打印当前所有消息
|
# 调试:打印当前所有消息
|
||||||
logger.info(f"📝 当前上下文中的所有消息:")
|
logger.info(f"📝 当前上下文中的所有消息:")
|
||||||
for cid, contexts in self.contexts.items():
|
for cid, contexts in self.contexts.items():
|
||||||
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
|
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
|
||||||
for i, msg in enumerate(contexts):
|
for i, msg in enumerate(contexts):
|
||||||
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...")
|
logger.info(
|
||||||
|
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
|
||||||
|
)
|
||||||
|
|
||||||
# 广播更新给所有WebSocket连接
|
# 广播更新给所有WebSocket连接
|
||||||
await self.broadcast_contexts()
|
await self.broadcast_contexts()
|
||||||
@@ -548,7 +556,7 @@ class ContextWebManager:
|
|||||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||||
|
|
||||||
# 转换为字典格式
|
# 转换为字典格式
|
||||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||||
|
|
||||||
data = {"contexts": contexts_data}
|
data = {"contexts": contexts_data}
|
||||||
await ws.send_str(json.dumps(data, ensure_ascii=False))
|
await ws.send_str(json.dumps(data, ensure_ascii=False))
|
||||||
@@ -567,7 +575,7 @@ class ContextWebManager:
|
|||||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||||
|
|
||||||
# 转换为字典格式
|
# 转换为字典格式
|
||||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||||
|
|
||||||
data = {"contexts": contexts_data}
|
data = {"contexts": contexts_data}
|
||||||
message = json.dumps(data, ensure_ascii=False)
|
message = json.dumps(data, ensure_ascii=False)
|
||||||
@@ -614,4 +622,3 @@ async def init_context_web_manager():
|
|||||||
manager = get_context_web_manager()
|
manager = get_context_web_manager()
|
||||||
await manager.start_server()
|
await manager.start_server()
|
||||||
return manager
|
return manager
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|||||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
|
|
||||||
|
|
||||||
async def send_loading(chat_id: str, content: str):
|
async def send_loading(chat_id: str, content: str):
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="loading",
|
message_type="loading",
|
||||||
@@ -20,6 +21,7 @@ async def send_loading(chat_id: str, content: str):
|
|||||||
show_log=True,
|
show_log=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def send_unloading(chat_id: str):
|
async def send_unloading(chat_id: str):
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="loading",
|
message_type="loading",
|
||||||
@@ -28,4 +30,3 @@ async def send_unloading(chat_id: str):
|
|||||||
storage_message=False,
|
storage_message=False,
|
||||||
show_log=True,
|
show_log=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,7 +30,6 @@ class MessageSenderContainer:
|
|||||||
self._paused_event = asyncio.Event()
|
self._paused_event = asyncio.Event()
|
||||||
self._paused_event.set() # 默认设置为非暂停状态
|
self._paused_event.set() # 默认设置为非暂停状态
|
||||||
|
|
||||||
|
|
||||||
async def add_message(self, chunk: str):
|
async def add_message(self, chunk: str):
|
||||||
"""向队列中添加一个消息块。"""
|
"""向队列中添加一个消息块。"""
|
||||||
await self.queue.put(chunk)
|
await self.queue.put(chunk)
|
||||||
@@ -201,9 +200,9 @@ class S4UChat:
|
|||||||
score = 0.0
|
score = 0.0
|
||||||
# 如果消息 @ 了机器人,则增加一个很大的分数
|
# 如果消息 @ 了机器人,则增加一个很大的分数
|
||||||
# if f"@{global_config.bot.nickname}" in message.processed_plain_text or any(
|
# if f"@{global_config.bot.nickname}" in message.processed_plain_text or any(
|
||||||
# f"@{alias}" in message.processed_plain_text for alias in global_config.bot.alias_names
|
# f"@{alias}" in message.processed_plain_text for alias in global_config.bot.alias_names
|
||||||
# ):
|
# ):
|
||||||
# score += self.at_bot_priority_bonus
|
# score += self.at_bot_priority_bonus
|
||||||
|
|
||||||
# 加上消息自带的优先级
|
# 加上消息自带的优先级
|
||||||
score += priority_info.get("message_priority", 0.0)
|
score += priority_info.get("message_priority", 0.0)
|
||||||
@@ -302,7 +301,9 @@ class S4UChat:
|
|||||||
self._normal_queue.put_nowait(item)
|
self._normal_queue.put_nowait(item)
|
||||||
|
|
||||||
if removed_count > 0:
|
if removed_count > 0:
|
||||||
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {self.recent_message_keep_count} range.")
|
logger.info(
|
||||||
|
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {self.recent_message_keep_count} range."
|
||||||
|
)
|
||||||
|
|
||||||
async def _message_processor(self):
|
async def _message_processor(self):
|
||||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||||
@@ -396,12 +397,10 @@ class S4UChat:
|
|||||||
# a. 发送文本块
|
# a. 发送文本块
|
||||||
await sender_container.add_message(chunk)
|
await sender_container.add_message(chunk)
|
||||||
|
|
||||||
|
|
||||||
# 等待所有文本消息发送完成
|
# 等待所有文本消息发送完成
|
||||||
await sender_container.close()
|
await sender_container.close()
|
||||||
await sender_container.join()
|
await sender_container.join()
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ class ChatMood:
|
|||||||
"joy": mood_values.get("joy", 5),
|
"joy": mood_values.get("joy", 5),
|
||||||
"anger": mood_values.get("anger", 1),
|
"anger": mood_values.get("anger", 1),
|
||||||
"sorrow": mood_values.get("sorrow", 1),
|
"sorrow": mood_values.get("sorrow", 1),
|
||||||
"fear": mood_values.get("fear", 1)
|
"fear": mood_values.get("fear", 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
@@ -379,14 +379,18 @@ class MoodRegressionTask(AsyncTask):
|
|||||||
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)")
|
logger.info(
|
||||||
|
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
|
||||||
|
)
|
||||||
await mood.regress_mood()
|
await mood.regress_mood()
|
||||||
regression_executed += 1
|
regression_executed += 1
|
||||||
else:
|
else:
|
||||||
if has_extreme_emotion:
|
if has_extreme_emotion:
|
||||||
remaining_time = 5 - time_since_last_change
|
remaining_time = 5 - time_since_last_change
|
||||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||||
logger.debug(f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒")
|
logger.debug(
|
||||||
|
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
remaining_time = 120 - time_since_last_change
|
remaining_time = 120 - time_since_last_change
|
||||||
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
||||||
|
|||||||
@@ -107,7 +107,6 @@ class S4UStreamGenerator:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
|
|
||||||
buffer = ""
|
buffer = ""
|
||||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||||
punctuation_buffer = ""
|
punctuation_buffer = ""
|
||||||
|
|||||||
@@ -43,22 +43,23 @@ logger = get_logger("watching")
|
|||||||
|
|
||||||
class WatchingState(Enum):
|
class WatchingState(Enum):
|
||||||
"""视线状态枚举"""
|
"""视线状态枚举"""
|
||||||
|
|
||||||
WANDERING = "wandering" # 随意看
|
WANDERING = "wandering" # 随意看
|
||||||
DANMU = "danmu" # 看弹幕
|
DANMU = "danmu" # 看弹幕
|
||||||
LENS = "lens" # 看镜头
|
LENS = "lens" # 看镜头
|
||||||
|
|
||||||
|
|
||||||
class ChatWatching:
|
class ChatWatching:
|
||||||
def __init__(self, chat_id: str):
|
def __init__(self, chat_id: str):
|
||||||
self.chat_id: str = chat_id
|
self.chat_id: str = chat_id
|
||||||
self.current_state: WatchingState = WatchingState.LENS # 默认看镜头
|
self.current_state: WatchingState = WatchingState.LENS # 默认看镜头
|
||||||
self.last_sent_state: Optional[WatchingState] = None # 上次发送的状态
|
self.last_sent_state: Optional[WatchingState] = None # 上次发送的状态
|
||||||
self.state_needs_update: bool = True # 是否需要更新状态
|
self.state_needs_update: bool = True # 是否需要更新状态
|
||||||
|
|
||||||
# 状态切换相关
|
# 状态切换相关
|
||||||
self.is_replying: bool = False # 是否正在生成回复
|
self.is_replying: bool = False # 是否正在生成回复
|
||||||
self.reply_finished_time: Optional[float] = None # 回复完成时间
|
self.reply_finished_time: Optional[float] = None # 回复完成时间
|
||||||
self.danmu_viewing_duration: float = 1.0 # 看弹幕持续时间(秒)
|
self.danmu_viewing_duration: float = 1.0 # 看弹幕持续时间(秒)
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 视线管理器初始化,默认状态: {self.current_state.value}")
|
logger.info(f"[{self.chat_id}] 视线管理器初始化,默认状态: {self.current_state.value}")
|
||||||
|
|
||||||
@@ -109,20 +110,14 @@ class ChatWatching:
|
|||||||
await asyncio.sleep(self.danmu_viewing_duration)
|
await asyncio.sleep(self.danmu_viewing_duration)
|
||||||
|
|
||||||
# 检查是否仍需要切换(可能状态已经被其他事件改变)
|
# 检查是否仍需要切换(可能状态已经被其他事件改变)
|
||||||
if (self.reply_finished_time is not None and
|
if self.reply_finished_time is not None and self.current_state == WatchingState.DANMU and not self.is_replying:
|
||||||
self.current_state == WatchingState.DANMU and
|
|
||||||
not self.is_replying):
|
|
||||||
|
|
||||||
await self._change_state(WatchingState.LENS, "看弹幕时间结束")
|
await self._change_state(WatchingState.LENS, "看弹幕时间结束")
|
||||||
self.reply_finished_time = None # 重置完成时间
|
self.reply_finished_time = None # 重置完成时间
|
||||||
|
|
||||||
async def _send_watching_update(self):
|
async def _send_watching_update(self):
|
||||||
"""立即发送视线状态更新"""
|
"""立即发送视线状态更新"""
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="watching",
|
message_type="watching", content=self.current_state.value, stream_id=self.chat_id, storage_message=False
|
||||||
content=self.current_state.value,
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
storage_message=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 发送视线状态更新: {self.current_state.value}")
|
logger.info(f"[{self.chat_id}] 发送视线状态更新: {self.current_state.value}")
|
||||||
@@ -139,11 +134,10 @@ class ChatWatching:
|
|||||||
"current_state": self.current_state.value,
|
"current_state": self.current_state.value,
|
||||||
"is_replying": self.is_replying,
|
"is_replying": self.is_replying,
|
||||||
"reply_finished_time": self.reply_finished_time,
|
"reply_finished_time": self.reply_finished_time,
|
||||||
"state_needs_update": self.state_needs_update
|
"state_needs_update": self.state_needs_update,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WatchingManager:
|
class WatchingManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.watching_list: list[ChatWatching] = []
|
self.watching_list: list[ChatWatching] = []
|
||||||
@@ -200,10 +194,7 @@ class WatchingManager:
|
|||||||
|
|
||||||
def get_all_watching_info(self) -> dict:
|
def get_all_watching_info(self) -> dict:
|
||||||
"""获取所有聊天的视线状态信息(用于调试)"""
|
"""获取所有聊天的视线状态信息(用于调试)"""
|
||||||
return {
|
return {watching.chat_id: watching.get_state_info() for watching in self.watching_list}
|
||||||
watching.chat_id: watching.get_state_info()
|
|
||||||
for watching in self.watching_list
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局视线管理器实例
|
# 全局视线管理器实例
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class ChatMood:
|
|||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_change_time,
|
timestamp_start=self.last_change_time,
|
||||||
timestamp_end=message_time,
|
timestamp_end=message_time,
|
||||||
limit=int(global_config.chat.max_context_size/3),
|
limit=int(global_config.chat.max_context_size / 3),
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
@@ -121,8 +121,6 @@ class ChatMood:
|
|||||||
mood_state=self.mood_state,
|
mood_state=self.mood_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||||
@@ -170,7 +168,6 @@ class ChatMood:
|
|||||||
mood_state=self.mood_state,
|
mood_state=self.mood_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
|
|||||||
@@ -39,7 +39,12 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[ChatStream]: 聊天流列表
|
List[ChatStream]: 聊天流列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
@@ -60,6 +65,8 @@ class ChatManager:
|
|||||||
Returns:
|
Returns:
|
||||||
List[ChatStream]: 群聊聊天流列表
|
List[ChatStream]: 群聊聊天流列表
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
@@ -79,7 +86,12 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[ChatStream]: 私聊聊天流列表
|
List[ChatStream]: 私聊聊天流列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
@@ -102,7 +114,17 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果 group_id 为空字符串
|
||||||
|
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(group_id, str):
|
||||||
|
raise TypeError("group_id 必须是字符串类型")
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
if not group_id:
|
||||||
|
raise ValueError("group_id 不能为空")
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
if (
|
if (
|
||||||
@@ -129,7 +151,17 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果 user_id 为空字符串
|
||||||
|
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(user_id, str):
|
||||||
|
raise TypeError("user_id 必须是字符串类型")
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("user_id 不能为空")
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
if (
|
if (
|
||||||
@@ -153,9 +185,15 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 聊天类型 ("group", "private", "unknown")
|
str: 聊天类型 ("group", "private", "unknown")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||||
|
ValueError: 如果 chat_stream 为空
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(chat_stream, ChatStream):
|
||||||
|
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
raise ValueError("chat_stream cannot be None")
|
raise ValueError("chat_stream 不能为 None")
|
||||||
|
|
||||||
if hasattr(chat_stream, "group_info"):
|
if hasattr(chat_stream, "group_info"):
|
||||||
return "group" if chat_stream.group_info else "private"
|
return "group" if chat_stream.group_info else "private"
|
||||||
@@ -170,9 +208,15 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 聊天流信息字典
|
Dict[str, Any]: 聊天流信息字典
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||||
|
ValueError: 如果 chat_stream 为空
|
||||||
"""
|
"""
|
||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
return {}
|
raise ValueError("chat_stream 不能为 None")
|
||||||
|
if not isinstance(chat_stream, ChatStream):
|
||||||
|
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
info: Dict[str, Any] = {
|
info: Dict[str, Any] = {
|
||||||
|
|||||||
@@ -8,6 +8,8 @@
|
|||||||
count = emoji_api.get_count()
|
count = emoji_api.get_count()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
@@ -29,7 +31,15 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果描述为空字符串
|
||||||
|
TypeError: 如果描述不是字符串类型
|
||||||
"""
|
"""
|
||||||
|
if not description:
|
||||||
|
raise ValueError("描述不能为空")
|
||||||
|
if not isinstance(description, str):
|
||||||
|
raise TypeError("描述必须是字符串类型")
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
||||||
|
|
||||||
@@ -55,7 +65,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]:
|
||||||
"""随机获取指定数量的表情包
|
"""随机获取指定数量的表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -63,8 +73,17 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None
|
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果count不是整数类型
|
||||||
|
ValueError: 如果count为负数
|
||||||
"""
|
"""
|
||||||
if count <= 0:
|
if not isinstance(count, int):
|
||||||
|
raise TypeError("count 必须是整数类型")
|
||||||
|
if count < 0:
|
||||||
|
raise ValueError("count 不能为负数")
|
||||||
|
if count == 0:
|
||||||
|
logger.warning("[EmojiAPI] count 为0,返回空列表")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -90,8 +109,6 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
|||||||
count = len(valid_emojis)
|
count = len(valid_emojis)
|
||||||
|
|
||||||
# 随机选择
|
# 随机选择
|
||||||
import random
|
|
||||||
|
|
||||||
selected_emojis = random.sample(valid_emojis, count)
|
selected_emojis = random.sample(valid_emojis, count)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@@ -128,7 +145,15 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果情感标签为空字符串
|
||||||
|
TypeError: 如果情感标签不是字符串类型
|
||||||
"""
|
"""
|
||||||
|
if not emotion:
|
||||||
|
raise ValueError("情感标签不能为空")
|
||||||
|
if not isinstance(emotion, str):
|
||||||
|
raise TypeError("情感标签必须是字符串类型")
|
||||||
try:
|
try:
|
||||||
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
||||||
|
|
||||||
@@ -146,8 +171,6 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 随机选择匹配的表情包
|
# 随机选择匹配的表情包
|
||||||
import random
|
|
||||||
|
|
||||||
selected_emoji = random.choice(matching_emojis)
|
selected_emoji = random.choice(matching_emojis)
|
||||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||||
|
|
||||||
@@ -185,11 +208,11 @@ def get_count() -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def get_info() -> dict:
|
def get_info():
|
||||||
"""获取表情包系统信息
|
"""获取表情包系统信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 包含表情包数量、最大数量等信息
|
dict: 包含表情包数量、最大数量、可用数量信息
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
emoji_manager = get_emoji_manager()
|
emoji_manager = get_emoji_manager()
|
||||||
@@ -203,7 +226,7 @@ def get_info() -> dict:
|
|||||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||||
|
|
||||||
|
|
||||||
def get_emotions() -> list:
|
def get_emotions() -> List[str]:
|
||||||
"""获取所有可用的情感标签
|
"""获取所有可用的情感标签
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -223,7 +246,7 @@ def get_emotions() -> list:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def get_descriptions() -> list:
|
def get_descriptions() -> List[str]:
|
||||||
"""获取所有表情包描述
|
"""获取所有表情包描述
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -5,11 +5,12 @@
|
|||||||
使用方式:
|
使用方式:
|
||||||
from src.plugin_system.apis import generator_api
|
from src.plugin_system.apis import generator_api
|
||||||
replyer = generator_api.get_replyer(chat_stream)
|
replyer = generator_api.get_replyer(chat_stream)
|
||||||
success, reply_set = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple, Any, Dict, List, Optional
|
from typing import Tuple, Any, Dict, List, Optional
|
||||||
|
from rich.traceback import install
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.replyer.default_generator import DefaultReplyer
|
from src.chat.replyer.default_generator import DefaultReplyer
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
@@ -17,6 +18,8 @@ from src.chat.utils.utils import process_llm_response
|
|||||||
from src.chat.replyer.replyer_manager import replyer_manager
|
from src.chat.replyer.replyer_manager import replyer_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("generator_api")
|
logger = get_logger("generator_api")
|
||||||
|
|
||||||
|
|
||||||
@@ -44,7 +47,12 @@ def get_replyer(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: chat_stream 和 chat_id 均为空
|
||||||
"""
|
"""
|
||||||
|
if not chat_id and not chat_stream:
|
||||||
|
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||||
return replyer_manager.get_replyer(
|
return replyer_manager.get_replyer(
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from src.config.config import global_config
|
|||||||
|
|
||||||
logger = get_logger("llm_api")
|
logger = get_logger("llm_api")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# LLM模型API函数
|
# LLM模型API函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -31,8 +30,21 @@ def get_available_models() -> Dict[str, Any]:
|
|||||||
logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置")
|
logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
# 自动获取所有属性并转换为字典形式
|
||||||
|
rets = {}
|
||||||
models = global_config.model
|
models = global_config.model
|
||||||
return models
|
attrs = dir(models)
|
||||||
|
for attr in attrs:
|
||||||
|
if not attr.startswith("__"):
|
||||||
|
try:
|
||||||
|
value = getattr(models, attr)
|
||||||
|
if not callable(value): # 排除方法
|
||||||
|
rets[attr] = value
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
|
||||||
|
continue
|
||||||
|
return rets
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@@ -114,7 +114,11 @@ async def _send_to_target(
|
|||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
sent_msg = await heart_fc_sender.send_message(
|
sent_msg = await heart_fc_sender.send_message(
|
||||||
bot_message, typing=typing, set_reply=(anchor_message is not None), storage_message=storage_message, show_log=show_log
|
bot_message,
|
||||||
|
typing=typing,
|
||||||
|
set_reply=(anchor_message is not None),
|
||||||
|
storage_message=storage_message,
|
||||||
|
show_log=show_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sent_msg:
|
if sent_msg:
|
||||||
@@ -363,7 +367,9 @@ async def custom_to_stream(
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log)
|
return await _send_to_target(
|
||||||
|
message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def text_to_group(
|
async def text_to_group(
|
||||||
|
|||||||
Reference in New Issue
Block a user