Feat:添加对Action插件的支持,现在可以编写插件
This commit is contained in:
@@ -1,18 +1,18 @@
|
||||
from typing import Dict, List, Optional, Callable, Coroutine, Type, Any, Union
|
||||
import os
|
||||
import importlib
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY, _DEFAULT_ACTIONS
|
||||
from typing import Dict, List, Optional, Callable, Coroutine, Type, Any
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
|
||||
from src.common.logger_manager import get_logger
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
|
||||
# 导入动作类,确保装饰器被执行
|
||||
from src.chat.focus_chat.planners.actions.reply_action import ReplyAction
|
||||
from src.chat.focus_chat.planners.actions.no_reply_action import NoReplyAction
|
||||
import src.chat.focus_chat.planners.actions # noqa
|
||||
|
||||
logger = get_logger("action_factory")
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
# 定义动作信息类型
|
||||
ActionInfo = Dict[str, Any]
|
||||
@@ -31,20 +31,18 @@ class ActionManager:
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
# 临时备份原始使用中的动作
|
||||
self._original_actions_backup: Optional[Dict[str, ActionInfo]] = None
|
||||
|
||||
|
||||
# 默认动作集,仅作为快照,用于恢复默认
|
||||
self._default_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
|
||||
# 加载所有已注册动作
|
||||
self._load_registered_actions()
|
||||
|
||||
# 加载插件动作
|
||||
self._load_plugin_actions()
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = self._default_actions.copy()
|
||||
|
||||
# logger.info(f"当前可用动作: {list(self._using_actions.keys())}")
|
||||
# for action_name, action_info in self._using_actions.items():
|
||||
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
|
||||
|
||||
|
||||
def _load_registered_actions(self) -> None:
|
||||
"""
|
||||
@@ -54,37 +52,78 @@ class ActionManager:
|
||||
# 从_ACTION_REGISTRY获取所有已注册动作
|
||||
for action_name, action_class in _ACTION_REGISTRY.items():
|
||||
# 获取动作相关信息
|
||||
action_description:str = getattr(action_class, "action_description", "")
|
||||
action_parameters:dict[str:str] = getattr(action_class, "action_parameters", {})
|
||||
action_require:list[str] = getattr(action_class, "action_require", [])
|
||||
is_default:bool = getattr(action_class, "default", False)
|
||||
|
||||
# 不读取插件动作和基类
|
||||
if action_name == "base_action" or action_name == "plugin_action":
|
||||
continue
|
||||
|
||||
action_description: str = getattr(action_class, "action_description", "")
|
||||
action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {})
|
||||
action_require: list[str] = getattr(action_class, "action_require", [])
|
||||
is_default: bool = getattr(action_class, "default", False)
|
||||
|
||||
if action_name and action_description:
|
||||
# 创建动作信息字典
|
||||
action_info = {
|
||||
"description": action_description,
|
||||
"parameters": action_parameters,
|
||||
"require": action_require
|
||||
"require": action_require,
|
||||
}
|
||||
|
||||
# 注册2
|
||||
print("注册2")
|
||||
print(action_info)
|
||||
|
||||
|
||||
# 添加到所有已注册的动作
|
||||
self._registered_actions[action_name] = action_info
|
||||
|
||||
|
||||
# 添加到默认动作(如果是默认动作)
|
||||
if is_default:
|
||||
self._default_actions[action_name] = action_info
|
||||
|
||||
|
||||
logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
|
||||
logger.info(f"默认动作: {list(self._default_actions.keys())}")
|
||||
# for action_name, action_info in self._default_actions.items():
|
||||
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
|
||||
|
||||
for action_name, action_info in self._default_actions.items():
|
||||
logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载已注册动作失败: {e}")
|
||||
|
||||
def _load_plugin_actions(self) -> None:
|
||||
"""
|
||||
加载所有插件目录中的动作
|
||||
"""
|
||||
try:
|
||||
# 检查插件目录是否存在
|
||||
plugin_path = "src.plugins"
|
||||
plugin_dir = plugin_path.replace('.', os.path.sep)
|
||||
if not os.path.exists(plugin_dir):
|
||||
logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载")
|
||||
return
|
||||
|
||||
# 导入插件包
|
||||
try:
|
||||
plugins_package = importlib.import_module(plugin_path)
|
||||
except ImportError as e:
|
||||
logger.error(f"导入插件包失败: {e}")
|
||||
return
|
||||
|
||||
# 遍历插件包中的所有子包
|
||||
for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + '.'):
|
||||
if not is_pkg:
|
||||
continue
|
||||
|
||||
# 检查插件是否有actions子包
|
||||
plugin_actions_path = f"{plugin_name}.actions"
|
||||
try:
|
||||
# 尝试导入插件的actions包
|
||||
importlib.import_module(plugin_actions_path)
|
||||
logger.info(f"成功加载插件动作模块: {plugin_actions_path}")
|
||||
except ImportError as e:
|
||||
logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}")
|
||||
continue
|
||||
|
||||
# 再次从_ACTION_REGISTRY获取所有动作(包括刚刚从插件加载的)
|
||||
self._load_registered_actions()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件动作失败: {e}")
|
||||
|
||||
def create_action(
|
||||
self,
|
||||
@@ -99,8 +138,8 @@ class ActionManager:
|
||||
current_cycle: CycleDetail,
|
||||
log_prefix: str,
|
||||
on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
|
||||
total_no_reply_count: int = 0,
|
||||
total_waiting_time: float = 0.0,
|
||||
# total_no_reply_count: int = 0,
|
||||
# total_waiting_time: float = 0.0,
|
||||
shutting_down: bool = False,
|
||||
) -> Optional[BaseAction]:
|
||||
"""
|
||||
@@ -129,14 +168,14 @@ class ActionManager:
|
||||
if action_name not in self._using_actions:
|
||||
logger.warning(f"当前不可用的动作类型: {action_name}")
|
||||
return None
|
||||
|
||||
|
||||
handler_class = _ACTION_REGISTRY.get(action_name)
|
||||
if not handler_class:
|
||||
logger.warning(f"未注册的动作类型: {action_name}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建动作实例并传递所有必要参数
|
||||
# 创建动作实例
|
||||
instance = handler_class(
|
||||
action_name=action_name,
|
||||
action_data=action_data,
|
||||
@@ -144,16 +183,16 @@ class ActionManager:
|
||||
cycle_timers=cycle_timers,
|
||||
thinking_id=thinking_id,
|
||||
observations=observations,
|
||||
on_consecutive_no_reply_callback=on_consecutive_no_reply_callback,
|
||||
current_cycle=current_cycle,
|
||||
log_prefix=log_prefix,
|
||||
total_no_reply_count=total_no_reply_count,
|
||||
total_waiting_time=total_waiting_time,
|
||||
shutting_down=shutting_down,
|
||||
expressor=expressor,
|
||||
chat_stream=chat_stream,
|
||||
current_cycle=current_cycle,
|
||||
log_prefix=log_prefix,
|
||||
on_consecutive_no_reply_callback=on_consecutive_no_reply_callback,
|
||||
# total_no_reply_count=total_no_reply_count,
|
||||
# total_waiting_time=total_waiting_time,
|
||||
shutting_down=shutting_down,
|
||||
)
|
||||
|
||||
|
||||
return instance
|
||||
|
||||
except Exception as e:
|
||||
@@ -167,7 +206,7 @@ class ActionManager:
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取默认动作集"""
|
||||
return self._default_actions.copy()
|
||||
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集"""
|
||||
return self._using_actions.copy()
|
||||
@@ -175,21 +214,21 @@ class ActionManager:
|
||||
def add_action_to_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
添加已注册的动作到当前使用的动作集
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
if action_name not in self._registered_actions:
|
||||
logger.warning(f"添加失败: 动作 {action_name} 未注册")
|
||||
return False
|
||||
|
||||
|
||||
if action_name in self._using_actions:
|
||||
logger.info(f"动作 {action_name} 已经在使用中")
|
||||
return True
|
||||
|
||||
|
||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
logger.info(f"添加动作 {action_name} 到使用集")
|
||||
return True
|
||||
@@ -197,17 +236,17 @@ class ActionManager:
|
||||
def remove_action_from_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
从当前使用的动作集中移除指定动作
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 移除是否成功
|
||||
"""
|
||||
if action_name not in self._using_actions:
|
||||
logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中")
|
||||
return False
|
||||
|
||||
|
||||
del self._using_actions[action_name]
|
||||
logger.info(f"已从使用集中移除动作 {action_name}")
|
||||
return True
|
||||
@@ -215,30 +254,26 @@ class ActionManager:
|
||||
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
||||
"""
|
||||
添加新的动作到注册集
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
description: 动作描述
|
||||
parameters: 动作参数定义,默认为空字典
|
||||
require: 动作依赖项,默认为空列表
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
if action_name in self._registered_actions:
|
||||
return False
|
||||
|
||||
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
if require is None:
|
||||
require = []
|
||||
|
||||
action_info = {
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
"require": require
|
||||
}
|
||||
|
||||
|
||||
action_info = {"description": description, "parameters": parameters, "require": require}
|
||||
|
||||
self._registered_actions[action_name] = action_info
|
||||
return True
|
||||
|
||||
@@ -264,7 +299,7 @@ class ActionManager:
|
||||
if self._original_actions_backup is not None:
|
||||
self._using_actions = self._original_actions_backup.copy()
|
||||
self._original_actions_backup = None
|
||||
|
||||
|
||||
def restore_default_actions(self) -> None:
|
||||
"""恢复默认动作集到使用集"""
|
||||
self._using_actions = self._default_actions.copy()
|
||||
@@ -273,15 +308,12 @@ class ActionManager:
|
||||
def get_action(self, action_name: str) -> Optional[Type[BaseAction]]:
|
||||
"""
|
||||
获取指定动作的处理器类
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None
|
||||
"""
|
||||
return _ACTION_REGISTRY.get(action_name)
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
ActionFactory = ActionManager()
|
||||
5
src/chat/focus_chat/planners/actions/__init__.py
Normal file
5
src/chat/focus_chat/planners/actions/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# 导入所有动作模块以确保装饰器被执行
|
||||
from . import reply_action # noqa
|
||||
from . import no_reply_action # noqa
|
||||
|
||||
# 在此处添加更多动作模块导入
|
||||
@@ -12,7 +12,7 @@ _DEFAULT_ACTIONS: Dict[str, str] = {}
|
||||
def register_action(cls):
|
||||
"""
|
||||
动作注册装饰器
|
||||
|
||||
|
||||
用法:
|
||||
@register_action
|
||||
class MyAction(BaseAction):
|
||||
@@ -24,22 +24,22 @@ def register_action(cls):
|
||||
if not hasattr(cls, "action_name") or not hasattr(cls, "action_description"):
|
||||
logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description")
|
||||
return cls
|
||||
|
||||
action_name = getattr(cls, "action_name")
|
||||
action_description = getattr(cls, "action_description")
|
||||
|
||||
action_name = cls.action_name
|
||||
action_description = cls.action_description
|
||||
is_default = getattr(cls, "default", False)
|
||||
|
||||
|
||||
if not action_name or not action_description:
|
||||
logger.error(f"动作类 {cls.__name__} 的 action_name 或 action_description 为空")
|
||||
return cls
|
||||
|
||||
|
||||
# 将动作类注册到全局注册表
|
||||
_ACTION_REGISTRY[action_name] = cls
|
||||
|
||||
|
||||
# 如果是默认动作,添加到默认动作集
|
||||
if is_default:
|
||||
_DEFAULT_ACTIONS[action_name] = action_description
|
||||
|
||||
|
||||
logger.info(f"已注册动作: {action_name} -> {cls.__name__},默认: {is_default}")
|
||||
return cls
|
||||
|
||||
@@ -60,15 +60,14 @@ class BaseAction(ABC):
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
"""
|
||||
#每个动作必须实现
|
||||
self.action_name:str = "base_action"
|
||||
self.action_description:str = "基础动作"
|
||||
self.action_parameters:dict = {}
|
||||
self.action_require:list[str] = []
|
||||
|
||||
self.default:bool = False
|
||||
|
||||
|
||||
# 每个动作必须实现
|
||||
self.action_name: str = "base_action"
|
||||
self.action_description: str = "基础动作"
|
||||
self.action_parameters: dict = {}
|
||||
self.action_require: list[str] = []
|
||||
|
||||
self.default: bool = False
|
||||
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
|
||||
@@ -29,7 +29,7 @@ class NoReplyAction(BaseAction):
|
||||
action_require = [
|
||||
"话题无关/无聊/不感兴趣/不懂",
|
||||
"最后一条消息是你自己发的且无人回应你",
|
||||
"你发送了太多消息,且无人回复"
|
||||
"你发送了太多消息,且无人回复",
|
||||
]
|
||||
default = True
|
||||
|
||||
@@ -43,10 +43,10 @@ class NoReplyAction(BaseAction):
|
||||
on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
|
||||
current_cycle: CycleDetail,
|
||||
log_prefix: str,
|
||||
total_no_reply_count: int = 0,
|
||||
total_waiting_time: float = 0.0,
|
||||
# total_no_reply_count: int = 0,
|
||||
# total_waiting_time: float = 0.0,
|
||||
shutting_down: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化不回复动作处理器
|
||||
|
||||
@@ -69,8 +69,8 @@ class NoReplyAction(BaseAction):
|
||||
self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback
|
||||
self._current_cycle = current_cycle
|
||||
self.log_prefix = log_prefix
|
||||
self.total_no_reply_count = total_no_reply_count
|
||||
self.total_waiting_time = total_waiting_time
|
||||
# self.total_no_reply_count = total_no_reply_count
|
||||
# self.total_waiting_time = total_waiting_time
|
||||
self._shutting_down = shutting_down
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
@@ -96,34 +96,6 @@ class NoReplyAction(BaseAction):
|
||||
# 从计时器获取实际等待时间
|
||||
current_waiting = self.cycle_timers.get("等待新消息", 0.0)
|
||||
|
||||
if not self._shutting_down:
|
||||
self.total_no_reply_count += 1
|
||||
self.total_waiting_time += current_waiting # 累加等待时间
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 连续不回复计数增加: {self.total_no_reply_count}/{CONSECUTIVE_NO_REPLY_THRESHOLD}, "
|
||||
f"本次等待: {current_waiting:.2f}秒, 累计等待: {self.total_waiting_time:.2f}秒"
|
||||
)
|
||||
|
||||
# 检查是否同时达到次数和时间阈值
|
||||
time_threshold = 0.66 * WAITING_TIME_THRESHOLD * CONSECUTIVE_NO_REPLY_THRESHOLD
|
||||
if (
|
||||
self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD
|
||||
and self.total_waiting_time >= time_threshold
|
||||
):
|
||||
logger.info(
|
||||
f"{self.log_prefix} 连续不回复达到阈值 ({self.total_no_reply_count}次) "
|
||||
f"且累计等待时间达到 {self.total_waiting_time:.2f}秒 (阈值 {time_threshold}秒),"
|
||||
f"调用回调请求状态转换"
|
||||
)
|
||||
# 调用回调。注意:这里不重置计数器和时间,依赖回调函数成功改变状态来隐式重置上下文。
|
||||
await self.on_consecutive_no_reply_callback()
|
||||
elif self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD:
|
||||
# 仅次数达到阈值,但时间未达到
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 连续不回复次数达到阈值 ({self.total_no_reply_count}次) "
|
||||
f"但累计等待时间 {self.total_waiting_time:.2f}秒 未达到时间阈值 ({time_threshold}秒),暂不调用回调"
|
||||
)
|
||||
# else: 次数和时间都未达到阈值,不做处理
|
||||
|
||||
return True, "" # 不回复动作没有回复文本
|
||||
|
||||
|
||||
215
src/chat/focus_chat/planners/actions/plugin_action.py
Normal file
215
src/chat/focus_chat/planners/actions/plugin_action.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import traceback
|
||||
from typing import Tuple, Dict, List, Any, Optional
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.person_info.person_info import person_info_manager
|
||||
from abc import abstractmethod
|
||||
|
||||
logger = get_logger("plugin_action")
|
||||
|
||||
class PluginAction(BaseAction):
|
||||
"""插件动作基类
|
||||
|
||||
封装了主程序内部依赖,提供简化的API接口给插件开发者
|
||||
"""
|
||||
|
||||
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs):
|
||||
"""初始化插件动作基类"""
|
||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||
|
||||
# 存储内部服务和对象引用
|
||||
self._services = {}
|
||||
|
||||
# 从kwargs提取必要的内部服务
|
||||
if "observations" in kwargs:
|
||||
self._services["observations"] = kwargs["observations"]
|
||||
if "expressor" in kwargs:
|
||||
self._services["expressor"] = kwargs["expressor"]
|
||||
if "chat_stream" in kwargs:
|
||||
self._services["chat_stream"] = kwargs["chat_stream"]
|
||||
if "current_cycle" in kwargs:
|
||||
self._services["current_cycle"] = kwargs["current_cycle"]
|
||||
|
||||
self.log_prefix = kwargs.get("log_prefix", "")
|
||||
|
||||
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
|
||||
"""根据用户名获取用户ID"""
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
user_id = await person_info_manager.get_value(person_id, "user_id")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
return platform, user_id
|
||||
|
||||
# 提供简化的API方法
|
||||
async def send_message(self, text: str, target: Optional[str] = None) -> bool:
|
||||
"""发送消息的简化方法
|
||||
|
||||
Args:
|
||||
text: 要发送的消息文本
|
||||
target: 目标消息(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
expressor = self._services.get("expressor")
|
||||
chat_stream = self._services.get("chat_stream")
|
||||
|
||||
if not expressor or not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
|
||||
return False
|
||||
|
||||
# 构造简化的动作数据
|
||||
reply_data = {
|
||||
"text": text,
|
||||
"target": target or "",
|
||||
"emojis": []
|
||||
}
|
||||
|
||||
# 获取锚定消息(如果有)
|
||||
observations = self._services.get("observations", [])
|
||||
|
||||
chatting_observation: ChattingObservation = next(
|
||||
obs for obs in observations
|
||||
if isinstance(obs, ChattingObservation)
|
||||
)
|
||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||
|
||||
# 如果没有找到锚点消息,创建一个占位符
|
||||
if not anchor_message:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
|
||||
response_set = [
|
||||
("text", text),
|
||||
]
|
||||
|
||||
# 调用内部方法发送消息
|
||||
success = await expressor.send_response_messages(
|
||||
anchor_message=anchor_message,
|
||||
response_set=response_set,
|
||||
)
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
|
||||
"""发送消息的简化方法
|
||||
|
||||
Args:
|
||||
text: 要发送的消息文本
|
||||
target: 目标消息(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
expressor = self._services.get("expressor")
|
||||
chat_stream = self._services.get("chat_stream")
|
||||
|
||||
if not expressor or not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
|
||||
return False
|
||||
|
||||
# 构造简化的动作数据
|
||||
reply_data = {
|
||||
"text": text,
|
||||
"target": target or "",
|
||||
"emojis": []
|
||||
}
|
||||
|
||||
# 获取锚定消息(如果有)
|
||||
observations = self._services.get("observations", [])
|
||||
|
||||
chatting_observation: ChattingObservation = next(
|
||||
obs for obs in observations
|
||||
if isinstance(obs, ChattingObservation)
|
||||
)
|
||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||
|
||||
# 如果没有找到锚点消息,创建一个占位符
|
||||
if not anchor_message:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message.update_chat_stream(chat_stream)
|
||||
|
||||
# 调用内部方法发送消息
|
||||
success, _ = await expressor.deal_reply(
|
||||
cycle_timers=self.cycle_timers,
|
||||
action_data=reply_data,
|
||||
anchor_message=anchor_message,
|
||||
reasoning=self.reasoning,
|
||||
thinking_id=self.thinking_id
|
||||
)
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
|
||||
return False
|
||||
|
||||
def get_chat_type(self) -> str:
|
||||
"""获取当前聊天类型
|
||||
|
||||
Returns:
|
||||
str: 聊天类型 ("group" 或 "private")
|
||||
"""
|
||||
chat_stream = self._services.get("chat_stream")
|
||||
if chat_stream and hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
return "unknown"
|
||||
|
||||
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息
|
||||
|
||||
Args:
|
||||
count: 要获取的消息数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
|
||||
"""
|
||||
messages = []
|
||||
observations = self._services.get("observations", [])
|
||||
|
||||
if observations and len(observations) > 0:
|
||||
obs = observations[0]
|
||||
if hasattr(obs, "get_talking_message"):
|
||||
raw_messages = obs.get_talking_message()
|
||||
# 转换为简化格式
|
||||
for msg in raw_messages[-count:]:
|
||||
simple_msg = {
|
||||
"sender": msg.get("sender", "未知"),
|
||||
"content": msg.get("content", ""),
|
||||
"timestamp": msg.get("timestamp", 0)
|
||||
}
|
||||
messages.append(simple_msg)
|
||||
|
||||
return messages
|
||||
|
||||
@abstractmethod
|
||||
async def process(self) -> Tuple[bool, str]:
|
||||
"""插件处理逻辑,子类必须实现此方法
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""实现BaseAction的抽象方法,调用子类的process方法
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
return await self.process()
|
||||
@@ -1,10 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
||||
from typing import Tuple, List, Optional
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -22,23 +20,22 @@ class ReplyAction(BaseAction):
|
||||
处理构建和发送消息回复的动作。
|
||||
"""
|
||||
|
||||
action_name:str = "reply"
|
||||
action_description:str = "表达想法,可以只包含文本、表情或两者都有"
|
||||
action_parameters:dict[str:str] = {
|
||||
action_name: str = "reply"
|
||||
action_description: str = "表达想法,可以只包含文本、表情或两者都有"
|
||||
action_parameters: dict[str:str] = {
|
||||
"text": "你想要表达的内容(可选)",
|
||||
"emojis": "描述当前使用表情包的场景(可选)",
|
||||
"target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)",
|
||||
}
|
||||
action_require:list[str] = [
|
||||
action_require: list[str] = [
|
||||
"有实质性内容需要表达",
|
||||
"有人提到你,但你还没有回应他",
|
||||
"在合适的时候添加表情(不要总是添加)",
|
||||
"如果你要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本",
|
||||
"除非有明确的回复目标,如果选择了target,不用特别提到某个人的人名",
|
||||
"如果你有明确的,要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本",
|
||||
"一次只回复一个人,一次只回复一个话题,突出重点",
|
||||
"如果是自己发的消息想继续,需自然衔接",
|
||||
"避免重复或评价自己的发言,不要和自己聊天",
|
||||
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。"
|
||||
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要有额外的符号,尽量简单简短",
|
||||
]
|
||||
default = True
|
||||
|
||||
@@ -54,7 +51,7 @@ class ReplyAction(BaseAction):
|
||||
chat_stream: ChatStream,
|
||||
current_cycle: CycleDetail,
|
||||
log_prefix: str,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化回复动作处理器
|
||||
|
||||
@@ -89,9 +86,9 @@ class ReplyAction(BaseAction):
|
||||
reasoning=self.reasoning,
|
||||
reply_data=self.action_data,
|
||||
cycle_timers=self.cycle_timers,
|
||||
thinking_id=self.thinking_id
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_reply(
|
||||
self, reasoning: str, reply_data: dict, cycle_timers: dict, thinking_id: str
|
||||
) -> tuple[bool, str]:
|
||||
@@ -105,13 +102,16 @@ class ReplyAction(BaseAction):
|
||||
"emojis": "微笑" # 表情关键词列表(可选)
|
||||
}
|
||||
"""
|
||||
# 重置连续不回复计数器
|
||||
self.total_no_reply_count = 0
|
||||
self.total_waiting_time = 0.0
|
||||
|
||||
# 从聊天观察获取锚定消息
|
||||
observations: ChattingObservation = self.observations[0]
|
||||
anchor_message = observations.serch_message_by_text(reply_data["target"])
|
||||
chatting_observation: ChattingObservation = next(
|
||||
obs for obs in self.observations
|
||||
if isinstance(obs, ChattingObservation)
|
||||
)
|
||||
if reply_data.get("target"):
|
||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||
else:
|
||||
anchor_message = None
|
||||
|
||||
# 如果没有找到锚点消息,创建一个占位符
|
||||
if not anchor_message:
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import List, Dict, Any, Optional
|
||||
from rich.traceback import install
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
from src.chat.focus_chat.info.cycle_info import CycleInfo
|
||||
@@ -13,16 +12,21 @@ from src.chat.focus_chat.info.structured_info import StructuredInfo
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.individuality.individuality import Individuality
|
||||
from src.chat.focus_chat.planners.action_factory import ActionManager
|
||||
from src.chat.focus_chat.planners.action_factory import ActionInfo
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.focus_chat.planners.action_manager import ActionInfo
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话:
|
||||
"""{extra_info_block}
|
||||
|
||||
你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话:
|
||||
{chat_content_block}
|
||||
|
||||
{mind_info_block}
|
||||
{cycle_info_block}
|
||||
|
||||
@@ -44,20 +48,20 @@ def init_prompt():
|
||||
}}
|
||||
|
||||
请输出你的决策 JSON:""",
|
||||
"planner_prompt",)
|
||||
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
action_name: {action_name}
|
||||
描述:{action_description}
|
||||
参数:
|
||||
{action_parameters}
|
||||
{action_parameters}
|
||||
动作要求:
|
||||
{action_require}
|
||||
""",
|
||||
{action_require}""",
|
||||
"action_prompt",
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
def __init__(self, log_prefix: str, action_manager: ActionManager):
|
||||
@@ -68,7 +72,7 @@ class ActionPlanner:
|
||||
max_tokens=1000,
|
||||
request_type="action_planning", # 用于动作规划
|
||||
)
|
||||
|
||||
|
||||
self.action_manager = action_manager
|
||||
|
||||
async def plan(self, all_plan_info: List[InfoBase], cycle_timers: dict) -> Dict[str, Any]:
|
||||
@@ -85,6 +89,7 @@ class ActionPlanner:
|
||||
|
||||
try:
|
||||
# 获取观察信息
|
||||
extra_info: list[str] = []
|
||||
for info in all_plan_info:
|
||||
if isinstance(info, ObsInfo):
|
||||
logger.debug(f"{self.log_prefix} 观察信息: {info}")
|
||||
@@ -104,9 +109,11 @@ class ActionPlanner:
|
||||
elif isinstance(info, StructuredInfo):
|
||||
logger.debug(f"{self.log_prefix} 结构化信息: {info}")
|
||||
structured_info = info.get_data()
|
||||
else:
|
||||
extra_info.append(info.get_processed_info())
|
||||
|
||||
current_available_actions = self.action_manager.get_using_actions()
|
||||
|
||||
|
||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||
prompt = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat, # <-- Pass HFC state
|
||||
@@ -116,6 +123,7 @@ class ActionPlanner:
|
||||
# structured_info=structured_info, # <-- Pass SubMind info
|
||||
current_available_actions=current_available_actions, # <-- Pass determined actions
|
||||
cycle_info=cycle_info, # <-- Pass cycle info
|
||||
extra_info=extra_info,
|
||||
)
|
||||
|
||||
# --- 调用 LLM (普通文本生成) ---
|
||||
@@ -142,15 +150,13 @@ class ActionPlanner:
|
||||
extracted_action = parsed_json.get("action", "no_reply")
|
||||
extracted_reasoning = parsed_json.get("reasoning", "LLM未提供理由")
|
||||
|
||||
# 新的reply格式
|
||||
if extracted_action == "reply":
|
||||
action_data = {
|
||||
"text": parsed_json.get("text", []),
|
||||
"emojis": parsed_json.get("emojis", []),
|
||||
"target": parsed_json.get("target", ""),
|
||||
}
|
||||
else:
|
||||
action_data = {} # 其他动作可能不需要额外数据
|
||||
# 将所有其他属性添加到action_data
|
||||
action_data = {}
|
||||
for key, value in parsed_json.items():
|
||||
if key not in ["action", "reasoning"]:
|
||||
action_data[key] = value
|
||||
|
||||
# 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data
|
||||
|
||||
if extracted_action not in current_available_actions:
|
||||
logger.warning(
|
||||
@@ -197,7 +203,6 @@ class ActionPlanner:
|
||||
# 返回结果字典
|
||||
return plan_result
|
||||
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool, # Now passed as argument
|
||||
@@ -206,6 +211,7 @@ class ActionPlanner:
|
||||
current_mind: Optional[str],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
cycle_info: Optional[str],
|
||||
extra_info: list[str],
|
||||
) -> str:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
@@ -218,7 +224,6 @@ class ActionPlanner:
|
||||
)
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
|
||||
chat_content_block = ""
|
||||
if observed_messages_str:
|
||||
chat_content_block = f"聊天记录:\n{observed_messages_str}"
|
||||
@@ -234,7 +239,6 @@ class ActionPlanner:
|
||||
individuality = Individuality.get_instance()
|
||||
personality_block = individuality.get_prompt(x_person=2, level=2)
|
||||
|
||||
|
||||
action_options_block = ""
|
||||
for using_actions_name, using_actions_info in current_available_actions.items():
|
||||
# print(using_actions_name)
|
||||
@@ -242,29 +246,29 @@ class ActionPlanner:
|
||||
# print(using_actions_info["parameters"])
|
||||
# print(using_actions_info["require"])
|
||||
# print(using_actions_info["description"])
|
||||
|
||||
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
||||
|
||||
|
||||
param_text = ""
|
||||
for param_name, param_description in using_actions_info["parameters"].items():
|
||||
param_text += f"{param_name}: {param_description}\n"
|
||||
|
||||
param_text += f" {param_name}: {param_description}\n"
|
||||
|
||||
require_text = ""
|
||||
for require_item in using_actions_info["require"]:
|
||||
require_text += f"- {require_item}\n"
|
||||
|
||||
require_text += f" - {require_item}\n"
|
||||
|
||||
using_action_prompt = using_action_prompt.format(
|
||||
action_name=using_actions_name,
|
||||
action_description=using_actions_info["description"],
|
||||
action_parameters=param_text,
|
||||
action_require=require_text,
|
||||
)
|
||||
|
||||
|
||||
action_options_block += using_action_prompt
|
||||
|
||||
|
||||
extra_info_block = "\n".join(extra_info)
|
||||
extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策"
|
||||
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
bot_name=global_config.BOT_NICKNAME,
|
||||
@@ -274,6 +278,7 @@ class ActionPlanner:
|
||||
mind_info_block=mind_info_block,
|
||||
cycle_info_block=cycle_info,
|
||||
action_options_text=action_options_block,
|
||||
extra_info_block=extra_info_block,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user