ruff, typing, api, bug fix

This commit is contained in:
UnCLASPrommer
2025-07-15 16:50:29 +08:00
parent 4ebcf4e056
commit b5fd959fe1
23 changed files with 335 additions and 238 deletions

View File

@@ -20,3 +20,4 @@
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。 - `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
- `config_api.py`中的`get_global_config``get_plugin_config`方法现在支持嵌套访问的配置键名。 - `config_api.py`中的`get_global_config``get_plugin_config`方法现在支持嵌套访问的配置键名。
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时保证了typing正确`db_get`方法增加了`single_result`参数,与`db_query`保持一致。 - `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时保证了typing正确`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
4. 现在增加了参数类型检查,完善了对应注释

View File

@@ -47,7 +47,7 @@ class MaiEmoji:
self.embedding = [] self.embedding = []
self.hash = "" # 初始为空,在创建实例时会计算 self.hash = "" # 初始为空,在创建实例时会计算
self.description = "" self.description = ""
self.emotion = [] self.emotion: List[str] = []
self.usage_count = 0 self.usage_count = 0
self.last_used_time = time.time() self.last_used_time = time.time()
self.register_time = time.time() self.register_time = time.time()

View File

@@ -243,6 +243,8 @@ class HeartFChatting:
loop_start_time = time.time() loop_start_time = time.time()
await self.relationship_builder.build_relation() await self.relationship_builder.build_relation()
available_actions = {}
# 第一步:动作修改 # 第一步:动作修改
with Timer("动作修改", cycle_timers): with Timer("动作修改", cycle_timers):
try: try:

View File

@@ -38,7 +38,9 @@ class HeartFCSender:
def __init__(self): def __init__(self):
self.storage = MessageStorage() self.storage = MessageStorage()
async def send_message(self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True): async def send_message(
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
):
""" """
处理、发送并存储一条消息。 处理、发送并存储一条消息。

View File

@@ -79,7 +79,9 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
async def plan(self, mode: ChatMode = ChatMode.FOCUS) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension async def plan(
self, mode: ChatMode = ChatMode.FOCUS
) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension
""" """
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
""" """

View File

@@ -508,7 +508,7 @@ class DefaultReplyer:
# 构建背景对话 prompt # 构建背景对话 prompt
background_dialogue_prompt = "" background_dialogue_prompt = ""
if background_dialogue_list: if background_dialogue_list:
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size*0.6):] latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :]
background_dialogue_prompt_str = build_readable_messages( background_dialogue_prompt_str = build_readable_messages(
latest_25_msgs, latest_25_msgs,
replace_bot_name=True, replace_bot_name=True,
@@ -521,7 +521,7 @@ class DefaultReplyer:
# 构建核心对话 prompt # 构建核心对话 prompt
core_dialogue_prompt = "" core_dialogue_prompt = ""
if core_dialogue_list: if core_dialogue_list:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size*2):] # 限制消息数量 core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
core_dialogue_prompt_str = build_readable_messages( core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list, core_dialogue_list,
@@ -586,7 +586,6 @@ class DefaultReplyer:
limit=global_config.chat.max_context_size * 2, limit=global_config.chat.max_context_size * 2,
) )
message_list_before_now = get_raw_msg_before_timestamp_with_chat( message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id,
timestamp=time.time(), timestamp=time.time(),
@@ -713,8 +712,6 @@ class DefaultReplyer:
# 根据sender通过person_info_manager反向查找person_id再获取user_id # 根据sender通过person_info_manager反向查找person_id再获取user_id
person_id = person_info_manager.get_person_id_by_person_name(sender) person_id = person_info_manager.get_person_id_by_person_name(sender)
# 根据配置选择使用哪种 prompt 构建模式 # 根据配置选择使用哪种 prompt 构建模式
if global_config.chat.use_s4u_prompt_mode and person_id: if global_config.chat.use_s4u_prompt_mode and person_id:
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话 # 使用 s4u 对话构建模式:分离当前对话对象和其他对话
@@ -726,7 +723,6 @@ class DefaultReplyer:
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}") logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
target_user_id = "" target_user_id = ""
# 构建分离的对话 prompt # 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts( core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
message_list_before_now_long, target_user_id message_list_before_now_long, target_user_id

View File

@@ -106,7 +106,6 @@ class ChatConfig(ConfigBase):
focus_value: float = 1.0 focus_value: float = 1.0
"""麦麦的专注思考能力越低越容易专注消耗token也越多""" """麦麦的专注思考能力越低越容易专注消耗token也越多"""
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float: def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
""" """
根据当前时间和聊天流获取对应的 talk_frequency 根据当前时间和聊天流获取对应的 talk_frequency
@@ -246,6 +245,7 @@ class ChatConfig(ConfigBase):
except (ValueError, IndexError): except (ValueError, IndexError):
return None return None
@dataclass @dataclass
class MessageReceiveConfig(ConfigBase): class MessageReceiveConfig(ConfigBase):
"""消息接收配置类""" """消息接收配置类"""
@@ -274,8 +274,6 @@ class NormalChatConfig(ConfigBase):
"""@bot 必然回复""" """@bot 必然回复"""
@dataclass @dataclass
class ExpressionConfig(ConfigBase): class ExpressionConfig(ConfigBase):
"""表达配置类""" """表达配置类"""

View File

@@ -63,10 +63,10 @@ class Individuality:
personality_side: 人格侧面描述 personality_side: 人格侧面描述
identity: 身份细节描述 identity: 身份细节描述
""" """
bot_nickname=global_config.bot.nickname bot_nickname = global_config.bot.nickname
personality_core=global_config.personality.personality_core personality_core = global_config.personality.personality_core
personality_side=global_config.personality.personality_side personality_side = global_config.personality.personality_side
identity=global_config.personality.identity identity = global_config.personality.identity
logger.info("正在初始化个体特征") logger.info("正在初始化个体特征")
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
@@ -168,7 +168,6 @@ class Individuality:
else: else:
logger.error("人设构建失败") logger.error("人设构建失败")
async def get_personality_block(self) -> str: async def get_personality_block(self) -> str:
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
bot_person_id = person_info_manager.get_person_id("system", "bot_id") bot_person_id = person_info_manager.get_person_id("system", "bot_id")
@@ -200,7 +199,6 @@ class Individuality:
return identity_block return identity_block
def _get_config_hash( def _get_config_hash(
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
) -> tuple[str, str]: ) -> tuple[str, str]:
@@ -295,7 +293,6 @@ class Individuality:
except IOError as e: except IOError as e:
logger.error(f"保存meta_info文件失败: {e}") logger.error(f"保存meta_info文件失败: {e}")
async def _create_personality(self, personality_core: str, personality_side: str) -> str: async def _create_personality(self, personality_core: str, personality_side: str) -> str:
# sourcery skip: merge-list-append, move-assign # sourcery skip: merge-list-append, move-assign
"""使用LLM创建压缩版本的impression """使用LLM创建压缩版本的impression

View File

@@ -42,7 +42,15 @@ class Personality:
return cls._instance return cls._instance
@classmethod @classmethod
def initialize(cls, bot_nickname: str, personality_core: str, personality_side: str, identity: List[str] = None, compress_personality: bool = True, compress_identity: bool = True) -> "Personality": def initialize(
cls,
bot_nickname: str,
personality_core: str,
personality_side: str,
identity: List[str] = None,
compress_personality: bool = True,
compress_identity: bool = True,
) -> "Personality":
"""初始化人格特质 """初始化人格特质
Args: Args:

View File

@@ -30,7 +30,7 @@ class ContextMessage:
"user_id": self.user_id, "user_id": self.user_id,
"content": self.content, "content": self.content,
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"), "timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name "group_name": self.group_name,
} }
@@ -66,20 +66,20 @@ class ContextWebManager:
self.app = web.Application() self.app = web.Application()
# 设置CORS # 设置CORS
cors = aiohttp_cors.setup(self.app, defaults={ cors = aiohttp_cors.setup(
self.app,
defaults={
"*": aiohttp_cors.ResourceOptions( "*": aiohttp_cors.ResourceOptions(
allow_credentials=True, allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
expose_headers="*", )
allow_headers="*", },
allow_methods="*"
) )
})
# 添加路由 # 添加路由
self.app.router.add_get('/', self.index_handler) self.app.router.add_get("/", self.index_handler)
self.app.router.add_get('/ws', self.websocket_handler) self.app.router.add_get("/ws", self.websocket_handler)
self.app.router.add_get('/api/contexts', self.get_contexts_handler) self.app.router.add_get("/api/contexts", self.get_contexts_handler)
self.app.router.add_get('/debug', self.debug_handler) self.app.router.add_get("/debug", self.debug_handler)
# 为所有路由添加CORS # 为所有路由添加CORS
for route in list(self.app.router.routes()): for route in list(self.app.router.routes()):
@@ -88,7 +88,7 @@ class ContextWebManager:
self.runner = web.AppRunner(self.app) self.runner = web.AppRunner(self.app)
await self.runner.setup() await self.runner.setup()
self.site = web.TCPSite(self.runner, 'localhost', self.port) self.site = web.TCPSite(self.runner, "localhost", self.port)
await self.site.start() await self.site.start()
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}") logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
@@ -118,7 +118,8 @@ class ContextWebManager:
async def index_handler(self, request): async def index_handler(self, request):
"""主页处理器""" """主页处理器"""
html_content = ''' html_content = (
"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@@ -231,7 +232,9 @@ class ContextWebManager:
function connectWebSocket() { function connectWebSocket() {
console.log('正在连接WebSocket...'); console.log('正在连接WebSocket...');
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws'); ws = new WebSocket('ws://localhost:"""
+ str(self.port)
+ """/ws');
ws.onopen = function() { ws.onopen = function() {
console.log('WebSocket连接已建立'); console.log('WebSocket连接已建立');
@@ -402,8 +405,9 @@ class ContextWebManager:
</script> </script>
</body> </body>
</html> </html>
''' """
return web.Response(text=html_content, content_type='text/html') )
return web.Response(text=html_content, content_type="text/html")
async def websocket_handler(self, request): async def websocket_handler(self, request):
"""WebSocket处理器""" """WebSocket处理器"""
@@ -418,7 +422,7 @@ class ContextWebManager:
async for msg in ws: async for msg in ws:
if msg.type == WSMsgType.ERROR: if msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}') logger.error(f"WebSocket错误: {ws.exception()}")
break break
# 清理断开的连接 # 清理断开的连接
@@ -438,7 +442,7 @@ class ContextWebManager:
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息") logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
return web.json_response({"contexts": contexts_data}) return web.json_response({"contexts": contexts_data})
@@ -461,14 +465,14 @@ class ContextWebManager:
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>' messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
chats_html += f''' chats_html += f"""
<div class="chat"> <div class="chat">
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3> <h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
{messages_html} {messages_html}
</div> </div>
''' """
html_content = f''' html_content = f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@@ -510,9 +514,9 @@ class ContextWebManager:
</script> </script>
</body> </body>
</html> </html>
''' """
return web.Response(text=html_content, content_type='text/html') return web.Response(text=html_content, content_type="text/html")
async def add_message(self, chat_id: str, message: MessageRecv): async def add_message(self, chat_id: str, message: MessageRecv):
"""添加新消息到上下文""" """添加新消息到上下文"""
@@ -526,14 +530,18 @@ class ContextWebManager:
# 统计当前总消息数 # 统计当前总消息数
total_messages = sum(len(contexts) for contexts in self.contexts.values()) total_messages = sum(len(contexts) for contexts in self.contexts.values())
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}") logger.info(
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
)
# 调试:打印当前所有消息 # 调试:打印当前所有消息
logger.info(f"📝 当前上下文中的所有消息:") logger.info(f"📝 当前上下文中的所有消息:")
for cid, contexts in self.contexts.items(): for cid, contexts in self.contexts.items():
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息") logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
for i, msg in enumerate(contexts): for i, msg in enumerate(contexts):
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...") logger.info(
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
)
# 广播更新给所有WebSocket连接 # 广播更新给所有WebSocket连接
await self.broadcast_contexts() await self.broadcast_contexts()
@@ -548,7 +556,7 @@ class ContextWebManager:
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data} data = {"contexts": contexts_data}
await ws.send_str(json.dumps(data, ensure_ascii=False)) await ws.send_str(json.dumps(data, ensure_ascii=False))
@@ -567,7 +575,7 @@ class ContextWebManager:
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data} data = {"contexts": contexts_data}
message = json.dumps(data, ensure_ascii=False) message = json.dumps(data, ensure_ascii=False)
@@ -614,4 +622,3 @@ async def init_context_web_manager():
manager = get_context_web_manager() manager = get_context_web_manager()
await manager.start_server() await manager.start_server()
return manager return manager

View File

@@ -11,6 +11,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
async def send_loading(chat_id: str, content: str): async def send_loading(chat_id: str, content: str):
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="loading", message_type="loading",
@@ -20,6 +21,7 @@ async def send_loading(chat_id: str, content: str):
show_log=True, show_log=True,
) )
async def send_unloading(chat_id: str): async def send_unloading(chat_id: str):
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="loading", message_type="loading",
@@ -28,4 +30,3 @@ async def send_unloading(chat_id: str):
storage_message=False, storage_message=False,
show_log=True, show_log=True,
) )

View File

@@ -30,7 +30,6 @@ class MessageSenderContainer:
self._paused_event = asyncio.Event() self._paused_event = asyncio.Event()
self._paused_event.set() # 默认设置为非暂停状态 self._paused_event.set() # 默认设置为非暂停状态
async def add_message(self, chunk: str): async def add_message(self, chunk: str):
"""向队列中添加一个消息块。""" """向队列中添加一个消息块。"""
await self.queue.put(chunk) await self.queue.put(chunk)
@@ -302,7 +301,9 @@ class S4UChat:
self._normal_queue.put_nowait(item) self._normal_queue.put_nowait(item)
if removed_count > 0: if removed_count > 0:
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {self.recent_message_keep_count} range.") logger.info(
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {self.recent_message_keep_count} range."
)
async def _message_processor(self): async def _message_processor(self):
"""调度器优先处理VIP队列然后处理普通队列。""" """调度器优先处理VIP队列然后处理普通队列。"""
@@ -396,12 +397,10 @@ class S4UChat:
# a. 发送文本块 # a. 发送文本块
await sender_container.add_message(chunk) await sender_container.add_message(chunk)
# 等待所有文本消息发送完成 # 等待所有文本消息发送完成
await sender_container.close() await sender_container.close()
await sender_container.join() await sender_container.join()
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。") logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@@ -322,7 +322,7 @@ class ChatMood:
"joy": mood_values.get("joy", 5), "joy": mood_values.get("joy", 5),
"anger": mood_values.get("anger", 1), "anger": mood_values.get("anger", 1),
"sorrow": mood_values.get("sorrow", 1), "sorrow": mood_values.get("sorrow", 1),
"fear": mood_values.get("fear", 1) "fear": mood_values.get("fear", 1),
} }
await send_api.custom_to_stream( await send_api.custom_to_stream(
@@ -379,14 +379,18 @@ class MoodRegressionTask(AsyncTask):
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归") logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
continue continue
logger.info(f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)") logger.info(
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
)
await mood.regress_mood() await mood.regress_mood()
regression_executed += 1 regression_executed += 1
else: else:
if has_extreme_emotion: if has_extreme_emotion:
remaining_time = 5 - time_since_last_change remaining_time = 5 - time_since_last_change
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()]) high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
logger.debug(f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}") logger.debug(
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}"
)
else: else:
remaining_time = 120 - time_since_last_change remaining_time = 120 - time_since_last_change
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}") logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}")

View File

@@ -107,7 +107,6 @@ class S4UStreamGenerator:
model_name: str, model_name: str,
**kwargs, **kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
buffer = "" buffer = ""
delimiters = ",。!?,.!?\n\r" # For final trimming delimiters = ",。!?,.!?\n\r" # For final trimming
punctuation_buffer = "" punctuation_buffer = ""

View File

@@ -43,6 +43,7 @@ logger = get_logger("watching")
class WatchingState(Enum): class WatchingState(Enum):
"""视线状态枚举""" """视线状态枚举"""
WANDERING = "wandering" # 随意看 WANDERING = "wandering" # 随意看
DANMU = "danmu" # 看弹幕 DANMU = "danmu" # 看弹幕
LENS = "lens" # 看镜头 LENS = "lens" # 看镜头
@@ -109,20 +110,14 @@ class ChatWatching:
await asyncio.sleep(self.danmu_viewing_duration) await asyncio.sleep(self.danmu_viewing_duration)
# 检查是否仍需要切换(可能状态已经被其他事件改变) # 检查是否仍需要切换(可能状态已经被其他事件改变)
if (self.reply_finished_time is not None and if self.reply_finished_time is not None and self.current_state == WatchingState.DANMU and not self.is_replying:
self.current_state == WatchingState.DANMU and
not self.is_replying):
await self._change_state(WatchingState.LENS, "看弹幕时间结束") await self._change_state(WatchingState.LENS, "看弹幕时间结束")
self.reply_finished_time = None # 重置完成时间 self.reply_finished_time = None # 重置完成时间
async def _send_watching_update(self): async def _send_watching_update(self):
"""立即发送视线状态更新""" """立即发送视线状态更新"""
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="watching", message_type="watching", content=self.current_state.value, stream_id=self.chat_id, storage_message=False
content=self.current_state.value,
stream_id=self.chat_id,
storage_message=False
) )
logger.info(f"[{self.chat_id}] 发送视线状态更新: {self.current_state.value}") logger.info(f"[{self.chat_id}] 发送视线状态更新: {self.current_state.value}")
@@ -139,11 +134,10 @@ class ChatWatching:
"current_state": self.current_state.value, "current_state": self.current_state.value,
"is_replying": self.is_replying, "is_replying": self.is_replying,
"reply_finished_time": self.reply_finished_time, "reply_finished_time": self.reply_finished_time,
"state_needs_update": self.state_needs_update "state_needs_update": self.state_needs_update,
} }
class WatchingManager: class WatchingManager:
def __init__(self): def __init__(self):
self.watching_list: list[ChatWatching] = [] self.watching_list: list[ChatWatching] = []
@@ -200,10 +194,7 @@ class WatchingManager:
def get_all_watching_info(self) -> dict: def get_all_watching_info(self) -> dict:
"""获取所有聊天的视线状态信息(用于调试)""" """获取所有聊天的视线状态信息(用于调试)"""
return { return {watching.chat_id: watching.get_state_info() for watching in self.watching_list}
watching.chat_id: watching.get_state_info()
for watching in self.watching_list
}
# 全局视线管理器实例 # 全局视线管理器实例

View File

@@ -92,7 +92,7 @@ class ChatMood:
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_change_time, timestamp_start=self.last_change_time,
timestamp_end=message_time, timestamp_end=message_time,
limit=int(global_config.chat.max_context_size/3), limit=int(global_config.chat.max_context_size / 3),
limit_mode="last", limit_mode="last",
) )
chat_talking_prompt = build_readable_messages( chat_talking_prompt = build_readable_messages(
@@ -121,8 +121,6 @@ class ChatMood:
mood_state=self.mood_state, mood_state=self.mood_state,
) )
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} prompt: {prompt}") logger.info(f"{self.log_prefix} prompt: {prompt}")
@@ -170,7 +168,6 @@ class ChatMood:
mood_state=self.mood_state, mood_state=self.mood_state,
) )
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
if global_config.debug.show_prompt: if global_config.debug.show_prompt:

View File

@@ -39,7 +39,12 @@ class ChatManager:
Returns: Returns:
List[ChatStream]: 聊天流列表 List[ChatStream]: 聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
""" """
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in get_chat_manager().streams.items():
@@ -60,6 +65,8 @@ class ChatManager:
Returns: Returns:
List[ChatStream]: 群聊聊天流列表 List[ChatStream]: 群聊聊天流列表
""" """
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in get_chat_manager().streams.items():
@@ -79,7 +86,12 @@ class ChatManager:
Returns: Returns:
List[ChatStream]: 私聊聊天流列表 List[ChatStream]: 私聊聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
""" """
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in get_chat_manager().streams.items():
@@ -102,7 +114,17 @@ class ChatManager:
Returns: Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None Optional[ChatStream]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 group_id 为空字符串
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
""" """
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not group_id:
raise ValueError("group_id 不能为空")
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in get_chat_manager().streams.items():
if ( if (
@@ -129,7 +151,17 @@ class ChatManager:
Returns: Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None Optional[ChatStream]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 user_id 为空字符串
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
""" """
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
if not user_id:
raise ValueError("user_id 不能为空")
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in get_chat_manager().streams.items():
if ( if (
@@ -153,9 +185,15 @@ class ChatManager:
Returns: Returns:
str: 聊天类型 ("group", "private", "unknown") str: 聊天类型 ("group", "private", "unknown")
Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型
ValueError: 如果 chat_stream 为空
""" """
if not isinstance(chat_stream, ChatStream):
raise TypeError("chat_stream 必须是 ChatStream 类型")
if not chat_stream: if not chat_stream:
raise ValueError("chat_stream cannot be None") raise ValueError("chat_stream 不能为 None")
if hasattr(chat_stream, "group_info"): if hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private" return "group" if chat_stream.group_info else "private"
@@ -170,9 +208,15 @@ class ChatManager:
Returns: Returns:
Dict[str, Any]: 聊天流信息字典 Dict[str, Any]: 聊天流信息字典
Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型
ValueError: 如果 chat_stream 为空
""" """
if not chat_stream: if not chat_stream:
return {} raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, ChatStream):
raise TypeError("chat_stream 必须是 ChatStream 类型")
try: try:
info: Dict[str, Any] = { info: Dict[str, Any] = {

View File

@@ -8,6 +8,8 @@
count = emoji_api.get_count() count = emoji_api.get_count()
""" """
import random
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
@@ -29,7 +31,15 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
Returns: Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
Raises:
ValueError: 如果描述为空字符串
TypeError: 如果描述不是字符串类型
""" """
if not description:
raise ValueError("描述不能为空")
if not isinstance(description, str):
raise TypeError("描述必须是字符串类型")
try: try:
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}") logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
@@ -55,7 +65,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
return None return None
async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]: async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]:
"""随机获取指定数量的表情包 """随机获取指定数量的表情包
Args: Args:
@@ -63,8 +73,17 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
Returns: Returns:
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表如果失败则为None Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表如果失败则为None
Raises:
TypeError: 如果count不是整数类型
ValueError: 如果count为负数
""" """
if count <= 0: if not isinstance(count, int):
raise TypeError("count 必须是整数类型")
if count < 0:
raise ValueError("count 不能为负数")
if count == 0:
logger.warning("[EmojiAPI] count 为0返回空列表")
return [] return []
try: try:
@@ -90,8 +109,6 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
count = len(valid_emojis) count = len(valid_emojis)
# 随机选择 # 随机选择
import random
selected_emojis = random.sample(valid_emojis, count) selected_emojis = random.sample(valid_emojis, count)
results = [] results = []
@@ -128,7 +145,15 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
Returns: Returns:
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
Raises:
ValueError: 如果情感标签为空字符串
TypeError: 如果情感标签不是字符串类型
""" """
if not emotion:
raise ValueError("情感标签不能为空")
if not isinstance(emotion, str):
raise TypeError("情感标签必须是字符串类型")
try: try:
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}") logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
@@ -146,8 +171,6 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
return None return None
# 随机选择匹配的表情包 # 随机选择匹配的表情包
import random
selected_emoji = random.choice(matching_emojis) selected_emoji = random.choice(matching_emojis)
emoji_base64 = image_path_to_base64(selected_emoji.full_path) emoji_base64 = image_path_to_base64(selected_emoji.full_path)
@@ -185,11 +208,11 @@ def get_count() -> int:
return 0 return 0
def get_info() -> dict: def get_info():
"""获取表情包系统信息 """获取表情包系统信息
Returns: Returns:
dict: 包含表情包数量、最大数量信息 dict: 包含表情包数量、最大数量、可用数量信息
""" """
try: try:
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
@@ -203,7 +226,7 @@ def get_info() -> dict:
return {"current_count": 0, "max_count": 0, "available_emojis": 0} return {"current_count": 0, "max_count": 0, "available_emojis": 0}
def get_emotions() -> list: def get_emotions() -> List[str]:
"""获取所有可用的情感标签 """获取所有可用的情感标签
Returns: Returns:
@@ -223,7 +246,7 @@ def get_emotions() -> list:
return [] return []
def get_descriptions() -> list: def get_descriptions() -> List[str]:
"""获取所有表情包描述 """获取所有表情包描述
Returns: Returns:

View File

@@ -5,11 +5,12 @@
使用方式: 使用方式:
from src.plugin_system.apis import generator_api from src.plugin_system.apis import generator_api
replyer = generator_api.get_replyer(chat_stream) replyer = generator_api.get_replyer(chat_stream)
success, reply_set = await generator_api.generate_reply(chat_stream, action_data, reasoning) success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
""" """
import traceback import traceback
from typing import Tuple, Any, Dict, List, Optional from typing import Tuple, Any, Dict, List, Optional
from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
@@ -17,6 +18,8 @@ from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo from src.plugin_system.base.component_types import ActionInfo
install(extra_lines=3)
logger = get_logger("generator_api") logger = get_logger("generator_api")
@@ -44,7 +47,12 @@ def get_replyer(
Returns: Returns:
Optional[DefaultReplyer]: 回复器对象如果获取失败则返回None Optional[DefaultReplyer]: 回复器对象如果获取失败则返回None
Raises:
ValueError: chat_stream 和 chat_id 均为空
""" """
if not chat_id and not chat_stream:
raise ValueError("chat_stream 和 chat_id 不可均为空")
try: try:
logger.debug(f"[GeneratorAPI] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}") logger.debug(f"[GeneratorAPI] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}")
return replyer_manager.get_replyer( return replyer_manager.get_replyer(

View File

@@ -14,7 +14,6 @@ from src.config.config import global_config
logger = get_logger("llm_api") logger = get_logger("llm_api")
# ============================================================================= # =============================================================================
# LLM模型API函数 # LLM模型API函数
# ============================================================================= # =============================================================================
@@ -31,8 +30,21 @@ def get_available_models() -> Dict[str, Any]:
logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置") logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置")
return {} return {}
# 自动获取所有属性并转换为字典形式
rets = {}
models = global_config.model models = global_config.model
return models attrs = dir(models)
for attr in attrs:
if not attr.startswith("__"):
try:
value = getattr(models, attr)
if not callable(value): # 排除方法
rets[attr] = value
except Exception as e:
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
continue
return rets
except Exception as e: except Exception as e:
logger.error(f"[LLMAPI] 获取可用模型失败: {e}") logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
return {} return {}

View File

@@ -114,7 +114,11 @@ async def _send_to_target(
# 发送消息 # 发送消息
sent_msg = await heart_fc_sender.send_message( sent_msg = await heart_fc_sender.send_message(
bot_message, typing=typing, set_reply=(anchor_message is not None), storage_message=storage_message, show_log=show_log bot_message,
typing=typing,
set_reply=(anchor_message is not None),
storage_message=storage_message,
show_log=show_log,
) )
if sent_msg: if sent_msg:
@@ -363,7 +367,9 @@ async def custom_to_stream(
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log) return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log
)
async def text_to_group( async def text_to_group(