Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -20,3 +20,4 @@
|
||||
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||
4. 现在增加了参数类型检查,完善了对应注释
|
||||
@@ -47,7 +47,7 @@ class MaiEmoji:
|
||||
self.embedding = []
|
||||
self.hash = "" # 初始为空,在创建实例时会计算
|
||||
self.description = ""
|
||||
self.emotion = []
|
||||
self.emotion: List[str] = []
|
||||
self.usage_count = 0
|
||||
self.last_used_time = time.time()
|
||||
self.register_time = time.time()
|
||||
|
||||
@@ -243,6 +243,8 @@ class HeartFChatting:
|
||||
loop_start_time = time.time()
|
||||
await self.relationship_builder.build_relation()
|
||||
|
||||
available_actions = {}
|
||||
|
||||
# 第一步:动作修改
|
||||
with Timer("动作修改", cycle_timers):
|
||||
try:
|
||||
|
||||
@@ -38,7 +38,9 @@ class HeartFCSender:
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def send_message(self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True):
|
||||
async def send_message(
|
||||
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
|
||||
):
|
||||
"""
|
||||
处理、发送并存储一条消息。
|
||||
|
||||
|
||||
@@ -79,7 +79,9 @@ class ActionPlanner:
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
async def plan(self, mode: ChatMode = ChatMode.FOCUS) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension
|
||||
async def plan(
|
||||
self, mode: ChatMode = ChatMode.FOCUS
|
||||
) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
|
||||
@@ -508,7 +508,7 @@ class DefaultReplyer:
|
||||
# 构建背景对话 prompt
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size*0.6):]
|
||||
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :]
|
||||
background_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
@@ -521,7 +521,7 @@ class DefaultReplyer:
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size*2):] # 限制消息数量
|
||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||
|
||||
core_dialogue_prompt_str = build_readable_messages(
|
||||
core_dialogue_list,
|
||||
@@ -586,7 +586,6 @@ class DefaultReplyer:
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
@@ -713,8 +712,6 @@ class DefaultReplyer:
|
||||
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
|
||||
|
||||
|
||||
# 根据配置选择使用哪种 prompt 构建模式
|
||||
if global_config.chat.use_s4u_prompt_mode and person_id:
|
||||
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
||||
@@ -726,7 +723,6 @@ class DefaultReplyer:
|
||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||
target_user_id = ""
|
||||
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||
message_list_before_now_long, target_user_id
|
||||
|
||||
@@ -106,7 +106,6 @@ class ChatConfig(ConfigBase):
|
||||
focus_value: float = 1.0
|
||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||
|
||||
|
||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 talk_frequency
|
||||
@@ -246,6 +245,7 @@ class ChatConfig(ConfigBase):
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageReceiveConfig(ConfigBase):
|
||||
"""消息接收配置类"""
|
||||
@@ -274,8 +274,6 @@ class NormalChatConfig(ConfigBase):
|
||||
"""@bot 必然回复"""
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
@@ -63,10 +63,10 @@ class Individuality:
|
||||
personality_side: 人格侧面描述
|
||||
identity: 身份细节描述
|
||||
"""
|
||||
bot_nickname=global_config.bot.nickname
|
||||
personality_core=global_config.personality.personality_core
|
||||
personality_side=global_config.personality.personality_side
|
||||
identity=global_config.personality.identity
|
||||
bot_nickname = global_config.bot.nickname
|
||||
personality_core = global_config.personality.personality_core
|
||||
personality_side = global_config.personality.personality_side
|
||||
identity = global_config.personality.identity
|
||||
|
||||
logger.info("正在初始化个体特征")
|
||||
person_info_manager = get_person_info_manager()
|
||||
@@ -168,7 +168,6 @@ class Individuality:
|
||||
else:
|
||||
logger.error("人设构建失败")
|
||||
|
||||
|
||||
async def get_personality_block(self) -> str:
|
||||
person_info_manager = get_person_info_manager()
|
||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||
@@ -200,7 +199,6 @@ class Individuality:
|
||||
|
||||
return identity_block
|
||||
|
||||
|
||||
def _get_config_hash(
|
||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
||||
) -> tuple[str, str]:
|
||||
@@ -295,7 +293,6 @@ class Individuality:
|
||||
except IOError as e:
|
||||
logger.error(f"保存meta_info文件失败: {e}")
|
||||
|
||||
|
||||
async def _create_personality(self, personality_core: str, personality_side: str) -> str:
|
||||
# sourcery skip: merge-list-append, move-assign
|
||||
"""使用LLM创建压缩版本的impression
|
||||
|
||||
@@ -42,7 +42,15 @@ class Personality:
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, bot_nickname: str, personality_core: str, personality_side: str, identity: List[str] = None, compress_personality: bool = True, compress_identity: bool = True) -> "Personality":
|
||||
def initialize(
|
||||
cls,
|
||||
bot_nickname: str,
|
||||
personality_core: str,
|
||||
personality_side: str,
|
||||
identity: List[str] = None,
|
||||
compress_personality: bool = True,
|
||||
compress_identity: bool = True,
|
||||
) -> "Personality":
|
||||
"""初始化人格特质
|
||||
|
||||
Args:
|
||||
|
||||
@@ -30,7 +30,7 @@ class ContextMessage:
|
||||
"user_id": self.user_id,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
||||
"group_name": self.group_name
|
||||
"group_name": self.group_name,
|
||||
}
|
||||
|
||||
|
||||
@@ -66,20 +66,20 @@ class ContextWebManager:
|
||||
self.app = web.Application()
|
||||
|
||||
# 设置CORS
|
||||
cors = aiohttp_cors.setup(self.app, defaults={
|
||||
cors = aiohttp_cors.setup(
|
||||
self.app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True,
|
||||
expose_headers="*",
|
||||
allow_headers="*",
|
||||
allow_methods="*"
|
||||
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
|
||||
)
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
# 添加路由
|
||||
self.app.router.add_get('/', self.index_handler)
|
||||
self.app.router.add_get('/ws', self.websocket_handler)
|
||||
self.app.router.add_get('/api/contexts', self.get_contexts_handler)
|
||||
self.app.router.add_get('/debug', self.debug_handler)
|
||||
self.app.router.add_get("/", self.index_handler)
|
||||
self.app.router.add_get("/ws", self.websocket_handler)
|
||||
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
|
||||
self.app.router.add_get("/debug", self.debug_handler)
|
||||
|
||||
# 为所有路由添加CORS
|
||||
for route in list(self.app.router.routes()):
|
||||
@@ -88,7 +88,7 @@ class ContextWebManager:
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
|
||||
self.site = web.TCPSite(self.runner, 'localhost', self.port)
|
||||
self.site = web.TCPSite(self.runner, "localhost", self.port)
|
||||
await self.site.start()
|
||||
|
||||
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
|
||||
@@ -118,7 +118,8 @@ class ContextWebManager:
|
||||
|
||||
async def index_handler(self, request):
|
||||
"""主页处理器"""
|
||||
html_content = '''
|
||||
html_content = (
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
@@ -231,7 +232,9 @@ class ContextWebManager:
|
||||
|
||||
function connectWebSocket() {
|
||||
console.log('正在连接WebSocket...');
|
||||
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws');
|
||||
ws = new WebSocket('ws://localhost:"""
|
||||
+ str(self.port)
|
||||
+ """/ws');
|
||||
|
||||
ws.onopen = function() {
|
||||
console.log('WebSocket连接已建立');
|
||||
@@ -402,8 +405,9 @@ class ContextWebManager:
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
return web.Response(text=html_content, content_type='text/html')
|
||||
"""
|
||||
)
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def websocket_handler(self, request):
|
||||
"""WebSocket处理器"""
|
||||
@@ -418,7 +422,7 @@ class ContextWebManager:
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
logger.error(f'WebSocket错误: {ws.exception()}')
|
||||
logger.error(f"WebSocket错误: {ws.exception()}")
|
||||
break
|
||||
|
||||
# 清理断开的连接
|
||||
@@ -438,7 +442,7 @@ class ContextWebManager:
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
|
||||
return web.json_response({"contexts": contexts_data})
|
||||
@@ -461,14 +465,14 @@ class ContextWebManager:
|
||||
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
||||
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
|
||||
|
||||
chats_html += f'''
|
||||
chats_html += f"""
|
||||
<div class="chat">
|
||||
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
|
||||
{messages_html}
|
||||
</div>
|
||||
'''
|
||||
"""
|
||||
|
||||
html_content = f'''
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
@@ -510,9 +514,9 @@ class ContextWebManager:
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
"""
|
||||
|
||||
return web.Response(text=html_content, content_type='text/html')
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def add_message(self, chat_id: str, message: MessageRecv):
|
||||
"""添加新消息到上下文"""
|
||||
@@ -526,14 +530,18 @@ class ContextWebManager:
|
||||
# 统计当前总消息数
|
||||
total_messages = sum(len(contexts) for contexts in self.contexts.values())
|
||||
|
||||
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}")
|
||||
logger.info(
|
||||
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
|
||||
)
|
||||
|
||||
# 调试:打印当前所有消息
|
||||
logger.info(f"📝 当前上下文中的所有消息:")
|
||||
for cid, contexts in self.contexts.items():
|
||||
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
|
||||
for i, msg in enumerate(contexts):
|
||||
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...")
|
||||
logger.info(
|
||||
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
|
||||
)
|
||||
|
||||
# 广播更新给所有WebSocket连接
|
||||
await self.broadcast_contexts()
|
||||
@@ -548,7 +556,7 @@ class ContextWebManager:
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
await ws.send_str(json.dumps(data, ensure_ascii=False))
|
||||
@@ -567,7 +575,7 @@ class ContextWebManager:
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
message = json.dumps(data, ensure_ascii=False)
|
||||
@@ -614,4 +622,3 @@ async def init_context_web_manager():
|
||||
manager = get_context_web_manager()
|
||||
await manager.start_server()
|
||||
return manager
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
|
||||
async def send_loading(chat_id: str, content: str):
|
||||
await send_api.custom_to_stream(
|
||||
message_type="loading",
|
||||
@@ -20,6 +21,7 @@ async def send_loading(chat_id: str, content: str):
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
|
||||
async def send_unloading(chat_id: str):
|
||||
await send_api.custom_to_stream(
|
||||
message_type="loading",
|
||||
@@ -28,4 +30,3 @@ async def send_unloading(chat_id: str):
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ class MessageSenderContainer:
|
||||
self._paused_event = asyncio.Event()
|
||||
self._paused_event.set() # 默认设置为非暂停状态
|
||||
|
||||
|
||||
async def add_message(self, chunk: str):
|
||||
"""向队列中添加一个消息块。"""
|
||||
await self.queue.put(chunk)
|
||||
@@ -302,7 +301,9 @@ class S4UChat:
|
||||
self._normal_queue.put_nowait(item)
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {self.recent_message_keep_count} range.")
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {self.recent_message_keep_count} range."
|
||||
)
|
||||
|
||||
async def _message_processor(self):
|
||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||
@@ -396,12 +397,10 @@ class S4UChat:
|
||||
# a. 发送文本块
|
||||
await sender_container.add_message(chunk)
|
||||
|
||||
|
||||
# 等待所有文本消息发送完成
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
|
||||
|
||||
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -322,7 +322,7 @@ class ChatMood:
|
||||
"joy": mood_values.get("joy", 5),
|
||||
"anger": mood_values.get("anger", 1),
|
||||
"sorrow": mood_values.get("sorrow", 1),
|
||||
"fear": mood_values.get("fear", 1)
|
||||
"fear": mood_values.get("fear", 1),
|
||||
}
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
@@ -379,14 +379,18 @@ class MoodRegressionTask(AsyncTask):
|
||||
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
||||
continue
|
||||
|
||||
logger.info(f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)")
|
||||
logger.info(
|
||||
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
|
||||
)
|
||||
await mood.regress_mood()
|
||||
regression_executed += 1
|
||||
else:
|
||||
if has_extreme_emotion:
|
||||
remaining_time = 5 - time_since_last_change
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
logger.debug(f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒")
|
||||
logger.debug(
|
||||
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒"
|
||||
)
|
||||
else:
|
||||
remaining_time = 120 - time_since_last_change
|
||||
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
||||
|
||||
@@ -107,7 +107,6 @@ class S4UStreamGenerator:
|
||||
model_name: str,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
|
||||
buffer = ""
|
||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||
punctuation_buffer = ""
|
||||
|
||||
@@ -43,6 +43,7 @@ logger = get_logger("watching")
|
||||
|
||||
class WatchingState(Enum):
|
||||
"""视线状态枚举"""
|
||||
|
||||
WANDERING = "wandering" # 随意看
|
||||
DANMU = "danmu" # 看弹幕
|
||||
LENS = "lens" # 看镜头
|
||||
@@ -109,20 +110,14 @@ class ChatWatching:
|
||||
await asyncio.sleep(self.danmu_viewing_duration)
|
||||
|
||||
# 检查是否仍需要切换(可能状态已经被其他事件改变)
|
||||
if (self.reply_finished_time is not None and
|
||||
self.current_state == WatchingState.DANMU and
|
||||
not self.is_replying):
|
||||
|
||||
if self.reply_finished_time is not None and self.current_state == WatchingState.DANMU and not self.is_replying:
|
||||
await self._change_state(WatchingState.LENS, "看弹幕时间结束")
|
||||
self.reply_finished_time = None # 重置完成时间
|
||||
|
||||
async def _send_watching_update(self):
|
||||
"""立即发送视线状态更新"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="watching",
|
||||
content=self.current_state.value,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False
|
||||
message_type="watching", content=self.current_state.value, stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 发送视线状态更新: {self.current_state.value}")
|
||||
@@ -139,11 +134,10 @@ class ChatWatching:
|
||||
"current_state": self.current_state.value,
|
||||
"is_replying": self.is_replying,
|
||||
"reply_finished_time": self.reply_finished_time,
|
||||
"state_needs_update": self.state_needs_update
|
||||
"state_needs_update": self.state_needs_update,
|
||||
}
|
||||
|
||||
|
||||
|
||||
class WatchingManager:
|
||||
def __init__(self):
|
||||
self.watching_list: list[ChatWatching] = []
|
||||
@@ -200,10 +194,7 @@ class WatchingManager:
|
||||
|
||||
def get_all_watching_info(self) -> dict:
|
||||
"""获取所有聊天的视线状态信息(用于调试)"""
|
||||
return {
|
||||
watching.chat_id: watching.get_state_info()
|
||||
for watching in self.watching_list
|
||||
}
|
||||
return {watching.chat_id: watching.get_state_info() for watching in self.watching_list}
|
||||
|
||||
|
||||
# 全局视线管理器实例
|
||||
|
||||
@@ -92,7 +92,7 @@ class ChatMood:
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=int(global_config.chat.max_context_size/3),
|
||||
limit=int(global_config.chat.max_context_size / 3),
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
@@ -121,8 +121,6 @@ class ChatMood:
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
|
||||
|
||||
|
||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||
@@ -170,7 +168,6 @@ class ChatMood:
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
|
||||
|
||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
|
||||
@@ -39,7 +39,12 @@ class ChatManager:
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 聊天流列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
@@ -60,6 +65,8 @@ class ChatManager:
|
||||
Returns:
|
||||
List[ChatStream]: 群聊聊天流列表
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
@@ -79,7 +86,12 @@ class ChatManager:
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 私聊聊天流列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
@@ -102,7 +114,17 @@ class ChatManager:
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 group_id 为空字符串
|
||||
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||
"""
|
||||
if not isinstance(group_id, str):
|
||||
raise TypeError("group_id 必须是字符串类型")
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
if not group_id:
|
||||
raise ValueError("group_id 不能为空")
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
@@ -129,7 +151,17 @@ class ChatManager:
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 user_id 为空字符串
|
||||
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||
"""
|
||||
if not isinstance(user_id, str):
|
||||
raise TypeError("user_id 必须是字符串类型")
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
if not user_id:
|
||||
raise ValueError("user_id 不能为空")
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
@@ -153,9 +185,15 @@ class ChatManager:
|
||||
|
||||
Returns:
|
||||
str: 聊天类型 ("group", "private", "unknown")
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream cannot be None")
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
|
||||
if hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
@@ -170,9 +208,15 @@ class ChatManager:
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 聊天流信息字典
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
if not chat_stream:
|
||||
return {}
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
|
||||
try:
|
||||
info: Dict[str, Any] = {
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
count = emoji_api.get_count()
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
@@ -29,7 +31,15 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果描述为空字符串
|
||||
TypeError: 如果描述不是字符串类型
|
||||
"""
|
||||
if not description:
|
||||
raise ValueError("描述不能为空")
|
||||
if not isinstance(description, str):
|
||||
raise TypeError("描述必须是字符串类型")
|
||||
try:
|
||||
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
||||
|
||||
@@ -55,7 +65,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
||||
return None
|
||||
|
||||
|
||||
async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
||||
async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]:
|
||||
"""随机获取指定数量的表情包
|
||||
|
||||
Args:
|
||||
@@ -63,8 +73,17 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
||||
|
||||
Returns:
|
||||
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None
|
||||
|
||||
Raises:
|
||||
TypeError: 如果count不是整数类型
|
||||
ValueError: 如果count为负数
|
||||
"""
|
||||
if count <= 0:
|
||||
if not isinstance(count, int):
|
||||
raise TypeError("count 必须是整数类型")
|
||||
if count < 0:
|
||||
raise ValueError("count 不能为负数")
|
||||
if count == 0:
|
||||
logger.warning("[EmojiAPI] count 为0,返回空列表")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -90,8 +109,6 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
||||
count = len(valid_emojis)
|
||||
|
||||
# 随机选择
|
||||
import random
|
||||
|
||||
selected_emojis = random.sample(valid_emojis, count)
|
||||
|
||||
results = []
|
||||
@@ -128,7 +145,15 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果情感标签为空字符串
|
||||
TypeError: 如果情感标签不是字符串类型
|
||||
"""
|
||||
if not emotion:
|
||||
raise ValueError("情感标签不能为空")
|
||||
if not isinstance(emotion, str):
|
||||
raise TypeError("情感标签必须是字符串类型")
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
||||
|
||||
@@ -146,8 +171,6 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
return None
|
||||
|
||||
# 随机选择匹配的表情包
|
||||
import random
|
||||
|
||||
selected_emoji = random.choice(matching_emojis)
|
||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
@@ -185,11 +208,11 @@ def get_count() -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def get_info() -> dict:
|
||||
def get_info():
|
||||
"""获取表情包系统信息
|
||||
|
||||
Returns:
|
||||
dict: 包含表情包数量、最大数量等信息
|
||||
dict: 包含表情包数量、最大数量、可用数量信息
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
@@ -203,7 +226,7 @@ def get_info() -> dict:
|
||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||
|
||||
|
||||
def get_emotions() -> list:
|
||||
def get_emotions() -> List[str]:
|
||||
"""获取所有可用的情感标签
|
||||
|
||||
Returns:
|
||||
@@ -223,7 +246,7 @@ def get_emotions() -> list:
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> list:
|
||||
def get_descriptions() -> List[str]:
|
||||
"""获取所有表情包描述
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -5,11 +5,12 @@
|
||||
使用方式:
|
||||
from src.plugin_system.apis import generator_api
|
||||
replyer = generator_api.get_replyer(chat_stream)
|
||||
success, reply_set = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -17,6 +18,8 @@ from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("generator_api")
|
||||
|
||||
|
||||
@@ -44,7 +47,12 @@ def get_replyer(
|
||||
|
||||
Returns:
|
||||
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: chat_stream 和 chat_id 均为空
|
||||
"""
|
||||
if not chat_id and not chat_stream:
|
||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||
try:
|
||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||
return replyer_manager.get_replyer(
|
||||
|
||||
@@ -14,7 +14,6 @@ from src.config.config import global_config
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM模型API函数
|
||||
# =============================================================================
|
||||
@@ -31,8 +30,21 @@ def get_available_models() -> Dict[str, Any]:
|
||||
logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置")
|
||||
return {}
|
||||
|
||||
# 自动获取所有属性并转换为字典形式
|
||||
rets = {}
|
||||
models = global_config.model
|
||||
return models
|
||||
attrs = dir(models)
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
value = getattr(models, attr)
|
||||
if not callable(value): # 排除方法
|
||||
rets[attr] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
|
||||
continue
|
||||
return rets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
||||
return {}
|
||||
|
||||
@@ -114,7 +114,11 @@ async def _send_to_target(
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
bot_message, typing=typing, set_reply=(anchor_message is not None), storage_message=storage_message, show_log=show_log
|
||||
bot_message,
|
||||
typing=typing,
|
||||
set_reply=(anchor_message is not None),
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
if sent_msg:
|
||||
@@ -363,7 +367,9 @@ async def custom_to_stream(
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log)
|
||||
return await _send_to_target(
|
||||
message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log
|
||||
)
|
||||
|
||||
|
||||
async def text_to_group(
|
||||
|
||||
Reference in New Issue
Block a user