🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-06-19 12:22:36 +00:00
parent 0467f97e7c
commit 7ed3ecb561
26 changed files with 450 additions and 450 deletions

View File

@@ -1183,4 +1183,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -370,7 +370,6 @@ class VirtualLogDisplay:
self.text_widget.tag_add(tag_name, f"{start_pos}+{tag_info[1]}c", f"{start_pos}+{tag_info[2]}c") self.text_widget.tag_add(tag_name, f"{start_pos}+{tag_info[1]}c", f"{start_pos}+{tag_info[2]}c")
class AsyncLogLoader: class AsyncLogLoader:
"""异步日志加载器""" """异步日志加载器"""

View File

@@ -99,7 +99,6 @@ class HeartFCSender:
message.build_reply() message.build_reply()
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...") logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
await message.process() await message.process()
if typing: if typing:
@@ -110,7 +109,6 @@ class HeartFCSender:
) )
await asyncio.sleep(typing_time) await asyncio.sleep(typing_time)
sent_msg = await send_message(message) sent_msg = await send_message(message)
if not sent_msg: if not sent_msg:
return False return False

View File

@@ -12,7 +12,9 @@ class BasePlanner(ABC):
self.action_manager = action_manager self.action_manager = action_manager
@abstractmethod @abstractmethod
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]], loop_start_time: float) -> Dict[str, Any]: async def plan(
self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]], loop_start_time: float
) -> Dict[str, Any]:
""" """
规划下一步行动 规划下一步行动

View File

@@ -82,7 +82,9 @@ class ActionPlanner(BasePlanner):
request_type="focus.planner", # 用于动作规划 request_type="focus.planner", # 用于动作规划
) )
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]], loop_start_time: float) -> Dict[str, Any]: async def plan(
self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]], loop_start_time: float
) -> Dict[str, Any]:
""" """
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。

View File

@@ -166,8 +166,6 @@ class DefaultReplyer:
(已整合原 HeartFCGenerator 的功能) (已整合原 HeartFCGenerator 的功能)
""" """
try: try:
# 3. 构建 Prompt # 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context( prompt = await self.build_prompt_reply_context(
@@ -206,7 +204,7 @@ class DefaultReplyer:
reply_seg = ("text", str) reply_seg = ("text", str)
reply_set.append(reply_seg) reply_set.append(reply_seg)
return True , reply_set return True, reply_set
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}回复生成意外失败: {e}") logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
@@ -221,8 +219,6 @@ class DefaultReplyer:
表达器 (Expressor): 核心逻辑,负责生成回复文本。 表达器 (Expressor): 核心逻辑,负责生成回复文本。
""" """
try: try:
reply_to = reply_data.get("reply_to", "") reply_to = reply_data.get("reply_to", "")
raw_reply = reply_data.get("raw_reply", "") raw_reply = reply_data.get("raw_reply", "")
reason = reply_data.get("reason", "") reason = reply_data.get("reason", "")
@@ -276,10 +272,6 @@ class DefaultReplyer:
traceback.print_exc() traceback.print_exc()
return False, None return False, None
async def build_prompt_reply_context( async def build_prompt_reply_context(
self, self,
reply_data=None, reply_data=None,
@@ -302,7 +294,6 @@ class DefaultReplyer:
sender = parts[0].strip() sender = parts[0].strip()
target = parts[1].strip() target = parts[1].strip()
message_list_before_now = get_raw_msg_before_timestamp_with_chat( message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
timestamp=time.time(), timestamp=time.time(),
@@ -379,9 +370,7 @@ class DefaultReplyer:
# logger.debug("开始构建 focus prompt") # logger.debug("开始构建 focus prompt")
if sender: if sender:
reply_target_block = ( reply_target_block = f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif target: elif target:
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。" reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
else: else:
@@ -436,9 +425,6 @@ class DefaultReplyer:
raw_reply, raw_reply,
reply_to, reply_to,
) -> str: ) -> str:
sender = "" sender = ""
target = "" target = ""
if ":" in reply_to or "" in reply_to: if ":" in reply_to or "" in reply_to:
@@ -608,9 +594,11 @@ class DefaultReplyer:
) )
try: try:
if (bot_message.is_private_message() or if (
bot_message.reply.processed_plain_text != "[System Trigger Context]" or bot_message.is_private_message()
mark_head): or bot_message.reply.processed_plain_text != "[System Trigger Context]"
or mark_head
):
set_reply = False set_reply = False
else: else:
set_reply = True set_reply = True
@@ -621,10 +609,7 @@ class DefaultReplyer:
else: else:
typing = True typing = True
sent_msg = await self.heart_fc_sender.send_message(bot_message, typing=typing, set_reply=set_reply)
sent_msg = await self.heart_fc_sender.send_message(
bot_message, typing=typing, set_reply=set_reply
)
reply_message_ids.append(part_message_id) # 记录我们生成的ID reply_message_ids.append(part_message_id) # 记录我们生成的ID
@@ -652,7 +637,7 @@ class DefaultReplyer:
is_emoji: bool, is_emoji: bool,
thinking_start_time: float, thinking_start_time: float,
display_message: str, display_message: str,
anchor_message: MessageRecv = None anchor_message: MessageRecv = None,
) -> MessageSending: ) -> MessageSending:
"""构建单个发送消息""" """构建单个发送消息"""

View File

@@ -16,6 +16,7 @@ from src.common.logger import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
logger = get_logger("observation") logger = get_logger("observation")
# 定义提示模板 # 定义提示模板
@@ -71,7 +72,7 @@ class ChattingObservation(Observation):
self.oldest_messages = [] self.oldest_messages = []
self.oldest_messages_str = "" self.oldest_messages_str = ""
self.last_observe_time = datetime.now().timestamp() -1 self.last_observe_time = datetime.now().timestamp() - 1
print(f"last_observe_time: {self.last_observe_time}") print(f"last_observe_time: {self.last_observe_time}")
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10) initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
@@ -162,7 +163,6 @@ class ChattingObservation(Observation):
find_rec_msg.update_chat_stream(get_chat_manager().get_or_create_stream(self.chat_id)) find_rec_msg.update_chat_stream(get_chat_manager().get_or_create_stream(self.chat_id))
return find_rec_msg return find_rec_msg
async def observe(self): async def observe(self):

View File

@@ -15,7 +15,7 @@ from src.plugin_system.apis import (
message_api, message_api,
person_api, person_api,
send_api, send_api,
utils_api utils_api,
) )
# 导出所有API模块使它们可以通过 apis.xxx 方式访问 # 导出所有API模块使它们可以通过 apis.xxx 方式访问
@@ -29,5 +29,5 @@ __all__ = [
"message_api", "message_api",
"person_api", "person_api",
"send_api", "send_api",
"utils_api" "utils_api",
] ]

View File

@@ -173,16 +173,20 @@ class ChatManager:
} }
if chat_stream.group_info: if chat_stream.group_info:
info.update({ info.update(
{
"group_id": chat_stream.group_info.group_id, "group_id": chat_stream.group_info.group_id,
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"), "group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"),
}) }
)
if chat_stream.user_info: if chat_stream.user_info:
info.update({ info.update(
{
"user_id": chat_stream.user_info.user_id, "user_id": chat_stream.user_info.user_id,
"user_name": chat_stream.user_info.user_nickname, "user_name": chat_stream.user_info.user_nickname,
}) }
)
return info return info
except Exception as e: except Exception as e:
@@ -252,6 +256,7 @@ class ChatManager:
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计 # 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
# ============================================================================= # =============================================================================
def get_all_streams(platform: str = "qq") -> List[ChatStream]: def get_all_streams(platform: str = "qq") -> List[ChatStream]:
"""获取所有聊天流的便捷函数""" """获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform) return ChatManager.get_all_streams(platform)

View File

@@ -19,6 +19,7 @@ logger = get_logger("config_api")
# 配置访问API函数 # 配置访问API函数
# ============================================================================= # =============================================================================
def get_global_config(key: str, default: Any = None) -> Any: def get_global_config(key: str, default: Any = None) -> Any:
""" """
安全地从全局配置中获取一个值。 安全地从全局配置中获取一个值。
@@ -79,6 +80,7 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
# 用户信息API函数 # 用户信息API函数
# ============================================================================= # =============================================================================
async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]: async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
"""根据用户名获取用户ID """根据用户名获取用户ID

View File

@@ -18,6 +18,7 @@ logger = get_logger("database_api")
# 通用数据库查询API函数 # 通用数据库查询API函数
# ============================================================================= # =============================================================================
async def db_query( async def db_query(
model_class: Type[Model], model_class: Type[Model],
query_type: str = "get", query_type: str = "get",
@@ -202,9 +203,7 @@ async def db_save(
# 如果提供了key_field和key_value尝试更新现有记录 # 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None: if key_field and key_value is not None:
# 查找现有记录 # 查找现有记录
existing_records = list( existing_records = list(model_class.select().where(getattr(model_class, key_field) == key_value).limit(1))
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
)
if existing_records: if existing_records:
# 更新现有记录 # 更新现有记录
@@ -352,25 +351,26 @@ async def store_action_info(
# 从chat_stream获取聊天信息 # 从chat_stream获取聊天信息
if chat_stream: if chat_stream:
record_data.update({ record_data.update(
"chat_id": getattr(chat_stream, 'stream_id', ''), {
"chat_info_stream_id": getattr(chat_stream, 'stream_id', ''), "chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, 'platform', ''), "chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
}) "chat_info_platform": getattr(chat_stream, "platform", ""),
}
)
else: else:
# 如果没有chat_stream设置默认值 # 如果没有chat_stream设置默认值
record_data.update({ record_data.update(
{
"chat_id": "", "chat_id": "",
"chat_info_stream_id": "", "chat_info_stream_id": "",
"chat_info_platform": "", "chat_info_platform": "",
}) }
)
# 使用已有的db_save函数保存记录 # 使用已有的db_save函数保存记录
saved_record = await db_save( saved_record = await db_save(
ActionRecords, ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
data=record_data,
key_field="action_id",
key_value=record_data["action_id"]
) )
if saved_record: if saved_record:

View File

@@ -20,6 +20,7 @@ logger = get_logger("emoji_api")
# 表情包获取API函数 # 表情包获取API函数
# ============================================================================= # =============================================================================
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]: async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
"""根据描述选择表情包 """根据描述选择表情包
@@ -78,6 +79,7 @@ async def get_random() -> Optional[Tuple[str, str, str]]:
# 随机选择 # 随机选择
import random import random
selected_emoji = random.choice(valid_emojis) selected_emoji = random.choice(valid_emojis)
emoji_base64 = image_path_to_base64(selected_emoji.full_path) emoji_base64 = image_path_to_base64(selected_emoji.full_path)
@@ -125,6 +127,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
# 随机选择匹配的表情包 # 随机选择匹配的表情包
import random import random
selected_emoji = random.choice(matching_emojis) selected_emoji = random.choice(matching_emojis)
emoji_base64 = image_path_to_base64(selected_emoji.full_path) emoji_base64 = image_path_to_base64(selected_emoji.full_path)
@@ -147,6 +150,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
# 表情包信息查询API函数 # 表情包信息查询API函数
# ============================================================================= # =============================================================================
def get_count() -> int: def get_count() -> int:
"""获取表情包数量 """获取表情包数量

View File

@@ -16,11 +16,11 @@ from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("generator_api") logger = get_logger("generator_api")
# ============================================================================= # =============================================================================
# 回复器获取API函数 # 回复器获取API函数
# ============================================================================= # =============================================================================
def get_replyer(chat_stream=None, platform: str = None, chat_id: str = None, is_group: bool = True) -> DefaultReplyer: def get_replyer(chat_stream=None, platform: str = None, chat_id: str = None, is_group: bool = True) -> DefaultReplyer:
"""获取回复器对象 """获取回复器对象
@@ -71,16 +71,18 @@ def get_replyer(chat_stream=None, platform: str = None, chat_id: str = None, is_
logger.error(f"[GeneratorAPI] 获取回复器失败: {e}") logger.error(f"[GeneratorAPI] 获取回复器失败: {e}")
return None return None
# ============================================================================= # =============================================================================
# 回复生成API函数 # 回复生成API函数
# ============================================================================= # =============================================================================
async def generate_reply( async def generate_reply(
chat_stream=None, chat_stream=None,
action_data: Dict[str, Any] = None, action_data: Dict[str, Any] = None,
platform: str = None, platform: str = None,
chat_id: str = None, chat_id: str = None,
is_group: bool = True is_group: bool = True,
) -> Tuple[bool, List[Tuple[str, Any]]]: ) -> Tuple[bool, List[Tuple[str, Any]]]:
"""生成回复 """生成回复
@@ -123,12 +125,13 @@ async def generate_reply(
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}") logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
return False, [] return False, []
async def rewrite_reply( async def rewrite_reply(
chat_stream=None, chat_stream=None,
reply_data: Dict[str, Any] = None, reply_data: Dict[str, Any] = None,
platform: str = None, platform: str = None,
chat_id: str = None, chat_id: str = None,
is_group: bool = True is_group: bool = True,
) -> Tuple[bool, List[Tuple[str, Any]]]: ) -> Tuple[bool, List[Tuple[str, Any]]]:
"""重写回复 """重写回复
@@ -166,5 +169,3 @@ async def rewrite_reply(
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
return False, [] return False, []

View File

@@ -19,6 +19,7 @@ logger = get_logger("llm_api")
# LLM模型API函数 # LLM模型API函数
# ============================================================================= # =============================================================================
def get_available_models() -> Dict[str, Any]: def get_available_models() -> Dict[str, Any]:
"""获取所有可用的模型配置 """获取所有可用的模型配置

View File

@@ -32,6 +32,7 @@ from src.chat.utils.chat_message_builder import (
# 消息查询API函数 # 消息查询API函数
# ============================================================================= # =============================================================================
def get_messages_by_time( def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest" start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
@@ -179,9 +180,7 @@ def get_messages_before_time_in_chat(chat_id: str, timestamp: float, limit: int
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
def get_messages_before_time_for_users( def get_messages_before_time_for_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
timestamp: float, person_ids: list, limit: int = 0
) -> List[Dict[str, Any]]:
""" """
获取指定用户在指定时间戳之前的消息 获取指定用户在指定时间戳之前的消息
@@ -197,10 +196,7 @@ def get_messages_before_time_for_users(
def get_recent_messages( def get_recent_messages(
chat_id: str, chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest"
hours: float = 24.0,
limit: int = 100,
limit_mode: str = "latest"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取指定聊天中最近一段时间的消息 获取指定聊天中最近一段时间的消息
@@ -223,9 +219,8 @@ def get_recent_messages(
# 消息计数API函数 # 消息计数API函数
# ============================================================================= # =============================================================================
def count_new_messages(
chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
) -> int:
""" """
计算指定聊天中从开始时间到结束时间的新消息数量 计算指定聊天中从开始时间到结束时间的新消息数量
@@ -240,9 +235,7 @@ def count_new_messages(
return num_new_messages_since(chat_id, start_time, end_time) return num_new_messages_since(chat_id, start_time, end_time)
def count_new_messages_for_users( def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list) -> int:
chat_id: str, start_time: float, end_time: float, person_ids: list
) -> int:
""" """
计算指定聊天中指定用户从开始时间到结束时间的新消息数量 计算指定聊天中指定用户从开始时间到结束时间的新消息数量
@@ -262,6 +255,7 @@ def count_new_messages_for_users(
# 消息格式化API函数 # 消息格式化API函数
# ============================================================================= # =============================================================================
def build_readable_messages_to_str( def build_readable_messages_to_str(
messages: List[Dict[str, Any]], messages: List[Dict[str, Any]],
replace_bot_name: bool = True, replace_bot_name: bool = True,
@@ -311,9 +305,7 @@ async def build_readable_messages_with_details(
Returns: Returns:
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容) 格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
""" """
return await build_readable_messages_with_list( return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]: async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:

View File

@@ -18,6 +18,7 @@ logger = get_logger("person_api")
# 个人信息API函数 # 个人信息API函数
# ============================================================================= # =============================================================================
def get_person_id(platform: str, user_id: int) -> str: def get_person_id(platform: str, user_id: int) -> str:
"""根据平台和用户ID获取person_id """根据平台和用户ID获取person_id

View File

@@ -31,6 +31,7 @@ logger = get_logger("send_api")
# 内部实现函数(不暴露给外部) # 内部实现函数(不暴露给外部)
# ============================================================================= # =============================================================================
async def _send_to_target( async def _send_to_target(
message_type: str, message_type: str,
content: str, content: str,
@@ -147,7 +148,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
reverse_talking_message = get_raw_msg_before_timestamp_with_chat( reverse_talking_message = get_raw_msg_before_timestamp_with_chat(
target_stream.stream_id, target_stream.stream_id,
time.time(), # 当前时间之前的消息 time.time(), # 当前时间之前的消息
20 # 最新的20条消息 20, # 最新的20条消息
) )
# 反转列表,使最新的消息在前面 # 反转列表,使最新的消息在前面
@@ -222,7 +223,15 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
# 公共API函数 - 预定义类型的发送函数 # 公共API函数 - 预定义类型的发送函数
# ============================================================================= # =============================================================================
async def text_to_group(text: str, group_id: str, platform: str = "qq", typing: bool = False, reply_to: str = "", storage_message: bool = True) -> bool:
async def text_to_group(
text: str,
group_id: str,
platform: str = "qq",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向群聊发送文本消息 """向群聊发送文本消息
Args: Args:
@@ -240,7 +249,14 @@ async def text_to_group(text: str, group_id: str, platform: str = "qq", typing:
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message) return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
async def text_to_user(text: str, user_id: str, platform: str = "qq", typing: bool = False, reply_to: str = "", storage_message: bool = True) -> bool: async def text_to_user(
text: str,
user_id: str,
platform: str = "qq",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向用户发送私聊文本消息 """向用户发送私聊文本消息
Args: Args:
@@ -316,6 +332,7 @@ async def image_to_user(image_base64: str, user_id: str, platform: str = "qq", s
stream_id = get_chat_manager().get_stream_id(platform, user_id, False) stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("image", image_base64, stream_id, "", typing=False) return await _send_to_target("image", image_base64, stream_id, "", typing=False)
async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool: async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送命令 """向群聊发送命令
@@ -330,6 +347,7 @@ async def command_to_group(command: str, group_id: str, platform: str = "qq", st
stream_id = get_chat_manager().get_stream_id(platform, group_id, True) stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message) return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool: async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送命令 """向用户发送命令
@@ -349,6 +367,7 @@ async def command_to_user(command: str, user_id: str, platform: str = "qq", stor
# 通用发送函数 - 支持任意消息类型 # 通用发送函数 - 支持任意消息类型
# ============================================================================= # =============================================================================
async def custom_to_group( async def custom_to_group(
message_type: str, message_type: str,
content: str, content: str,
@@ -357,7 +376,7 @@ async def custom_to_group(
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
reply_to: str = "", reply_to: str = "",
storage_message: bool = True storage_message: bool = True,
) -> bool: ) -> bool:
"""向群聊发送自定义类型消息 """向群聊发送自定义类型消息
@@ -385,7 +404,7 @@ async def custom_to_user(
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
reply_to: str = "", reply_to: str = "",
storage_message: bool = True storage_message: bool = True,
) -> bool: ) -> bool:
"""向用户发送自定义类型消息 """向用户发送自定义类型消息
@@ -414,7 +433,7 @@ async def custom_message(
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
reply_to: str = "", reply_to: str = "",
storage_message: bool = True storage_message: bool = True,
) -> bool: ) -> bool:
"""发送自定义消息的通用接口 """发送自定义消息的通用接口

View File

@@ -24,6 +24,7 @@ logger = get_logger("utils_api")
# 文件操作API函数 # 文件操作API函数
# ============================================================================= # =============================================================================
def get_plugin_path(caller_frame=None) -> str: def get_plugin_path(caller_frame=None) -> str:
"""获取调用者插件的路径 """获取调用者插件的路径
@@ -106,6 +107,7 @@ def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool:
# 时间相关API函数 # 时间相关API函数
# ============================================================================= # =============================================================================
def get_timestamp() -> int: def get_timestamp() -> int:
"""获取当前时间戳 """获取当前时间戳
@@ -156,6 +158,7 @@ def parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
# 其他工具函数 # 其他工具函数
# ============================================================================= # =============================================================================
def generate_unique_id() -> str: def generate_unique_id() -> str:
"""生成唯一ID """生成唯一ID

View File

@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from typing import Tuple, Optional from typing import Tuple, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api,message_api from src.plugin_system.apis import send_api, database_api, message_api
import time import time
import asyncio import asyncio
@@ -84,7 +84,6 @@ class BaseAction(ABC):
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解) # 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# ============================================================================= # =============================================================================
# 获取聊天流对象 # 获取聊天流对象
self.chat_stream = chat_stream or kwargs.get("chat_stream") self.chat_stream = chat_stream or kwargs.get("chat_stream")
@@ -100,7 +99,7 @@ class BaseAction(ABC):
# 如果有聊天流,提取所有信息 # 如果有聊天流,提取所有信息
if self.chat_stream: if self.chat_stream:
self.platform = getattr(self.chat_stream, 'platform', None) self.platform = getattr(self.chat_stream, "platform", None)
# 获取群聊信息 # 获取群聊信息
# print(self.chat_stream) # print(self.chat_stream)
@@ -108,17 +107,19 @@ class BaseAction(ABC):
if self.chat_stream.group_info: if self.chat_stream.group_info:
self.is_group = True self.is_group = True
self.group_id = str(self.chat_stream.group_info.group_id) self.group_id = str(self.chat_stream.group_info.group_id)
self.group_name = getattr(self.chat_stream.group_info, 'group_name', None) self.group_name = getattr(self.chat_stream.group_info, "group_name", None)
else: else:
self.is_group = False self.is_group = False
self.user_id = str(self.chat_stream.user_info.user_id) self.user_id = str(self.chat_stream.user_info.user_id)
self.user_nickname = getattr(self.chat_stream.user_info, 'user_nickname', None) self.user_nickname = getattr(self.chat_stream.user_info, "user_nickname", None)
# 设置目标ID群聊用群ID私聊用户ID # 设置目标ID群聊用群ID私聊用户ID
self.target_id = self.group_id if self.is_group else self.user_id self.target_id = self.group_id if self.is_group else self.user_id
logger.debug(f"{self.log_prefix} Action组件初始化完成") logger.debug(f"{self.log_prefix} Action组件初始化完成")
logger.debug(f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}") logger.debug(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
)
def _get_activation_type_value(self, attr_name: str, default: str) -> str: def _get_activation_type_value(self, attr_name: str, default: str) -> str:
"""获取激活类型的字符串值""" """获取激活类型的字符串值"""
@@ -138,7 +139,6 @@ class BaseAction(ABC):
return attr.value return attr.value
return str(attr) return str(attr)
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]: async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
"""等待新消息或超时 """等待新消息或超时
@@ -172,9 +172,7 @@ class BaseAction(ABC):
# 检查新消息 # 检查新消息
current_time = time.time() current_time = time.time()
new_message_count = message_api.count_new_messages( new_message_count = message_api.count_new_messages(
chat_id=self.chat_id, chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
start_time=loop_start_time,
end_time=current_time
) )
if new_message_count > 0: if new_message_count > 0:
@@ -289,7 +287,7 @@ class BaseAction(ABC):
target_id=self.target_id, target_id=self.target_id,
is_group=self.is_group, is_group=self.is_group,
platform=self.platform, platform=self.platform,
typing=typing typing=typing,
) )
async def store_action_info( async def store_action_info(
@@ -315,7 +313,9 @@ class BaseAction(ABC):
action_name=self.action_name, action_name=self.action_name,
) )
async def send_command(self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True) -> bool: async def send_command(
self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True
) -> bool:
"""发送命令消息 """发送命令消息
使用和send_text相同的方式通过MessageAPI发送命令 使用和send_text相同的方式通过MessageAPI发送命令
@@ -338,7 +338,7 @@ class BaseAction(ABC):
command=command_data, command=command_data,
group_id=str(self.group_id), group_id=str(self.group_id),
platform=self.platform, platform=self.platform,
storage_message=storage_message storage_message=storage_message,
) )
else: else:
# 私聊 # 私聊
@@ -346,7 +346,7 @@ class BaseAction(ABC):
command=command_data, command=command_data,
user_id=str(self.user_id), user_id=str(self.user_id),
platform=self.platform, platform=self.platform,
storage_message=storage_message storage_message=storage_message,
) )
if success: if success:

View File

@@ -99,17 +99,13 @@ class BaseCommand(ABC):
# 群聊 # 群聊
await send_api.text_to_group( await send_api.text_to_group(
text=content, text=content, group_id=str(chat_stream.group_info.group_id), platform=chat_stream.platform
group_id=str(chat_stream.group_info.group_id),
platform=chat_stream.platform
) )
else: else:
# 私聊 # 私聊
await send_api.text_to_user( await send_api.text_to_user(
text=content, text=content, user_id=str(chat_stream.user_info.user_id), platform=chat_stream.platform
user_id=str(chat_stream.user_info.user_id),
platform=chat_stream.platform
) )
async def send_type( async def send_type(
@@ -131,6 +127,7 @@ class BaseCommand(ABC):
if chat_stream.group_info: if chat_stream.group_info:
# 群聊 # 群聊
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
return await send_api.custom_message( return await send_api.custom_message(
message_type=message_type, message_type=message_type,
content=content, content=content,
@@ -142,6 +139,7 @@ class BaseCommand(ABC):
else: else:
# 私聊 # 私聊
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
return await send_api.custom_message( return await send_api.custom_message(
message_type=message_type, message_type=message_type,
content=content, content=content,
@@ -172,6 +170,7 @@ class BaseCommand(ABC):
if chat_stream.group_info: if chat_stream.group_info:
# 群聊 # 群聊
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
success = await send_api.custom_message( success = await send_api.custom_message(
message_type="command", message_type="command",
content=command_data, content=command_data,
@@ -182,6 +181,7 @@ class BaseCommand(ABC):
else: else:
# 私聊 # 私聊
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
success = await send_api.custom_message( success = await send_api.custom_message(
message_type="command", message_type="command",
content=command_data, content=command_data,

View File

@@ -54,7 +54,7 @@ class ComponentRegistry:
""" """
component_name = component_info.name component_name = component_info.name
component_type = component_info.component_type component_type = component_info.component_type
plugin_name = getattr(component_info, 'plugin_name', 'unknown') plugin_name = getattr(component_info, "plugin_name", "unknown")
# 🔥 系统级别自动区分:为不同类型的组件添加命名空间前缀 # 🔥 系统级别自动区分:为不同类型的组件添加命名空间前缀
if component_type == ComponentType.ACTION: if component_type == ComponentType.ACTION:
@@ -68,7 +68,7 @@ class ComponentRegistry:
# 检查命名空间化的名称是否冲突 # 检查命名空间化的名称是否冲突
if namespaced_name in self._components: if namespaced_name in self._components:
existing_info = self._components[namespaced_name] existing_info = self._components[namespaced_name]
existing_plugin = getattr(existing_info, 'plugin_name', 'unknown') existing_plugin = getattr(existing_info, "plugin_name", "unknown")
logger.warning( logger.warning(
f"组件冲突: {component_type.value}组件 '{component_name}' " f"组件冲突: {component_type.value}组件 '{component_name}' "
@@ -125,7 +125,7 @@ class ComponentRegistry:
Optional[ComponentInfo]: 组件信息或None Optional[ComponentInfo]: 组件信息或None
""" """
# 1. 如果已经是命名空间化的名称,直接查找 # 1. 如果已经是命名空间化的名称,直接查找
if '.' in component_name: if "." in component_name:
return self._components.get(component_name) return self._components.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找 # 2. 如果指定了组件类型,构造命名空间化的名称查找
@@ -154,8 +154,7 @@ class ComponentRegistry:
# 多个匹配,记录警告并返回第一个 # 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates] namespaces = [ns for ns, _, _ in candidates]
logger.warning( logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces}" f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces}使用第一个匹配项: {candidates[0][1]}"
f"使用第一个匹配项: {candidates[0][1]}"
) )
return candidates[0][2] return candidates[0][2]
@@ -173,7 +172,7 @@ class ComponentRegistry:
Optional[Type]: 组件类或None Optional[Type]: 组件类或None
""" """
# 1. 如果已经是命名空间化的名称,直接查找 # 1. 如果已经是命名空间化的名称,直接查找
if '.' in component_name: if "." in component_name:
return self._component_classes.get(component_name) return self._component_classes.get(component_name)
# 2. 如果指定了组件类型,构造命名空间化的名称查找 # 2. 如果指定了组件类型,构造命名空间化的名称查找
@@ -204,8 +203,7 @@ class ComponentRegistry:
# 多个匹配,记录警告并返回第一个 # 多个匹配,记录警告并返回第一个
namespaces = [ns for ns, _, _ in candidates] namespaces = [ns for ns, _, _ in candidates]
logger.warning( logger.warning(
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces}" f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces}使用第一个匹配项: {candidates[0][1]}"
f"使用第一个匹配项: {candidates[0][1]}"
) )
return candidates[0][2] return candidates[0][2]
@@ -262,7 +260,6 @@ class ComponentRegistry:
""" """
for pattern, command_class in self._command_patterns.items(): for pattern, command_class in self._command_patterns.items():
match = pattern.match(text) match = pattern.match(text)
if match: if match:
command_name = None command_name = None
@@ -349,11 +346,15 @@ class ComponentRegistry:
# 根据组件类型构造正确的命名空间化名称 # 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION: if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if '.' not in component_name else component_name namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND: elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if '.' not in component_name else component_name namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else: else:
namespaced_name = f"{component_info.component_type.value}.{component_name}" if '.' not in component_name else component_name namespaced_name = (
f"{component_info.component_type.value}.{component_name}"
if "." not in component_name
else component_name
)
if namespaced_name in self._components: if namespaced_name in self._components:
self._components[namespaced_name].enabled = True self._components[namespaced_name].enabled = True
@@ -373,11 +374,15 @@ class ComponentRegistry:
# 根据组件类型构造正确的命名空间化名称 # 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION: if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if '.' not in component_name else component_name namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND: elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if '.' not in component_name else component_name namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else: else:
namespaced_name = f"{component_info.component_type.value}.{component_name}" if '.' not in component_name else component_name namespaced_name = (
f"{component_info.component_type.value}.{component_name}"
if "." not in component_name
else component_name
)
if namespaced_name in self._components: if namespaced_name in self._components:
self._components[namespaced_name].enabled = False self._components[namespaced_name].enabled = False

View File

@@ -38,9 +38,7 @@ class ReplyAction(BaseAction):
action_description = "参与聊天回复,发送文本进行表达" action_description = "参与聊天回复,发送文本进行表达"
# 动作参数定义 # 动作参数定义
action_parameters = { action_parameters = {"reply_to": "你要回复的对方的发言内容,格式:(用户名:发言内容可以为none"}
"reply_to": "你要回复的对方的发言内容,格式:(用户名:发言内容可以为none"
}
# 动作使用场景 # 动作使用场景
action_require = ["你想要闲聊或者随便附和", "有人提到你", "如果你刚刚进行了回复,不要对同一个话题重复回应"] action_require = ["你想要闲聊或者随便附和", "有人提到你", "如果你刚刚进行了回复,不要对同一个话题重复回应"]
@@ -55,27 +53,26 @@ class ReplyAction(BaseAction):
start_time = self.action_data.get("loop_start_time", time.time()) start_time = self.action_data.get("loop_start_time", time.time())
try: try:
success, reply_set = await generator_api.generate_reply( success, reply_set = await generator_api.generate_reply(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
action_data=self.action_data, action_data=self.action_data,
platform=self.platform, platform=self.platform,
chat_id=self.chat_id, chat_id=self.chat_id,
is_group=self.is_group is_group=self.is_group,
) )
# 检查从start_time以来的新消息数量 # 检查从start_time以来的新消息数量
# 获取动作触发时间或使用默认值 # 获取动作触发时间或使用默认值
current_time = time.time() current_time = time.time()
new_message_count = message_api.count_new_messages( new_message_count = message_api.count_new_messages(
chat_id=self.chat_id, chat_id=self.chat_id, start_time=start_time, end_time=current_time
start_time=start_time,
end_time=current_time
) )
# 根据新消息数量决定是否使用reply_to # 根据新消息数量决定是否使用reply_to
need_reply = new_message_count >= 4 need_reply = new_message_count >= 4
logger.info(f"{self.log_prefix}{start_time}{current_time}共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}reply_to") logger.info(
f"{self.log_prefix}{start_time}{current_time}共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}reply_to"
)
# 构建回复文本 # 构建回复文本
reply_text = "" reply_text = ""
@@ -83,16 +80,12 @@ class ReplyAction(BaseAction):
for reply_seg in reply_set: for reply_seg in reply_set:
data = reply_seg[1] data = reply_seg[1]
if not first_reply and need_reply: if not first_reply and need_reply:
await self.send_text( await self.send_text(content=data, reply_to=self.action_data.get("reply_to", ""))
content=data,
reply_to=self.action_data.get("reply_to", "")
)
else: else:
await self.send_text(content=data) await self.send_text(content=data)
first_reply = True first_reply = True
reply_text += data reply_text += data
# 存储动作记录 # 存储动作记录
await self.store_action_info( await self.store_action_info(
action_build_into_prompt=False, action_build_into_prompt=False,
@@ -110,7 +103,6 @@ class ReplyAction(BaseAction):
return False, f"回复失败: {str(e)}" return False, f"回复失败: {str(e)}"
class NoReplyAction(BaseAction): class NoReplyAction(BaseAction):
"""不回复动作,继承时会等待新消息或超时""" """不回复动作,继承时会等待新消息或超时"""
@@ -430,10 +422,6 @@ class CoreActionsPlugin(BasePlugin):
return components return components
# class DeepReplyAction(BaseAction): # class DeepReplyAction(BaseAction):
# """回复动作 - 参与聊天回复""" # """回复动作 - 参与聊天回复"""
@@ -475,7 +463,6 @@ class CoreActionsPlugin(BasePlugin):
# anchor_message = await create_empty_anchor_message(chat_stream.platform, chat_stream.group_info, chat_stream) # anchor_message = await create_empty_anchor_message(chat_stream.platform, chat_stream.group_info, chat_stream)
# llm_model = self.api.get_available_models().replyer_1 # llm_model = self.api.get_available_models().replyer_1
# prompt = f""" # prompt = f"""
@@ -502,8 +489,6 @@ class CoreActionsPlugin(BasePlugin):
# self.action_data["extra_info_block"] = extra_info_block # self.action_data["extra_info_block"] = extra_info_block
# # 获取回复器服务 # # 获取回复器服务
# # replyer = self.api.get_service("replyer") # # replyer = self.api.get_service("replyer")
# # if not replyer: # # if not replyer:

View File

@@ -26,6 +26,7 @@ from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode
from src.plugin_system.base.config_types import ConfigField from src.plugin_system.base.config_types import ConfigField
from src.common.logger import get_logger from src.common.logger import get_logger
# 导入配置API可选的简便方法 # 导入配置API可选的简便方法
from src.plugin_system.apis import person_api, generator_api from src.plugin_system.apis import person_api, generator_api
@@ -140,7 +141,7 @@ class MuteAction(BaseAction):
# 获取用户ID # 获取用户ID
person_id = person_api.get_person_id_by_name(target) person_id = person_api.get_person_id_by_name(target)
user_id = await person_api.get_person_value(person_id,"user_id") user_id = await person_api.get_person_value(person_id, "user_id")
if not user_id: if not user_id:
error_msg = f"未找到用户 {target} 的ID" error_msg = f"未找到用户 {target} 的ID"
await self.send_text(f"找不到 {target} 这个人呢~") await self.send_text(f"找不到 {target} 这个人呢~")
@@ -154,12 +155,12 @@ class MuteAction(BaseAction):
# 获取模板化消息 # 获取模板化消息
message = self._get_template_message(target, time_str, reason) message = self._get_template_message(target, time_str, reason)
result_status,result_message = await generator_api.rewrite_reply( result_status, result_message = await generator_api.rewrite_reply(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
reply_data={ reply_data={
"raw_reply": message, "raw_reply": message,
"reason": reason, "reason": reason,
} },
) )
if result_status: if result_status:
@@ -169,9 +170,7 @@ class MuteAction(BaseAction):
# 发送群聊禁言命令 # 发送群聊禁言命令
success = await self.send_command( success = await self.send_command(
command_name="GROUP_BAN", command_name="GROUP_BAN", args={"qq_id": str(user_id), "duration": str(duration_int)}, storage_message=False
args={"qq_id": str(user_id), "duration": str(duration_int)},
storage_message=False
) )
if success: if success:
@@ -192,9 +191,7 @@ class MuteAction(BaseAction):
def _get_template_message(self, target: str, duration_str: str, reason: str) -> str: def _get_template_message(self, target: str, duration_str: str, reason: str) -> str:
"""获取模板化的禁言消息""" """获取模板化的禁言消息"""
templates = self.get_config( templates = self.get_config("mute.templates")
"mute.templates"
)
template = random.choice(templates) template = random.choice(templates)
return template.format(target=target, duration=duration_str, reason=reason) return template.format(target=target, duration=duration_str, reason=reason)