refc:重构插件api,补全文档,合并expressor和replyer,分离reply和sender,新log浏览器
This commit is contained in:
@@ -1,534 +0,0 @@
|
||||
import traceback
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import get_expression_learner
|
||||
from src.chat.message_receive.message import MessageRecv, MessageThinking, MessageSending
|
||||
from src.chat.message_receive.message import Seg # Local import needed after move
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
import random
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
你可以参考你的以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
|
||||
{style_habbits}
|
||||
|
||||
你现在正在群里聊天,以下是群里正在进行的聊天内容:
|
||||
{chat_info}
|
||||
|
||||
以上是聊天内容,你需要了解聊天记录中的内容
|
||||
|
||||
{chat_target}
|
||||
你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
|
||||
请你根据情景使用以下句法:
|
||||
{grammar_habbits}
|
||||
{config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
不要浮夸,不要夸张修辞,平淡且不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。
|
||||
现在,你说:
|
||||
""",
|
||||
"default_expressor_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
|
||||
{style_habbits}
|
||||
|
||||
你现在正在群里聊天,以下是群里正在进行的聊天内容:
|
||||
{chat_info}
|
||||
|
||||
以上是聊天内容,你需要了解聊天记录中的内容
|
||||
|
||||
{chat_target}
|
||||
你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
|
||||
请你根据情景使用以下句法:
|
||||
{grammar_habbits}
|
||||
{config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
不要浮夸,不要夸张修辞,平淡且不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。
|
||||
现在,你说:
|
||||
""",
|
||||
"default_expressor_private_prompt", # New template for private FOCUSED chat
|
||||
)
|
||||
|
||||
|
||||
class DefaultExpressor:
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
self.log_prefix = "expressor"
|
||||
# TODO: API-Adapter修改标记
|
||||
self.express_model = LLMRequest(
|
||||
model=global_config.model.replyer_1,
|
||||
request_type="focus.expressor",
|
||||
)
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
|
||||
self.chat_id = chat_stream.stream_id
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
|
||||
"""创建思考消息 (尝试锚定到 anchor_message)"""
|
||||
if not anchor_message or not anchor_message.chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法创建思考消息,缺少有效的锚点消息或聊天流。")
|
||||
return None
|
||||
|
||||
chat = anchor_message.chat_stream
|
||||
messageinfo = anchor_message.message_info
|
||||
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=messageinfo.platform,
|
||||
)
|
||||
|
||||
thinking_message = MessageThinking(
|
||||
message_id=thinking_id,
|
||||
chat_stream=chat,
|
||||
bot_user_info=bot_user_info,
|
||||
reply=anchor_message, # 回复的是锚点消息
|
||||
thinking_start_time=thinking_time_point,
|
||||
)
|
||||
# logger.debug(f"创建思考消息thinking_message:{thinking_message}")
|
||||
|
||||
await self.heart_fc_sender.register_thinking(thinking_message)
|
||||
return None
|
||||
|
||||
async def deal_reply(
|
||||
self,
|
||||
cycle_timers: dict,
|
||||
action_data: Dict[str, Any],
|
||||
reasoning: str,
|
||||
anchor_message: MessageRecv,
|
||||
thinking_id: str,
|
||||
) -> tuple[bool, Optional[List[Tuple[str, str]]]]:
|
||||
# 创建思考消息
|
||||
await self._create_thinking_message(anchor_message, thinking_id)
|
||||
|
||||
reply = [] # 初始化 reply,防止未定义
|
||||
try:
|
||||
has_sent_something = False
|
||||
|
||||
# 处理文本部分
|
||||
text_part = action_data.get("text", [])
|
||||
if text_part:
|
||||
with Timer("生成回复", cycle_timers):
|
||||
# 可以保留原有的文本处理逻辑或进行适当调整
|
||||
reply = await self.express(
|
||||
in_mind_reply=text_part,
|
||||
anchor_message=anchor_message,
|
||||
thinking_id=thinking_id,
|
||||
reason=reasoning,
|
||||
action_data=action_data,
|
||||
)
|
||||
|
||||
with Timer("选择表情", cycle_timers):
|
||||
emoji_keyword = action_data.get("emojis", [])
|
||||
emoji_base64 = await self._choose_emoji(emoji_keyword)
|
||||
if emoji_base64:
|
||||
reply.append(("emoji", emoji_base64))
|
||||
|
||||
if reply:
|
||||
with Timer("发送消息", cycle_timers):
|
||||
sent_msg_list = await self.send_response_messages(
|
||||
anchor_message=anchor_message,
|
||||
thinking_id=thinking_id,
|
||||
response_set=reply,
|
||||
)
|
||||
has_sent_something = True
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 文本回复生成失败")
|
||||
|
||||
if not has_sent_something:
|
||||
logger.warning(f"{self.log_prefix} 回复动作未包含任何有效内容")
|
||||
|
||||
return has_sent_something, sent_msg_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回复失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
# --- 回复器 (Replier) 的定义 --- #
|
||||
|
||||
async def express(
|
||||
self,
|
||||
in_mind_reply: str,
|
||||
reason: str,
|
||||
anchor_message: MessageRecv,
|
||||
thinking_id: str,
|
||||
action_data: Dict[str, Any],
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
回复器 (Replier): 核心逻辑,负责生成回复文本。
|
||||
(已整合原 HeartFCGenerator 的功能)
|
||||
"""
|
||||
try:
|
||||
# --- Determine sender_name for private chat ---
|
||||
sender_name_for_prompt = "某人" # Default for group or if info unavailable
|
||||
if not self.is_group_chat and self.chat_target_info:
|
||||
# Prioritize person_name, then nickname
|
||||
sender_name_for_prompt = (
|
||||
self.chat_target_info.get("person_name")
|
||||
or self.chat_target_info.get("user_nickname")
|
||||
or sender_name_for_prompt
|
||||
)
|
||||
# --- End determining sender_name ---
|
||||
|
||||
target_message = action_data.get("target", "")
|
||||
|
||||
# 3. 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_focus(
|
||||
chat_stream=self.chat_stream, # Pass the stream object
|
||||
in_mind_reply=in_mind_reply,
|
||||
reason=reason,
|
||||
sender_name=sender_name_for_prompt, # Pass determined name
|
||||
target_message=target_message,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
)
|
||||
|
||||
# 4. 调用 LLM 生成回复
|
||||
content = None
|
||||
reasoning_content = None
|
||||
model_name = "unknown_model"
|
||||
if not prompt:
|
||||
logger.error(f"{self.log_prefix}[Replier-{thinking_id}] Prompt 构建失败,无法生成回复。")
|
||||
return None
|
||||
|
||||
try:
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# TODO: API-Adapter修改标记
|
||||
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
|
||||
content, (reasoning_content, model_name) = await self.express_model.generate_response_async(prompt)
|
||||
|
||||
logger.info(f"想要表达:{in_mind_reply}||理由:{reason}")
|
||||
logger.info(f"最终回复: {content}\n")
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
return None # LLM 调用失败则无法生成回复
|
||||
|
||||
processed_response = process_llm_response(content)
|
||||
|
||||
# 5. 处理 LLM 响应
|
||||
if not content:
|
||||
logger.warning(f"{self.log_prefix}LLM 生成了空内容。")
|
||||
return None
|
||||
if not processed_response:
|
||||
logger.warning(f"{self.log_prefix}处理后的回复为空。")
|
||||
return None
|
||||
|
||||
reply_set = []
|
||||
for str in processed_response:
|
||||
reply_seg = ("text", str)
|
||||
reply_set.append(reply_seg)
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def build_prompt_focus(
|
||||
self,
|
||||
reason,
|
||||
chat_stream,
|
||||
sender_name,
|
||||
in_mind_reply,
|
||||
target_message,
|
||||
config_expression_style,
|
||||
) -> str:
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.focus_chat.observation_context_size,
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
)
|
||||
|
||||
expression_learner = get_expression_learner()
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
personality_expressions,
|
||||
) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
|
||||
|
||||
style_habbits = []
|
||||
grammar_habbits = []
|
||||
# 1. learnt_expressions加权随机选3条
|
||||
if learnt_style_expressions:
|
||||
weights = [expr["count"] for expr in learnt_style_expressions]
|
||||
selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3)
|
||||
for expr in selected_learnt:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
# 2. learnt_grammar_expressions加权随机选3条
|
||||
if learnt_grammar_expressions:
|
||||
weights = [expr["count"] for expr in learnt_grammar_expressions]
|
||||
selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3)
|
||||
for expr in selected_learnt:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
# 3. personality_expressions随机选1条
|
||||
if personality_expressions:
|
||||
expr = random.choice(personality_expressions)
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
|
||||
style_habbits_str = "\n".join(style_habbits)
|
||||
grammar_habbits_str = "\n".join(grammar_habbits)
|
||||
|
||||
logger.debug("开始构建 focus prompt")
|
||||
|
||||
# --- Choose template based on chat type ---
|
||||
if is_group_chat:
|
||||
template_name = "default_expressor_prompt"
|
||||
# Group specific formatting variables (already fetched or default)
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
style_habbits=style_habbits_str,
|
||||
grammar_habbits=grammar_habbits_str,
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
bot_name=global_config.bot.nickname,
|
||||
prompt_personality="",
|
||||
reason=reason,
|
||||
in_mind_reply=in_mind_reply,
|
||||
target_message=target_message,
|
||||
config_expression_style=config_expression_style,
|
||||
)
|
||||
else: # Private chat
|
||||
template_name = "default_expressor_private_prompt"
|
||||
chat_target_1 = "你正在和人私聊"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
style_habbits=style_habbits_str,
|
||||
grammar_habbits=grammar_habbits_str,
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
bot_name=global_config.bot.nickname,
|
||||
prompt_personality="",
|
||||
reason=reason,
|
||||
in_mind_reply=in_mind_reply,
|
||||
target_message=target_message,
|
||||
config_expression_style=config_expression_style,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
# --- 发送器 (Sender) --- #
|
||||
|
||||
async def send_response_messages(
|
||||
self,
|
||||
anchor_message: Optional[MessageRecv],
|
||||
response_set: List[Tuple[str, str]],
|
||||
thinking_id: str = "",
|
||||
display_message: str = "",
|
||||
) -> Optional[MessageSending]:
|
||||
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
|
||||
chat = self.chat_stream
|
||||
chat_id = self.chat_id
|
||||
if chat is None:
|
||||
logger.error(f"{self.log_prefix} 无法发送回复,chat_stream 为空。")
|
||||
return None
|
||||
if not anchor_message:
|
||||
logger.error(f"{self.log_prefix} 无法发送回复,anchor_message 为空。")
|
||||
return None
|
||||
|
||||
stream_name = get_chat_manager().get_stream_name(chat_id) or chat_id # 获取流名称用于日志
|
||||
|
||||
# 检查思考过程是否仍在进行,并获取开始时间
|
||||
if thinking_id:
|
||||
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
|
||||
else:
|
||||
thinking_id = "ds" + str(round(time.time(), 2))
|
||||
thinking_start_time = time.time()
|
||||
|
||||
if thinking_start_time is None:
|
||||
logger.error(f"[{stream_name}]expressor思考过程未找到或已结束,无法发送回复。")
|
||||
return None
|
||||
|
||||
mark_head = False
|
||||
# first_bot_msg: Optional[MessageSending] = None
|
||||
reply_message_ids = [] # 记录实际发送的消息ID
|
||||
|
||||
sent_msg_list = []
|
||||
|
||||
for i, msg_text in enumerate(response_set):
|
||||
# 为每个消息片段生成唯一ID
|
||||
type = msg_text[0]
|
||||
data = msg_text[1]
|
||||
|
||||
if global_config.experimental.debug_show_chat_mode and type == "text":
|
||||
data += "ᶠ"
|
||||
|
||||
part_message_id = f"{thinking_id}_{i}"
|
||||
message_segment = Seg(type=type, data=data)
|
||||
|
||||
if type == "emoji":
|
||||
is_emoji = True
|
||||
else:
|
||||
is_emoji = False
|
||||
reply_to = not mark_head
|
||||
|
||||
bot_message = await self._build_single_sending_message(
|
||||
anchor_message=anchor_message,
|
||||
message_id=part_message_id,
|
||||
message_segment=message_segment,
|
||||
display_message=display_message,
|
||||
reply_to=reply_to,
|
||||
is_emoji=is_emoji,
|
||||
thinking_id=thinking_id,
|
||||
thinking_start_time=thinking_start_time,
|
||||
)
|
||||
|
||||
try:
|
||||
if not mark_head:
|
||||
mark_head = True
|
||||
# first_bot_msg = bot_message # 保存第一个成功发送的消息对象
|
||||
typing = False
|
||||
else:
|
||||
typing = True
|
||||
|
||||
if type == "emoji":
|
||||
typing = False
|
||||
|
||||
if anchor_message.raw_message:
|
||||
set_reply = True
|
||||
else:
|
||||
set_reply = False
|
||||
sent_msg = await self.heart_fc_sender.send_message(
|
||||
bot_message, has_thinking=True, typing=typing, set_reply=set_reply
|
||||
)
|
||||
|
||||
reply_message_ids.append(part_message_id) # 记录我们生成的ID
|
||||
|
||||
sent_msg_list.append((type, sent_msg))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
|
||||
traceback.print_exc()
|
||||
# 这里可以选择是继续发送下一个片段还是中止
|
||||
|
||||
# 在尝试发送完所有片段后,完成原始的 thinking_id 状态
|
||||
try:
|
||||
await self.heart_fc_sender.complete_thinking(chat_id, thinking_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}完成思考状态 {thinking_id} 时出错: {e}")
|
||||
|
||||
return sent_msg_list
|
||||
|
||||
async def _choose_emoji(self, send_emoji: str):
|
||||
"""
|
||||
选择表情,根据send_emoji文本选择表情,返回表情base64
|
||||
"""
|
||||
emoji_base64 = ""
|
||||
emoji_raw = await get_emoji_manager().get_emoji_for_text(send_emoji)
|
||||
if emoji_raw:
|
||||
emoji_path, _description, _emotion = emoji_raw
|
||||
emoji_base64 = image_path_to_base64(emoji_path)
|
||||
return emoji_base64
|
||||
|
||||
async def _build_single_sending_message(
|
||||
self,
|
||||
anchor_message: MessageRecv,
|
||||
message_id: str,
|
||||
message_segment: Seg,
|
||||
reply_to: bool,
|
||||
is_emoji: bool,
|
||||
thinking_id: str,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.chat_stream.platform,
|
||||
)
|
||||
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=anchor_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=anchor_message, # 回复原始锚点
|
||||
is_head=reply_to,
|
||||
is_emoji=is_emoji,
|
||||
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
return bot_message
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
加权且不放回地随机抽取k个元素。
|
||||
|
||||
参数:
|
||||
items: 待抽取的元素列表
|
||||
weights: 每个元素对应的权重(与items等长,且为正数)
|
||||
k: 需要抽取的元素个数
|
||||
返回:
|
||||
selected: 按权重加权且不重复抽取的k个元素组成的列表
|
||||
|
||||
如果 items 中的元素不足 k 个,就只会返回所有可用的元素
|
||||
|
||||
实现思路:
|
||||
每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
|
||||
这样保证了:
|
||||
1. count越大被选中概率越高
|
||||
2. 不会重复选中同一个元素
|
||||
"""
|
||||
selected = []
|
||||
pool = list(zip(items, weights))
|
||||
for _ in range(min(k, len(pool))):
|
||||
total = sum(w for _, w in pool)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for idx, (item, weight) in enumerate(pool):
|
||||
upto += weight
|
||||
if upto >= r:
|
||||
selected.append(item)
|
||||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
init_prompt()
|
||||
@@ -21,8 +21,6 @@ from src.chat.heart_flow.observation.chatting_observation import ChattingObserva
|
||||
from src.chat.heart_flow.observation.structure_observation import StructureObservation
|
||||
from src.chat.heart_flow.observation.actions_observation import ActionObservation
|
||||
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.focus_chat.memory_activator import MemoryActivator
|
||||
from src.chat.focus_chat.info_processors.base_processor import BaseProcessor
|
||||
from src.chat.focus_chat.info_processors.self_processor import SelfProcessor
|
||||
@@ -125,9 +123,6 @@ class HeartFChatting:
|
||||
self.processors: List[BaseProcessor] = []
|
||||
self._register_default_processors()
|
||||
|
||||
self.expressor = DefaultExpressor(chat_stream=self.chat_stream)
|
||||
self.replyer = DefaultReplyer(chat_stream=self.chat_stream)
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = PlannerFactory.create_planner(
|
||||
log_prefix=self.log_prefix, action_manager=self.action_manager
|
||||
@@ -543,6 +538,7 @@ class HeartFChatting:
|
||||
|
||||
async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> dict:
|
||||
try:
|
||||
loop_start_time = time.time()
|
||||
with Timer("观察", cycle_timers):
|
||||
# 执行所有观察器的观察
|
||||
for observation in self.observations:
|
||||
@@ -583,7 +579,7 @@ class HeartFChatting:
|
||||
}
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
plan_result = await self.action_planner.plan(all_plan_info, running_memorys)
|
||||
plan_result = await self.action_planner.plan(all_plan_info, running_memorys, loop_start_time)
|
||||
|
||||
loop_plan_info = {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
@@ -607,7 +603,7 @@ class HeartFChatting:
|
||||
logger.debug(f"{self.log_prefix} 麦麦想要:'{action_str}'")
|
||||
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_type, reasoning, action_data, cycle_timers, thinking_id, self.observations
|
||||
action_type, reasoning, action_data, cycle_timers, thinking_id
|
||||
)
|
||||
|
||||
loop_action_info = {
|
||||
@@ -646,7 +642,6 @@ class HeartFChatting:
|
||||
action_data: dict,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
) -> tuple[bool, str, str]:
|
||||
"""
|
||||
处理规划动作,使用动作工厂创建相应的动作处理器
|
||||
@@ -670,9 +665,6 @@ class HeartFChatting:
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
observations=observations,
|
||||
expressor=self.expressor,
|
||||
replyer=self.replyer,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.log_prefix,
|
||||
shutting_down=self._shutting_down,
|
||||
|
||||
@@ -15,7 +15,7 @@ install(extra_lines=3)
|
||||
logger = get_logger("sender")
|
||||
|
||||
|
||||
async def send_message(message: MessageSending) -> str:
|
||||
async def send_message(message: MessageSending) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=40)
|
||||
|
||||
@@ -23,7 +23,7 @@ async def send_message(message: MessageSending) -> str:
|
||||
# 直接调用API发送消息
|
||||
await get_global_api().send_message(message)
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
return message.processed_plain_text
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
|
||||
@@ -73,17 +73,15 @@ class HeartFCSender:
|
||||
thinking_message = self.thinking_messages.get(chat_id, {}).get(message_id)
|
||||
return thinking_message.thinking_start_time if thinking_message else None
|
||||
|
||||
async def send_message(self, message: MessageSending, has_thinking=False, typing=False, set_reply=False):
|
||||
async def send_message(self, message: MessageSending, typing=False, set_reply=False, storage_message=True):
|
||||
"""
|
||||
处理、发送并存储一条消息。
|
||||
|
||||
参数:
|
||||
message: MessageSending 对象,待发送的消息。
|
||||
has_thinking: 是否管理思考状态,表情包无思考状态(如需调用 register_thinking/complete_thinking)。
|
||||
typing: 是否模拟打字等待(根据 has_thinking 控制等待时长)。
|
||||
typing: 是否模拟打字等待。
|
||||
|
||||
用法:
|
||||
- has_thinking=True 时,自动处理思考消息的时间和清理。
|
||||
- typing=True 时,发送前会有打字等待。
|
||||
"""
|
||||
if not message.chat_stream:
|
||||
@@ -98,40 +96,29 @@ class HeartFCSender:
|
||||
|
||||
try:
|
||||
if set_reply:
|
||||
_ = message.update_thinking_time()
|
||||
message.build_reply()
|
||||
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
|
||||
|
||||
# --- 条件应用 set_reply 逻辑 ---
|
||||
if (
|
||||
message.is_head
|
||||
and not message.is_private_message()
|
||||
and message.reply.processed_plain_text != "[System Trigger Context]"
|
||||
):
|
||||
# message.set_reply(message.reply)
|
||||
message.set_reply()
|
||||
logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...")
|
||||
|
||||
# print(f"message.display_message: {message.display_message}")
|
||||
await message.process()
|
||||
# print(f"message.display_message: {message.display_message}")
|
||||
|
||||
if typing:
|
||||
if has_thinking:
|
||||
typing_time = calculate_typing_time(
|
||||
input_string=message.processed_plain_text,
|
||||
thinking_start_time=message.thinking_start_time,
|
||||
is_emoji=message.is_emoji,
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
else:
|
||||
await asyncio.sleep(0.5)
|
||||
typing_time = calculate_typing_time(
|
||||
input_string=message.processed_plain_text,
|
||||
thinking_start_time=message.thinking_start_time,
|
||||
is_emoji=message.is_emoji,
|
||||
)
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
|
||||
sent_msg = await send_message(message)
|
||||
await self.storage.store_message(message, message.chat_stream)
|
||||
if not sent_msg:
|
||||
return False
|
||||
|
||||
if sent_msg:
|
||||
return sent_msg
|
||||
else:
|
||||
return "发送失败"
|
||||
if storage_message:
|
||||
await self.storage.store_message(message, message.chat_stream)
|
||||
|
||||
return sent_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
|
||||
|
||||
@@ -172,6 +172,7 @@ class HeartFCMessageReceiver:
|
||||
return
|
||||
|
||||
# 5. 消息存储
|
||||
print(f"message: {message.message_info.time}")
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
# 6. 兴趣度计算与更新
|
||||
|
||||
@@ -11,7 +11,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from .base_processor import BaseProcessor
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.focus_chat.info.expression_selection_info import ExpressionSelectionInfo
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import get_expression_learner
|
||||
from src.chat.express.exprssion_learner import get_expression_learner
|
||||
from json_repair import repair_json
|
||||
import json
|
||||
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from typing import Dict, List, Optional, Type, Any
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 不再需要导入动作类,因为已经在main.py中导入
|
||||
# import src.chat.actions.default_actions # noqa
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
@@ -15,87 +11,11 @@ logger = get_logger("action_manager")
|
||||
ActionInfo = Dict[str, Any]
|
||||
|
||||
|
||||
class PluginActionWrapper(BaseAction):
|
||||
"""
|
||||
新插件系统Action组件的兼容性包装器
|
||||
|
||||
将新插件系统的Action组件包装为旧系统兼容的BaseAction接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, plugin_action, action_name: str, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str
|
||||
):
|
||||
"""初始化包装器"""
|
||||
# 调用旧系统BaseAction初始化,只传递它能接受的参数
|
||||
super().__init__(
|
||||
action_data=action_data, reasoning=reasoning, cycle_timers=cycle_timers, thinking_id=thinking_id
|
||||
)
|
||||
|
||||
# 存储插件Action实例(它已经包含了所有必要的服务对象)
|
||||
self.plugin_action = plugin_action
|
||||
self.action_name = action_name
|
||||
|
||||
# 从插件Action实例复制属性到包装器
|
||||
self._sync_attributes_from_plugin_action()
|
||||
|
||||
def _sync_attributes_from_plugin_action(self):
|
||||
"""从插件Action实例同步属性到包装器"""
|
||||
# 基本属性
|
||||
self.action_name = getattr(self.plugin_action, "action_name", self.action_name)
|
||||
|
||||
# 设置兼容的默认值
|
||||
self.action_description = f"插件Action: {self.action_name}"
|
||||
self.action_parameters = {}
|
||||
self.action_require = []
|
||||
|
||||
# 激活类型属性(从新插件系统转换)
|
||||
plugin_focus_type = getattr(self.plugin_action, "focus_activation_type", None)
|
||||
plugin_normal_type = getattr(self.plugin_action, "normal_activation_type", None)
|
||||
|
||||
if plugin_focus_type:
|
||||
self.focus_activation_type = (
|
||||
plugin_focus_type.value if hasattr(plugin_focus_type, "value") else str(plugin_focus_type)
|
||||
)
|
||||
if plugin_normal_type:
|
||||
self.normal_activation_type = (
|
||||
plugin_normal_type.value if hasattr(plugin_normal_type, "value") else str(plugin_normal_type)
|
||||
)
|
||||
|
||||
# 其他属性
|
||||
self.random_activation_probability = getattr(self.plugin_action, "random_activation_probability", 0.0)
|
||||
self.llm_judge_prompt = getattr(self.plugin_action, "llm_judge_prompt", "")
|
||||
self.activation_keywords = getattr(self.plugin_action, "activation_keywords", [])
|
||||
self.keyword_case_sensitive = getattr(self.plugin_action, "keyword_case_sensitive", False)
|
||||
|
||||
# 模式和并行设置
|
||||
plugin_mode = getattr(self.plugin_action, "mode_enable", None)
|
||||
if plugin_mode:
|
||||
self.mode_enable = plugin_mode.value if hasattr(plugin_mode, "value") else str(plugin_mode)
|
||||
|
||||
self.parallel_action = getattr(self.plugin_action, "parallel_action", True)
|
||||
self.enable_plugin = True
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""实现抽象方法execute,委托给插件Action的execute方法"""
|
||||
try:
|
||||
# 调用插件Action的execute方法
|
||||
success, response = await self.plugin_action.execute()
|
||||
|
||||
logger.debug(f"插件Action {self.action_name} 执行{'成功' if success else '失败'}: {response}")
|
||||
return success, response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"插件Action {self.action_name} 执行异常: {e}")
|
||||
return False, f"插件Action执行失败: {str(e)}"
|
||||
|
||||
async def handle_action(self) -> tuple[bool, str]:
|
||||
"""兼容旧系统的动作处理接口,委托给execute方法"""
|
||||
return await self.execute()
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""
|
||||
动作管理器,用于管理各种类型的动作
|
||||
|
||||
现在统一使用新插件系统,简化了原有的新旧兼容逻辑。
|
||||
"""
|
||||
|
||||
# 类常量
|
||||
@@ -119,23 +39,20 @@ class ActionManager:
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = self._default_actions.copy()
|
||||
|
||||
# 添加系统核心动作
|
||||
# self._add_system_core_actions()
|
||||
|
||||
def _load_plugin_actions(self) -> None:
|
||||
"""
|
||||
加载所有插件目录中的动作
|
||||
加载所有插件系统中的动作
|
||||
"""
|
||||
try:
|
||||
# 从新插件系统获取Action组件
|
||||
self._load_plugin_system_actions()
|
||||
logger.debug("从新插件系统加载Action组件成功")
|
||||
logger.debug("从插件系统加载Action组件成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件动作失败: {e}")
|
||||
|
||||
def _load_plugin_system_actions(self) -> None:
|
||||
"""从新插件系统的component_registry加载Action组件"""
|
||||
"""从插件系统的component_registry加载Action组件"""
|
||||
try:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
@@ -148,7 +65,7 @@ class ActionManager:
|
||||
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
||||
continue
|
||||
|
||||
# 将新插件系统的ActionInfo转换为旧系统格式
|
||||
# 将插件系统的ActionInfo转换为ActionManager格式
|
||||
converted_action_info = {
|
||||
"description": action_info.description,
|
||||
"parameters": getattr(action_info, "action_parameters", {}),
|
||||
@@ -165,8 +82,7 @@ class ActionManager:
|
||||
# 模式和并行设置
|
||||
"mode_enable": action_info.mode_enable.value,
|
||||
"parallel_action": action_info.parallel_action,
|
||||
# 标记这是来自新插件系统的组件
|
||||
"_plugin_system_component": True,
|
||||
# 插件信息
|
||||
"_plugin_name": getattr(action_info, "plugin_name", ""),
|
||||
}
|
||||
|
||||
@@ -180,7 +96,7 @@ class ActionManager:
|
||||
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
||||
)
|
||||
|
||||
logger.info(f"从新插件系统加载了 {len(action_components)} 个Action组件")
|
||||
logger.info(f"从插件系统加载了 {len(action_components)} 个Action组件")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从插件系统加载Action组件失败: {e}")
|
||||
@@ -195,12 +111,9 @@ class ActionManager:
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
expressor: DefaultExpressor = None,
|
||||
replyer: DefaultReplyer = None,
|
||||
) -> Optional[BaseAction]:
|
||||
"""
|
||||
创建动作处理器实例
|
||||
@@ -211,9 +124,6 @@ class ActionManager:
|
||||
reasoning: 执行理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
observations: 观察列表
|
||||
expressor: 表达器
|
||||
replyer: 回复器
|
||||
chat_stream: 聊天流
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
@@ -221,122 +131,39 @@ class ActionManager:
|
||||
Returns:
|
||||
Optional[BaseAction]: 创建的动作处理器实例,如果动作名称未注册则返回None
|
||||
"""
|
||||
# 检查动作是否在当前使用的动作集中
|
||||
# if action_name not in self._using_actions:
|
||||
# logger.warning(f"当前不可用的动作类型: {action_name}")
|
||||
# return None
|
||||
|
||||
# 检查是否是新插件系统的Action组件
|
||||
action_info = self._registered_actions.get(action_name)
|
||||
if action_info and action_info.get("_plugin_system_component", False):
|
||||
return self._create_plugin_system_action(
|
||||
action_name,
|
||||
action_data,
|
||||
reasoning,
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
observations,
|
||||
chat_stream,
|
||||
log_prefix,
|
||||
shutting_down,
|
||||
expressor,
|
||||
replyer,
|
||||
)
|
||||
|
||||
# 旧系统的动作创建逻辑
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
action_registry = component_registry.get_action_registry()
|
||||
handler_class = action_registry.get(action_name)
|
||||
if not handler_class:
|
||||
logger.warning(f"未注册的动作类型: {action_name}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建动作实例
|
||||
instance = handler_class(
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
observations=observations,
|
||||
expressor=expressor,
|
||||
replyer=replyer,
|
||||
chat_stream=chat_stream,
|
||||
log_prefix=log_prefix,
|
||||
shutting_down=shutting_down,
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建动作处理器实例失败: {e}")
|
||||
return None
|
||||
|
||||
def _create_plugin_system_action(
|
||||
self,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
expressor: DefaultExpressor = None,
|
||||
replyer: DefaultReplyer = None,
|
||||
) -> Optional["PluginActionWrapper"]:
|
||||
"""
|
||||
创建新插件系统的Action组件实例,并包装为兼容旧系统的接口
|
||||
|
||||
Returns:
|
||||
Optional[PluginActionWrapper]: 包装后的Action实例
|
||||
"""
|
||||
try:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
# 获取组件类
|
||||
component_class = component_registry.get_component_class(action_name)
|
||||
# 获取组件类 - 明确指定查询Action类型
|
||||
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
||||
if not component_class:
|
||||
logger.error(f"未找到插件Action组件类: {action_name}")
|
||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||
return None
|
||||
|
||||
# 获取组件信息
|
||||
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
|
||||
if not component_info:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||
return None
|
||||
|
||||
# 获取插件配置
|
||||
component_info = component_registry.get_component_info(action_name)
|
||||
plugin_config = None
|
||||
if component_info and component_info.plugin_name:
|
||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
||||
|
||||
# 创建插件Action实例
|
||||
plugin_action_instance = component_class(
|
||||
# 创建动作实例
|
||||
instance = component_class(
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
chat_stream=chat_stream,
|
||||
expressor=expressor,
|
||||
replyer=replyer,
|
||||
observations=observations,
|
||||
log_prefix=log_prefix,
|
||||
shutting_down=shutting_down,
|
||||
plugin_config=plugin_config,
|
||||
)
|
||||
|
||||
# 创建兼容性包装器
|
||||
wrapper = PluginActionWrapper(
|
||||
plugin_action=plugin_action_instance,
|
||||
action_name=action_name,
|
||||
action_data=action_data,
|
||||
reasoning=reasoning,
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
)
|
||||
|
||||
logger.debug(f"创建插件Action实例成功: {action_name}")
|
||||
return wrapper
|
||||
logger.debug(f"创建Action实例成功: {action_name}")
|
||||
return instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建插件Action实例失败 {action_name}: {e}")
|
||||
logger.error(f"创建Action实例失败 {action_name}: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -366,19 +193,13 @@ class ActionManager:
|
||||
"""
|
||||
filtered_actions = {}
|
||||
|
||||
# print(self._using_actions)
|
||||
|
||||
for action_name, action_info in self._using_actions.items():
|
||||
# print(f"action_info: {action_info}")
|
||||
# print(f"action_name: {action_name}")
|
||||
action_mode = action_info.get("mode_enable", "all")
|
||||
|
||||
# 检查动作是否在当前模式下启用
|
||||
if action_mode == "all" or action_mode == mode:
|
||||
filtered_actions[action_name] = action_info
|
||||
logger.debug(f"动作 {action_name} 在模式 {mode} 下可用 (mode_enable: {action_mode})")
|
||||
# else:
|
||||
# logger.debug(f"动作 {action_name} 在模式 {mode} 下不可用 (mode_enable: {action_mode})")
|
||||
|
||||
logger.debug(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}")
|
||||
return filtered_actions
|
||||
@@ -474,20 +295,6 @@ class ActionManager:
|
||||
def restore_default_actions(self) -> None:
|
||||
"""恢复默认动作集到使用集"""
|
||||
self._using_actions = self._default_actions.copy()
|
||||
# 添加系统核心动作(即使enable_plugin为False的系统动作)
|
||||
# self._add_system_core_actions()
|
||||
|
||||
# def _add_system_core_actions(self) -> None:
|
||||
# """
|
||||
# 添加系统核心动作到使用集
|
||||
# 系统核心动作是那些enable_plugin为False但是系统必需的动作
|
||||
# """
|
||||
# system_core_actions = ["exit_focus_chat"] # 可以根据需要扩展
|
||||
|
||||
# for action_name in system_core_actions:
|
||||
# if action_name in self._registered_actions and action_name not in self._using_actions:
|
||||
# self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
# logger.debug(f"添加系统核心动作到使用集: {action_name}")
|
||||
|
||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||
"""
|
||||
@@ -517,5 +324,4 @@ class ActionManager:
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
action_registry = component_registry.get_action_registry()
|
||||
return action_registry.get(action_name)
|
||||
return component_registry.get_component_class(action_name)
|
||||
|
||||
@@ -12,14 +12,14 @@ class BasePlanner(ABC):
|
||||
self.action_manager = action_manager
|
||||
|
||||
@abstractmethod
|
||||
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]], loop_start_time: float) -> Dict[str, Any]:
|
||||
"""
|
||||
规划下一步行动
|
||||
|
||||
Args:
|
||||
all_plan_info: 所有计划信息
|
||||
running_memorys: 回忆信息
|
||||
|
||||
loop_start_time: 循环开始时间
|
||||
Returns:
|
||||
Dict[str, Any]: 规划结果
|
||||
"""
|
||||
|
||||
@@ -242,6 +242,8 @@ class ActionModifier:
|
||||
|
||||
for action_name, action_info in actions_with_info.items():
|
||||
activation_type = action_info.get("focus_activation_type", "always")
|
||||
|
||||
print(f"action_name: {action_name}, activation_type: {activation_type}")
|
||||
|
||||
# 现在统一是字符串格式的激活类型值
|
||||
if activation_type == "always":
|
||||
|
||||
@@ -32,11 +32,7 @@ def init_prompt():
|
||||
{self_info_block}
|
||||
请记住你的性格,身份和特点。
|
||||
|
||||
{extra_info_block}
|
||||
{memory_str}
|
||||
|
||||
{time_block}
|
||||
|
||||
你是群内的一员,你现在正在参与群内的闲聊,以下是群内的聊天内容:
|
||||
|
||||
{chat_content_block}
|
||||
@@ -86,13 +82,14 @@ class ActionPlanner(BasePlanner):
|
||||
request_type="focus.planner", # 用于动作规划
|
||||
)
|
||||
|
||||
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
async def plan(self, all_plan_info: List[InfoBase], running_memorys: List[Dict[str, Any]], loop_start_time: float) -> Dict[str, Any]:
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
|
||||
参数:
|
||||
all_plan_info: 所有计划信息
|
||||
running_memorys: 回忆信息
|
||||
loop_start_time: 循环开始时间
|
||||
"""
|
||||
|
||||
action = "no_reply" # 默认动作
|
||||
@@ -246,6 +243,8 @@ class ActionPlanner(BasePlanner):
|
||||
if selected_expressions:
|
||||
action_data["selected_expressions"] = selected_expressions
|
||||
logger.debug(f"{self.log_prefix} 传递{len(selected_expressions)}个选中的表达方式到action_data")
|
||||
|
||||
action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
# 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data
|
||||
|
||||
@@ -326,7 +325,7 @@ class ActionPlanner(BasePlanner):
|
||||
|
||||
chat_content_block = ""
|
||||
if observed_messages_str:
|
||||
chat_content_block = f"聊天记录:\n{observed_messages_str}"
|
||||
chat_content_block = f"\n{observed_messages_str}"
|
||||
else:
|
||||
chat_content_block = "你还未开始聊天"
|
||||
|
||||
@@ -387,7 +386,7 @@ class ActionPlanner(BasePlanner):
|
||||
prompt = planner_prompt_template.format(
|
||||
relation_info_block=relation_info_block,
|
||||
self_info_block=self_info_block,
|
||||
memory_str=memory_str,
|
||||
# memory_str=memory_str,
|
||||
time_block=time_block,
|
||||
# bot_name=global_config.bot.nickname,
|
||||
prompt_personality=personality_block,
|
||||
@@ -397,7 +396,7 @@ class ActionPlanner(BasePlanner):
|
||||
cycle_info_block=cycle_info,
|
||||
action_options_text=action_options_block,
|
||||
# action_available_block=action_available_block,
|
||||
extra_info_block=extra_info_block,
|
||||
# extra_info_block=extra_info_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
return prompt
|
||||
|
||||
@@ -8,9 +8,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
@@ -18,6 +16,7 @@ from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.express.exprssion_learner import get_expression_learner
|
||||
import time
|
||||
import random
|
||||
from datetime import datetime
|
||||
@@ -50,7 +49,7 @@ def init_prompt():
|
||||
不要浮夸,不要夸张修辞,只输出一条回复就好。
|
||||
现在,你说:
|
||||
""",
|
||||
"default_replyer_prompt",
|
||||
"default_generator_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
@@ -70,7 +69,51 @@ def init_prompt():
|
||||
不要浮夸,不要夸张修辞,只输出一条回复就好。
|
||||
现在,你说:
|
||||
""",
|
||||
"default_replyer_private_prompt",
|
||||
"default_generator_private_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你可以参考你的以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
|
||||
{style_habbits}
|
||||
|
||||
你现在正在群里聊天,以下是群里正在进行的聊天内容:
|
||||
{chat_info}
|
||||
|
||||
以上是聊天内容,你需要了解聊天记录中的内容
|
||||
|
||||
{chat_target}
|
||||
你的名字是{bot_name},{prompt_personality},在这聊天中,"{sender_name}"说的"{target_message}"引起了你的注意,对这句话,你想表达:{raw_reply},原因是:{reason}。你现在要思考怎么回复
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
|
||||
请你根据情景使用以下句法:
|
||||
{grammar_habbits}
|
||||
{config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
不要浮夸,不要夸张修辞,平淡且不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。
|
||||
现在,你说:
|
||||
""",
|
||||
"default_expressor_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
|
||||
{style_habbits}
|
||||
|
||||
你现在正在群里聊天,以下是群里正在进行的聊天内容:
|
||||
{chat_info}
|
||||
|
||||
以上是聊天内容,你需要了解聊天记录中的内容
|
||||
|
||||
{chat_target}
|
||||
你的名字是{bot_name},{prompt_personality},在这聊天中,"{sender_name}"说的"{target_message}"引起了你的注意,对这句话,你想表达:{raw_reply},原因是:{reason}。你现在要思考怎么回复
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
|
||||
请你根据情景使用以下句法:
|
||||
{grammar_habbits}
|
||||
{config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
不要浮夸,不要夸张修辞,平淡且不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。
|
||||
现在,你说:
|
||||
""",
|
||||
"default_expressor_private_prompt", # New template for private FOCUSED chat
|
||||
)
|
||||
|
||||
|
||||
@@ -84,9 +127,8 @@ class DefaultReplyer:
|
||||
)
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
|
||||
self.chat_id = chat_stream.stream_id
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
|
||||
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
|
||||
"""创建思考消息 (尝试锚定到 anchor_message)"""
|
||||
@@ -114,213 +156,152 @@ class DefaultReplyer:
|
||||
|
||||
await self.heart_fc_sender.register_thinking(thinking_message)
|
||||
return None
|
||||
|
||||
async def deal_reply(
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
cycle_timers: dict,
|
||||
action_data: Dict[str, Any],
|
||||
reasoning: str,
|
||||
anchor_message: MessageRecv,
|
||||
thinking_id: str,
|
||||
) -> tuple[bool, Optional[List[Tuple[str, str]]]]:
|
||||
# 创建思考消息
|
||||
await self._create_thinking_message(anchor_message, thinking_id)
|
||||
|
||||
reply = [] # 初始化 reply,防止未定义
|
||||
try:
|
||||
has_sent_something = False
|
||||
|
||||
# 处理文本部分
|
||||
# text_part = action_data.get("text", [])
|
||||
# if text_part:
|
||||
sent_msg_list = []
|
||||
|
||||
with Timer("生成回复", cycle_timers):
|
||||
# 可以保留原有的文本处理逻辑或进行适当调整
|
||||
reply = await self.reply(
|
||||
# in_mind_reply=text_part,
|
||||
anchor_message=anchor_message,
|
||||
thinking_id=thinking_id,
|
||||
reason=reasoning,
|
||||
action_data=action_data,
|
||||
)
|
||||
|
||||
if reply:
|
||||
with Timer("发送消息", cycle_timers):
|
||||
sent_msg_list = await self.send_response_messages(
|
||||
anchor_message=anchor_message,
|
||||
thinking_id=thinking_id,
|
||||
response_set=reply,
|
||||
)
|
||||
has_sent_something = True
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 文本回复生成失败")
|
||||
|
||||
if not has_sent_something:
|
||||
logger.warning(f"{self.log_prefix} 回复动作未包含任何有效内容")
|
||||
|
||||
return has_sent_something, sent_msg_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回复失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
# --- 回复器 (Replier) 的定义 --- #
|
||||
|
||||
async def deal_emoji(
|
||||
self,
|
||||
anchor_message: MessageRecv,
|
||||
thinking_id: str,
|
||||
action_data: Dict[str, Any],
|
||||
cycle_timers: dict,
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
表情动作处理类
|
||||
"""
|
||||
|
||||
await self._create_thinking_message(anchor_message, thinking_id)
|
||||
|
||||
try:
|
||||
has_sent_something = False
|
||||
sent_msg_list = []
|
||||
reply = []
|
||||
with Timer("选择表情", cycle_timers):
|
||||
emoji_keyword = action_data.get("description", [])
|
||||
emoji_base64, _description, emotion = await self._choose_emoji(emoji_keyword)
|
||||
if emoji_base64:
|
||||
# logger.info(f"选择表情: {_description}")
|
||||
reply.append(("emoji", emoji_base64))
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 没有找到合适表情")
|
||||
|
||||
if reply:
|
||||
with Timer("发送表情", cycle_timers):
|
||||
sent_msg_list = await self.send_response_messages(
|
||||
anchor_message=anchor_message,
|
||||
thinking_id=thinking_id,
|
||||
response_set=reply,
|
||||
)
|
||||
has_sent_something = True
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 表情发送失败")
|
||||
|
||||
if not has_sent_something:
|
||||
logger.warning(f"{self.log_prefix} 表情发送失败")
|
||||
|
||||
return has_sent_something, sent_msg_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回复失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
# in_mind_reply: str,
|
||||
reason: str,
|
||||
anchor_message: MessageRecv,
|
||||
thinking_id: str,
|
||||
action_data: Dict[str, Any],
|
||||
) -> Optional[List[str]]:
|
||||
reply_data: Dict[str, Any],
|
||||
) -> Tuple[bool, Optional[List[str]]]:
|
||||
"""
|
||||
回复器 (Replier): 核心逻辑,负责生成回复文本。
|
||||
(已整合原 HeartFCGenerator 的功能)
|
||||
"""
|
||||
try:
|
||||
# 1. 获取情绪影响因子并调整模型温度
|
||||
# arousal_multiplier = mood_manager.get_arousal_multiplier()
|
||||
# current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
|
||||
# self.express_model.params["temperature"] = current_temp # 动态调整温度
|
||||
|
||||
reply_to = action_data.get("reply_to", "none")
|
||||
|
||||
sender = ""
|
||||
targer = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
targer = parts[1].strip()
|
||||
|
||||
identity = action_data.get("identity", "")
|
||||
extra_info_block = action_data.get("extra_info_block", "")
|
||||
relation_info_block = action_data.get("relation_info_block", "")
|
||||
|
||||
# 3. 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_focus(
|
||||
chat_stream=self.chat_stream, # Pass the stream object
|
||||
# in_mind_reply=in_mind_reply,
|
||||
identity=identity,
|
||||
extra_info_block=extra_info_block,
|
||||
relation_info_block=relation_info_block,
|
||||
reason=reason,
|
||||
sender_name=sender, # Pass determined name
|
||||
target_message=targer,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
action_data=action_data, # 传递action_data
|
||||
prompt = await self.build_prompt_reply_context(
|
||||
reply_data=reply_data, # 传递action_data
|
||||
)
|
||||
|
||||
# 4. 调用 LLM 生成回复
|
||||
content = None
|
||||
reasoning_content = None
|
||||
model_name = "unknown_model"
|
||||
if not prompt:
|
||||
logger.error(f"{self.log_prefix}[Replier-{thinking_id}] Prompt 构建失败,无法生成回复。")
|
||||
return None
|
||||
|
||||
try:
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
logger.info(f"{self.log_prefix}Prompt:\n{prompt}\n")
|
||||
content, (reasoning_content, model_name) = await self.express_model.generate_response_async(prompt)
|
||||
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
logger.info(f"最终回复: {content}")
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
return None # LLM 调用失败则无法生成回复
|
||||
return False, None # LLM 调用失败则无法生成回复
|
||||
|
||||
processed_response = process_llm_response(content)
|
||||
|
||||
# 5. 处理 LLM 响应
|
||||
if not content:
|
||||
logger.warning(f"{self.log_prefix}LLM 生成了空内容。")
|
||||
return None
|
||||
return False, None
|
||||
if not processed_response:
|
||||
logger.warning(f"{self.log_prefix}处理后的回复为空。")
|
||||
return None
|
||||
return False, None
|
||||
|
||||
reply_set = []
|
||||
for str in processed_response:
|
||||
reply_seg = ("text", str)
|
||||
reply_set.append(reply_seg)
|
||||
|
||||
return reply_set
|
||||
return True , reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def build_prompt_focus(
|
||||
return False, None
|
||||
|
||||
async def rewrite_reply_with_context(
|
||||
self,
|
||||
reason,
|
||||
chat_stream,
|
||||
sender_name,
|
||||
# in_mind_reply,
|
||||
extra_info_block,
|
||||
relation_info_block,
|
||||
identity,
|
||||
target_message,
|
||||
config_expression_style,
|
||||
action_data=None,
|
||||
# stuation,
|
||||
reply_data: Dict[str, Any],
|
||||
) -> Tuple[bool, Optional[List[str]]]:
|
||||
"""
|
||||
表达器 (Expressor): 核心逻辑,负责生成回复文本。
|
||||
"""
|
||||
try:
|
||||
|
||||
|
||||
reply_to = reply_data.get("reply_to", "")
|
||||
raw_reply = reply_data.get("raw_reply", "")
|
||||
reason = reply_data.get("reason", "")
|
||||
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_rewrite_context(
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
|
||||
content = None
|
||||
reasoning_content = None
|
||||
model_name = "unknown_model"
|
||||
if not prompt:
|
||||
logger.error(f"{self.log_prefix}Prompt 构建失败,无法生成回复。")
|
||||
return False, None
|
||||
|
||||
try:
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# TODO: API-Adapter修改标记
|
||||
content, (reasoning_content, model_name) = await self.express_model.generate_response_async(prompt)
|
||||
|
||||
logger.info(f"想要表达:{raw_reply}||理由:{reason}")
|
||||
logger.info(f"最终回复: {content}\n")
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
return False, None # LLM 调用失败则无法生成回复
|
||||
|
||||
processed_response = process_llm_response(content)
|
||||
|
||||
# 5. 处理 LLM 响应
|
||||
if not content:
|
||||
logger.warning(f"{self.log_prefix}LLM 生成了空内容。")
|
||||
return False, None
|
||||
if not processed_response:
|
||||
logger.warning(f"{self.log_prefix}处理后的回复为空。")
|
||||
return False, None
|
||||
|
||||
reply_set = []
|
||||
for str in processed_response:
|
||||
reply_seg = ("text", str)
|
||||
reply_set.append(reply_seg)
|
||||
|
||||
return True, reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_data=None,
|
||||
) -> str:
|
||||
chat_stream = self.chat_stream
|
||||
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
identity = reply_data.get("identity", "")
|
||||
extra_info_block = reply_data.get("extra_info_block", "")
|
||||
relation_info_block = reply_data.get("relation_info_block", "")
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
|
||||
sender = ""
|
||||
target = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
@@ -341,7 +322,7 @@ class DefaultReplyer:
|
||||
grammar_habbits = []
|
||||
|
||||
# 使用从处理器传来的选中表达方式
|
||||
selected_expressions = action_data.get("selected_expressions", []) if action_data else []
|
||||
selected_expressions = reply_data.get("selected_expressions", []) if reply_data else []
|
||||
|
||||
if selected_expressions:
|
||||
logger.info(f"{self.log_prefix} 使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
@@ -371,7 +352,7 @@ class DefaultReplyer:
|
||||
try:
|
||||
# 处理关键词规则
|
||||
for rule in global_config.keyword_reaction.keyword_rules:
|
||||
if any(keyword in target_message for keyword in rule.keywords):
|
||||
if any(keyword in target for keyword in rule.keywords):
|
||||
logger.info(f"检测到关键词规则:{rule.keywords},触发反应:{rule.reaction}")
|
||||
keywords_reaction_prompt += f"{rule.reaction},"
|
||||
|
||||
@@ -380,7 +361,7 @@ class DefaultReplyer:
|
||||
for pattern_str in rule.regex:
|
||||
try:
|
||||
pattern = re.compile(pattern_str)
|
||||
if result := pattern.search(target_message):
|
||||
if result := pattern.search(target):
|
||||
reaction = rule.reaction
|
||||
for name, content in result.groupdict().items():
|
||||
reaction = reaction.replace(f"[{name}]", content)
|
||||
@@ -397,18 +378,18 @@ class DefaultReplyer:
|
||||
|
||||
# logger.debug("开始构建 focus prompt")
|
||||
|
||||
if sender_name:
|
||||
if sender:
|
||||
reply_target_block = (
|
||||
f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
f"现在{sender}说的:{target}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target_message:
|
||||
reply_target_block = f"现在{target_message}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
elif target:
|
||||
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
|
||||
# --- Choose template based on chat type ---
|
||||
if is_group_chat:
|
||||
template_name = "default_replyer_prompt"
|
||||
template_name = "default_generator_prompt"
|
||||
# Group specific formatting variables (already fetched or default)
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
@@ -422,18 +403,14 @@ class DefaultReplyer:
|
||||
relation_info_block=relation_info_block,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
# bot_name=global_config.bot.nickname,
|
||||
# prompt_personality="",
|
||||
# reason=reason,
|
||||
# in_mind_reply=in_mind_reply,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
identity=identity,
|
||||
target_message=target_message,
|
||||
sender_name=sender_name,
|
||||
config_expression_style=config_expression_style,
|
||||
target_message=target,
|
||||
sender_name=sender,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
)
|
||||
else: # Private chat
|
||||
template_name = "default_replyer_private_prompt"
|
||||
template_name = "default_generator_private_prompt"
|
||||
chat_target_1 = "你正在和人私聊"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
@@ -444,20 +421,125 @@ class DefaultReplyer:
|
||||
relation_info_block=relation_info_block,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
# bot_name=global_config.bot.nickname,
|
||||
# prompt_personality="",
|
||||
# reason=reason,
|
||||
# in_mind_reply=in_mind_reply,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
identity=identity,
|
||||
target_message=target_message,
|
||||
sender_name=sender_name,
|
||||
config_expression_style=config_expression_style,
|
||||
target_message=target,
|
||||
sender_name=sender,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
# --- 发送器 (Sender) --- #
|
||||
async def build_prompt_rewrite_context(
|
||||
self,
|
||||
reason,
|
||||
raw_reply,
|
||||
reply_to,
|
||||
) -> str:
|
||||
|
||||
|
||||
|
||||
sender = ""
|
||||
target = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
|
||||
chat_stream = self.chat_stream
|
||||
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.focus_chat.observation_context_size,
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
)
|
||||
|
||||
expression_learner = get_expression_learner()
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
personality_expressions,
|
||||
) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
|
||||
|
||||
style_habbits = []
|
||||
grammar_habbits = []
|
||||
# 1. learnt_expressions加权随机选3条
|
||||
if learnt_style_expressions:
|
||||
weights = [expr["count"] for expr in learnt_style_expressions]
|
||||
selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3)
|
||||
for expr in selected_learnt:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
# 2. learnt_grammar_expressions加权随机选3条
|
||||
if learnt_grammar_expressions:
|
||||
weights = [expr["count"] for expr in learnt_grammar_expressions]
|
||||
selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3)
|
||||
for expr in selected_learnt:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
# 3. personality_expressions随机选1条
|
||||
if personality_expressions:
|
||||
expr = random.choice(personality_expressions)
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
|
||||
style_habbits_str = "\n".join(style_habbits)
|
||||
grammar_habbits_str = "\n".join(grammar_habbits)
|
||||
|
||||
logger.debug("开始构建 focus prompt")
|
||||
|
||||
# --- Choose template based on chat type ---
|
||||
if is_group_chat:
|
||||
template_name = "default_expressor_prompt"
|
||||
# Group specific formatting variables (already fetched or default)
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
style_habbits=style_habbits_str,
|
||||
grammar_habbits=grammar_habbits_str,
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
bot_name=global_config.bot.nickname,
|
||||
prompt_personality="",
|
||||
reason=reason,
|
||||
raw_reply=raw_reply,
|
||||
sender_name=sender,
|
||||
target_message=target,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
)
|
||||
else: # Private chat
|
||||
template_name = "default_expressor_private_prompt"
|
||||
chat_target_1 = "你正在和人私聊"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
style_habbits=style_habbits_str,
|
||||
grammar_habbits=grammar_habbits_str,
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
bot_name=global_config.bot.nickname,
|
||||
prompt_personality="",
|
||||
reason=reason,
|
||||
raw_reply=raw_reply,
|
||||
sender_name=sender,
|
||||
target_message=target,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
async def send_response_messages(
|
||||
self,
|
||||
@@ -468,7 +550,7 @@ class DefaultReplyer:
|
||||
) -> Optional[MessageSending]:
|
||||
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
|
||||
chat = self.chat_stream
|
||||
chat_id = self.chat_id
|
||||
chat_id = self.chat_stream.stream_id
|
||||
if chat is None:
|
||||
logger.error(f"{self.log_prefix} 无法发送回复,chat_stream 为空。")
|
||||
return None
|
||||
@@ -514,7 +596,7 @@ class DefaultReplyer:
|
||||
is_emoji = False
|
||||
reply_to = not mark_head
|
||||
|
||||
bot_message = await self._build_single_sending_message(
|
||||
bot_message: MessageSending = await self._build_single_sending_message(
|
||||
anchor_message=anchor_message,
|
||||
message_id=part_message_id,
|
||||
message_segment=message_segment,
|
||||
@@ -526,22 +608,22 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
try:
|
||||
if (bot_message.is_private_message() or
|
||||
bot_message.reply.processed_plain_text != "[System Trigger Context]" or
|
||||
mark_head):
|
||||
set_reply = False
|
||||
else:
|
||||
set_reply = True
|
||||
|
||||
if not mark_head:
|
||||
mark_head = True
|
||||
# first_bot_msg = bot_message # 保存第一个成功发送的消息对象
|
||||
typing = False
|
||||
else:
|
||||
typing = True
|
||||
|
||||
if type == "emoji":
|
||||
typing = False
|
||||
|
||||
if anchor_message.raw_message:
|
||||
set_reply = True
|
||||
else:
|
||||
set_reply = False
|
||||
|
||||
|
||||
sent_msg = await self.heart_fc_sender.send_message(
|
||||
bot_message, has_thinking=True, typing=typing, set_reply=set_reply
|
||||
bot_message, typing=typing, set_reply=set_reply
|
||||
)
|
||||
|
||||
reply_message_ids.append(part_message_id) # 记录我们生成的ID
|
||||
@@ -562,30 +644,15 @@ class DefaultReplyer:
|
||||
|
||||
return sent_msg_list
|
||||
|
||||
async def _choose_emoji(self, send_emoji: str):
|
||||
"""
|
||||
选择表情,根据send_emoji文本选择表情,返回表情base64
|
||||
"""
|
||||
emoji_base64 = ""
|
||||
description = ""
|
||||
emoji_raw = await get_emoji_manager().get_emoji_for_text(send_emoji)
|
||||
if emoji_raw:
|
||||
emoji_path, description, _emotion = emoji_raw
|
||||
emoji_base64 = image_path_to_base64(emoji_path)
|
||||
return emoji_base64, description, _emotion
|
||||
else:
|
||||
return None, None, None
|
||||
|
||||
async def _build_single_sending_message(
|
||||
self,
|
||||
anchor_message: MessageRecv,
|
||||
message_id: str,
|
||||
message_segment: Seg,
|
||||
reply_to: bool,
|
||||
is_emoji: bool,
|
||||
thinking_id: str,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: MessageRecv = None
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
@@ -596,12 +663,16 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
# await anchor_message.process()
|
||||
if anchor_message:
|
||||
sender_info = anchor_message.message_info.user_info
|
||||
else:
|
||||
sender_info = None
|
||||
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=anchor_message.message_info.user_info,
|
||||
sender_info=sender_info,
|
||||
message_segment=message_segment,
|
||||
reply=anchor_message, # 回复原始锚点
|
||||
is_head=reply_to,
|
||||
@@ -14,7 +14,8 @@ from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
logger = get_logger("observation")
|
||||
|
||||
# 定义提示模板
|
||||
@@ -70,6 +71,8 @@ class ChattingObservation(Observation):
|
||||
self.oldest_messages = []
|
||||
self.oldest_messages_str = ""
|
||||
|
||||
self.last_observe_time = datetime.now().timestamp() -1
|
||||
print(f"last_observe_time: {self.last_observe_time}")
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
|
||||
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
|
||||
self.talking_message = initial_messages
|
||||
@@ -92,39 +95,28 @@ class ChattingObservation(Observation):
|
||||
def get_observe_info(self, ids=None):
|
||||
return self.talking_message_str
|
||||
|
||||
def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
|
||||
def get_recv_message_by_text(self, sender: str, text: str) -> Optional[MessageRecv]:
|
||||
"""
|
||||
根据回复的纯文本
|
||||
1. 在talking_message中查找最新的,最匹配的消息
|
||||
2. 如果找到,则返回消息
|
||||
"""
|
||||
msg_list = []
|
||||
find_msg = None
|
||||
reverse_talking_message = list(reversed(self.talking_message))
|
||||
|
||||
for message in reverse_talking_message:
|
||||
if message["processed_plain_text"] == text:
|
||||
find_msg = message
|
||||
break
|
||||
else:
|
||||
raw_message = message.get("raw_message")
|
||||
if raw_message:
|
||||
similarity = difflib.SequenceMatcher(None, text, raw_message).ratio()
|
||||
else:
|
||||
similarity = difflib.SequenceMatcher(None, text, message.get("processed_plain_text", "")).ratio()
|
||||
msg_list.append({"message": message, "similarity": similarity})
|
||||
user_id = message["user_id"]
|
||||
platform = message["platform"]
|
||||
person_id = get_person_info_manager().get_person_id(platform, user_id)
|
||||
person_name = get_person_info_manager().get_value(person_id, "person_name")
|
||||
if person_name == sender:
|
||||
similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio()
|
||||
if similarity >= 0.9:
|
||||
find_msg = message
|
||||
break
|
||||
|
||||
if not find_msg:
|
||||
if msg_list:
|
||||
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
if msg_list[0]["similarity"] >= 0.9:
|
||||
find_msg = msg_list[0]["message"]
|
||||
else:
|
||||
logger.debug("没有找到锚定消息,相似度低")
|
||||
return None
|
||||
else:
|
||||
logger.debug("没有找到锚定消息,没有消息捕获")
|
||||
return None
|
||||
return None
|
||||
|
||||
user_info = {
|
||||
"platform": find_msg.get("user_platform", ""),
|
||||
@@ -167,6 +159,10 @@ class ChattingObservation(Observation):
|
||||
"processed_plain_text": find_msg.get("processed_plain_text"),
|
||||
}
|
||||
find_rec_msg = MessageRecv(message_dict)
|
||||
|
||||
find_rec_msg.update_chat_stream(get_chat_manager().get_or_create_stream(self.chat_id))
|
||||
|
||||
|
||||
return find_rec_msg
|
||||
|
||||
async def observe(self):
|
||||
@@ -179,6 +175,8 @@ class ChattingObservation(Observation):
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
print(f"new_messages_list: {new_messages_list}")
|
||||
|
||||
last_obs_time_mark = self.last_observe_time
|
||||
if new_messages_list:
|
||||
self.last_observe_time = new_messages_list[-1]["time"]
|
||||
|
||||
@@ -171,6 +171,15 @@ class ChatManager:
|
||||
# 使用MD5生成唯一ID
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
def get_stream_id(self, platform: str, chat_id: str, is_group: bool = True) -> str:
|
||||
"""获取聊天流ID"""
|
||||
if is_group:
|
||||
components = [platform, str(chat_id)]
|
||||
else:
|
||||
components = [platform, str(chat_id), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
|
||||
@@ -275,7 +275,7 @@ class MessageSending(MessageProcessBase):
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
|
||||
sender_info: UserInfo | None, # 用来记录发送者信息
|
||||
message_segment: Seg,
|
||||
display_message: str = "",
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
@@ -304,20 +304,17 @@ class MessageSending(MessageProcessBase):
|
||||
# 用于显示发送内容与显示不一致的情况
|
||||
self.display_message = display_message
|
||||
|
||||
def set_reply(self, reply: Optional["MessageRecv"] = None):
|
||||
def build_reply(self):
|
||||
"""设置回复消息"""
|
||||
if True:
|
||||
if reply:
|
||||
self.reply = reply
|
||||
if self.reply:
|
||||
self.reply_to_message_id = self.reply.message_info.message_id
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=self.reply.message_info.message_id),
|
||||
self.message_segment,
|
||||
],
|
||||
)
|
||||
if self.reply:
|
||||
self.reply_to_message_id = self.reply.message_info.message_id
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=self.reply.message_info.message_id),
|
||||
self.message_segment,
|
||||
],
|
||||
)
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本"""
|
||||
|
||||
@@ -230,7 +230,7 @@ class MessageManager:
|
||||
logger.debug(
|
||||
f"[{message.chat_stream.stream_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}..."
|
||||
)
|
||||
message.set_reply(message.reply)
|
||||
message.build_reply()
|
||||
# --- 结束条件 set_reply ---
|
||||
|
||||
await message.process() # 预处理消息内容
|
||||
|
||||
@@ -22,7 +22,7 @@ from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.normal_chat.normal_chat_planner import NormalChatPlanner
|
||||
from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionModifier
|
||||
from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.focus_chat.replyer.default_generator import DefaultReplyer
|
||||
from src.person_info.person_info import PersonInfoManager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
@@ -1063,9 +1063,6 @@ class NormalChat:
|
||||
reasoning=action_data.get("reasoning", ""),
|
||||
cycle_timers={}, # normal_chat使用空的cycle_timers
|
||||
thinking_id=thinking_id,
|
||||
observations=[], # normal_chat不使用observations
|
||||
expressor=self.expressor, # 使用normal_chat专用的expressor
|
||||
replyer=self.replyer,
|
||||
chat_stream=self.chat_stream,
|
||||
log_prefix=self.stream_name,
|
||||
shutting_down=self._disabled,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import get_expression_learner
|
||||
from src.chat.express.exprssion_learner import get_expression_learner
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.individuality.individuality import get_individuality
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
from maim_message import MessageServer
|
||||
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import get_expression_learner
|
||||
from src.chat.express.exprssion_learner import get_expression_learner
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
|
||||
@@ -8,6 +8,7 @@ MaiBot 插件系统
|
||||
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.base.component_types import (
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
@@ -18,11 +19,11 @@ from src.plugin_system.base.component_types import (
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
)
|
||||
from src.plugin_system.apis.plugin_api import PluginAPI, create_plugin_api, create_command_api
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
__all__ = [
|
||||
@@ -39,14 +40,11 @@ __all__ = [
|
||||
"CommandInfo",
|
||||
"PluginInfo",
|
||||
"PythonDependency",
|
||||
# API接口
|
||||
"PluginAPI",
|
||||
"create_plugin_api",
|
||||
"create_command_api",
|
||||
# 管理器
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
# 装饰器
|
||||
"register_plugin",
|
||||
"ConfigField",
|
||||
]
|
||||
|
||||
@@ -1,37 +1,33 @@
|
||||
"""
|
||||
插件API模块
|
||||
插件系统API模块
|
||||
|
||||
提供插件可以使用的各种API接口
|
||||
提供了插件开发所需的各种API
|
||||
"""
|
||||
|
||||
from src.plugin_system.apis.plugin_api import PluginAPI, create_plugin_api, create_command_api
|
||||
from src.plugin_system.apis.message_api import MessageAPI
|
||||
from src.plugin_system.apis.llm_api import LLMAPI
|
||||
from src.plugin_system.apis.database_api import DatabaseAPI
|
||||
from src.plugin_system.apis.config_api import ConfigAPI
|
||||
from src.plugin_system.apis.utils_api import UtilsAPI
|
||||
from src.plugin_system.apis.stream_api import StreamAPI
|
||||
from src.plugin_system.apis.hearflow_api import HearflowAPI
|
||||
|
||||
# 新增:分类的API聚合
|
||||
from src.plugin_system.apis.action_apis import ActionAPI
|
||||
from src.plugin_system.apis.independent_apis import IndependentAPI, StaticAPI
|
||||
# 导入所有API模块
|
||||
from src.plugin_system.apis import (
|
||||
chat_api,
|
||||
config_api,
|
||||
database_api,
|
||||
emoji_api,
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
send_api,
|
||||
utils_api
|
||||
)
|
||||
|
||||
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
||||
__all__ = [
|
||||
# 原有统一API
|
||||
"PluginAPI",
|
||||
"create_plugin_api",
|
||||
"create_command_api",
|
||||
# 原有单独API
|
||||
"MessageAPI",
|
||||
"LLMAPI",
|
||||
"DatabaseAPI",
|
||||
"ConfigAPI",
|
||||
"UtilsAPI",
|
||||
"StreamAPI",
|
||||
"HearflowAPI",
|
||||
# 新增分类API
|
||||
"ActionAPI", # 需要Action依赖的API
|
||||
"IndependentAPI", # 独立API
|
||||
"StaticAPI", # 静态API
|
||||
"chat_api",
|
||||
"config_api",
|
||||
"database_api",
|
||||
"emoji_api",
|
||||
"generator_api",
|
||||
"llm_api",
|
||||
"message_api",
|
||||
"person_api",
|
||||
"send_api",
|
||||
"utils_api"
|
||||
]
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
"""
|
||||
Action相关API聚合模块
|
||||
|
||||
聚合了需要Action组件依赖的API,这些API需要通过Action初始化时注入的服务对象才能正常工作。
|
||||
包括:MessageAPI、DatabaseAPI等需要chat_stream、expressor等服务的API。
|
||||
"""
|
||||
|
||||
from src.plugin_system.apis.message_api import MessageAPI
|
||||
from src.plugin_system.apis.database_api import DatabaseAPI
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("action_apis")
|
||||
|
||||
|
||||
class ActionAPI(MessageAPI, DatabaseAPI):
|
||||
"""
|
||||
Action相关API聚合类
|
||||
|
||||
聚合了需要Action组件依赖的API功能。这些API需要以下依赖:
|
||||
- _services: 包含chat_stream、expressor、replyer、observations等服务对象
|
||||
- log_prefix: 日志前缀
|
||||
- thinking_id: 思考ID
|
||||
- cycle_timers: 计时器
|
||||
- action_data: Action数据
|
||||
|
||||
使用场景:
|
||||
- 在Action组件中使用,需要发送消息、存储数据等功能
|
||||
- 需要访问聊天上下文和执行环境的操作
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream=None,
|
||||
expressor=None,
|
||||
replyer=None,
|
||||
observations=None,
|
||||
log_prefix: str = "[ActionAPI]",
|
||||
thinking_id: str = "",
|
||||
cycle_timers: dict = None,
|
||||
action_data: dict = None,
|
||||
):
|
||||
"""
|
||||
初始化Action相关API
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
expressor: 表达器对象
|
||||
replyer: 回复器对象
|
||||
observations: 观察列表
|
||||
log_prefix: 日志前缀
|
||||
thinking_id: 思考ID
|
||||
cycle_timers: 计时器字典
|
||||
action_data: Action数据
|
||||
"""
|
||||
# 存储依赖对象
|
||||
self._services = {
|
||||
"chat_stream": chat_stream,
|
||||
"expressor": expressor,
|
||||
"replyer": replyer,
|
||||
"observations": observations or [],
|
||||
}
|
||||
|
||||
self.log_prefix = log_prefix
|
||||
self.thinking_id = thinking_id
|
||||
self.cycle_timers = cycle_timers or {}
|
||||
self.action_data = action_data or {}
|
||||
|
||||
logger.debug(f"{self.log_prefix} ActionAPI 初始化完成")
|
||||
|
||||
def set_chat_stream(self, chat_stream):
|
||||
"""设置聊天流对象"""
|
||||
self._services["chat_stream"] = chat_stream
|
||||
logger.debug(f"{self.log_prefix} 设置聊天流")
|
||||
|
||||
def set_expressor(self, expressor):
|
||||
"""设置表达器对象"""
|
||||
self._services["expressor"] = expressor
|
||||
logger.debug(f"{self.log_prefix} 设置表达器")
|
||||
|
||||
def set_replyer(self, replyer):
|
||||
"""设置回复器对象"""
|
||||
self._services["replyer"] = replyer
|
||||
logger.debug(f"{self.log_prefix} 设置回复器")
|
||||
|
||||
def set_observations(self, observations):
|
||||
"""设置观察列表"""
|
||||
self._services["observations"] = observations or []
|
||||
logger.debug(f"{self.log_prefix} 设置观察列表")
|
||||
292
src/plugin_system/apis/chat_api.py
Normal file
292
src/plugin_system/apis/chat_api.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
聊天API模块
|
||||
|
||||
专门负责聊天信息的查询和管理,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import chat_api
|
||||
streams = chat_api.get_all_group_streams()
|
||||
chat_type = chat_api.get_stream_type(stream)
|
||||
|
||||
或者:
|
||||
from src.plugin_system.apis.chat_api import ChatManager as chat
|
||||
streams = chat.get_all_group_streams()
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入依赖
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
"""获取所有聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 聊天流列表
|
||||
"""
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 群聊聊天流列表
|
||||
"""
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform and stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 私聊聊天流列表
|
||||
"""
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform and not stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
"""
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(group_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
|
||||
return stream
|
||||
logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 查找群聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
"""
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
not stream.group_info
|
||||
and str(stream.user_info.user_id) == str(user_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
|
||||
return stream
|
||||
logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 查找私聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
"""获取聊天流类型
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
Returns:
|
||||
str: 聊天类型 ("group", "private", "unknown")
|
||||
"""
|
||||
if not chat_stream:
|
||||
return "unknown"
|
||||
|
||||
if hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
"""获取聊天流详细信息
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 聊天流信息字典
|
||||
"""
|
||||
if not chat_stream:
|
||||
return {}
|
||||
|
||||
try:
|
||||
info = {
|
||||
"stream_id": chat_stream.stream_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
}
|
||||
|
||||
if chat_stream.group_info:
|
||||
info.update({
|
||||
"group_id": chat_stream.group_info.group_id,
|
||||
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"),
|
||||
})
|
||||
|
||||
if chat_stream.user_info:
|
||||
info.update({
|
||||
"user_id": chat_stream.user_info.user_id,
|
||||
"user_name": chat_stream.user_info.user_nickname,
|
||||
})
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_recent_messages_from_obs(observations: List[Any], count: int = 5) -> List[Dict[str, Any]]:
|
||||
"""从观察对象获取最近的消息
|
||||
|
||||
Args:
|
||||
observations: 观察对象列表
|
||||
count: 要获取的消息数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
|
||||
"""
|
||||
messages = []
|
||||
|
||||
try:
|
||||
if observations and len(observations) > 0:
|
||||
obs = observations[0]
|
||||
if hasattr(obs, "get_talking_message"):
|
||||
obs: ObsInfo
|
||||
raw_messages = obs.get_talking_message()
|
||||
# 转换为简化格式
|
||||
for msg in raw_messages[-count:]:
|
||||
simple_msg = {
|
||||
"sender": msg.get("sender", "未知"),
|
||||
"content": msg.get("content", ""),
|
||||
"timestamp": msg.get("timestamp", 0),
|
||||
}
|
||||
messages.append(simple_msg)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(messages)} 条最近消息")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取最近消息失败: {e}")
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
"""获取聊天流统计摘要
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 包含各种统计信息的字典
|
||||
"""
|
||||
try:
|
||||
all_streams = ChatManager.get_all_streams()
|
||||
group_streams = ChatManager.get_group_streams()
|
||||
private_streams = ChatManager.get_private_streams()
|
||||
|
||||
summary = {
|
||||
"total_streams": len(all_streams),
|
||||
"group_streams": len(group_streams),
|
||||
"private_streams": len(private_streams),
|
||||
"qq_streams": len([s for s in all_streams if s.platform == "qq"]),
|
||||
}
|
||||
|
||||
logger.debug(f"[ChatAPI] 聊天流统计: {summary}")
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
|
||||
return {"total_streams": 0, "group_streams": 0, "private_streams": 0, "qq_streams": 0}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
|
||||
# =============================================================================
|
||||
|
||||
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_stream_by_user_id(user_id, platform)
|
||||
|
||||
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
"""获取聊天流类型的便捷函数"""
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
"""获取聊天流统计摘要的便捷函数"""
|
||||
return ChatManager.get_streams_summary()
|
||||
@@ -1,3 +1,12 @@
|
||||
"""配置API模块
|
||||
|
||||
提供了配置读取和用户信息获取等功能
|
||||
使用方式:
|
||||
from src.plugin_system.apis import config_api
|
||||
value = config_api.get_global_config("section.key")
|
||||
platform, user_id = await config_api.get_user_id_by_person_name("用户名")
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -6,92 +15,104 @@ from src.person_info.person_info import get_person_info_manager
|
||||
logger = get_logger("config_api")
|
||||
|
||||
|
||||
class ConfigAPI:
|
||||
"""配置API模块
|
||||
# =============================================================================
|
||||
# 配置访问API函数
|
||||
# =============================================================================
|
||||
|
||||
提供了配置读取和用户信息获取等功能
|
||||
def get_global_config(key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从全局配置中获取一个值。
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
|
||||
def get_global_config(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从全局配置中获取一个值。
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = global_config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
if hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"获取全局配置 {key} 失败: {e}")
|
||||
return default
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
从插件配置中获取值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 获取插件配置
|
||||
plugin_config = getattr(self, "_plugin_config", {})
|
||||
if not plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = plugin_config
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = global_config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
if hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
|
||||
return default
|
||||
|
||||
async def get_user_id_by_person_name(self, person_name: str) -> tuple[str, str]:
|
||||
"""根据用户名获取用户ID
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
从插件配置中获取值,支持嵌套键访问
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (平台, 用户ID)
|
||||
"""
|
||||
Args:
|
||||
plugin_config: 插件配置字典
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 用户信息API函数
|
||||
# =============================================================================
|
||||
|
||||
async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
|
||||
"""根据用户名获取用户ID
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (平台, 用户ID)
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
user_id = await person_info_manager.get_value(person_id, "user_id")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
return platform, user_id
|
||||
except Exception as e:
|
||||
logger.error(f"[ConfigAPI] 根据用户名获取用户ID失败: {e}")
|
||||
return "", ""
|
||||
|
||||
async def get_person_info(self, person_id: str, key: str, default: Any = None) -> Any:
|
||||
"""获取用户信息
|
||||
|
||||
Args:
|
||||
person_id: 用户ID
|
||||
key: 信息键名
|
||||
default: 默认值
|
||||
async def get_person_info(person_id: str, key: str, default: Any = None) -> Any:
|
||||
"""获取用户信息
|
||||
|
||||
Returns:
|
||||
Any: 用户信息值或默认值
|
||||
"""
|
||||
Args:
|
||||
person_id: 用户ID
|
||||
key: 信息键名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 用户信息值或默认值
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return await person_info_manager.get_value(person_id, key, default)
|
||||
except Exception as e:
|
||||
logger.error(f"[ConfigAPI] 获取用户信息失败: {e}")
|
||||
return default
|
||||
|
||||
@@ -1,352 +1,97 @@
|
||||
"""数据库API模块
|
||||
|
||||
提供数据库操作相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import database_api
|
||||
records = await database_api.db_query(ActionRecords, query_type="get")
|
||||
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, List, Any, Union, Type
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database import db
|
||||
from peewee import Model, DoesNotExist
|
||||
|
||||
logger = get_logger("database_api")
|
||||
|
||||
# =============================================================================
|
||||
# 通用数据库查询API函数
|
||||
# =============================================================================
|
||||
|
||||
class DatabaseAPI:
|
||||
"""数据库API模块
|
||||
async def db_query(
|
||||
model_class: Type[Model],
|
||||
query_type: str = "get",
|
||||
filters: Dict[str, Any] = None,
|
||||
data: Dict[str, Any] = None,
|
||||
limit: int = None,
|
||||
order_by: List[str] = None,
|
||||
single_result: bool = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
提供了数据库操作相关的功能
|
||||
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
filters: 过滤条件字典,键为字段名,值为要匹配的值
|
||||
data: 用于创建或更新的数据字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回不同的结果:
|
||||
- "get": 返回查询结果列表或单个结果(如果 single_result=True)
|
||||
- "create": 返回创建的记录
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await database_api.db_query(
|
||||
Messages,
|
||||
query_type="get",
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by=["-time"]
|
||||
)
|
||||
|
||||
# 创建一条记录
|
||||
new_record = await database_api.db_query(
|
||||
ActionRecords,
|
||||
query_type="create",
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
|
||||
)
|
||||
|
||||
# 更新记录
|
||||
updated_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
query_type="update",
|
||||
filters={"action_id": "123"},
|
||||
data={"action_done": True}
|
||||
)
|
||||
|
||||
# 删除记录
|
||||
deleted_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
query_type="delete",
|
||||
filters={"action_id": "123"}
|
||||
)
|
||||
|
||||
# 计数
|
||||
count = await database_api.db_query(
|
||||
Messages,
|
||||
query_type="count",
|
||||
filters={"chat_id": chat_stream.stream_id}
|
||||
)
|
||||
"""
|
||||
|
||||
async def store_action_info(
|
||||
self,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: dict = None,
|
||||
) -> None:
|
||||
"""存储action信息到数据库
|
||||
|
||||
Args:
|
||||
action_build_into_prompt: 是否构建到提示中
|
||||
action_prompt_display: 显示的action提示信息
|
||||
action_done: action是否完成
|
||||
thinking_id: 思考ID
|
||||
action_data: action数据,如果不提供则使用空字典
|
||||
"""
|
||||
try:
|
||||
chat_stream = self.get_service("chat_stream")
|
||||
if not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法存储action信息:缺少chat_stream服务")
|
||||
return
|
||||
|
||||
action_time = time.time()
|
||||
action_id = f"{action_time}_{thinking_id}"
|
||||
|
||||
ActionRecords.create(
|
||||
action_id=action_id,
|
||||
time=action_time,
|
||||
action_name=self.__class__.__name__,
|
||||
action_data=str(action_data or {}),
|
||||
action_done=action_done,
|
||||
action_build_into_prompt=action_build_into_prompt,
|
||||
action_prompt_display=action_prompt_display,
|
||||
chat_id=chat_stream.stream_id,
|
||||
chat_info_stream_id=chat_stream.stream_id,
|
||||
chat_info_platform=chat_stream.platform,
|
||||
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
|
||||
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def db_query(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
query_type: str = "get",
|
||||
filters: Dict[str, Any] = None,
|
||||
data: Dict[str, Any] = None,
|
||||
limit: int = None,
|
||||
order_by: List[str] = None,
|
||||
single_result: bool = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
filters: 过滤条件字典,键为字段名,值为要匹配的值
|
||||
data: 用于创建或更新的数据字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回不同的结果:
|
||||
- "get": 返回查询结果列表或单个结果(如果 single_result=True)
|
||||
- "create": 返回创建的记录
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await self.db_query(
|
||||
Messages,
|
||||
query_type="get",
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by=["-time"]
|
||||
)
|
||||
|
||||
# 创建一条记录
|
||||
new_record = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="create",
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
|
||||
)
|
||||
|
||||
# 更新记录
|
||||
updated_count = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="update",
|
||||
filters={"action_id": "123"},
|
||||
data={"action_done": True}
|
||||
)
|
||||
|
||||
# 删除记录
|
||||
deleted_count = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="delete",
|
||||
filters={"action_id": "123"}
|
||||
)
|
||||
|
||||
# 计数
|
||||
count = await self.db_query(
|
||||
Messages,
|
||||
query_type="count",
|
||||
filters={"chat_id": chat_stream.stream_id}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
|
||||
# 构建基本查询
|
||||
if query_type in ["get", "update", "delete", "count"]:
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 执行查询
|
||||
if query_type == "get":
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field in order_by:
|
||||
if field.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, field[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, field))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get()
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
# 更新记录
|
||||
return query.update(**data).execute()
|
||||
|
||||
elif query_type == "delete":
|
||||
# 删除记录
|
||||
return query.delete().execute()
|
||||
|
||||
elif query_type == "count":
|
||||
# 计数
|
||||
return query.count()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的查询类型: {query_type}")
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
if query_type == "get" and single_result:
|
||||
return None
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def db_raw_query(
|
||||
self, sql: str, params: List[Any] = None, fetch_results: bool = True
|
||||
) -> Union[List[Dict[str, Any]], int, None]:
|
||||
"""执行原始SQL查询
|
||||
|
||||
警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。
|
||||
|
||||
Args:
|
||||
sql: 原始SQL查询字符串
|
||||
params: 查询参数列表,用于替换SQL中的占位符
|
||||
fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于
|
||||
UPDATE/INSERT/DELETE等操作设为False
|
||||
|
||||
Returns:
|
||||
如果fetch_results为True,返回查询结果列表;
|
||||
如果fetch_results为False,返回受影响的行数;
|
||||
如果出错,返回None
|
||||
"""
|
||||
try:
|
||||
cursor = db.execute_sql(sql, params or [])
|
||||
|
||||
if fetch_results:
|
||||
# 获取列名
|
||||
columns = [col[0] for col in cursor.description]
|
||||
|
||||
# 构建结果字典列表
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
results.append(dict(zip(columns, row)))
|
||||
|
||||
return results
|
||||
else:
|
||||
# 返回受影响的行数
|
||||
return cursor.rowcount
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def db_save(
|
||||
self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类,如ActionRecords, Messages等
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名,例如"action_id"
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存后的记录数据
|
||||
None: 如果操作失败
|
||||
|
||||
示例:
|
||||
# 创建或更新一条记录
|
||||
record = await self.db_save(
|
||||
ActionRecords,
|
||||
{
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_name": "TestAction",
|
||||
"action_done": True
|
||||
},
|
||||
key_field="action_id",
|
||||
key_value="123"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
# 查找现有记录
|
||||
existing_records = list(
|
||||
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
|
||||
)
|
||||
|
||||
if existing_records:
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
for field, value in data.items():
|
||||
setattr(existing_record, field, value)
|
||||
existing_record.save()
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
|
||||
return created_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def db_get(
|
||||
self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
这是db_query方法的简化版本,专注于数据检索操作。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
|
||||
limit: 结果数量限制,如果为1则返回单个记录而不是列表
|
||||
|
||||
Returns:
|
||||
如果limit=1,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
示例:
|
||||
# 获取单个记录
|
||||
record = await self.db_get(
|
||||
ActionRecords,
|
||||
filters={"action_id": "123"},
|
||||
limit=1
|
||||
)
|
||||
|
||||
# 获取最近10条记录
|
||||
records = await self.db_get(
|
||||
Messages,
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
order_by="-time",
|
||||
limit=10
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
|
||||
# 构建基本查询
|
||||
if query_type in ["get", "update", "delete", "count"]:
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
@@ -354,12 +99,15 @@ class DatabaseAPI:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 执行查询
|
||||
if query_type == "get":
|
||||
# 应用排序
|
||||
if order_by:
|
||||
if order_by.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, order_by[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, order_by))
|
||||
for field in order_by:
|
||||
if field.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, field[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, field))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
@@ -369,11 +117,270 @@ class DatabaseAPI:
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if limit == 1:
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if limit == 1 else []
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get()
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
# 更新记录
|
||||
return query.update(**data).execute()
|
||||
|
||||
elif query_type == "delete":
|
||||
# 删除记录
|
||||
return query.delete().execute()
|
||||
|
||||
elif query_type == "count":
|
||||
# 计数
|
||||
return query.count()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的查询类型: {query_type}")
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
if query_type == "get" and single_result:
|
||||
return None
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类,如ActionRecords, Messages等
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名,例如"action_id"
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存后的记录数据
|
||||
None: 如果操作失败
|
||||
|
||||
示例:
|
||||
# 创建或更新一条记录
|
||||
record = await database_api.db_save(
|
||||
ActionRecords,
|
||||
{
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_name": "TestAction",
|
||||
"action_done": True
|
||||
},
|
||||
key_field="action_id",
|
||||
key_value="123"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
# 查找现有记录
|
||||
existing_records = list(
|
||||
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
|
||||
)
|
||||
|
||||
if existing_records:
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
for field, value in data.items():
|
||||
setattr(existing_record, field, value)
|
||||
existing_record.save()
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
|
||||
return created_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
这是db_query方法的简化版本,专注于数据检索操作。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
|
||||
limit: 结果数量限制,如果为1则返回单个记录而不是列表
|
||||
|
||||
Returns:
|
||||
如果limit=1,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
示例:
|
||||
# 获取单个记录
|
||||
record = await database_api.db_get(
|
||||
ActionRecords,
|
||||
filters={"action_id": "123"},
|
||||
limit=1
|
||||
)
|
||||
|
||||
# 获取最近10条记录
|
||||
records = await database_api.db_get(
|
||||
Messages,
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
order_by="-time",
|
||||
limit=10
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
if order_by.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, order_by[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, order_by))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if limit == 1:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if limit == 1 else []
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: dict = None,
|
||||
action_name: str = "",
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象,包含聊天相关信息
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
action_prompt_display: 动作的提示显示文本
|
||||
action_done: 动作是否完成
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存的记录数据
|
||||
None: 如果保存失败
|
||||
|
||||
示例:
|
||||
record = await database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display="执行了回复动作",
|
||||
action_done=True,
|
||||
thinking_id="thinking_123",
|
||||
action_data={"content": "Hello"},
|
||||
action_name="reply_action"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
import json
|
||||
from src.common.database.database_model import ActionRecords
|
||||
|
||||
# 构建动作记录数据
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update({
|
||||
"chat_id": getattr(chat_stream, 'stream_id', ''),
|
||||
"chat_info_stream_id": getattr(chat_stream, 'stream_id', ''),
|
||||
"chat_info_platform": getattr(chat_stream, 'platform', ''),
|
||||
})
|
||||
else:
|
||||
# 如果没有chat_stream,设置默认值
|
||||
record_data.update({
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
})
|
||||
|
||||
# 使用已有的db_save函数保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords,
|
||||
data=record_data,
|
||||
key_field="action_id",
|
||||
key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.info(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
|
||||
|
||||
return saved_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
219
src/plugin_system/apis/emoji_api.py
Normal file
219
src/plugin_system/apis/emoji_api.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
表情API模块
|
||||
|
||||
提供表情包相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import emoji_api
|
||||
result = await emoji_api.get_by_description("开心")
|
||||
count = emoji_api.get_count()
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
|
||||
logger = get_logger("emoji_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包获取API函数
|
||||
# =============================================================================
|
||||
|
||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据描述选择表情包
|
||||
|
||||
Args:
|
||||
description: 表情包的描述文本,例如"开心"、"难过"、"愤怒"等
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
emoji_result = await emoji_manager.get_emoji_for_text(description)
|
||||
|
||||
if not emoji_result:
|
||||
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
|
||||
return None
|
||||
|
||||
emoji_path, emoji_description, matched_emotion = emoji_result
|
||||
emoji_base64 = image_path_to_base64(emoji_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
||||
return emoji_base64, emoji_description, matched_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_random() -> Optional[Tuple[str, str, str]]:
|
||||
"""随机获取表情包
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 随机情感标签) 或 None
|
||||
"""
|
||||
try:
|
||||
logger.info("[EmojiAPI] 随机获取表情包")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis = emoji_manager.emoji_objects
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
||||
return None
|
||||
|
||||
# 过滤有效表情包
|
||||
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
||||
if not valid_emojis:
|
||||
logger.warning("[EmojiAPI] 没有有效的表情包")
|
||||
return None
|
||||
|
||||
# 随机选择
|
||||
import random
|
||||
selected_emoji = random.choice(valid_emojis)
|
||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
return None
|
||||
|
||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
||||
|
||||
# 记录使用次数
|
||||
emoji_manager.record_usage(selected_emoji.hash)
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取随机表情包: {selected_emoji.description}")
|
||||
return emoji_base64, selected_emoji.description, matched_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据情感标签获取表情包
|
||||
|
||||
Args:
|
||||
emotion: 情感标签,如"happy"、"sad"、"angry"等
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis = emoji_manager.emoji_objects
|
||||
|
||||
# 筛选匹配情感的表情包
|
||||
matching_emojis = []
|
||||
for emoji_obj in all_emojis:
|
||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]:
|
||||
matching_emojis.append(emoji_obj)
|
||||
|
||||
if not matching_emojis:
|
||||
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
|
||||
return None
|
||||
|
||||
# 随机选择匹配的表情包
|
||||
import random
|
||||
selected_emoji = random.choice(matching_emojis)
|
||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
return None
|
||||
|
||||
# 记录使用次数
|
||||
emoji_manager.record_usage(selected_emoji.hash)
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
|
||||
return emoji_base64, selected_emoji.description, emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包信息查询API函数
|
||||
# =============================================================================
|
||||
|
||||
def get_count() -> int:
|
||||
"""获取表情包数量
|
||||
|
||||
Returns:
|
||||
int: 当前可用的表情包数量
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
return emoji_manager.emoji_num
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def get_info() -> dict:
|
||||
"""获取表情包系统信息
|
||||
|
||||
Returns:
|
||||
dict: 包含表情包数量、最大数量等信息
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
return {
|
||||
"current_count": emoji_manager.emoji_num,
|
||||
"max_count": emoji_manager.emoji_num_max,
|
||||
"available_emojis": len([e for e in emoji_manager.emoji_objects if not e.is_deleted]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
|
||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||
|
||||
|
||||
def get_emotions() -> list:
|
||||
"""获取所有可用的情感标签
|
||||
|
||||
Returns:
|
||||
list: 所有表情包的情感标签列表(去重)
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
emotions = set()
|
||||
|
||||
for emoji_obj in emoji_manager.emoji_objects:
|
||||
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
||||
emotions.update(emoji_obj.emotion)
|
||||
|
||||
return sorted(list(emotions))
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取情感标签失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> list:
|
||||
"""获取所有表情包描述
|
||||
|
||||
Returns:
|
||||
list: 所有可用表情包的描述列表
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
descriptions = []
|
||||
|
||||
for emoji_obj in emoji_manager.emoji_objects:
|
||||
if not emoji_obj.is_deleted and emoji_obj.description:
|
||||
descriptions.append(emoji_obj.description)
|
||||
|
||||
return descriptions
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
|
||||
return []
|
||||
170
src/plugin_system/apis/generator_api.py
Normal file
170
src/plugin_system/apis/generator_api.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
回复器API模块
|
||||
|
||||
提供回复器相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import generator_api
|
||||
replyer = generator_api.get_replyer(chat_stream)
|
||||
success, reply_set = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Any, Dict, List
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
logger = get_logger("generator_api")
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 回复器获取API函数
|
||||
# =============================================================================
|
||||
|
||||
def get_replyer(chat_stream=None, platform: str = None, chat_id: str = None, is_group: bool = True) -> DefaultReplyer:
|
||||
"""获取回复器对象
|
||||
|
||||
优先使用chat_stream,如果没有则使用platform和chat_id组合
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
platform: 平台名称,如"qq"
|
||||
chat_id: 聊天ID(群ID或用户ID)
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
Optional[Any]: 回复器对象,如果获取失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 优先使用聊天流
|
||||
if chat_stream:
|
||||
logger.debug("[GeneratorAPI] 使用聊天流获取回复器")
|
||||
return DefaultReplyer(chat_stream=chat_stream)
|
||||
|
||||
# 使用平台和ID组合
|
||||
if platform and chat_id:
|
||||
logger.debug("[GeneratorAPI] 使用平台和ID获取回复器")
|
||||
chat_manager = get_chat_manager()
|
||||
if not chat_manager:
|
||||
logger.warning("[GeneratorAPI] 无法获取聊天管理器")
|
||||
return None
|
||||
|
||||
# 查找对应的聊天流
|
||||
target_stream = None
|
||||
for _stream_id, stream in chat_manager.streams.items():
|
||||
if stream.platform == platform:
|
||||
if is_group and stream.group_info:
|
||||
if str(stream.group_info.group_id) == str(chat_id):
|
||||
target_stream = stream
|
||||
break
|
||||
elif not is_group and stream.user_info:
|
||||
if str(stream.user_info.user_id) == str(chat_id):
|
||||
target_stream = stream
|
||||
break
|
||||
|
||||
return DefaultReplyer(chat_stream=target_stream)
|
||||
|
||||
logger.warning("[GeneratorAPI] 缺少必要参数,无法获取回复器")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 获取回复器失败: {e}")
|
||||
return None
|
||||
|
||||
# =============================================================================
|
||||
# 回复生成API函数
|
||||
# =============================================================================
|
||||
|
||||
async def generate_reply(
|
||||
chat_stream=None,
|
||||
action_data: Dict[str, Any] = None,
|
||||
platform: str = None,
|
||||
chat_id: str = None,
|
||||
is_group: bool = True
|
||||
) -> Tuple[bool, List[Tuple[str, Any]]]:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
action_data: 动作数据
|
||||
reasoning: 推理原因
|
||||
thinking_id: 思考ID
|
||||
cycle_timers: 循环计时器
|
||||
anchor_message: 锚点消息
|
||||
platform: 平台名称(备用)
|
||||
chat_id: 聊天ID(备用)
|
||||
is_group: 是否为群聊(备用)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, platform, chat_id, is_group)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, []
|
||||
|
||||
logger.info("[GeneratorAPI] 开始生成回复")
|
||||
|
||||
# 调用回复器生成回复
|
||||
success, reply_set = await replyer.generate_reply_with_context(
|
||||
reply_data=action_data or {},
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
|
||||
return success, reply_set or []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||
return False, []
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream=None,
|
||||
reply_data: Dict[str, Any] = None,
|
||||
platform: str = None,
|
||||
chat_id: str = None,
|
||||
is_group: bool = True
|
||||
) -> Tuple[bool, List[Tuple[str, Any]]]:
|
||||
"""重写回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
action_data: 动作数据
|
||||
platform: 平台名称(备用)
|
||||
chat_id: 聊天ID(备用)
|
||||
is_group: 是否为群聊(备用)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, platform, chat_id, is_group)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, []
|
||||
|
||||
logger.info("[GeneratorAPI] 开始重写回复")
|
||||
|
||||
# 调用回复器重写回复
|
||||
success, reply_set = await replyer.rewrite_reply_with_context(
|
||||
reply_data=reply_data or {},
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
||||
|
||||
return success, reply_set or []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
||||
return False, []
|
||||
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
from typing import Optional, List, Any, Tuple
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("hearflow_api")
|
||||
|
||||
|
||||
def _get_heartflow():
|
||||
"""获取heartflow实例的延迟导入函数"""
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
|
||||
return heartflow
|
||||
|
||||
|
||||
def _get_subheartflow_types():
|
||||
"""获取SubHeartflow和ChatState类型的延迟导入函数"""
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||
|
||||
return SubHeartflow, ChatState
|
||||
|
||||
|
||||
class HearflowAPI:
|
||||
"""心流API模块
|
||||
|
||||
提供与心流和子心流相关的操作接口
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[HearflowAPI]"
|
||||
|
||||
async def get_sub_hearflow_by_chat_id(self, chat_id: str) -> Optional[Any]:
|
||||
"""根据chat_id获取指定的sub_hearflow实例
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID,与sub_hearflow的subheartflow_id相同
|
||||
|
||||
Returns:
|
||||
Optional[SubHeartflow]: sub_hearflow实例,如果不存在则返回None
|
||||
"""
|
||||
# 使用延迟导入
|
||||
heartflow = _get_heartflow()
|
||||
|
||||
# 直接从subheartflow_manager获取已存在的子心流
|
||||
# 使用锁来确保线程安全
|
||||
async with heartflow.subheartflow_manager._lock:
|
||||
subflow = heartflow.subheartflow_manager.subheartflows.get(chat_id)
|
||||
if subflow and not subflow.should_stop:
|
||||
logger.debug(f"{self.log_prefix} 成功获取子心流实例: {chat_id}")
|
||||
return subflow
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 子心流不存在或已停止: {chat_id}")
|
||||
return None
|
||||
|
||||
async def get_or_create_sub_hearflow_by_chat_id(self, chat_id: str) -> Optional[Any]:
|
||||
"""根据chat_id获取或创建sub_hearflow实例
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
Optional[SubHeartflow]: sub_hearflow实例,创建失败时返回None
|
||||
"""
|
||||
heartflow = _get_heartflow()
|
||||
return await heartflow.get_or_create_subheartflow(chat_id)
|
||||
|
||||
def get_all_sub_hearflow_ids(self) -> List[str]:
|
||||
"""获取所有子心流的ID列表
|
||||
|
||||
Returns:
|
||||
List[str]: 所有子心流的ID列表
|
||||
"""
|
||||
heartflow = _get_heartflow()
|
||||
all_subflows = heartflow.subheartflow_manager.get_all_subheartflows()
|
||||
chat_ids = [subflow.chat_id for subflow in all_subflows if not subflow.should_stop]
|
||||
logger.debug(f"{self.log_prefix} 获取到 {len(chat_ids)} 个活跃的子心流ID")
|
||||
return chat_ids
|
||||
|
||||
def get_all_sub_hearflows(self) -> List[Any]:
|
||||
"""获取所有子心流实例
|
||||
|
||||
Returns:
|
||||
List[SubHeartflow]: 所有活跃的子心流实例列表
|
||||
"""
|
||||
heartflow = _get_heartflow()
|
||||
all_subflows = heartflow.subheartflow_manager.get_all_subheartflows()
|
||||
active_subflows = [subflow for subflow in all_subflows if not subflow.should_stop]
|
||||
logger.debug(f"{self.log_prefix} 获取到 {len(active_subflows)} 个活跃的子心流实例")
|
||||
return active_subflows
|
||||
|
||||
async def get_sub_hearflow_chat_state(self, chat_id: str) -> Optional[Any]:
|
||||
"""获取指定子心流的聊天状态
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
Optional[ChatState]: 聊天状态,如果子心流不存在则返回None
|
||||
"""
|
||||
subflow = await self.get_sub_hearflow_by_chat_id(chat_id)
|
||||
if subflow:
|
||||
return subflow.chat_state.chat_status
|
||||
return None
|
||||
|
||||
async def set_sub_hearflow_chat_state(self, chat_id: str, target_state: Any) -> bool:
|
||||
"""设置指定子心流的聊天状态
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
target_state: 目标状态(ChatState枚举值)
|
||||
|
||||
Returns:
|
||||
bool: 是否设置成功
|
||||
"""
|
||||
heartflow = _get_heartflow()
|
||||
return await heartflow.subheartflow_manager.force_change_state(chat_id, target_state)
|
||||
|
||||
async def get_sub_hearflow_replyer_and_expressor(self, chat_id: str) -> Tuple[Optional[Any], Optional[Any]]:
|
||||
"""根据chat_id获取指定子心流的replyer和expressor实例
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Any], Optional[Any]]: (replyer实例, expressor实例),如果子心流不存在或未处于FOCUSED状态,返回(None, None)
|
||||
"""
|
||||
subflow = await self.get_sub_hearflow_by_chat_id(chat_id)
|
||||
if not subflow:
|
||||
logger.debug(f"{self.log_prefix} 子心流不存在: {chat_id}")
|
||||
return None, None
|
||||
|
||||
# 使用延迟导入获取ChatState
|
||||
_, ChatState = _get_subheartflow_types()
|
||||
|
||||
# 检查子心流是否处于FOCUSED状态且有HeartFC实例
|
||||
if subflow.chat_state.chat_status != ChatState.FOCUSED:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 子心流 {chat_id} 未处于FOCUSED状态,当前状态: {subflow.chat_state.chat_status.value}"
|
||||
)
|
||||
return None, None
|
||||
|
||||
if not subflow.heart_fc_instance:
|
||||
logger.debug(f"{self.log_prefix} 子心流 {chat_id} 没有HeartFC实例")
|
||||
return None, None
|
||||
|
||||
# 返回replyer和expressor实例
|
||||
replyer = subflow.heart_fc_instance.replyer
|
||||
expressor = subflow.heart_fc_instance.expressor
|
||||
|
||||
if replyer and expressor:
|
||||
logger.debug(f"{self.log_prefix} 成功获取子心流 {chat_id} 的replyer和expressor")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 子心流 {chat_id} 的replyer或expressor为空")
|
||||
|
||||
return replyer, expressor
|
||||
|
||||
async def get_sub_hearflow_replyer(self, chat_id: str) -> Optional[Any]:
|
||||
"""根据chat_id获取指定子心流的replyer实例
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
Optional[Any]: replyer实例,如果不存在则返回None
|
||||
"""
|
||||
replyer, _ = await self.get_sub_hearflow_replyer_and_expressor(chat_id)
|
||||
return replyer
|
||||
|
||||
async def get_sub_hearflow_expressor(self, chat_id: str) -> Optional[Any]:
|
||||
"""根据chat_id获取指定子心流的expressor实例
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
Optional[Any]: expressor实例,如果不存在则返回None
|
||||
"""
|
||||
_, expressor = await self.get_sub_hearflow_replyer_and_expressor(chat_id)
|
||||
return expressor
|
||||
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
独立API聚合模块
|
||||
|
||||
聚合了不需要Action组件依赖的API,这些API可以独立使用,不需要注入服务对象。
|
||||
包括:LLMAPI、ConfigAPI、UtilsAPI、StreamAPI、HearflowAPI等独立功能的API。
|
||||
"""
|
||||
|
||||
from src.plugin_system.apis.llm_api import LLMAPI
|
||||
from src.plugin_system.apis.config_api import ConfigAPI
|
||||
from src.plugin_system.apis.utils_api import UtilsAPI
|
||||
from src.plugin_system.apis.stream_api import StreamAPI
|
||||
from src.plugin_system.apis.hearflow_api import HearflowAPI
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("independent_apis")
|
||||
|
||||
|
||||
class IndependentAPI(LLMAPI, ConfigAPI, UtilsAPI, StreamAPI, HearflowAPI):
|
||||
"""
|
||||
独立API聚合类
|
||||
|
||||
聚合了不需要Action组件依赖的API功能。这些API的特点:
|
||||
- 不需要chat_stream、expressor等服务对象
|
||||
- 可以独立调用,不依赖Action执行上下文
|
||||
- 主要是工具类方法和配置查询方法
|
||||
|
||||
包含的API:
|
||||
- LLMAPI: LLM模型调用(仅需要全局配置)
|
||||
- ConfigAPI: 配置读取(使用全局配置)
|
||||
- UtilsAPI: 工具方法(文件操作、时间处理等)
|
||||
- StreamAPI: 聊天流查询(使用ChatManager)
|
||||
- HearflowAPI: 心流状态控制(使用heartflow)
|
||||
|
||||
使用场景:
|
||||
- 在Command组件中使用
|
||||
- 独立的工具函数调用
|
||||
- 配置查询和系统状态检查
|
||||
"""
|
||||
|
||||
def __init__(self, log_prefix: str = "[IndependentAPI]"):
|
||||
"""
|
||||
初始化独立API
|
||||
|
||||
Args:
|
||||
log_prefix: 日志前缀,用于区分不同的调用来源
|
||||
"""
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
logger.debug(f"{self.log_prefix} IndependentAPI 初始化完成")
|
||||
|
||||
|
||||
# 提供便捷的静态访问方式
|
||||
class StaticAPI:
|
||||
"""
|
||||
静态API类
|
||||
|
||||
提供完全静态的API访问方式,不需要实例化,适合简单的工具调用。
|
||||
"""
|
||||
|
||||
# LLM相关
|
||||
@staticmethod
|
||||
def get_available_models():
|
||||
"""获取可用的LLM模型"""
|
||||
api = LLMAPI()
|
||||
return api.get_available_models()
|
||||
|
||||
@staticmethod
|
||||
async def generate_with_model(prompt: str, model_config: dict, **kwargs):
|
||||
"""使用LLM生成内容"""
|
||||
api = LLMAPI()
|
||||
api.log_prefix = "[StaticAPI]"
|
||||
return await api.generate_with_model(prompt, model_config, **kwargs)
|
||||
|
||||
# 配置相关
|
||||
@staticmethod
|
||||
def get_global_config(key: str, default=None):
|
||||
"""获取全局配置"""
|
||||
api = ConfigAPI()
|
||||
return api.get_global_config(key, default)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_id_by_name(person_name: str):
|
||||
"""根据用户名获取用户ID"""
|
||||
api = ConfigAPI()
|
||||
return await api.get_user_id_by_person_name(person_name)
|
||||
|
||||
# 工具相关
|
||||
@staticmethod
|
||||
def get_timestamp():
|
||||
"""获取当前时间戳"""
|
||||
api = UtilsAPI()
|
||||
return api.get_timestamp()
|
||||
|
||||
@staticmethod
|
||||
def format_time(timestamp=None, format_str="%Y-%m-%d %H:%M:%S"):
|
||||
"""格式化时间"""
|
||||
api = UtilsAPI()
|
||||
return api.format_time(timestamp, format_str)
|
||||
|
||||
@staticmethod
|
||||
def generate_unique_id():
|
||||
"""生成唯一ID"""
|
||||
api = UtilsAPI()
|
||||
return api.generate_unique_id()
|
||||
|
||||
# 聊天流相关
|
||||
@staticmethod
|
||||
def get_chat_stream_by_group_id(group_id: str, platform: str = "qq"):
|
||||
"""通过群ID获取聊天流"""
|
||||
api = StreamAPI()
|
||||
api.log_prefix = "[StaticAPI]"
|
||||
return api.get_chat_stream_by_group_id(group_id, platform)
|
||||
|
||||
@staticmethod
|
||||
def get_all_group_chat_streams(platform: str = "qq"):
|
||||
"""获取所有群聊聊天流"""
|
||||
api = StreamAPI()
|
||||
api.log_prefix = "[StaticAPI]"
|
||||
return api.get_all_group_chat_streams(platform)
|
||||
|
||||
# 心流相关
|
||||
@staticmethod
|
||||
async def get_sub_hearflow_by_chat_id(chat_id: str):
|
||||
"""获取子心流"""
|
||||
api = HearflowAPI()
|
||||
api.log_prefix = "[StaticAPI]"
|
||||
return await api.get_sub_hearflow_by_chat_id(chat_id)
|
||||
|
||||
@staticmethod
|
||||
async def set_sub_hearflow_chat_state(chat_id: str, target_state):
|
||||
"""设置子心流状态"""
|
||||
api = HearflowAPI()
|
||||
api.log_prefix = "[StaticAPI]"
|
||||
return await api.set_sub_hearflow_chat_state(chat_id, target_state)
|
||||
@@ -1,3 +1,12 @@
|
||||
"""LLM API模块
|
||||
|
||||
提供了与LLM模型交互的功能
|
||||
使用方式:
|
||||
from src.plugin_system.apis import llm_api
|
||||
models = llm_api.get_available_models()
|
||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -6,49 +15,51 @@ from src.config.config import global_config
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
|
||||
class LLMAPI:
|
||||
"""LLM API模块
|
||||
# =============================================================================
|
||||
# LLM模型API函数
|
||||
# =============================================================================
|
||||
|
||||
提供了与LLM模型交互的功能
|
||||
def get_available_models() -> Dict[str, Any]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||
"""
|
||||
|
||||
def get_available_models(self) -> Dict[str, Any]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||
"""
|
||||
try:
|
||||
if not hasattr(global_config, "model"):
|
||||
logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置")
|
||||
logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置")
|
||||
return {}
|
||||
|
||||
models = global_config.model
|
||||
|
||||
return models
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
||||
return {}
|
||||
|
||||
async def generate_with_model(
|
||||
self, prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
request_type: 请求类型标识
|
||||
**kwargs: 其他模型特定参数,如temperature、max_tokens等
|
||||
async def generate_with_model(
|
||||
prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"{self.log_prefix} 使用模型生成内容,提示词: {prompt[:100]}...")
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
request_type: 请求类型标识
|
||||
**kwargs: 其他模型特定参数,如temperature、max_tokens等
|
||||
|
||||
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[LLMAPI] 使用模型生成内容,提示词: {prompt[:100]}...")
|
||||
|
||||
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
|
||||
return True, response, reasoning, model_name
|
||||
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
|
||||
return True, response, reasoning, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
|
||||
@@ -1,202 +1,329 @@
|
||||
import traceback
|
||||
"""
|
||||
消息API模块
|
||||
|
||||
提供消息查询和构建成字符串的功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import message_api
|
||||
messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time)
|
||||
readable_text = message_api.build_readable_messages(messages)
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
|
||||
# 以下为类型注解需要
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
|
||||
# 新增导入
|
||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from maim_message import Seg, UserInfo
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("message_api")
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_by_timestamp_with_chat_users,
|
||||
get_raw_msg_by_timestamp_random,
|
||||
get_raw_msg_by_timestamp_with_users,
|
||||
get_raw_msg_before_timestamp,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_before_timestamp_with_users,
|
||||
num_new_messages_since,
|
||||
num_new_messages_since_with_users,
|
||||
build_readable_messages,
|
||||
build_readable_messages_with_list,
|
||||
get_person_id_list,
|
||||
)
|
||||
|
||||
|
||||
class MessageAPI:
|
||||
"""消息API模块
|
||||
# =============================================================================
|
||||
# 消息查询API函数
|
||||
# =============================================================================
|
||||
|
||||
提供了发送消息、获取消息历史等功能
|
||||
def get_messages_by_time(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
|
||||
|
||||
async def send_message_to_target(
|
||||
self,
|
||||
message_type: str,
|
||||
content: str,
|
||||
platform: str,
|
||||
target_id: str,
|
||||
is_group: bool = True,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
) -> bool:
|
||||
"""直接向指定目标发送消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"等
|
||||
content: 消息内容
|
||||
platform: 目标平台,如"qq"
|
||||
target_id: 目标ID(群ID或用户ID)
|
||||
is_group: 是否为群聊,True为群聊,False为私聊
|
||||
display_message: 显示消息(可选)
|
||||
def get_messages_by_time_in_chat(
|
||||
chat_id: str, start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
# 构建目标聊天流ID
|
||||
if is_group:
|
||||
# 群聊:从数据库查找对应的聊天流
|
||||
target_stream = None
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(target_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
target_stream = stream
|
||||
break
|
||||
|
||||
if not target_stream:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到群ID为 {target_id} 的聊天流")
|
||||
return False
|
||||
else:
|
||||
# 私聊:从数据库查找对应的聊天流
|
||||
target_stream = None
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
not stream.group_info
|
||||
and str(stream.user_info.user_id) == str(target_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
target_stream = stream
|
||||
break
|
||||
def get_messages_by_time_in_chat_inclusive(
|
||||
chat_id: str, start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳(包含)
|
||||
end_time: 结束时间戳(包含)
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode)
|
||||
|
||||
if not target_stream:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到用户ID为 {target_id} 的私聊流")
|
||||
return False
|
||||
|
||||
# 创建HeartFCSender实例
|
||||
heart_fc_sender = HeartFCSender()
|
||||
def get_messages_by_time_in_chat_for_users(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
person_ids: list,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定用户在指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
# 生成消息ID和thinking_id
|
||||
current_time = time.time()
|
||||
message_id = f"plugin_msg_{int(current_time * 1000)}"
|
||||
|
||||
# 构建机器人用户信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=platform,
|
||||
)
|
||||
def get_random_chat_messages(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
|
||||
|
||||
# 创建消息段
|
||||
message_segment = Seg(type=message_type, data=content)
|
||||
|
||||
# 创建空锚点消息(用于回复)
|
||||
anchor_message = await create_empty_anchor_message(platform, target_stream.group_info, target_stream)
|
||||
def get_messages_by_time_for_users(
|
||||
start_time: float, end_time: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在所有聊天中指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
# 构建发送消息对象
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=target_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=target_stream.user_info, # 目标用户信息
|
||||
message_segment=message_segment,
|
||||
display_message=display_message,
|
||||
reply=anchor_message,
|
||||
is_head=True,
|
||||
is_emoji=(message_type == "emoji"),
|
||||
thinking_start_time=current_time,
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
bot_message, has_thinking=False, typing=typing, set_reply=False
|
||||
)
|
||||
def get_messages_before_time(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_before_timestamp(timestamp, limit)
|
||||
|
||||
if sent_msg:
|
||||
logger.info(f"{getattr(self, 'log_prefix', '')} 成功发送消息到 {platform}:{target_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')} 发送消息失败")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')} 向目标发送消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
def get_messages_before_time_in_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
timestamp: 时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
||||
|
||||
async def send_text_to_group(self, text: str, group_id: str, platform: str = "qq") -> bool:
|
||||
"""便捷方法:向指定群聊发送文本消息
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
def get_messages_before_time_for_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await self.send_message_to_target(
|
||||
message_type="text", content=text, platform=platform, target_id=group_id, is_group=True
|
||||
)
|
||||
|
||||
async def send_text_to_user(self, text: str, user_id: str, platform: str = "qq") -> bool:
|
||||
"""便捷方法:向指定用户发送私聊文本消息
|
||||
def get_recent_messages(
|
||||
chat_id: str,
|
||||
hours: float = 24.0,
|
||||
limit: int = 100,
|
||||
limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中最近一段时间的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
hours: 最近多少小时,默认24小时
|
||||
limit: 限制返回的消息数量,默认100条
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
now = time.time()
|
||||
start_time = now - hours * 3600
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await self.send_message_to_target(
|
||||
message_type="text", content=text, platform=platform, target_id=user_id, is_group=False
|
||||
)
|
||||
# =============================================================================
|
||||
# 消息计数API函数
|
||||
# =============================================================================
|
||||
|
||||
def get_chat_type(self) -> str:
|
||||
"""获取当前聊天类型
|
||||
def count_new_messages(
|
||||
chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None
|
||||
) -> int:
|
||||
"""
|
||||
计算指定聊天中从开始时间到结束时间的新消息数量
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳,如果为None则使用当前时间
|
||||
|
||||
Returns:
|
||||
新消息数量
|
||||
"""
|
||||
return num_new_messages_since(chat_id, start_time, end_time)
|
||||
|
||||
Returns:
|
||||
str: 聊天类型 ("group" 或 "private")
|
||||
"""
|
||||
services = getattr(self, "_services", {})
|
||||
chat_stream: ChatStream = services.get("chat_stream")
|
||||
if chat_stream and hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
return "unknown"
|
||||
|
||||
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息
|
||||
def count_new_messages_for_users(
|
||||
chat_id: str, start_time: float, end_time: float, person_ids: list
|
||||
) -> int:
|
||||
"""
|
||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
|
||||
Returns:
|
||||
新消息数量
|
||||
"""
|
||||
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
|
||||
|
||||
Args:
|
||||
count: 要获取的消息数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
|
||||
"""
|
||||
messages = []
|
||||
services = getattr(self, "_services", {})
|
||||
observations = services.get("observations", [])
|
||||
# =============================================================================
|
||||
# 消息格式化API函数
|
||||
# =============================================================================
|
||||
|
||||
if observations and len(observations) > 0:
|
||||
obs = observations[0]
|
||||
if hasattr(obs, "get_talking_message"):
|
||||
obs: ObsInfo
|
||||
raw_messages = obs.get_talking_message()
|
||||
# 转换为简化格式
|
||||
for msg in raw_messages[-count:]:
|
||||
simple_msg = {
|
||||
"sender": msg.get("sender", "未知"),
|
||||
"content": msg.get("content", ""),
|
||||
"timestamp": msg.get("timestamp", 0),
|
||||
}
|
||||
messages.append(simple_msg)
|
||||
def build_readable_messages_to_str(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
将消息列表构建成可读的字符串
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否将机器人的名称替换为"你"
|
||||
merge_messages: 是否合并连续消息
|
||||
timestamp_mode: 时间戳显示模式,'relative'或'absolute'
|
||||
read_mark: 已读标记时间戳,用于分割已读和未读消息
|
||||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
|
||||
Returns:
|
||||
格式化后的可读字符串
|
||||
"""
|
||||
return build_readable_messages(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表构建成可读的字符串,并返回详细信息
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否将机器人的名称替换为"你"
|
||||
merge_messages: 是否合并连续消息
|
||||
timestamp_mode: 时间戳显示模式,'relative'或'absolute'
|
||||
truncate: 是否截断长消息
|
||||
|
||||
Returns:
|
||||
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
|
||||
"""
|
||||
return await build_readable_messages_with_list(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
|
||||
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的用户ID列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
用户ID列表
|
||||
"""
|
||||
return await get_person_id_list(messages)
|
||||
|
||||
153
src/plugin_system/apis/person_api.py
Normal file
153
src/plugin_system/apis/person_api.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""个人信息API模块
|
||||
|
||||
提供个人信息查询功能,用于插件获取用户相关信息
|
||||
使用方式:
|
||||
from src.plugin_system.apis import person_api
|
||||
person_id = person_api.get_person_id("qq", 123456)
|
||||
value = await person_api.get_person_value(person_id, "nickname")
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
|
||||
|
||||
logger = get_logger("person_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 个人信息API函数
|
||||
# =============================================================================
|
||||
|
||||
def get_person_id(platform: str, user_id: int) -> str:
|
||||
"""根据平台和用户ID获取person_id
|
||||
|
||||
Args:
|
||||
platform: 平台名称,如 "qq", "telegram" 等
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
str: 唯一的person_id(MD5哈希值)
|
||||
|
||||
示例:
|
||||
person_id = person_api.get_person_id("qq", 123456)
|
||||
"""
|
||||
try:
|
||||
return PersonInfoManager.get_person_id(platform, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
|
||||
"""根据person_id和字段名获取某个值
|
||||
|
||||
Args:
|
||||
person_id: 用户的唯一标识ID
|
||||
field_name: 要获取的字段名,如 "nickname", "impression" 等
|
||||
default: 当字段不存在或获取失败时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 字段值或默认值
|
||||
|
||||
示例:
|
||||
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
||||
impression = await person_api.get_person_value(person_id, "impression")
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
value = await person_info_manager.get_value(person_id, field_name)
|
||||
return value if value is not None else default
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
|
||||
return default
|
||||
|
||||
|
||||
async def get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict:
|
||||
"""批量获取用户信息字段值
|
||||
|
||||
Args:
|
||||
person_id: 用户的唯一标识ID
|
||||
field_names: 要获取的字段名列表
|
||||
default_dict: 默认值字典,键为字段名,值为默认值
|
||||
|
||||
Returns:
|
||||
dict: 字段名到值的映射字典
|
||||
|
||||
示例:
|
||||
values = await person_api.get_person_values(
|
||||
person_id,
|
||||
["nickname", "impression", "know_times"],
|
||||
{"nickname": "未知用户", "know_times": 0}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
values = await person_info_manager.get_values(person_id, field_names)
|
||||
|
||||
# 如果获取成功,返回结果
|
||||
if values:
|
||||
return values
|
||||
|
||||
# 如果获取失败,构建默认值字典
|
||||
result = {}
|
||||
if default_dict:
|
||||
for field in field_names:
|
||||
result[field] = default_dict.get(field, None)
|
||||
else:
|
||||
for field in field_names:
|
||||
result[field] = None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 批量获取用户信息失败: person_id={person_id}, fields={field_names}, error={e}")
|
||||
# 返回默认值字典
|
||||
result = {}
|
||||
if default_dict:
|
||||
for field in field_names:
|
||||
result[field] = default_dict.get(field, None)
|
||||
else:
|
||||
for field in field_names:
|
||||
result[field] = None
|
||||
return result
|
||||
|
||||
|
||||
async def is_person_known(platform: str, user_id: int) -> bool:
|
||||
"""判断是否认识某个用户
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否认识该用户
|
||||
|
||||
示例:
|
||||
known = await person_api.is_person_known("qq", 123456)
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return await person_info_manager.is_person_known(platform, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 检查用户是否已知失败: platform={platform}, user_id={user_id}, error={e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_person_id_by_name(person_name: str) -> str:
|
||||
"""根据用户名获取person_id
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
|
||||
Returns:
|
||||
str: person_id,如果未找到返回空字符串
|
||||
|
||||
示例:
|
||||
person_id = person_api.get_person_id_by_name("张三")
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return person_info_manager.get_person_id_by_person_name(person_name)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
|
||||
return ""
|
||||
@@ -1,234 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统一的插件API聚合模块
|
||||
|
||||
提供所有插件API功能的统一访问入口
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入所有API模块
|
||||
from src.plugin_system.apis.message_api import MessageAPI
|
||||
from src.plugin_system.apis.llm_api import LLMAPI
|
||||
from src.plugin_system.apis.database_api import DatabaseAPI
|
||||
from src.plugin_system.apis.config_api import ConfigAPI
|
||||
from src.plugin_system.apis.utils_api import UtilsAPI
|
||||
from src.plugin_system.apis.stream_api import StreamAPI
|
||||
from src.plugin_system.apis.hearflow_api import HearflowAPI
|
||||
|
||||
logger = get_logger("plugin_api")
|
||||
|
||||
|
||||
class PluginAPI(MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, UtilsAPI, StreamAPI, HearflowAPI):
|
||||
"""
|
||||
插件API聚合类
|
||||
|
||||
集成了所有可供插件使用的API功能,提供统一的访问接口。
|
||||
插件组件可以直接使用此API实例来访问各种功能。
|
||||
|
||||
特性:
|
||||
- 聚合所有API模块的功能
|
||||
- 支持依赖注入和配置
|
||||
- 提供统一的错误处理和日志记录
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream=None,
|
||||
expressor=None,
|
||||
replyer=None,
|
||||
observations=None,
|
||||
log_prefix: str = "[PluginAPI]",
|
||||
plugin_config: dict = None,
|
||||
):
|
||||
"""
|
||||
初始化插件API
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
expressor: 表达器对象
|
||||
replyer: 回复器对象
|
||||
observations: 观察列表
|
||||
log_prefix: 日志前缀
|
||||
plugin_config: 插件配置字典
|
||||
"""
|
||||
# 存储依赖对象
|
||||
self._services = {
|
||||
"chat_stream": chat_stream,
|
||||
"expressor": expressor,
|
||||
"replyer": replyer,
|
||||
"observations": observations or [],
|
||||
}
|
||||
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
# 存储action上下文信息
|
||||
self._action_context = {}
|
||||
|
||||
# 调用所有父类的初始化
|
||||
super().__init__()
|
||||
|
||||
# 存储插件配置
|
||||
self._plugin_config = plugin_config or {}
|
||||
|
||||
def set_chat_stream(self, chat_stream):
|
||||
"""设置聊天流对象"""
|
||||
self._services["chat_stream"] = chat_stream
|
||||
logger.debug(f"{self.log_prefix} 设置聊天流: {getattr(chat_stream, 'stream_id', 'Unknown')}")
|
||||
|
||||
def set_expressor(self, expressor):
|
||||
"""设置表达器对象"""
|
||||
self._services["expressor"] = expressor
|
||||
logger.debug(f"{self.log_prefix} 设置表达器")
|
||||
|
||||
def set_replyer(self, replyer):
|
||||
"""设置回复器对象"""
|
||||
self._services["replyer"] = replyer
|
||||
logger.debug(f"{self.log_prefix} 设置回复器")
|
||||
|
||||
def set_observations(self, observations):
|
||||
"""设置观察列表"""
|
||||
self._services["observations"] = observations or []
|
||||
logger.debug(f"{self.log_prefix} 设置观察列表,数量: {len(observations or [])}")
|
||||
|
||||
def get_service(self, service_name: str):
|
||||
"""获取指定的服务对象"""
|
||||
return self._services.get(service_name)
|
||||
|
||||
def has_service(self, service_name: str) -> bool:
|
||||
"""检查是否有指定的服务对象"""
|
||||
return service_name in self._services and self._services[service_name] is not None
|
||||
|
||||
def set_action_context(self, thinking_id: str = None, shutting_down: bool = False, **kwargs):
|
||||
"""设置action上下文信息"""
|
||||
if thinking_id:
|
||||
self._action_context["thinking_id"] = thinking_id
|
||||
self._action_context["shutting_down"] = shutting_down
|
||||
self._action_context.update(kwargs)
|
||||
|
||||
def get_action_context(self, key: str, default=None):
|
||||
"""获取action上下文信息"""
|
||||
return self._action_context.get(key, default)
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self._plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self._plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
def has_config(self, key: str) -> bool:
|
||||
"""检查是否存在指定的配置项
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
|
||||
Returns:
|
||||
bool: 是否存在该配置项
|
||||
"""
|
||||
if not self._plugin_config:
|
||||
return False
|
||||
|
||||
keys = key.split(".")
|
||||
current = self._plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_all_config(self) -> dict:
|
||||
"""获取所有插件配置
|
||||
|
||||
Returns:
|
||||
dict: 插件配置字典的副本
|
||||
"""
|
||||
return self._plugin_config.copy() if self._plugin_config else {}
|
||||
|
||||
|
||||
# 便捷的工厂函数
|
||||
def create_plugin_api(
|
||||
chat_stream=None,
|
||||
expressor=None,
|
||||
replyer=None,
|
||||
observations=None,
|
||||
log_prefix: str = "[Plugin]",
|
||||
plugin_config: dict = None,
|
||||
) -> PluginAPI:
|
||||
"""
|
||||
创建插件API实例的便捷函数
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
expressor: 表达器对象
|
||||
replyer: 回复器对象
|
||||
observations: 观察列表
|
||||
log_prefix: 日志前缀
|
||||
plugin_config: 插件配置字典
|
||||
|
||||
Returns:
|
||||
PluginAPI: 配置好的插件API实例
|
||||
"""
|
||||
return PluginAPI(
|
||||
chat_stream=chat_stream,
|
||||
expressor=expressor,
|
||||
replyer=replyer,
|
||||
observations=observations,
|
||||
log_prefix=log_prefix,
|
||||
plugin_config=plugin_config,
|
||||
)
|
||||
|
||||
|
||||
def create_command_api(message, log_prefix: str = "[Command]") -> PluginAPI:
|
||||
"""
|
||||
为命令创建插件API实例的便捷函数
|
||||
|
||||
Args:
|
||||
message: 消息对象,应该包含 chat_stream 等信息
|
||||
log_prefix: 日志前缀
|
||||
|
||||
Returns:
|
||||
PluginAPI: 配置好的插件API实例
|
||||
"""
|
||||
chat_stream = getattr(message, "chat_stream", None)
|
||||
|
||||
api = PluginAPI(chat_stream=chat_stream, log_prefix=log_prefix)
|
||||
|
||||
return api
|
||||
|
||||
|
||||
# 导出主要接口
|
||||
__all__ = [
|
||||
"PluginAPI",
|
||||
"create_plugin_api",
|
||||
"create_command_api",
|
||||
# 也可以导出各个API类供单独使用
|
||||
"MessageAPI",
|
||||
"LLMAPI",
|
||||
"DatabaseAPI",
|
||||
"ConfigAPI",
|
||||
"UtilsAPI",
|
||||
"StreamAPI",
|
||||
"HearflowAPI",
|
||||
]
|
||||
445
src/plugin_system/apis/send_api.py
Normal file
445
src/plugin_system/apis/send_api.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
发送API模块
|
||||
|
||||
专门负责发送各种类型的消息,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import send_api
|
||||
await send_api.text_to_group("hello", "123456")
|
||||
await send_api.emoji_to_group(emoji_base64, "123456")
|
||||
await send_api.custom_message("video", video_data, "123456", True)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import difflib
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入依赖
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from maim_message import Seg, UserInfo
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("send_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 内部实现函数(不暴露给外部)
|
||||
# =============================================================================
|
||||
|
||||
async def _send_to_target(
|
||||
message_type: str,
|
||||
content: str,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""向指定目标发送消息的内部实现
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"等
|
||||
content: 消息内容
|
||||
stream_id: 目标流ID
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息的格式,如"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
logger.info(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
|
||||
|
||||
# 查找目标聊天流
|
||||
target_stream = get_chat_manager().get_stream(stream_id)
|
||||
if not target_stream:
|
||||
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
|
||||
return False
|
||||
|
||||
# 创建发送器
|
||||
heart_fc_sender = HeartFCSender()
|
||||
|
||||
# 生成消息ID
|
||||
current_time = time.time()
|
||||
message_id = f"send_api_{int(current_time * 1000)}"
|
||||
|
||||
# 构建机器人用户信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=target_stream.platform,
|
||||
)
|
||||
|
||||
# 创建消息段
|
||||
message_segment = Seg(type=message_type, data=content)
|
||||
|
||||
# 处理回复消息
|
||||
anchor_message = None
|
||||
if reply_to:
|
||||
anchor_message = await _find_reply_message(target_stream, reply_to)
|
||||
|
||||
# 构建发送消息对象
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=target_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=target_stream.user_info,
|
||||
message_segment=message_segment,
|
||||
display_message=display_message,
|
||||
reply=anchor_message,
|
||||
is_head=True,
|
||||
is_emoji=(message_type == "emoji"),
|
||||
thinking_start_time=current_time,
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
bot_message, typing=typing, set_reply=(anchor_message is not None), storage_message=storage_message
|
||||
)
|
||||
|
||||
if sent_msg:
|
||||
logger.info(f"[SendAPI] 成功发送消息到 {stream_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error("[SendAPI] 发送消息失败")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SendAPI] 发送消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
|
||||
"""查找要回复的消息
|
||||
|
||||
Args:
|
||||
target_stream: 目标聊天流
|
||||
reply_to: 回复格式,如"发送者:消息内容"或"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
|
||||
"""
|
||||
try:
|
||||
# 解析reply_to参数
|
||||
if ":" in reply_to:
|
||||
parts = reply_to.split(":", 1)
|
||||
elif ":" in reply_to:
|
||||
parts = reply_to.split(":", 1)
|
||||
else:
|
||||
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
|
||||
return None
|
||||
|
||||
if len(parts) != 2:
|
||||
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
|
||||
return None
|
||||
|
||||
sender = parts[0].strip()
|
||||
text = parts[1].strip()
|
||||
|
||||
# 获取聊天流的最新20条消息
|
||||
reverse_talking_message = get_raw_msg_before_timestamp_with_chat(
|
||||
target_stream.stream_id,
|
||||
time.time(), # 当前时间之前的消息
|
||||
20 # 最新的20条消息
|
||||
)
|
||||
|
||||
# 反转列表,使最新的消息在前面
|
||||
reverse_talking_message = list(reversed(reverse_talking_message))
|
||||
|
||||
find_msg = None
|
||||
for message in reverse_talking_message:
|
||||
user_id = message["user_id"]
|
||||
platform = message["chat_info_platform"]
|
||||
person_id = get_person_info_manager().get_person_id(platform, user_id)
|
||||
person_name = await get_person_info_manager().get_value(person_id, "person_name")
|
||||
if person_name == sender:
|
||||
similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio()
|
||||
if similarity >= 0.9:
|
||||
find_msg = message
|
||||
break
|
||||
|
||||
if not find_msg:
|
||||
logger.info("[SendAPI] 未找到匹配的回复消息")
|
||||
return None
|
||||
|
||||
# 构建MessageRecv对象
|
||||
user_info = {
|
||||
"platform": find_msg.get("user_platform", ""),
|
||||
"user_id": find_msg.get("user_id", ""),
|
||||
"user_nickname": find_msg.get("user_nickname", ""),
|
||||
"user_cardname": find_msg.get("user_cardname", ""),
|
||||
}
|
||||
|
||||
group_info = {}
|
||||
if find_msg.get("chat_info_group_id"):
|
||||
group_info = {
|
||||
"platform": find_msg.get("chat_info_group_platform", ""),
|
||||
"group_id": find_msg.get("chat_info_group_id", ""),
|
||||
"group_name": find_msg.get("chat_info_group_name", ""),
|
||||
}
|
||||
|
||||
format_info = {"content_format": "", "accept_format": ""}
|
||||
template_info = {"template_items": {}}
|
||||
|
||||
message_info = {
|
||||
"platform": target_stream.platform,
|
||||
"message_id": find_msg.get("message_id"),
|
||||
"time": find_msg.get("time"),
|
||||
"group_info": group_info,
|
||||
"user_info": user_info,
|
||||
"additional_config": find_msg.get("additional_config"),
|
||||
"format_info": format_info,
|
||||
"template_info": template_info,
|
||||
}
|
||||
|
||||
message_dict = {
|
||||
"message_info": message_info,
|
||||
"raw_message": find_msg.get("processed_plain_text"),
|
||||
"detailed_plain_text": find_msg.get("processed_plain_text"),
|
||||
"processed_plain_text": find_msg.get("processed_plain_text"),
|
||||
}
|
||||
|
||||
find_rec_msg = MessageRecv(message_dict)
|
||||
find_rec_msg.update_chat_stream(target_stream)
|
||||
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {sender}")
|
||||
return find_rec_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SendAPI] 查找回复消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 公共API函数 - 预定义类型的发送函数
|
||||
# =============================================================================
|
||||
|
||||
async def text_to_group(text: str, group_id: str, platform: str = "qq", typing: bool = False, reply_to: str = "", storage_message: bool = True) -> bool:
|
||||
"""向群聊发送文本消息
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
|
||||
|
||||
|
||||
async def text_to_user(text: str, user_id: str, platform: str = "qq", typing: bool = False, reply_to: str = "", storage_message: bool = True) -> bool:
|
||||
"""向用户发送私聊文本消息
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
|
||||
|
||||
|
||||
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向群聊发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def emoji_to_user(emoji_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向用户发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def image_to_group(image_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向群聊发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def image_to_user(image_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向用户发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False)
|
||||
|
||||
async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向群聊发送命令
|
||||
|
||||
Args:
|
||||
command: 命令
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向用户发送命令
|
||||
|
||||
Args:
|
||||
command: 命令
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 通用发送函数 - 支持任意消息类型
|
||||
# =============================================================================
|
||||
|
||||
async def custom_to_group(
|
||||
message_type: str,
|
||||
content: str,
|
||||
group_id: str,
|
||||
platform: str = "qq",
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True
|
||||
) -> bool:
|
||||
"""向群聊发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
|
||||
content: 消息内容(通常是base64编码或文本)
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
||||
|
||||
|
||||
async def custom_to_user(
|
||||
message_type: str,
|
||||
content: str,
|
||||
user_id: str,
|
||||
platform: str = "qq",
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True
|
||||
) -> bool:
|
||||
"""向用户发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
|
||||
content: 消息内容(通常是base64编码或文本)
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
||||
|
||||
|
||||
async def custom_message(
|
||||
message_type: str,
|
||||
content: str,
|
||||
target_id: str,
|
||||
is_group: bool = True,
|
||||
platform: str = "qq",
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送自定义消息的通用接口
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"、"audio"等
|
||||
content: 消息内容
|
||||
target_id: 目标ID(群ID或用户ID)
|
||||
is_group: 是否为群聊,True为群聊,False为私聊
|
||||
platform: 平台,默认为"qq"
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
|
||||
示例:
|
||||
# 发送视频到群聊
|
||||
await send_api.custom_message("video", video_base64, "123456", True)
|
||||
|
||||
# 发送文件到用户
|
||||
await send_api.custom_message("file", file_base64, "987654", False)
|
||||
|
||||
# 发送音频到群聊并回复特定消息
|
||||
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
|
||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
||||
@@ -1,220 +0,0 @@
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatManager, ChatStream
|
||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||
import asyncio
|
||||
|
||||
logger = get_logger("stream_api")
|
||||
|
||||
|
||||
class StreamAPI:
|
||||
"""聊天流API模块
|
||||
|
||||
提供了获取聊天流、通过群ID查找聊天流等功能
|
||||
"""
|
||||
|
||||
def get_chat_stream_by_group_id(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""通过QQ群ID获取聊天流
|
||||
|
||||
Args:
|
||||
group_id: QQ群ID
|
||||
platform: 平台标识,默认为"qq"
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 找到的聊天流对象,如果未找到则返回None
|
||||
"""
|
||||
try:
|
||||
chat_manager = ChatManager()
|
||||
|
||||
# 遍历所有已加载的聊天流,查找匹配的群ID
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(group_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.info(f"{self.log_prefix} 通过群ID {group_id} 找到聊天流: {stream_id}")
|
||||
return stream
|
||||
|
||||
logger.warning(f"{self.log_prefix} 未找到群ID为 {group_id} 的聊天流")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 通过群ID获取聊天流时出错: {e}")
|
||||
return None
|
||||
|
||||
def get_all_group_chat_streams(self, platform: str = "qq") -> List[ChatStream]:
|
||||
"""获取所有群聊的聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台标识,默认为"qq"
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 所有群聊的聊天流列表
|
||||
"""
|
||||
try:
|
||||
chat_manager = ChatManager()
|
||||
group_streams = []
|
||||
|
||||
for stream in chat_manager.streams.values():
|
||||
if stream.group_info and stream.platform == platform:
|
||||
group_streams.append(stream)
|
||||
|
||||
logger.info(f"{self.log_prefix} 找到 {len(group_streams)} 个群聊聊天流")
|
||||
return group_streams
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取所有群聊聊天流时出错: {e}")
|
||||
return []
|
||||
|
||||
def get_chat_stream_by_user_id(self, user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""通过用户ID获取私聊聊天流
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 平台标识,默认为"qq"
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 找到的私聊聊天流对象,如果未找到则返回None
|
||||
"""
|
||||
try:
|
||||
chat_manager = ChatManager()
|
||||
|
||||
# 遍历所有已加载的聊天流,查找匹配的用户ID(私聊)
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (
|
||||
not stream.group_info # 私聊没有群信息
|
||||
and stream.user_info
|
||||
and str(stream.user_info.user_id) == str(user_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.info(f"{self.log_prefix} 通过用户ID {user_id} 找到私聊聊天流: {stream_id}")
|
||||
return stream
|
||||
|
||||
logger.warning(f"{self.log_prefix} 未找到用户ID为 {user_id} 的私聊聊天流")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 通过用户ID获取私聊聊天流时出错: {e}")
|
||||
return None
|
||||
|
||||
def get_chat_streams_info(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有聊天流的基本信息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 包含聊天流基本信息的字典列表
|
||||
"""
|
||||
try:
|
||||
chat_manager = ChatManager()
|
||||
streams_info = []
|
||||
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
info = {
|
||||
"stream_id": stream_id,
|
||||
"platform": stream.platform,
|
||||
"chat_type": "group" if stream.group_info else "private",
|
||||
"create_time": stream.create_time,
|
||||
"last_active_time": stream.last_active_time,
|
||||
}
|
||||
|
||||
if stream.group_info:
|
||||
info.update({"group_id": stream.group_info.group_id, "group_name": stream.group_info.group_name})
|
||||
|
||||
if stream.user_info:
|
||||
info.update({"user_id": stream.user_info.user_id, "user_nickname": stream.user_info.user_nickname})
|
||||
|
||||
streams_info.append(info)
|
||||
|
||||
logger.info(f"{self.log_prefix} 获取到 {len(streams_info)} 个聊天流信息")
|
||||
return streams_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}")
|
||||
return []
|
||||
|
||||
async def get_chat_stream_by_group_id_async(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""异步通过QQ群ID获取聊天流(包括从数据库搜索)
|
||||
|
||||
Args:
|
||||
group_id: QQ群ID
|
||||
platform: 平台标识,默认为"qq"
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 找到的聊天流对象,如果未找到则返回None
|
||||
"""
|
||||
try:
|
||||
# 首先尝试从内存中查找
|
||||
stream = self.get_chat_stream_by_group_id(group_id, platform)
|
||||
if stream:
|
||||
return stream
|
||||
|
||||
# 如果内存中没有,尝试从数据库加载所有聊天流后再查找
|
||||
chat_manager = ChatManager()
|
||||
await chat_manager.load_all_streams()
|
||||
|
||||
# 再次尝试从内存中查找
|
||||
stream = self.get_chat_stream_by_group_id(group_id, platform)
|
||||
return stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 异步通过群ID获取聊天流时出错: {e}")
|
||||
return None
|
||||
|
||||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||||
"""等待新消息或超时
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒),默认1200秒
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否收到新消息, 空字符串)
|
||||
"""
|
||||
try:
|
||||
# 获取必要的服务对象
|
||||
observations = self.get_service("observations")
|
||||
if not observations:
|
||||
logger.warning(f"{self.log_prefix} 无法获取observations服务,无法等待新消息")
|
||||
return False, ""
|
||||
|
||||
# 获取第一个观察对象(通常是ChattingObservation)
|
||||
observation = observations[0] if observations else None
|
||||
if not observation:
|
||||
logger.warning(f"{self.log_prefix} 无观察对象,无法等待新消息")
|
||||
return False, ""
|
||||
|
||||
# 从action上下文获取thinking_id
|
||||
thinking_id = self.get_action_context("thinking_id")
|
||||
if not thinking_id:
|
||||
logger.warning(f"{self.log_prefix} 无thinking_id,无法等待新消息")
|
||||
return False, ""
|
||||
|
||||
logger.info(f"{self.log_prefix} 开始等待新消息... (超时: {timeout}秒)")
|
||||
|
||||
wait_start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
# 检查关闭标志
|
||||
shutting_down = self.get_action_context("shutting_down", False)
|
||||
if shutting_down:
|
||||
logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
|
||||
return False, ""
|
||||
|
||||
# 检查新消息
|
||||
thinking_id_timestamp = parse_thinking_id_to_timestamp(thinking_id)
|
||||
if await observation.has_new_messages_since(thinking_id_timestamp):
|
||||
logger.info(f"{self.log_prefix} 检测到新消息")
|
||||
return True, ""
|
||||
|
||||
# 检查超时
|
||||
if asyncio.get_event_loop().time() - wait_start_time > timeout:
|
||||
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒)")
|
||||
return False, ""
|
||||
|
||||
# 短暂休眠
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
@@ -1,126 +1,165 @@
|
||||
"""工具类API模块
|
||||
|
||||
提供了各种辅助功能
|
||||
使用方式:
|
||||
from src.plugin_system.apis import utils_api
|
||||
plugin_path = utils_api.get_plugin_path()
|
||||
data = utils_api.read_json_file("data.json")
|
||||
timestamp = utils_api.get_timestamp()
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import inspect
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("utils_api")
|
||||
|
||||
|
||||
class UtilsAPI:
|
||||
"""工具类API模块
|
||||
# =============================================================================
|
||||
# 文件操作API函数
|
||||
# =============================================================================
|
||||
|
||||
提供了各种辅助功能
|
||||
def get_plugin_path(caller_frame=None) -> str:
|
||||
"""获取调用者插件的路径
|
||||
|
||||
Args:
|
||||
caller_frame: 调用者的栈帧,默认为None(自动获取)
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径
|
||||
"""
|
||||
try:
|
||||
if caller_frame is None:
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
|
||||
def get_plugin_path(self) -> str:
|
||||
"""获取当前插件的路径
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径
|
||||
"""
|
||||
import inspect
|
||||
|
||||
plugin_module_path = inspect.getfile(self.__class__)
|
||||
plugin_module_path = inspect.getfile(caller_frame)
|
||||
plugin_dir = os.path.dirname(plugin_module_path)
|
||||
return plugin_dir
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 获取插件路径失败: {e}")
|
||||
return ""
|
||||
|
||||
def read_json_file(self, file_path: str, default: Any = None) -> Any:
|
||||
"""读取JSON文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,可以是相对于插件目录的路径
|
||||
default: 如果文件不存在或读取失败时返回的默认值
|
||||
def read_json_file(file_path: str, default: Any = None) -> Any:
|
||||
"""读取JSON文件
|
||||
|
||||
Returns:
|
||||
Any: JSON数据或默认值
|
||||
"""
|
||||
try:
|
||||
# 如果是相对路径,则相对于插件目录
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.join(self.get_plugin_path(), file_path)
|
||||
Args:
|
||||
file_path: 文件路径,可以是相对于插件目录的路径
|
||||
default: 如果文件不存在或读取失败时返回的默认值
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f"{self.log_prefix} 文件不存在: {file_path}")
|
||||
return default
|
||||
Returns:
|
||||
Any: JSON数据或默认值
|
||||
"""
|
||||
try:
|
||||
# 如果是相对路径,则相对于调用者的插件目录
|
||||
if not os.path.isabs(file_path):
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
plugin_dir = get_plugin_path(caller_frame)
|
||||
file_path = os.path.join(plugin_dir, file_path)
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 读取JSON文件出错: {e}")
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f"[UtilsAPI] 文件不存在: {file_path}")
|
||||
return default
|
||||
|
||||
def write_json_file(self, file_path: str, data: Any, indent: int = 2) -> bool:
|
||||
"""写入JSON文件
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 读取JSON文件出错: {e}")
|
||||
return default
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,可以是相对于插件目录的路径
|
||||
data: 要写入的数据
|
||||
indent: JSON缩进
|
||||
|
||||
Returns:
|
||||
bool: 是否写入成功
|
||||
"""
|
||||
try:
|
||||
# 如果是相对路径,则相对于插件目录
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.join(self.get_plugin_path(), file_path)
|
||||
def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool:
|
||||
"""写入JSON文件
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
Args:
|
||||
file_path: 文件路径,可以是相对于插件目录的路径
|
||||
data: 要写入的数据
|
||||
indent: JSON缩进
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=indent)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 写入JSON文件出错: {e}")
|
||||
return False
|
||||
Returns:
|
||||
bool: 是否写入成功
|
||||
"""
|
||||
try:
|
||||
# 如果是相对路径,则相对于调用者的插件目录
|
||||
if not os.path.isabs(file_path):
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
plugin_dir = get_plugin_path(caller_frame)
|
||||
file_path = os.path.join(plugin_dir, file_path)
|
||||
|
||||
def get_timestamp(self) -> int:
|
||||
"""获取当前时间戳
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
Returns:
|
||||
int: 当前时间戳(秒)
|
||||
"""
|
||||
return int(time.time())
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=indent)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 写入JSON文件出错: {e}")
|
||||
return False
|
||||
|
||||
def format_time(self, timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""格式化时间
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳,如果为None则使用当前时间
|
||||
format_str: 时间格式字符串
|
||||
# =============================================================================
|
||||
# 时间相关API函数
|
||||
# =============================================================================
|
||||
|
||||
Returns:
|
||||
str: 格式化后的时间字符串
|
||||
"""
|
||||
import datetime
|
||||
def get_timestamp() -> int:
|
||||
"""获取当前时间戳
|
||||
|
||||
Returns:
|
||||
int: 当前时间戳(秒)
|
||||
"""
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""格式化时间
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳,如果为None则使用当前时间
|
||||
format_str: 时间格式字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化后的时间字符串
|
||||
"""
|
||||
try:
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 格式化时间失败: {e}")
|
||||
return ""
|
||||
|
||||
def parse_time(self, time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
|
||||
"""解析时间字符串为时间戳
|
||||
|
||||
Args:
|
||||
time_str: 时间字符串
|
||||
format_str: 时间格式字符串
|
||||
def parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
|
||||
"""解析时间字符串为时间戳
|
||||
|
||||
Returns:
|
||||
int: 时间戳(秒)
|
||||
"""
|
||||
import datetime
|
||||
Args:
|
||||
time_str: 时间字符串
|
||||
format_str: 时间格式字符串
|
||||
|
||||
Returns:
|
||||
int: 时间戳(秒)
|
||||
"""
|
||||
try:
|
||||
dt = datetime.datetime.strptime(time_str, format_str)
|
||||
return int(dt.timestamp())
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 解析时间失败: {e}")
|
||||
return 0
|
||||
|
||||
def generate_unique_id(self) -> str:
|
||||
"""生成唯一ID
|
||||
|
||||
Returns:
|
||||
str: 唯一ID
|
||||
"""
|
||||
import uuid
|
||||
# =============================================================================
|
||||
# 其他工具函数
|
||||
# =============================================================================
|
||||
|
||||
return str(uuid.uuid4())
|
||||
def generate_unique_id() -> str:
|
||||
"""生成唯一ID
|
||||
|
||||
Returns:
|
||||
str: 唯一ID
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@@ -1,435 +1,469 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis.plugin_api import PluginAPI
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
|
||||
|
||||
logger = get_logger("base_action")
|
||||
|
||||
|
||||
class BaseAction(ABC):
|
||||
"""Action组件基类
|
||||
|
||||
Action是插件的一种组件类型,用于处理聊天中的动作逻辑
|
||||
|
||||
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
|
||||
- focus_activation_type: 专注模式激活类型
|
||||
- normal_activation_type: 普通模式激活类型
|
||||
- activation_keywords: 激活关键词列表
|
||||
- keyword_case_sensitive: 关键词是否区分大小写
|
||||
- mode_enable: 启用的聊天模式
|
||||
- parallel_action: 是否允许并行执行
|
||||
- random_activation_probability: 随机激活概率
|
||||
- llm_judge_prompt: LLM判断提示词
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: list = None,
|
||||
expressor=None,
|
||||
replyer=None,
|
||||
chat_stream=None,
|
||||
log_prefix: str = "",
|
||||
shutting_down: bool = False,
|
||||
plugin_config: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化Action组件
|
||||
|
||||
Args:
|
||||
action_data: 动作数据
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
observations: 观察列表
|
||||
expressor: 表达器对象
|
||||
replyer: 回复器对象
|
||||
chat_stream: 聊天流对象
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
plugin_config: 插件配置字典
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
self.shutting_down = shutting_down
|
||||
|
||||
# 设置动作基本信息实例属性
|
||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||
self.focus_activation_type: str = self._get_activation_type_value("focus_activation_type", "never")
|
||||
self.normal_activation_type: str = self._get_activation_type_value("normal_activation_type", "never")
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
self.mode_enable: str = self._get_mode_value("mode_enable", "all")
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
self.enable_plugin: bool = True # 默认启用
|
||||
|
||||
# 创建API实例,传递所有服务对象
|
||||
self.api = PluginAPI(
|
||||
chat_stream=chat_stream or kwargs.get("chat_stream"),
|
||||
expressor=expressor or kwargs.get("expressor"),
|
||||
replyer=replyer or kwargs.get("replyer"),
|
||||
observations=observations or kwargs.get("observations", []),
|
||||
log_prefix=log_prefix,
|
||||
plugin_config=plugin_config or kwargs.get("plugin_config"),
|
||||
)
|
||||
|
||||
# 设置API的action上下文
|
||||
self.api.set_action_context(thinking_id=thinking_id, shutting_down=shutting_down)
|
||||
|
||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||
|
||||
def _get_activation_type_value(self, attr_name: str, default: str) -> str:
|
||||
"""获取激活类型的字符串值"""
|
||||
attr = getattr(self.__class__, attr_name, None)
|
||||
if attr is None:
|
||||
return default
|
||||
if hasattr(attr, "value"):
|
||||
return attr.value
|
||||
return str(attr)
|
||||
|
||||
def _get_mode_value(self, attr_name: str, default: str) -> str:
|
||||
"""获取模式的字符串值"""
|
||||
attr = getattr(self.__class__, attr_name, None)
|
||||
if attr is None:
|
||||
return default
|
||||
if hasattr(attr, "value"):
|
||||
return attr.value
|
||||
return str(attr)
|
||||
|
||||
async def send_text(self, content: str) -> bool:
|
||||
"""发送回复消息
|
||||
|
||||
Args:
|
||||
content: 回复内容
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
if not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 没有可用的聊天流发送回复")
|
||||
return False
|
||||
|
||||
if chat_stream.group_info:
|
||||
# 群聊
|
||||
return await self.api.send_text_to_group(
|
||||
text=content, group_id=str(chat_stream.group_info.group_id), platform=chat_stream.platform
|
||||
)
|
||||
else:
|
||||
# 私聊
|
||||
return await self.api.send_text_to_user(
|
||||
text=content, user_id=str(chat_stream.user_info.user_id), platform=chat_stream.platform
|
||||
)
|
||||
|
||||
async def send_type(self, type: str, text: str, typing: bool = False) -> bool:
|
||||
"""发送回复消息
|
||||
|
||||
Args:
|
||||
text: 回复内容
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
if not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 没有可用的聊天流发送回复")
|
||||
return False
|
||||
|
||||
if chat_stream.group_info:
|
||||
# 群聊
|
||||
return await self.api.send_message_to_target(
|
||||
message_type=type,
|
||||
content=text,
|
||||
platform=chat_stream.platform,
|
||||
target_id=str(chat_stream.group_info.group_id),
|
||||
is_group=True,
|
||||
typing=typing,
|
||||
)
|
||||
else:
|
||||
# 私聊
|
||||
return await self.api.send_message_to_target(
|
||||
message_type=type,
|
||||
content=text,
|
||||
platform=chat_stream.platform,
|
||||
target_id=str(chat_stream.user_info.user_id),
|
||||
is_group=False,
|
||||
typing=typing,
|
||||
)
|
||||
|
||||
async def send_command(self, command_name: str, args: dict = None, display_message: str = None) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
使用和send_text相同的方式通过MessageAPI发送命令
|
||||
|
||||
Args:
|
||||
command_name: 命令名称
|
||||
args: 命令参数
|
||||
display_message: 显示消息
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
# 构造命令数据
|
||||
command_data = {"name": command_name, "args": args or {}}
|
||||
|
||||
# 使用send_message_to_target方法发送命令
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
if not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 没有可用的聊天流发送命令")
|
||||
return False
|
||||
|
||||
if chat_stream.group_info:
|
||||
# 群聊
|
||||
success = await self.api.send_message_to_target(
|
||||
message_type="command",
|
||||
content=command_data,
|
||||
platform=chat_stream.platform,
|
||||
target_id=str(chat_stream.group_info.group_id),
|
||||
is_group=True,
|
||||
display_message=display_message or f"执行命令: {command_name}",
|
||||
)
|
||||
else:
|
||||
# 私聊
|
||||
success = await self.api.send_message_to_target(
|
||||
message_type="command",
|
||||
content=command_data,
|
||||
platform=chat_stream.platform,
|
||||
target_id=str(chat_stream.user_info.user_id),
|
||||
is_group=False,
|
||||
display_message=display_message or f"执行命令: {command_name}",
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||
return False
|
||||
|
||||
async def send_message_by_expressor(self, text: str, target: str = "") -> bool:
|
||||
"""通过expressor发送文本消息的Action专用方法
|
||||
|
||||
Args:
|
||||
text: 要发送的消息文本
|
||||
target: 目标消息(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
|
||||
# 获取服务
|
||||
expressor = self.api.get_service("expressor")
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
observations = self.api.get_service("observations") or []
|
||||
|
||||
if not expressor or not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法通过expressor发送消息:缺少必要的服务")
|
||||
return False
|
||||
|
||||
# 构造动作数据
|
||||
reply_data = {"text": text, "target": target, "emojis": []}
|
||||
|
||||
# 查找 ChattingObservation 实例
|
||||
chatting_observation = None
|
||||
for obs in observations:
|
||||
if isinstance(obs, ChattingObservation):
|
||||
chatting_observation = obs
|
||||
break
|
||||
|
||||
if not chatting_observation:
|
||||
logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message = chatting_observation.search_message_by_text(target)
|
||||
if not anchor_message:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
|
||||
# 使用Action上下文信息发送消息
|
||||
success, _ = await expressor.deal_reply(
|
||||
cycle_timers=self.cycle_timers,
|
||||
action_data=reply_data,
|
||||
anchor_message=anchor_message,
|
||||
reasoning=self.reasoning,
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功通过expressor发送消息")
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 通过expressor发送消息失败")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 通过expressor发送消息时出错: {e}")
|
||||
return False
|
||||
|
||||
async def send_message_by_replyer(self, target: str = "", extra_info_block: str = None) -> bool:
|
||||
"""通过replyer发送消息的Action专用方法
|
||||
|
||||
Args:
|
||||
target: 目标消息(可选)
|
||||
extra_info_block: 额外信息块(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
|
||||
# 获取服务
|
||||
replyer = self.api.get_service("replyer")
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
observations = self.api.get_service("observations") or []
|
||||
|
||||
if not replyer or not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法通过replyer发送消息:缺少必要的服务")
|
||||
return False
|
||||
|
||||
# 构造动作数据
|
||||
reply_data = {"target": target, "extra_info_block": extra_info_block}
|
||||
|
||||
# 查找 ChattingObservation 实例
|
||||
chatting_observation = None
|
||||
for obs in observations:
|
||||
if isinstance(obs, ChattingObservation):
|
||||
chatting_observation = obs
|
||||
break
|
||||
|
||||
if not chatting_observation:
|
||||
logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message = chatting_observation.search_message_by_text(target)
|
||||
if not anchor_message:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
|
||||
# 使用Action上下文信息发送消息
|
||||
success, _ = await replyer.deal_reply(
|
||||
cycle_timers=self.cycle_timers,
|
||||
action_data=reply_data,
|
||||
anchor_message=anchor_message,
|
||||
reasoning=self.reasoning,
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功通过replyer发送消息")
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 通过replyer发送消息失败")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 通过replyer发送消息时出错: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_action_info(cls) -> "ActionInfo":
|
||||
"""从类属性生成ActionInfo
|
||||
|
||||
所有信息都从类属性中读取,确保一致性和完整性。
|
||||
Action类必须定义所有必要的类属性。
|
||||
|
||||
Returns:
|
||||
ActionInfo: 生成的Action信息对象
|
||||
"""
|
||||
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
|
||||
|
||||
# 从类属性读取描述,如果没有定义则使用文档字符串的第一行
|
||||
description = getattr(cls, "action_description", None)
|
||||
if description is None:
|
||||
description = "Action动作"
|
||||
|
||||
# 安全获取激活类型值
|
||||
def get_enum_value(attr_name, default):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
if attr is None:
|
||||
# 如果没有定义,返回默认的枚举值
|
||||
return getattr(ActionActivationType, default.upper(), ActionActivationType.NEVER)
|
||||
return attr
|
||||
|
||||
def get_mode_value(attr_name, default):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
if attr is None:
|
||||
return getattr(ChatMode, default.upper(), ChatMode.ALL)
|
||||
return attr
|
||||
|
||||
return ActionInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=description,
|
||||
focus_activation_type=get_enum_value("focus_activation_type", "never"),
|
||||
normal_activation_type=get_enum_value("normal_activation_type", "never"),
|
||||
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
|
||||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
||||
mode_enable=get_mode_value("mode_enable", "all"),
|
||||
parallel_action=getattr(cls, "parallel_action", True),
|
||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
||||
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
|
||||
# 使用正确的字段名
|
||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||||
action_require=getattr(cls, "action_require", []).copy(),
|
||||
associated_types=getattr(cls, "associated_types", []).copy(),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行Action的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""兼容旧系统的handle_action接口,委托给execute方法
|
||||
|
||||
为了保持向后兼容性,旧系统的代码可能会调用handle_action方法。
|
||||
此方法将调用委托给新的execute方法。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
return await self.execute()
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
|
||||
from src.plugin_system.apis import send_api, database_api,message_api
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
logger = get_logger("base_action")
|
||||
|
||||
|
||||
class BaseAction(ABC):
|
||||
"""Action组件基类
|
||||
|
||||
Action是插件的一种组件类型,用于处理聊天中的动作逻辑
|
||||
|
||||
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
|
||||
- focus_activation_type: 专注模式激活类型
|
||||
- normal_activation_type: 普通模式激活类型
|
||||
- activation_keywords: 激活关键词列表
|
||||
- keyword_case_sensitive: 关键词是否区分大小写
|
||||
- mode_enable: 启用的聊天模式
|
||||
- parallel_action: 是否允许并行执行
|
||||
- random_activation_probability: 随机激活概率
|
||||
- llm_judge_prompt: LLM判断提示词
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream=None,
|
||||
log_prefix: str = "",
|
||||
shutting_down: bool = False,
|
||||
plugin_config: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化Action组件
|
||||
|
||||
Args:
|
||||
action_data: 动作数据
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
observations: 观察列表
|
||||
expressor: 表达器对象
|
||||
replyer: 回复器对象
|
||||
chat_stream: 聊天流对象
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
plugin_config: 插件配置字典
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
self.shutting_down = shutting_down
|
||||
|
||||
# 保存插件配置
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
# 设置动作基本信息实例属性
|
||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||
self.focus_activation_type: str = self._get_activation_type_value("focus_activation_type", "always")
|
||||
self.normal_activation_type: str = self._get_activation_type_value("normal_activation_type", "always")
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
self.mode_enable: str = self._get_mode_value("mode_enable", "all")
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
# 获取聊天流对象
|
||||
self.chat_stream = chat_stream or kwargs.get("chat_stream")
|
||||
|
||||
self.chat_id = self.chat_stream.stream_id
|
||||
# 初始化基础信息(带类型注解)
|
||||
self.is_group: bool = False
|
||||
self.platform: Optional[str] = None
|
||||
self.group_id: Optional[str] = None
|
||||
self.user_id: Optional[str] = None
|
||||
self.target_id: Optional[str] = None
|
||||
self.group_name: Optional[str] = None
|
||||
self.user_nickname: Optional[str] = None
|
||||
|
||||
# 如果有聊天流,提取所有信息
|
||||
if self.chat_stream:
|
||||
self.platform = getattr(self.chat_stream, 'platform', None)
|
||||
|
||||
# 获取群聊信息
|
||||
# print(self.chat_stream)
|
||||
# print(self.chat_stream.group_info)
|
||||
if self.chat_stream.group_info:
|
||||
self.is_group = True
|
||||
self.group_id = str(self.chat_stream.group_info.group_id)
|
||||
self.group_name = getattr(self.chat_stream.group_info, 'group_name', None)
|
||||
else:
|
||||
self.is_group = False
|
||||
self.user_id = str(self.chat_stream.user_info.user_id)
|
||||
self.user_nickname = getattr(self.chat_stream.user_info, 'user_nickname', None)
|
||||
|
||||
# 设置目标ID(群聊用群ID,私聊用户ID)
|
||||
self.target_id = self.group_id if self.is_group else self.user_id
|
||||
|
||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||
logger.debug(f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}")
|
||||
|
||||
def _get_activation_type_value(self, attr_name: str, default: str) -> str:
|
||||
"""获取激活类型的字符串值"""
|
||||
attr = getattr(self.__class__, attr_name, None)
|
||||
if attr is None:
|
||||
return default
|
||||
if hasattr(attr, "value"):
|
||||
return attr.value
|
||||
return str(attr)
|
||||
|
||||
def _get_mode_value(self, attr_name: str, default: str) -> str:
|
||||
"""获取模式的字符串值"""
|
||||
attr = getattr(self.__class__, attr_name, None)
|
||||
if attr is None:
|
||||
return default
|
||||
if hasattr(attr, "value"):
|
||||
return attr.value
|
||||
return str(attr)
|
||||
|
||||
|
||||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||||
"""等待新消息或超时
|
||||
|
||||
在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。
|
||||
使用message_api检查self.chat_id对应的聊天中是否有新消息。
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒),默认1200秒
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否收到新消息, 空字符串)
|
||||
"""
|
||||
try:
|
||||
# 获取循环开始时间,如果没有则使用当前时间
|
||||
loop_start_time = self.action_data.get("loop_start_time", time.time())
|
||||
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
|
||||
|
||||
# 确保有有效的chat_id
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
|
||||
return False, "没有有效的chat_id"
|
||||
|
||||
wait_start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
# 检查关闭标志
|
||||
# shutting_down = self.get_action_context("shutting_down", False)
|
||||
# if shutting_down:
|
||||
# logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
|
||||
# return False, ""
|
||||
|
||||
# 检查新消息
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_id,
|
||||
start_time=loop_start_time,
|
||||
end_time=current_time
|
||||
)
|
||||
|
||||
if new_message_count > 0:
|
||||
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息,聊天ID: {self.chat_id}")
|
||||
return True, ""
|
||||
|
||||
# 检查超时
|
||||
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
|
||||
if elapsed_time > timeout:
|
||||
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒),聊天ID: {self.chat_id}")
|
||||
return False, ""
|
||||
|
||||
# 每30秒记录一次等待状态
|
||||
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
|
||||
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
|
||||
|
||||
# 短暂休眠
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
|
||||
async def send_text(self, content: str, reply_to: str = "") -> bool:
|
||||
"""发送文本消息
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.target_id or not self.platform:
|
||||
logger.error(f"{self.log_prefix} 缺少发送消息所需的信息")
|
||||
return False
|
||||
|
||||
if self.is_group:
|
||||
return await send_api.text_to_group(
|
||||
text=content, group_id=self.target_id, platform=self.platform, reply_to=reply_to
|
||||
)
|
||||
else:
|
||||
return await send_api.text_to_user(
|
||||
text=content, user_id=self.target_id, platform=self.platform, reply_to=reply_to
|
||||
)
|
||||
|
||||
async def send_emoji(self, emoji_base64: str) -> bool:
|
||||
"""发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 导入send_api
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
if not self.target_id or not self.platform:
|
||||
logger.error(f"{self.log_prefix} 缺少发送消息所需的信息")
|
||||
return False
|
||||
|
||||
if self.is_group:
|
||||
return await send_api.emoji_to_group(emoji_base64, self.target_id, self.platform)
|
||||
else:
|
||||
return await send_api.emoji_to_user(emoji_base64, self.target_id, self.platform)
|
||||
|
||||
async def send_image(self, image_base64: str) -> bool:
|
||||
"""发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 导入send_api
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
if not self.target_id or not self.platform:
|
||||
logger.error(f"{self.log_prefix} 缺少发送消息所需的信息")
|
||||
return False
|
||||
|
||||
if self.is_group:
|
||||
return await send_api.image_to_group(image_base64, self.target_id, self.platform)
|
||||
else:
|
||||
return await send_api.image_to_user(image_base64, self.target_id, self.platform)
|
||||
|
||||
async def send_custom(self, message_type: str, content: str, typing: bool = False) -> bool:
|
||||
"""发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"video"、"file"、"audio"等
|
||||
content: 消息内容
|
||||
typing: 是否显示正在输入
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 导入send_api
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
if not self.target_id or not self.platform:
|
||||
logger.error(f"{self.log_prefix} 缺少发送消息所需的信息")
|
||||
return False
|
||||
|
||||
return await send_api.custom_message(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
target_id=self.target_id,
|
||||
is_group=self.is_group,
|
||||
platform=self.platform,
|
||||
typing=typing
|
||||
)
|
||||
|
||||
async def store_action_info(
|
||||
self,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
) -> None:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
action_build_into_prompt: 是否构建到提示中
|
||||
action_prompt_display: 显示的action提示信息
|
||||
action_done: action是否完成
|
||||
"""
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=action_build_into_prompt,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=action_done,
|
||||
thinking_id=self.thinking_id,
|
||||
action_data=self.action_data,
|
||||
action_name=self.action_name,
|
||||
)
|
||||
|
||||
async def send_command(self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
使用和send_text相同的方式通过MessageAPI发送命令
|
||||
|
||||
Args:
|
||||
command_name: 命令名称
|
||||
args: 命令参数
|
||||
display_message: 显示消息
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
# 构造命令数据
|
||||
command_data = {"name": command_name, "args": args or {}}
|
||||
|
||||
if self.is_group:
|
||||
# 群聊
|
||||
success = await send_api.command_to_group(
|
||||
command=command_data,
|
||||
group_id=str(self.group_id),
|
||||
platform=self.platform,
|
||||
storage_message=storage_message
|
||||
)
|
||||
else:
|
||||
# 私聊
|
||||
success = await send_api.command_to_user(
|
||||
command=command_data,
|
||||
user_id=str(self.user_id),
|
||||
platform=self.platform,
|
||||
storage_message=storage_message
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_action_info(cls) -> "ActionInfo":
|
||||
"""从类属性生成ActionInfo
|
||||
|
||||
所有信息都从类属性中读取,确保一致性和完整性。
|
||||
Action类必须定义所有必要的类属性。
|
||||
|
||||
Returns:
|
||||
ActionInfo: 生成的Action信息对象
|
||||
"""
|
||||
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
|
||||
|
||||
# 从类属性读取描述,如果没有定义则使用文档字符串的第一行
|
||||
description = getattr(cls, "action_description", None)
|
||||
if description is None:
|
||||
description = "Action动作"
|
||||
|
||||
# 安全获取激活类型值
|
||||
def get_enum_value(attr_name, default):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
if attr is None:
|
||||
# 如果没有定义,返回默认的枚举值
|
||||
return getattr(ActionActivationType, default.upper(), ActionActivationType.NEVER)
|
||||
return attr
|
||||
|
||||
def get_mode_value(attr_name, default):
|
||||
attr = getattr(cls, attr_name, None)
|
||||
if attr is None:
|
||||
return getattr(ChatMode, default.upper(), ChatMode.ALL)
|
||||
return attr
|
||||
|
||||
return ActionInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=description,
|
||||
focus_activation_type=get_enum_value("focus_activation_type", "always"),
|
||||
normal_activation_type=get_enum_value("normal_activation_type", "always"),
|
||||
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
|
||||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
||||
mode_enable=get_mode_value("mode_enable", "all"),
|
||||
parallel_action=getattr(cls, "parallel_action", True),
|
||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
||||
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
|
||||
# 使用正确的字段名
|
||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||||
action_require=getattr(cls, "action_require", []).copy(),
|
||||
associated_types=getattr(cls, "associated_types", []).copy(),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行Action的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""兼容旧系统的handle_action接口,委托给execute方法
|
||||
|
||||
为了保持向后兼容性,旧系统的代码可能会调用handle_action方法。
|
||||
此方法将调用委托给新的execute方法。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
return await self.execute()
|
||||
|
||||
def get_action_context(self, key: str, default=None):
|
||||
"""获取action上下文信息
|
||||
|
||||
Args:
|
||||
key: 上下文键名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 上下文值或默认值
|
||||
"""
|
||||
return self.api.get_action_context(key, default)
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple, Optional, List
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis.plugin_api import PluginAPI
|
||||
from src.plugin_system.base.component_types import CommandInfo, ComponentType
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger("base_command")
|
||||
|
||||
@@ -20,6 +20,9 @@ class BaseCommand(ABC):
|
||||
- intercept_message: 是否拦截消息处理(默认True拦截,False继续传递)
|
||||
"""
|
||||
|
||||
command_name: str = ""
|
||||
command_description: str = ""
|
||||
|
||||
# 默认命令设置(子类可以覆盖)
|
||||
command_pattern: str = ""
|
||||
command_help: str = ""
|
||||
@@ -35,9 +38,7 @@ class BaseCommand(ABC):
|
||||
"""
|
||||
self.message = message
|
||||
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
|
||||
|
||||
# 创建API实例
|
||||
self.api = PluginAPI(chat_stream=message.chat_stream, log_prefix="[Command]", plugin_config=plugin_config)
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
self.log_prefix = "[Command]"
|
||||
|
||||
@@ -60,6 +61,31 @@ class BaseCommand(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
async def send_text(self, content: str) -> None:
|
||||
"""发送回复消息
|
||||
|
||||
@@ -71,13 +97,19 @@ class BaseCommand(ABC):
|
||||
|
||||
if chat_stream.group_info:
|
||||
# 群聊
|
||||
await self.api.send_text_to_group(
|
||||
text=content, group_id=str(chat_stream.group_info.group_id), platform=chat_stream.platform
|
||||
|
||||
await send_api.text_to_group(
|
||||
text=content,
|
||||
group_id=str(chat_stream.group_info.group_id),
|
||||
platform=chat_stream.platform
|
||||
)
|
||||
else:
|
||||
# 私聊
|
||||
await self.api.send_text_to_user(
|
||||
text=content, user_id=str(chat_stream.user_info.user_id), platform=chat_stream.platform
|
||||
|
||||
await send_api.text_to_user(
|
||||
text=content,
|
||||
user_id=str(chat_stream.user_info.user_id),
|
||||
platform=chat_stream.platform
|
||||
)
|
||||
|
||||
async def send_type(
|
||||
@@ -98,31 +130,30 @@ class BaseCommand(ABC):
|
||||
|
||||
if chat_stream.group_info:
|
||||
# 群聊
|
||||
return await self.api.send_message_to_target(
|
||||
from src.plugin_system.apis import send_api
|
||||
return await send_api.custom_message(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
platform=chat_stream.platform,
|
||||
target_id=str(chat_stream.group_info.group_id),
|
||||
is_group=True,
|
||||
display_message=display_message,
|
||||
platform=chat_stream.platform,
|
||||
typing=typing,
|
||||
)
|
||||
else:
|
||||
# 私聊
|
||||
return await self.api.send_message_to_target(
|
||||
from src.plugin_system.apis import send_api
|
||||
return await send_api.custom_message(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
platform=chat_stream.platform,
|
||||
target_id=str(chat_stream.user_info.user_id),
|
||||
is_group=False,
|
||||
display_message=display_message,
|
||||
platform=chat_stream.platform,
|
||||
typing=typing,
|
||||
)
|
||||
|
||||
async def send_command(self, command_name: str, args: dict = None, display_message: str = None) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
使用和send_text相同的方式通过MessageAPI发送命令
|
||||
|
||||
Args:
|
||||
command_name: 命令名称
|
||||
args: 命令参数
|
||||
@@ -135,29 +166,28 @@ class BaseCommand(ABC):
|
||||
# 构造命令数据
|
||||
command_data = {"name": command_name, "args": args or {}}
|
||||
|
||||
# 使用send_message_to_target方法发送命令
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
command_content = command_data
|
||||
|
||||
if chat_stream.group_info:
|
||||
# 群聊
|
||||
success = await self.api.send_message_to_target(
|
||||
from src.plugin_system.apis import send_api
|
||||
success = await send_api.custom_message(
|
||||
message_type="command",
|
||||
content=command_content,
|
||||
platform=chat_stream.platform,
|
||||
content=command_data,
|
||||
target_id=str(chat_stream.group_info.group_id),
|
||||
is_group=True,
|
||||
display_message=display_message or f"执行命令: {command_name}",
|
||||
platform=chat_stream.platform,
|
||||
)
|
||||
else:
|
||||
# 私聊
|
||||
success = await self.api.send_message_to_target(
|
||||
from src.plugin_system.apis import send_api
|
||||
success = await send_api.custom_message(
|
||||
message_type="command",
|
||||
content=command_content,
|
||||
platform=chat_stream.platform,
|
||||
content=command_data,
|
||||
target_id=str(chat_stream.user_info.user_id),
|
||||
is_group=False,
|
||||
display_message=display_message or f"执行命令: {command_name}",
|
||||
platform=chat_stream.platform,
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -172,7 +202,7 @@ class BaseCommand(ABC):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_command_info(cls, name: str = None, description: str = None) -> "CommandInfo":
|
||||
def get_command_info(cls) -> "CommandInfo":
|
||||
"""从类属性生成CommandInfo
|
||||
|
||||
Args:
|
||||
@@ -183,19 +213,10 @@ class BaseCommand(ABC):
|
||||
CommandInfo: 生成的Command信息对象
|
||||
"""
|
||||
|
||||
# 优先使用类属性,然后自动生成
|
||||
if name is None:
|
||||
name = getattr(cls, "command_name", cls.__name__.lower().replace("command", ""))
|
||||
if description is None:
|
||||
description = getattr(cls, "command_description", None)
|
||||
if description is None:
|
||||
description = cls.__doc__ or f"{cls.__name__} Command组件"
|
||||
description = description.strip().split("\n")[0] # 取第一行作为描述
|
||||
|
||||
return CommandInfo(
|
||||
name=name,
|
||||
name=cls.command_name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
description=description,
|
||||
description=cls.command_description,
|
||||
command_pattern=cls.command_pattern,
|
||||
command_help=cls.command_help,
|
||||
command_examples=cls.command_examples.copy() if cls.command_examples else [],
|
||||
|
||||
@@ -54,23 +54,43 @@ class ComponentRegistry:
|
||||
"""
|
||||
component_name = component_info.name
|
||||
component_type = component_info.component_type
|
||||
plugin_name = getattr(component_info, 'plugin_name', 'unknown')
|
||||
|
||||
if component_name in self._components:
|
||||
logger.warning(f"组件 {component_name} 已存在,跳过注册")
|
||||
# 🔥 系统级别自动区分:为不同类型的组件添加命名空间前缀
|
||||
if component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}"
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}"
|
||||
else:
|
||||
# 未来扩展的组件类型
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
|
||||
# 检查命名空间化的名称是否冲突
|
||||
if namespaced_name in self._components:
|
||||
existing_info = self._components[namespaced_name]
|
||||
existing_plugin = getattr(existing_info, 'plugin_name', 'unknown')
|
||||
|
||||
logger.warning(
|
||||
f"组件冲突: {component_type.value}组件 '{component_name}' "
|
||||
f"已被插件 '{existing_plugin}' 注册,跳过插件 '{plugin_name}' 的注册"
|
||||
)
|
||||
return False
|
||||
|
||||
# 注册到通用注册表
|
||||
self._components[component_name] = component_info
|
||||
self._components_by_type[component_type][component_name] = component_info
|
||||
self._component_classes[component_name] = component_class
|
||||
# 注册到通用注册表(使用命名空间化的名称)
|
||||
self._components[namespaced_name] = component_info
|
||||
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
|
||||
self._component_classes[namespaced_name] = component_class
|
||||
|
||||
# 根据组件类型进行特定注册
|
||||
# 根据组件类型进行特定注册(使用原始名称)
|
||||
if component_type == ComponentType.ACTION:
|
||||
self._register_action_component(component_info, component_class)
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
self._register_command_component(component_info, component_class)
|
||||
|
||||
logger.debug(f"已注册{component_type.value}组件: {component_name} ({component_class.__name__})")
|
||||
logger.debug(
|
||||
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
|
||||
f"({component_class.__name__}) [插件: {plugin_name}]"
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type):
|
||||
@@ -94,13 +114,103 @@ class ComponentRegistry:
|
||||
|
||||
# === 组件查询方法 ===
|
||||
|
||||
def get_component_info(self, component_name: str) -> Optional[ComponentInfo]:
|
||||
"""获取组件信息"""
|
||||
return self._components.get(component_name)
|
||||
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]:
|
||||
"""获取组件信息,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
component_name: 组件名称,可以是原始名称或命名空间化的名称
|
||||
component_type: 组件类型,如果提供则优先在该类型中查找
|
||||
|
||||
Returns:
|
||||
Optional[ComponentInfo]: 组件信息或None
|
||||
"""
|
||||
# 1. 如果已经是命名空间化的名称,直接查找
|
||||
if '.' in component_name:
|
||||
return self._components.get(component_name)
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
if component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}"
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}"
|
||||
else:
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
|
||||
return self._components.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
for namespace_prefix in ["action", "command"]:
|
||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
||||
component_info = self._components.get(namespaced_name)
|
||||
if component_info:
|
||||
candidates.append((namespace_prefix, namespaced_name, component_info))
|
||||
|
||||
if len(candidates) == 1:
|
||||
# 只有一个匹配,直接返回
|
||||
return candidates[0][2]
|
||||
elif len(candidates) > 1:
|
||||
# 多个匹配,记录警告并返回第一个
|
||||
namespaces = [ns for ns, _, _ in candidates]
|
||||
logger.warning(
|
||||
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},"
|
||||
f"使用第一个匹配项: {candidates[0][1]}"
|
||||
)
|
||||
return candidates[0][2]
|
||||
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_component_class(self, component_name: str) -> Optional[Type]:
|
||||
"""获取组件类"""
|
||||
return self._component_classes.get(component_name)
|
||||
def get_component_class(self, component_name: str, component_type: ComponentType = None) -> Optional[Type]:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
component_name: 组件名称,可以是原始名称或命名空间化的名称
|
||||
component_type: 组件类型,如果提供则优先在该类型中查找
|
||||
|
||||
Returns:
|
||||
Optional[Type]: 组件类或None
|
||||
"""
|
||||
# 1. 如果已经是命名空间化的名称,直接查找
|
||||
if '.' in component_name:
|
||||
return self._component_classes.get(component_name)
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
if component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}"
|
||||
elif component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}"
|
||||
else:
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
|
||||
return self._component_classes.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
for namespace_prefix in ["action", "command"]:
|
||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
||||
component_class = self._component_classes.get(namespaced_name)
|
||||
if component_class:
|
||||
candidates.append((namespace_prefix, namespaced_name, component_class))
|
||||
|
||||
if len(candidates) == 1:
|
||||
# 只有一个匹配,直接返回
|
||||
namespace, full_name, cls = candidates[0]
|
||||
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
|
||||
return cls
|
||||
elif len(candidates) > 1:
|
||||
# 多个匹配,记录警告并返回第一个
|
||||
namespaces = [ns for ns, _, _ in candidates]
|
||||
logger.warning(
|
||||
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},"
|
||||
f"使用第一个匹配项: {candidates[0][1]}"
|
||||
)
|
||||
return candidates[0][2]
|
||||
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有组件"""
|
||||
@@ -123,7 +233,7 @@ class ComponentRegistry:
|
||||
|
||||
def get_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
"""获取Action信息"""
|
||||
info = self.get_component_info(action_name)
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
@@ -138,7 +248,7 @@ class ComponentRegistry:
|
||||
|
||||
def get_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
||||
"""获取Command信息"""
|
||||
info = self.get_component_info(command_name)
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[tuple[Type, dict, bool, str]]:
|
||||
@@ -150,7 +260,9 @@ class ComponentRegistry:
|
||||
Returns:
|
||||
Optional[tuple[Type, dict, bool, str]]: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
||||
"""
|
||||
|
||||
for pattern, command_class in self._command_patterns.items():
|
||||
|
||||
match = pattern.match(text)
|
||||
if match:
|
||||
command_name = None
|
||||
@@ -159,17 +271,18 @@ class ComponentRegistry:
|
||||
if cls == command_class:
|
||||
command_name = name
|
||||
break
|
||||
|
||||
|
||||
# 检查命令是否启用
|
||||
if command_name:
|
||||
command_info = self.get_command_info(command_name)
|
||||
if command_info and command_info.enabled:
|
||||
return (
|
||||
command_class,
|
||||
match.groupdict(),
|
||||
command_info.intercept_message,
|
||||
command_info.plugin_name,
|
||||
)
|
||||
if command_info:
|
||||
if command_info.enabled:
|
||||
return (
|
||||
command_class,
|
||||
match.groupdict(),
|
||||
command_info.intercept_message,
|
||||
command_info.plugin_name,
|
||||
)
|
||||
return None
|
||||
|
||||
# === 插件管理方法 ===
|
||||
@@ -227,26 +340,51 @@ class ComponentRegistry:
|
||||
|
||||
# === 状态管理方法 ===
|
||||
|
||||
def enable_component(self, component_name: str) -> bool:
|
||||
"""启用组件"""
|
||||
if component_name in self._components:
|
||||
self._components[component_name].enabled = True
|
||||
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||
"""启用组件,支持命名空间解析"""
|
||||
# 首先尝试找到正确的命名空间化名称
|
||||
component_info = self.get_component_info(component_name, component_type)
|
||||
if not component_info:
|
||||
return False
|
||||
|
||||
# 根据组件类型构造正确的命名空间化名称
|
||||
if component_info.component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}" if '.' not in component_name else component_name
|
||||
elif component_info.component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}" if '.' not in component_name else component_name
|
||||
else:
|
||||
namespaced_name = f"{component_info.component_type.value}.{component_name}" if '.' not in component_name else component_name
|
||||
|
||||
if namespaced_name in self._components:
|
||||
self._components[namespaced_name].enabled = True
|
||||
# 如果是Action,更新默认动作集
|
||||
component_info = self._components[component_name]
|
||||
if isinstance(component_info, ActionInfo):
|
||||
self._default_actions[component_name] = component_info.description
|
||||
logger.debug(f"已启用组件: {component_name}")
|
||||
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable_component(self, component_name: str) -> bool:
|
||||
"""禁用组件"""
|
||||
if component_name in self._components:
|
||||
self._components[component_name].enabled = False
|
||||
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||
"""禁用组件,支持命名空间解析"""
|
||||
# 首先尝试找到正确的命名空间化名称
|
||||
component_info = self.get_component_info(component_name, component_type)
|
||||
if not component_info:
|
||||
return False
|
||||
|
||||
# 根据组件类型构造正确的命名空间化名称
|
||||
if component_info.component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}" if '.' not in component_name else component_name
|
||||
elif component_info.component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}" if '.' not in component_name else component_name
|
||||
else:
|
||||
namespaced_name = f"{component_info.component_type.value}.{component_name}" if '.' not in component_name else component_name
|
||||
|
||||
if namespaced_name in self._components:
|
||||
self._components[namespaced_name].enabled = False
|
||||
# 如果是Action,从默认动作集中移除
|
||||
if component_name in self._default_actions:
|
||||
del self._default_actions[component_name]
|
||||
logger.debug(f"已禁用组件: {component_name}")
|
||||
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# 核心动作插件配置文件
|
||||
|
||||
[plugin]
|
||||
name = "core_actions"
|
||||
description = "系统核心动作插件"
|
||||
version = "0.2"
|
||||
author = "built-in"
|
||||
enabled = true
|
||||
|
||||
[no_reply]
|
||||
# 等待新消息的超时时间(秒)
|
||||
waiting_timeout = 1200
|
||||
|
||||
[emoji]
|
||||
# 表情动作配置
|
||||
enabled = true
|
||||
# 在Normal模式下的随机激活概率
|
||||
random_probability = 0.1
|
||||
# 是否启用智能表情选择
|
||||
smart_selection = true
|
||||
|
||||
# LLM判断相关配置
|
||||
[emoji.llm_judge]
|
||||
# 是否启用LLM智能判断
|
||||
enabled = true
|
||||
# 自定义判断提示词(可选)
|
||||
custom_prompt = ""
|
||||
@@ -5,18 +5,18 @@
|
||||
这是系统的内置插件,提供基础的聊天交互功能
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Tuple, Type, Optional
|
||||
import time
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, BaseAction, ComponentInfo, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import emoji_api, generator_api, message_api
|
||||
|
||||
logger = get_logger("core_actions")
|
||||
|
||||
@@ -35,11 +35,11 @@ class ReplyAction(BaseAction):
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "reply"
|
||||
action_description = "参与聊天回复,处理文本和表情的发送"
|
||||
action_description = "参与聊天回复,发送文本进行表达"
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"reply_to": "如果是明确回复某个人的发言,请在reply_to参数中指定,格式:(用户名:发言内容),如果不是,reply_to的值设为none"
|
||||
"reply_to": "你要回复的对方的发言内容,格式:(用户名:发言内容),可以为none"
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
@@ -52,40 +52,52 @@ class ReplyAction(BaseAction):
|
||||
"""执行回复动作"""
|
||||
logger.info(f"{self.log_prefix} 决定回复: {self.reasoning}")
|
||||
|
||||
start_time = self.action_data.get("loop_start_time", time.time())
|
||||
|
||||
try:
|
||||
# 获取聊天观察
|
||||
chatting_observation = self._get_chatting_observation()
|
||||
if not chatting_observation:
|
||||
return False, "未找到聊天观察"
|
||||
|
||||
# 处理回复目标
|
||||
anchor_message = await self._resolve_reply_target(chatting_observation)
|
||||
|
||||
# 获取回复器服务
|
||||
replyer = self.api.get_service("replyer")
|
||||
if not replyer:
|
||||
logger.error(f"{self.log_prefix} 未找到回复器服务")
|
||||
return False, "回复器服务不可用"
|
||||
|
||||
# 执行回复
|
||||
success, reply_set = await replyer.deal_reply(
|
||||
cycle_timers=self.cycle_timers,
|
||||
|
||||
success, reply_set = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
action_data=self.action_data,
|
||||
anchor_message=anchor_message,
|
||||
reasoning=self.reasoning,
|
||||
thinking_id=self.thinking_id,
|
||||
platform=self.platform,
|
||||
chat_id=self.chat_id,
|
||||
is_group=self.is_group
|
||||
)
|
||||
|
||||
# 检查从start_time以来的新消息数量
|
||||
# 获取动作触发时间或使用默认值
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_id,
|
||||
start_time=start_time,
|
||||
end_time=current_time
|
||||
)
|
||||
|
||||
# 根据新消息数量决定是否使用reply_to
|
||||
need_reply = new_message_count >= 4
|
||||
logger.info(f"{self.log_prefix} 从{start_time}到{current_time}共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}reply_to")
|
||||
|
||||
# 构建回复文本
|
||||
reply_text = self._build_reply_text(reply_set)
|
||||
reply_text = ""
|
||||
first_reply = False
|
||||
for reply_seg in reply_set:
|
||||
data = reply_seg[1]
|
||||
if not first_reply and need_reply:
|
||||
await self.send_text(
|
||||
content=data,
|
||||
reply_to=self.action_data.get("reply_to", "")
|
||||
)
|
||||
else:
|
||||
await self.send_text(content=data)
|
||||
first_reply = True
|
||||
reply_text += data
|
||||
|
||||
|
||||
# 存储动作记录
|
||||
await self.api.store_action_info(
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reply_text,
|
||||
action_done=True,
|
||||
thinking_id=self.thinking_id,
|
||||
action_data=self.action_data,
|
||||
)
|
||||
|
||||
# 重置NoReplyAction的连续计数器
|
||||
@@ -97,47 +109,6 @@ class ReplyAction(BaseAction):
|
||||
logger.error(f"{self.log_prefix} 回复动作执行失败: {e}")
|
||||
return False, f"回复失败: {str(e)}"
|
||||
|
||||
def _get_chatting_observation(self) -> Optional[ChattingObservation]:
|
||||
"""获取聊天观察对象"""
|
||||
observations = self.api.get_service("observations") or []
|
||||
for obs in observations:
|
||||
if isinstance(obs, ChattingObservation):
|
||||
return obs
|
||||
return None
|
||||
|
||||
async def _resolve_reply_target(self, chatting_observation: ChattingObservation):
|
||||
"""解析回复目标消息"""
|
||||
reply_to = self.action_data.get("reply_to", "none")
|
||||
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 解析回复目标格式:用户名:消息内容
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
target = parts[1].strip()
|
||||
anchor_message = chatting_observation.search_message_by_text(target)
|
||||
if anchor_message:
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
if chat_stream:
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
return anchor_message
|
||||
|
||||
# 创建空锚点消息
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
if chat_stream:
|
||||
return await create_empty_anchor_message(chat_stream.platform, chat_stream.group_info, chat_stream)
|
||||
return None
|
||||
|
||||
def _build_reply_text(self, reply_set) -> str:
|
||||
"""构建回复文本"""
|
||||
reply_text = ""
|
||||
if reply_set:
|
||||
for reply in reply_set:
|
||||
reply_type = reply[0]
|
||||
data = reply[1]
|
||||
if reply_type in ["text", "emoji"]:
|
||||
reply_text += data
|
||||
return reply_text
|
||||
|
||||
|
||||
class NoReplyAction(BaseAction):
|
||||
@@ -178,30 +149,26 @@ class NoReplyAction(BaseAction):
|
||||
count = NoReplyAction._consecutive_count
|
||||
|
||||
# 计算本次等待时间
|
||||
timeout = self._calculate_waiting_time(count)
|
||||
if count <= len(self._waiting_stages):
|
||||
# 前3次使用预设时间
|
||||
stage_time = self._waiting_stages[count - 1]
|
||||
# 如果WAITING_TIME_THRESHOLD更小,则使用它
|
||||
timeout = min(stage_time, self.waiting_timeout)
|
||||
else:
|
||||
# 第4次及以后使用WAITING_TIME_THRESHOLD
|
||||
timeout = self.waiting_timeout
|
||||
|
||||
logger.info(f"{self.log_prefix} 选择不回复(第{count}次连续),等待新消息中... (超时: {timeout}秒)")
|
||||
|
||||
# 等待新消息或达到时间上限
|
||||
result = await self.api.wait_for_new_message(timeout)
|
||||
result = await self.wait_for_new_message(timeout)
|
||||
|
||||
# 如果有新消息或者超时,都不重置计数器,因为可能还会继续no_reply
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 不回复动作执行失败: {e}")
|
||||
return False, f"不回复动作执行失败: {e}"
|
||||
|
||||
def _calculate_waiting_time(self, consecutive_count: int) -> int:
|
||||
"""根据连续次数计算等待时间"""
|
||||
if consecutive_count <= len(self._waiting_stages):
|
||||
# 前3次使用预设时间
|
||||
stage_time = self._waiting_stages[consecutive_count - 1]
|
||||
# 如果WAITING_TIME_THRESHOLD更小,则使用它
|
||||
return min(stage_time, self.waiting_timeout)
|
||||
else:
|
||||
# 第4次及以后使用WAITING_TIME_THRESHOLD
|
||||
return self.waiting_timeout
|
||||
return False, f"不回复动作执行失败: {e}"
|
||||
|
||||
@classmethod
|
||||
def reset_consecutive_count(cls):
|
||||
@@ -248,56 +215,33 @@ class EmojiAction(BaseAction):
|
||||
logger.info(f"{self.log_prefix} 决定发送表情")
|
||||
|
||||
try:
|
||||
# 创建空锚点消息
|
||||
anchor_message = await self._create_anchor_message()
|
||||
if not anchor_message:
|
||||
return False, "无法创建锚点消息"
|
||||
|
||||
# 获取回复器服务
|
||||
replyer = self.api.get_service("replyer")
|
||||
if not replyer:
|
||||
logger.error(f"{self.log_prefix} 未找到回复器服务")
|
||||
return False, "回复器服务不可用"
|
||||
|
||||
# 执行表情处理
|
||||
success, reply_set = await replyer.deal_emoji(
|
||||
cycle_timers=self.cycle_timers,
|
||||
action_data=self.action_data,
|
||||
anchor_message=anchor_message,
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
# 构建回复文本
|
||||
reply_text = self._build_reply_text(reply_set)
|
||||
# 1. 根据描述选择表情包
|
||||
description = self.action_data.get("description", "")
|
||||
emoji_result = await emoji_api.get_by_description(description)
|
||||
|
||||
if not emoji_result:
|
||||
logger.warning(f"{self.log_prefix} 未找到匹配描述 '{description}' 的表情包")
|
||||
return False, f"未找到匹配 '{description}' 的表情包"
|
||||
|
||||
emoji_base64, emoji_description, matched_emotion = emoji_result
|
||||
logger.info(f"{self.log_prefix} 找到表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
||||
|
||||
# 使用BaseAction的便捷方法发送表情包
|
||||
success = await self.send_emoji(emoji_base64)
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix} 表情包发送失败")
|
||||
return False, "表情包发送失败"
|
||||
|
||||
# 重置NoReplyAction的连续计数器
|
||||
NoReplyAction.reset_consecutive_count()
|
||||
|
||||
return success, reply_text
|
||||
return True, f"发送表情包: {emoji_description}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 表情动作执行失败: {e}")
|
||||
return False, f"表情发送失败: {str(e)}"
|
||||
|
||||
async def _create_anchor_message(self):
|
||||
"""创建锚点消息"""
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
if chat_stream:
|
||||
logger.info(f"{self.log_prefix} 为表情包创建占位符")
|
||||
return await create_empty_anchor_message(chat_stream.platform, chat_stream.group_info, chat_stream)
|
||||
return None
|
||||
|
||||
def _build_reply_text(self, reply_set) -> str:
|
||||
"""构建回复文本"""
|
||||
reply_text = ""
|
||||
if reply_set:
|
||||
for reply in reply_set:
|
||||
reply_type = reply[0]
|
||||
data = reply[1]
|
||||
if reply_type in ["text", "emoji"]:
|
||||
reply_text += data
|
||||
return reply_text
|
||||
|
||||
|
||||
class ChangeToFocusChatAction(BaseAction):
|
||||
"""切换到专注聊天动作 - 从普通模式切换到专注模式"""
|
||||
@@ -314,6 +258,7 @@ class ChangeToFocusChatAction(BaseAction):
|
||||
# 动作参数定义
|
||||
action_parameters = {}
|
||||
|
||||
apex = 111
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
"你想要进入专注聊天模式",
|
||||
@@ -437,8 +382,6 @@ class CoreActionsPlugin(BasePlugin):
|
||||
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用'表情'动作"),
|
||||
"enable_change_to_focus": ConfigField(type=bool, default=True, description="是否启用'切换到专注模式'动作"),
|
||||
"enable_exit_focus": ConfigField(type=bool, default=True, description="是否启用'退出专注模式'动作"),
|
||||
"enable_ping_command": ConfigField(type=bool, default=True, description="是否启用'/ping'测试命令"),
|
||||
"enable_log_command": ConfigField(type=bool, default=True, description="是否启用'/log'日志命令"),
|
||||
},
|
||||
"no_reply": {
|
||||
"waiting_timeout": ConfigField(
|
||||
@@ -482,73 +425,137 @@ class CoreActionsPlugin(BasePlugin):
|
||||
components.append((ExitFocusChatAction.get_action_info(), ExitFocusChatAction))
|
||||
if self.get_config("components.enable_change_to_focus", True):
|
||||
components.append((ChangeToFocusChatAction.get_action_info(), ChangeToFocusChatAction))
|
||||
if self.get_config("components.enable_ping_command", True):
|
||||
components.append(
|
||||
(PingCommand.get_command_info(name="ping", description="测试机器人响应,拦截后续处理"), PingCommand)
|
||||
)
|
||||
if self.get_config("components.enable_log_command", True):
|
||||
components.append(
|
||||
(LogCommand.get_command_info(name="log", description="记录消息到日志,不拦截后续处理"), LogCommand)
|
||||
)
|
||||
# components.append((DeepReplyAction.get_action_info(), DeepReplyAction))
|
||||
|
||||
return components
|
||||
|
||||
|
||||
# ===== 示例Command组件 =====
|
||||
|
||||
|
||||
class PingCommand(BaseCommand):
|
||||
"""Ping命令 - 测试响应,拦截消息处理"""
|
||||
|
||||
command_pattern = r"^/ping(\s+(?P<message>.+))?$"
|
||||
command_help = "测试机器人响应 - 拦截后续处理"
|
||||
command_examples = ["/ping", "/ping 测试消息"]
|
||||
intercept_message = True # 拦截消息,不继续处理
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行ping命令"""
|
||||
try:
|
||||
message = self.matched_groups.get("message", "")
|
||||
reply_text = f"🏓 Pong! {message}" if message else "🏓 Pong!"
|
||||
|
||||
await self.send_text(reply_text)
|
||||
return True, f"发送ping响应: {reply_text}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ping命令执行失败: {e}")
|
||||
return False, f"执行失败: {str(e)}"
|
||||
|
||||
|
||||
class LogCommand(BaseCommand):
|
||||
"""日志命令 - 记录消息但不拦截后续处理"""
|
||||
# class DeepReplyAction(BaseAction):
|
||||
# """回复动作 - 参与聊天回复"""
|
||||
|
||||
command_pattern = r"^/log(\s+(?P<level>debug|info|warn|error))?$"
|
||||
command_help = "记录当前消息到日志 - 不拦截后续处理"
|
||||
command_examples = ["/log", "/log info", "/log debug"]
|
||||
intercept_message = False # 不拦截消息,继续后续处理
|
||||
# # 激活设置
|
||||
# focus_activation_type = ActionActivationType.ALWAYS
|
||||
# normal_activation_type = ActionActivationType.NEVER
|
||||
# mode_enable = ChatMode.FOCUS
|
||||
# parallel_action = False
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行日志命令"""
|
||||
try:
|
||||
level = self.matched_groups.get("level", "info")
|
||||
user_nickname = self.message.message_info.user_info.user_nickname
|
||||
content = self.message.processed_plain_text
|
||||
# # 动作基本信息
|
||||
# action_name = "deep_reply"
|
||||
# action_description = "参与聊天回复,关注某个话题,对聊天内容进行深度思考,给出回复"
|
||||
|
||||
log_message = f"[{level.upper()}] 用户 {user_nickname}: {content}"
|
||||
# # 动作参数定义
|
||||
# action_parameters = {
|
||||
# "topic": "想要思考的话题"
|
||||
# }
|
||||
|
||||
# 根据级别记录日志
|
||||
if level == "debug":
|
||||
logger.debug(log_message)
|
||||
elif level == "warn":
|
||||
logger.warning(log_message)
|
||||
elif level == "error":
|
||||
logger.error(log_message)
|
||||
else:
|
||||
logger.info(log_message)
|
||||
# # 动作使用场景
|
||||
# action_require = ["有些问题需要深度思考", "某个问题可能涉及多个方面", "某个问题涉及专业领域或者需要专业知识","这个问题讨论的很激烈,需要深度思考"]
|
||||
|
||||
# 不发送回复,让消息继续处理
|
||||
return True, f"已记录到{level}级别日志"
|
||||
# # 关联类型
|
||||
# associated_types = ["text"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Log命令执行失败: {e}")
|
||||
return False, f"执行失败: {str(e)}"
|
||||
# async def execute(self) -> Tuple[bool, str]:
|
||||
# """执行回复动作"""
|
||||
# logger.info(f"{self.log_prefix} 决定深度思考")
|
||||
|
||||
# try:
|
||||
# # 获取聊天观察
|
||||
# chatting_observation = self._get_chatting_observation()
|
||||
# if not chatting_observation:
|
||||
# return False, "未找到聊天观察"
|
||||
|
||||
# talking_message_str = chatting_observation.talking_message_str
|
||||
|
||||
# # 处理回复目标
|
||||
# chat_stream = self.api.get_service("chat_stream")
|
||||
# anchor_message = await create_empty_anchor_message(chat_stream.platform, chat_stream.group_info, chat_stream)
|
||||
|
||||
|
||||
|
||||
# llm_model = self.api.get_available_models().replyer_1
|
||||
|
||||
# prompt = f"""
|
||||
# {talking_message_str}
|
||||
|
||||
# 在上面的聊天中,你对{self.action_data.get("topic", "")}感兴趣,形成深刻观点,请你思考,总结成一份学术论文,APA标准格式
|
||||
# """
|
||||
|
||||
# success, response, reasoning, model_name = await self.api.generate_with_model(prompt, llm_model)
|
||||
|
||||
# print(prompt)
|
||||
# print(f"DeepReplyAction: {response}")
|
||||
|
||||
# # prompt = f"""
|
||||
# # {talking_message_str}
|
||||
|
||||
# # 在上面的聊天中,你对{self.action_data.get("topic", "")}感兴趣,请你思考
|
||||
# # """
|
||||
|
||||
# extra_info_block = self.action_data.get("extra_info_block", "")
|
||||
# extra_info_block += response
|
||||
# # extra_info_block += f"\n--------------------------------\n注意,这是最重要的内容!!!!!你现在可以用比较长的篇幅来表达你的观点,不要只回复一个字或者几个字\n由于你进入了深度思考模式,所以下方聊天记录的回复要求不再适用,请你自由的表达,不论字数长短限制\n\n--------------------------------\n注意,你现在可以用比较长的篇幅来表达你的观点,不要只回复一个字或者几个字\n由于你进入了深度思考模式,所以下方聊天记录的回复要求不再适用,请你自由的表达,不论字数长短限制\n"
|
||||
# # extra_info_block += f"\n--------------------------------\n注意,优先关注这句!!!!你现在可以用比较长的篇幅来表达你的观点,不要只回复一个字或者几个字\n由于你进入了深度思考模式,所以下方聊天记录的回复要求不再适用,请你自由的表达,不论字数长短限制\n\n--------------------------------\n注意,你现在可以用比较长的篇幅来表达你的观点,不要只回复一个字或者几个字\n由于你进入了深度思考模式,所以其他的回复要求不再适用,请你自由的表达,不论字数长短限制\n"
|
||||
# self.action_data["extra_info_block"] = extra_info_block
|
||||
|
||||
|
||||
|
||||
|
||||
# # 获取回复器服务
|
||||
# # replyer = self.api.get_service("replyer")
|
||||
# # if not replyer:
|
||||
# # logger.error(f"{self.log_prefix} 未找到回复器服务")
|
||||
# # return False, "回复器服务不可用"
|
||||
|
||||
# # await self.send_message_by_expressor(extra_info_block)
|
||||
# await self.send_text(extra_info_block)
|
||||
# # 执行回复
|
||||
# # success, reply_set = await replyer.deal_reply(
|
||||
# # cycle_timers=self.cycle_timers,
|
||||
# # action_data=self.action_data,
|
||||
# # anchor_message=anchor_message,
|
||||
# # reasoning=self.reasoning,
|
||||
# # thinking_id=self.thinking_id,
|
||||
# # )
|
||||
|
||||
# # 构建回复文本
|
||||
# reply_text = "self._build_reply_text(reply_set)"
|
||||
|
||||
# # 存储动作记录
|
||||
# await self.api.store_action_info(
|
||||
# action_build_into_prompt=False,
|
||||
# action_prompt_display=reply_text,
|
||||
# action_done=True,
|
||||
# thinking_id=self.thinking_id,
|
||||
# action_data=self.action_data,
|
||||
# )
|
||||
|
||||
# # 重置NoReplyAction的连续计数器
|
||||
# NoReplyAction.reset_consecutive_count()
|
||||
|
||||
# return success, reply_text
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"{self.log_prefix} 回复动作执行失败: {e}")
|
||||
# return False, f"回复失败: {str(e)}"
|
||||
|
||||
# def _get_chatting_observation(self) -> Optional[ChattingObservation]:
|
||||
# """获取聊天观察对象"""
|
||||
# observations = self.api.get_service("observations") or []
|
||||
# for obs in observations:
|
||||
# if isinstance(obs, ChattingObservation):
|
||||
# return obs
|
||||
# return None
|
||||
|
||||
|
||||
# def _build_reply_text(self, reply_set) -> str:
|
||||
# """构建回复文本"""
|
||||
# reply_text = ""
|
||||
# if reply_set:
|
||||
# for reply in reply_set:
|
||||
# data = reply[1]
|
||||
# reply_text += data
|
||||
# return reply_text
|
||||
@@ -26,6 +26,8 @@ from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.common.logger import get_logger
|
||||
# 导入配置API(可选的简便方法)
|
||||
from src.plugin_system.apis import person_api, generator_api
|
||||
|
||||
logger = get_logger("mute_plugin")
|
||||
|
||||
@@ -110,8 +112,8 @@ class MuteAction(BaseAction):
|
||||
return False, error_msg
|
||||
|
||||
# 获取时长限制配置
|
||||
min_duration = self.api.get_config("mute.min_duration", 60)
|
||||
max_duration = self.api.get_config("mute.max_duration", 2592000)
|
||||
min_duration = self.get_config("mute.min_duration", 60)
|
||||
max_duration = self.get_config("mute.max_duration", 2592000)
|
||||
|
||||
# 验证时长格式并转换
|
||||
try:
|
||||
@@ -133,72 +135,65 @@ class MuteAction(BaseAction):
|
||||
except (ValueError, TypeError):
|
||||
error_msg = f"禁言时长格式无效: {duration}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
await self.send_text("禁言时长必须是数字哦~")
|
||||
# await self.send_text("禁言时长必须是数字哦~")
|
||||
return False, error_msg
|
||||
|
||||
# 获取用户ID
|
||||
try:
|
||||
platform, user_id = await self.api.get_user_id_by_person_name(target)
|
||||
except Exception as e:
|
||||
error_msg = f"查找用户ID时出错: {e}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
await self.send_text("查找用户信息时出现问题~")
|
||||
return False, error_msg
|
||||
|
||||
person_id = person_api.get_person_id_by_name(target)
|
||||
user_id = await person_api.get_person_value(person_id,"user_id")
|
||||
if not user_id:
|
||||
error_msg = f"未找到用户 {target} 的ID"
|
||||
await self.send_text(f"找不到 {target} 这个人呢~")
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
|
||||
# 格式化时长显示
|
||||
enable_formatting = self.api.get_config("mute.enable_duration_formatting", True)
|
||||
enable_formatting = self.get_config("mute.enable_duration_formatting", True)
|
||||
time_str = self._format_duration(duration_int) if enable_formatting else f"{duration_int}秒"
|
||||
|
||||
# 获取模板化消息
|
||||
message = self._get_template_message(target, time_str, reason)
|
||||
# await self.send_text(message)
|
||||
await self.send_message_by_expressor(message)
|
||||
|
||||
result_status,result_message = await generator_api.rewrite_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_data={
|
||||
"raw_reply": message,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
|
||||
if result_status:
|
||||
for reply_seg in result_message:
|
||||
data = reply_seg[1]
|
||||
await self.send_text(data)
|
||||
|
||||
# 发送群聊禁言命令
|
||||
success = await self.send_command(
|
||||
command_name="GROUP_BAN",
|
||||
args={"qq_id": str(user_id), "duration": str(duration_int)},
|
||||
display_message="发送禁言命令",
|
||||
storage_message=False
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送禁言命令,用户 {target}({user_id}),时长 {duration_int} 秒")
|
||||
# 存储动作信息
|
||||
await self.api.store_action_info(
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"尝试禁言了用户 {target},时长 {time_str},原因:{reason}",
|
||||
action_done=True,
|
||||
thinking_id=self.thinking_id,
|
||||
action_data={
|
||||
"target": target,
|
||||
"user_id": user_id,
|
||||
"duration": duration_int,
|
||||
"duration_str": time_str,
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
return True, f"成功禁言 {target},时长 {time_str}"
|
||||
else:
|
||||
error_msg = "发送禁言命令失败"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
|
||||
await self.send_text("执行禁言动作失败")
|
||||
return False, error_msg
|
||||
|
||||
def _get_template_message(self, target: str, duration_str: str, reason: str) -> str:
|
||||
"""获取模板化的禁言消息"""
|
||||
templates = self.api.get_config(
|
||||
"mute.templates",
|
||||
[
|
||||
"好的,禁言 {target} {duration},理由:{reason}",
|
||||
"收到,对 {target} 执行禁言 {duration},因为{reason}",
|
||||
"明白了,禁言 {target} {duration},原因是{reason}",
|
||||
],
|
||||
templates = self.get_config(
|
||||
"mute.templates"
|
||||
)
|
||||
|
||||
template = random.choice(templates)
|
||||
@@ -258,8 +253,8 @@ class MuteCommand(BaseCommand):
|
||||
return False, "参数不完整"
|
||||
|
||||
# 获取时长限制配置
|
||||
min_duration = self.api.get_config("mute.min_duration", 60)
|
||||
max_duration = self.api.get_config("mute.max_duration", 2592000)
|
||||
min_duration = self.get_config("mute.min_duration", 60)
|
||||
max_duration = self.get_config("mute.max_duration", 2592000)
|
||||
|
||||
# 验证时长
|
||||
try:
|
||||
@@ -281,19 +276,16 @@ class MuteCommand(BaseCommand):
|
||||
return False, "时长格式错误"
|
||||
|
||||
# 获取用户ID
|
||||
try:
|
||||
platform, user_id = await self.api.get_user_id_by_person_name(target)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 查找用户ID时出错: {e}")
|
||||
await self.send_text("❌ 查找用户信息时出现问题")
|
||||
return False, str(e)
|
||||
|
||||
person_id = person_api.get_person_id_by_name(target)
|
||||
user_id = person_api.get_person_value(person_id, "user_id")
|
||||
if not user_id:
|
||||
error_msg = f"未找到用户 {target} 的ID"
|
||||
await self.send_text(f"❌ 找不到用户: {target}")
|
||||
return False, "用户不存在"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# 格式化时长显示
|
||||
enable_formatting = self.api.get_config("mute.enable_duration_formatting", True)
|
||||
enable_formatting = self.get_config("mute.enable_duration_formatting", True)
|
||||
time_str = self._format_duration(duration_int) if enable_formatting else f"{duration_int}秒"
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行禁言命令: {target}({user_id}) -> {time_str}")
|
||||
@@ -323,14 +315,7 @@ class MuteCommand(BaseCommand):
|
||||
|
||||
def _get_template_message(self, target: str, duration_str: str, reason: str) -> str:
|
||||
"""获取模板化的禁言消息"""
|
||||
templates = self.api.get_config(
|
||||
"mute.templates",
|
||||
[
|
||||
"✅ 已禁言 {target} {duration},理由:{reason}",
|
||||
"🔇 对 {target} 执行禁言 {duration},因为{reason}",
|
||||
"⛔ 禁言 {target} {duration},原因:{reason}",
|
||||
],
|
||||
)
|
||||
templates = self.get_config("mute.templates")
|
||||
|
||||
template = random.choice(templates)
|
||||
return template.format(target=target, duration=duration_str, reason=reason)
|
||||
|
||||
@@ -1,534 +0,0 @@
|
||||
"""
|
||||
拍照插件
|
||||
|
||||
功能特性:
|
||||
- Action: 生成一张自拍照,prompt由人设和模板生成
|
||||
- Command: 展示最近生成的照片
|
||||
|
||||
#此插件并不完善
|
||||
#此插件并不完善
|
||||
|
||||
#此插件并不完善
|
||||
|
||||
#此插件并不完善
|
||||
|
||||
#此插件并不完善
|
||||
|
||||
#此插件并不完善
|
||||
|
||||
#此插件并不完善
|
||||
|
||||
|
||||
|
||||
包含组件:
|
||||
- 拍照Action - 生成自拍照
|
||||
- 展示照片Command - 展示最近生成的照片
|
||||
"""
|
||||
from typing import List, Tuple, Type, Optional
|
||||
import random
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import base64
|
||||
import traceback
|
||||
|
||||
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("take_picture_plugin")
|
||||
|
||||
# 定义数据目录常量
|
||||
DATA_DIR = os.path.join("data", "take_picture_data")
|
||||
# 确保数据目录存在
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
# 创建全局锁
|
||||
file_lock = asyncio.Lock()
|
||||
|
||||
|
||||
class TakePictureAction(BaseAction):
|
||||
"""生成一张自拍照"""
|
||||
|
||||
focus_activation_type = ActionActivationType.KEYWORD
|
||||
normal_activation_type = ActionActivationType.KEYWORD
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = False
|
||||
|
||||
action_name = "take_picture"
|
||||
action_description = "生成一张用手机拍摄,比如自拍或者近照"
|
||||
activation_keywords = ["拍张照", "自拍", "发张照片", "看看你", "你的照片"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
action_parameters = {}
|
||||
|
||||
action_require = [
|
||||
"当用户想看你的照片时使用",
|
||||
"当用户让你发自拍时使用"
|
||||
"当想随手拍眼前的场景时使用"
|
||||
]
|
||||
|
||||
associated_types = ["text","image"]
|
||||
|
||||
# 内置的Prompt模板,如果配置文件中没有定义,将使用这些模板
|
||||
DEFAULT_PROMPT_TEMPLATES = [
|
||||
"极其频繁无奇的iPhone自拍照,没有明确的主体或构图感,就是随手一拍的快照照片略带运动模糊,阳光或室内打光不均匀导致的轻微曝光过度,整体呈现出一种刻意的平庸感,就像是从口袋里拿手机时不小心拍到的一张自拍。主角是{name},{personality}"
|
||||
]
|
||||
|
||||
# 简单的请求缓存,避免短时间内重复请求
|
||||
_request_cache = {}
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
logger.info(f"{self.log_prefix} 执行拍照动作")
|
||||
|
||||
try:
|
||||
# 配置验证
|
||||
http_base_url = self.api.get_config("api.base_url")
|
||||
http_api_key = self.api.get_config("api.volcano_generate_api_key")
|
||||
|
||||
if not (http_base_url and http_api_key):
|
||||
error_msg = "抱歉,照片生成功能所需的API配置(如API地址或密钥)不完整,无法提供服务。"
|
||||
await self.send_text(error_msg)
|
||||
logger.error(f"{self.log_prefix} HTTP调用配置缺失: base_url 或 volcano_generate_api_key.")
|
||||
return False, "API配置不完整"
|
||||
|
||||
# API密钥验证
|
||||
if http_api_key == "YOUR_DOUBAO_API_KEY_HERE":
|
||||
error_msg = "照片生成功能尚未配置,请设置正确的API密钥。"
|
||||
await self.send_text(error_msg)
|
||||
logger.error(f"{self.log_prefix} API密钥未配置")
|
||||
return False, "API密钥未配置"
|
||||
|
||||
# 获取全局配置信息
|
||||
bot_nickname = self.api.get_global_config("bot.nickname", "麦麦")
|
||||
bot_personality = self.api.get_global_config("personality.personality_core", "")
|
||||
|
||||
|
||||
personality_sides = self.api.get_global_config("personality.personality_sides", [])
|
||||
if personality_sides:
|
||||
bot_personality += random.choice(personality_sides)
|
||||
|
||||
# 准备模板变量
|
||||
template_vars = {
|
||||
"name": bot_nickname,
|
||||
"personality": bot_personality
|
||||
}
|
||||
|
||||
logger.info(f"{self.log_prefix} 使用的全局配置: name={bot_nickname}, personality={bot_personality}")
|
||||
|
||||
# 尝试从配置文件获取模板,如果没有则使用默认模板
|
||||
templates = self.api.get_config("picture.prompt_templates", self.DEFAULT_PROMPT_TEMPLATES)
|
||||
if not templates:
|
||||
logger.warning(f"{self.log_prefix} 未找到有效的提示词模板,使用默认模板")
|
||||
templates = self.DEFAULT_PROMPT_TEMPLATES
|
||||
|
||||
prompt_template = random.choice(templates)
|
||||
|
||||
# 填充模板
|
||||
final_prompt = prompt_template.format(**template_vars)
|
||||
|
||||
logger.info(f"{self.log_prefix} 生成的最终Prompt: {final_prompt}")
|
||||
|
||||
# 从配置获取参数
|
||||
model = self.api.get_config("picture.default_model", "doubao-seedream-3-0-t2i-250415")
|
||||
style = self.api.get_config("picture.default_style", "动漫")
|
||||
size = self.api.get_config("picture.default_size", "1024x1024")
|
||||
watermark = self.api.get_config("picture.default_watermark", True)
|
||||
guidance_scale = self.api.get_config("picture.default_guidance_scale", 2.5)
|
||||
seed = self.api.get_config("picture.default_seed", 42)
|
||||
|
||||
# 检查缓存
|
||||
enable_cache = self.api.get_config("storage.enable_cache", True)
|
||||
if enable_cache:
|
||||
cache_key = self._get_cache_key(final_prompt, model, size)
|
||||
if cache_key in self._request_cache:
|
||||
cached_result = self._request_cache[cache_key]
|
||||
logger.info(f"{self.log_prefix} 使用缓存的图片结果")
|
||||
await self.send_text("我之前拍过类似的照片,用之前的结果~")
|
||||
|
||||
# 直接发送缓存的结果
|
||||
send_success = await self._send_image(cached_result)
|
||||
if send_success:
|
||||
await self.send_text("这是我的照片,好看吗?")
|
||||
return True, "照片已发送(缓存)"
|
||||
else:
|
||||
# 缓存失败,清除这个缓存项并继续正常流程
|
||||
del self._request_cache[cache_key]
|
||||
|
||||
await self.send_text(f"正在为你拍照,请稍候...")
|
||||
|
||||
try:
|
||||
seed = random.randint(1, 1000000)
|
||||
success, result = await asyncio.to_thread(
|
||||
self._make_http_image_request,
|
||||
prompt=final_prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
seed=seed,
|
||||
guidance_scale=guidance_scale,
|
||||
watermark=watermark,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} (HTTP) 异步请求执行失败: {e!r}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
success = False
|
||||
result = f"照片生成服务遇到意外问题: {str(e)[:100]}"
|
||||
|
||||
if success:
|
||||
image_url = result
|
||||
logger.info(f"{self.log_prefix} 图片URL获取成功: {image_url[:70]}... 下载并编码.")
|
||||
|
||||
try:
|
||||
encode_success, encode_result = await asyncio.to_thread(self._download_and_encode_base64, image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} (B64) 异步下载/编码失败: {e!r}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
encode_success = False
|
||||
encode_result = f"图片下载或编码时发生内部错误: {str(e)[:100]}"
|
||||
|
||||
if encode_success:
|
||||
base64_image_string = encode_result
|
||||
# 更新缓存
|
||||
if enable_cache:
|
||||
self._update_cache(final_prompt, model, size, base64_image_string)
|
||||
|
||||
# 发送图片
|
||||
send_success = await self._send_image(base64_image_string)
|
||||
if send_success:
|
||||
# 存储到文件
|
||||
await self._store_picture_info(final_prompt, image_url)
|
||||
logger.info(f"{self.log_prefix} 成功生成并存储照片: {image_url}")
|
||||
await self.send_text("当当当当~这是我刚拍的照片,好看吗?")
|
||||
return True, f"成功生成照片: {image_url}"
|
||||
else:
|
||||
await self.send_text("照片生成了,但发送失败了,可能是格式问题...")
|
||||
return False, "照片发送失败"
|
||||
else:
|
||||
await self.send_text(f"照片下载失败: {encode_result}")
|
||||
return False, encode_result
|
||||
else:
|
||||
await self.send_text(f"哎呀,拍照失败了: {result}")
|
||||
return False, result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行拍照动作失败: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
await self.send_text("呜呜,拍照的时候出了一点小问题...")
|
||||
return False, str(e)
|
||||
|
||||
async def _store_picture_info(self, prompt: str, image_url: str):
|
||||
"""将照片信息存入日志文件"""
|
||||
log_file = self.api.get_config("storage.log_file", "picture_log.json")
|
||||
log_path = os.path.join(DATA_DIR, log_file)
|
||||
max_photos = self.api.get_config("storage.max_photos", 50)
|
||||
|
||||
async with file_lock:
|
||||
try:
|
||||
if os.path.exists(log_path):
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
log_data = json.load(f)
|
||||
else:
|
||||
log_data = []
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
log_data = []
|
||||
|
||||
# 添加新照片
|
||||
log_data.append({
|
||||
"prompt": prompt,
|
||||
"image_url": image_url,
|
||||
"timestamp": datetime.datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 如果超过最大数量,删除最旧的
|
||||
if len(log_data) > max_photos:
|
||||
log_data = sorted(log_data, key=lambda x: x.get('timestamp', ''), reverse=True)[:max_photos]
|
||||
|
||||
try:
|
||||
with open(log_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(log_data, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 写入照片日志文件失败: {e}", exc_info=True)
|
||||
|
||||
def _make_http_image_request(
|
||||
self, prompt: str, model: str, size: str, seed: int, guidance_scale: float, watermark: bool
|
||||
) -> Tuple[bool, str]:
|
||||
"""发送HTTP请求到火山引擎豆包API生成图片"""
|
||||
try:
|
||||
base_url = self.api.get_config("api.base_url")
|
||||
api_key = self.api.get_config("api.volcano_generate_api_key")
|
||||
|
||||
# 构建请求URL和头部
|
||||
endpoint = f"{base_url.rstrip('/')}/images/generations"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
# 构建请求体
|
||||
request_body = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"response_format": "url",
|
||||
"size": size,
|
||||
"seed": seed,
|
||||
"guidance_scale": guidance_scale,
|
||||
"watermark": watermark,
|
||||
"api-key": api_key,
|
||||
}
|
||||
|
||||
# 创建请求对象
|
||||
req = urllib.request.Request(
|
||||
endpoint,
|
||||
data=json.dumps(request_body).encode("utf-8"),
|
||||
headers=headers,
|
||||
method="POST",
|
||||
)
|
||||
|
||||
# 发送请求并获取响应
|
||||
with urllib.request.urlopen(req, timeout=60) as response:
|
||||
response_data = json.loads(response.read().decode("utf-8"))
|
||||
|
||||
# 解析响应
|
||||
image_url = None
|
||||
if (
|
||||
isinstance(response_data.get("data"), list)
|
||||
and response_data["data"]
|
||||
and isinstance(response_data["data"][0], dict)
|
||||
):
|
||||
image_url = response_data["data"][0].get("url")
|
||||
elif response_data.get("url"):
|
||||
image_url = response_data.get("url")
|
||||
|
||||
if image_url:
|
||||
return True, image_url
|
||||
else:
|
||||
error_msg = response_data.get("error", {}).get("message", "未知错误")
|
||||
logger.error(f"API返回错误: {error_msg}")
|
||||
return False, f"API错误: {error_msg}"
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = e.read().decode("utf-8")
|
||||
logger.error(f"HTTP错误 {e.code}: {error_body}")
|
||||
return False, f"HTTP错误 {e.code}: {error_body[:100]}..."
|
||||
except Exception as e:
|
||||
logger.error(f"请求异常: {e}", exc_info=True)
|
||||
return False, f"请求异常: {str(e)}"
|
||||
|
||||
def _download_and_encode_base64(self, image_url: str) -> Tuple[bool, str]:
|
||||
"""下载图片并转换为Base64编码"""
|
||||
try:
|
||||
with urllib.request.urlopen(image_url) as response:
|
||||
image_data = response.read()
|
||||
|
||||
base64_encoded = base64.b64encode(image_data).decode('utf-8')
|
||||
return True, base64_encoded
|
||||
except Exception as e:
|
||||
logger.error(f"图片下载编码失败: {e}", exc_info=True)
|
||||
return False, str(e)
|
||||
|
||||
async def _send_image(self, base64_image: str) -> bool:
|
||||
"""发送图片"""
|
||||
try:
|
||||
# 使用聊天流信息确定发送目标
|
||||
chat_stream = self.api.get_service("chat_stream")
|
||||
if not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 没有可用的聊天流发送图片")
|
||||
return False
|
||||
|
||||
if chat_stream.group_info:
|
||||
# 群聊
|
||||
return await self.api.send_message_to_target(
|
||||
message_type="image",
|
||||
content=base64_image,
|
||||
platform=chat_stream.platform,
|
||||
target_id=str(chat_stream.group_info.group_id),
|
||||
is_group=True,
|
||||
display_message="发送生成的照片",
|
||||
)
|
||||
else:
|
||||
# 私聊
|
||||
return await self.api.send_message_to_target(
|
||||
message_type="image",
|
||||
content=base64_image,
|
||||
platform=chat_stream.platform,
|
||||
target_id=str(chat_stream.user_info.user_id),
|
||||
is_group=False,
|
||||
display_message="发送生成的照片",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送图片时出错: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _get_cache_key(cls, description: str, model: str, size: str) -> str:
|
||||
"""生成缓存键"""
|
||||
return f"{description}|{model}|{size}"
|
||||
|
||||
def _update_cache(self, description: str, model: str, size: str, base64_image: str):
|
||||
"""更新缓存"""
|
||||
max_cache_size = self.api.get_config("storage.max_cache_size", 10)
|
||||
cache_key = self._get_cache_key(description, model, size)
|
||||
|
||||
# 添加到缓存
|
||||
self._request_cache[cache_key] = base64_image
|
||||
|
||||
# 如果缓存超过最大大小,删除最旧的项
|
||||
if len(self._request_cache) > max_cache_size:
|
||||
oldest_key = next(iter(self._request_cache))
|
||||
del self._request_cache[oldest_key]
|
||||
|
||||
|
||||
class ShowRecentPicturesCommand(BaseCommand):
|
||||
"""展示最近生成的照片"""
|
||||
|
||||
command_name = "show_recent_pictures"
|
||||
command_description = "展示最近生成的5张照片"
|
||||
command_pattern = r"^/show_pics$"
|
||||
command_help = "用法: /show_pics"
|
||||
command_examples = ["/show_pics"]
|
||||
intercept_message = True
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
logger.info(f"{self.log_prefix} 执行展示最近照片命令")
|
||||
log_file = self.api.get_config("storage.log_file", "picture_log.json")
|
||||
log_path = os.path.join(DATA_DIR, log_file)
|
||||
|
||||
async with file_lock:
|
||||
try:
|
||||
if not os.path.exists(log_path):
|
||||
await self.send_text("最近还没有拍过照片哦,快让我自拍一张吧!")
|
||||
return True, "没有照片日志文件"
|
||||
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
log_data = json.load(f)
|
||||
|
||||
if not log_data:
|
||||
await self.send_text("最近还没有拍过照片哦,快让我自拍一张吧!")
|
||||
return True, "没有照片"
|
||||
|
||||
# 获取最新的5张照片
|
||||
recent_pics = sorted(log_data, key=lambda x: x['timestamp'], reverse=True)[:5]
|
||||
|
||||
# 先发送文本消息
|
||||
await self.send_text("这是我最近拍的几张照片~")
|
||||
|
||||
# 逐个发送图片
|
||||
for pic in recent_pics:
|
||||
# 尝试获取图片URL
|
||||
image_url = pic.get('image_url')
|
||||
if image_url:
|
||||
try:
|
||||
# 下载图片并转换为Base64
|
||||
with urllib.request.urlopen(image_url) as response:
|
||||
image_data = response.read()
|
||||
base64_encoded = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
# 发送图片
|
||||
await self.send_type(
|
||||
message_type="image",
|
||||
content=base64_encoded,
|
||||
display_message="发送最近的照片"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 下载或发送照片失败: {e}", exc_info=True)
|
||||
|
||||
return True, "成功展示最近的照片"
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await self.send_text("照片记录文件好像损坏了...")
|
||||
return False, "JSON解码错误"
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 展示照片失败: {e}", exc_info=True)
|
||||
await self.send_text("哎呀,查找照片的时候出错了。")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
@register_plugin
|
||||
class TakePicturePlugin(BasePlugin):
|
||||
"""拍照插件"""
|
||||
plugin_name = "take_picture_plugin"
|
||||
plugin_description = "提供生成自拍照和展示最近照片的功能"
|
||||
plugin_version = "1.0.0"
|
||||
plugin_author = "SengokuCola"
|
||||
enable_plugin = True
|
||||
config_file_name = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息配置",
|
||||
"api": "API相关配置,包含火山引擎API的访问信息",
|
||||
"components": "组件启用控制",
|
||||
"picture": "拍照功能核心配置",
|
||||
"storage": "照片存储相关配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="take_picture_plugin", description="插件名称", required=True),
|
||||
"version": ConfigField(type=str, default="1.3.0", description="插件版本号"),
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"description": ConfigField(type=str, default="提供生成自拍照和展示最近照片的功能", description="插件描述", required=True),
|
||||
},
|
||||
"api": {
|
||||
"base_url": ConfigField(
|
||||
type=str,
|
||||
default="https://ark.cn-beijing.volces.com/api/v3",
|
||||
description="API基础URL",
|
||||
example="https://api.example.com/v1",
|
||||
),
|
||||
"volcano_generate_api_key": ConfigField(
|
||||
type=str, default="YOUR_DOUBAO_API_KEY_HERE", description="火山引擎豆包API密钥", required=True
|
||||
),
|
||||
},
|
||||
"components": {
|
||||
"enable_take_picture_action": ConfigField(type=bool, default=True, description="是否启用拍照Action"),
|
||||
"enable_show_pics_command": ConfigField(type=bool, default=True, description="是否启用展示照片Command"),
|
||||
},
|
||||
"picture": {
|
||||
"default_model": ConfigField(
|
||||
type=str,
|
||||
default="doubao-seedream-3-0-t2i-250415",
|
||||
description="默认使用的文生图模型",
|
||||
choices=["doubao-seedream-3-0-t2i-250415", "doubao-seedream-2-0-t2i"],
|
||||
),
|
||||
"default_style": ConfigField(type=str, default="动漫", description="默认图片风格"),
|
||||
"default_size": ConfigField(
|
||||
type=str,
|
||||
default="1024x1024",
|
||||
description="默认图片尺寸",
|
||||
example="1024x1024",
|
||||
choices=["1024x1024", "1024x1280", "1280x1024", "1024x1536", "1536x1024"],
|
||||
),
|
||||
"default_watermark": ConfigField(type=bool, default=True, description="是否默认添加水印"),
|
||||
"default_guidance_scale": ConfigField(
|
||||
type=float, default=2.5, description="模型指导强度,影响图片与提示的关联性", example="2.0"
|
||||
),
|
||||
"default_seed": ConfigField(type=int, default=42, description="随机种子,用于复现图片"),
|
||||
"prompt_templates": ConfigField(
|
||||
type=list,
|
||||
default=TakePictureAction.DEFAULT_PROMPT_TEMPLATES,
|
||||
description="用于生成自拍照的prompt模板"
|
||||
),
|
||||
},
|
||||
"storage": {
|
||||
"max_photos": ConfigField(type=int, default=50, description="最大保存的照片数量"),
|
||||
"log_file": ConfigField(type=str, default="picture_log.json", description="照片日志文件名"),
|
||||
"enable_cache": ConfigField(type=bool, default=True, description="是否启用请求缓存"),
|
||||
"max_cache_size": ConfigField(type=int, default=10, description="最大缓存数量"),
|
||||
}
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
components = []
|
||||
if self.get_config("components.enable_take_picture_action", True):
|
||||
components.append((TakePictureAction.get_action_info(), TakePictureAction))
|
||||
if self.get_config("components.enable_show_pics_command", True):
|
||||
components.append((ShowRecentPicturesCommand.get_command_info(), ShowRecentPicturesCommand))
|
||||
return components
|
||||
@@ -57,7 +57,7 @@ class TTSAction(BaseAction):
|
||||
|
||||
try:
|
||||
# 发送TTS消息
|
||||
await self.send_type(type="tts_text", text=processed_text)
|
||||
await self.send_custom(message_type="tts_text", content=processed_text)
|
||||
|
||||
logger.info(f"{self.log_prefix} TTS动作执行成功,文本长度: {len(processed_text)}")
|
||||
return True, "TTS动作执行成功"
|
||||
|
||||
@@ -62,7 +62,7 @@ class VTBAction(BaseAction):
|
||||
|
||||
try:
|
||||
# 发送VTB动作消息 - 使用新版本的send_type方法
|
||||
await self.send_type(type="vtb_text", text=processed_text)
|
||||
await self.send_custom(message_type="vtb_text", content=processed_text)
|
||||
|
||||
logger.info(f"{self.log_prefix} VTB动作执行成功,文本内容: {processed_text}")
|
||||
return True, "VTB动作执行成功"
|
||||
|
||||
Reference in New Issue
Block a user