feat: 增加对动作的重载选项

This commit is contained in:
tcmofashi
2025-08-12 18:42:55 +08:00
parent 1e7f3a92a6
commit 76285ecb8b
2 changed files with 44 additions and 6 deletions

View File

@@ -1 +1 @@
ENABLE_S4U = False
ENABLE_S4U = True

View File

@@ -15,7 +15,8 @@ from src.mais4u.s4u_config import s4u_config
logger = get_logger("action")
HEAD_CODE = {
# 使用字典作为默认值但通过Prompt来注册以便外部重载
DEFAULT_HEAD_CODE = {
"看向上方": "(0,0.5,0)",
"看向下方": "(0,-0.5,0)",
"看向左边": "(-1,0,0)",
@@ -26,7 +27,7 @@ HEAD_CODE = {
"看向正前方": "(0,0,0)",
}
BODY_CODE = {
DEFAULT_BODY_CODE = {
"双手背后向前弯腰": "010_0070",
"歪头双手合十": "010_0100",
"标准文静站立": "010_0101",
@@ -42,7 +43,44 @@ BODY_CODE = {
}
def get_head_code() -> dict:
"""获取头部动作代码字典"""
head_code_str = global_prompt_manager.get_prompt("head_code_prompt")
if not head_code_str:
return DEFAULT_HEAD_CODE
try:
return json.loads(head_code_str)
except Exception as e:
logger.error(f"解析head_code_prompt失败使用默认值: {e}")
return DEFAULT_HEAD_CODE
def get_body_code() -> dict:
"""获取身体动作代码字典"""
body_code_str = global_prompt_manager.get_prompt("body_code_prompt")
if not body_code_str:
return DEFAULT_BODY_CODE
try:
return json.loads(body_code_str)
except Exception as e:
logger.error(f"解析body_code_prompt失败使用默认值: {e}")
return DEFAULT_BODY_CODE
def init_prompt():
# 注册头部动作代码
Prompt(
json.dumps(DEFAULT_HEAD_CODE, ensure_ascii=False, indent=2),
"head_code_prompt",
)
# 注册身体动作代码
Prompt(
json.dumps(DEFAULT_BODY_CODE, ensure_ascii=False, indent=2),
"body_code_prompt",
)
# 注册原有提示模板
Prompt(
"""
{chat_talking_prompt}
@@ -105,7 +143,7 @@ class ChatAction:
async def send_action_update(self):
"""发送动作更新到前端"""
body_code = BODY_CODE.get(self.body_action, "")
body_code = get_body_code().get(self.body_action, "")
await send_api.custom_to_stream(
message_type="body_action",
content=body_code,
@@ -147,7 +185,7 @@ class ChatAction:
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
available_actions = [k for k in get_body_code().keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(
@@ -210,7 +248,7 @@ class ChatAction:
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
available_actions = [k for k in get_body_code().keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(