ruff
This commit is contained in:
@@ -8,19 +8,22 @@ logger = get_logger("base_action")
|
||||
_ACTION_REGISTRY: Dict[str, Type["BaseAction"]] = {}
|
||||
_DEFAULT_ACTIONS: Dict[str, str] = {}
|
||||
|
||||
|
||||
# 动作激活类型枚举
|
||||
class ActionActivationType:
|
||||
ALWAYS = "always" # 默认参与到planner
|
||||
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
|
||||
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
|
||||
RANDOM = "random" # 随机启用action到planner
|
||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||
|
||||
|
||||
# 聊天模式枚举
|
||||
class ChatMode:
|
||||
FOCUS = "focus" # Focus聊天模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
ALL = "all" # 所有聊天模式
|
||||
|
||||
|
||||
def register_action(cls):
|
||||
"""
|
||||
动作注册装饰器
|
||||
@@ -81,13 +84,13 @@ class BaseAction(ABC):
|
||||
self.action_description: str = "基础动作"
|
||||
self.action_parameters: dict = {}
|
||||
self.action_require: list[str] = []
|
||||
|
||||
|
||||
# 动作激活类型设置
|
||||
# Focus模式下的激活类型,默认为always
|
||||
self.focus_activation_type: str = ActionActivationType.ALWAYS
|
||||
# Normal模式下的激活类型,默认为always
|
||||
# Normal模式下的激活类型,默认为always
|
||||
self.normal_activation_type: str = ActionActivationType.ALWAYS
|
||||
|
||||
|
||||
# 随机激活的概率(0.0-1.0),用于RANDOM激活类型
|
||||
self.random_activation_probability: float = 0.3
|
||||
# LLM判定的提示词,用于LLM_JUDGE激活类型
|
||||
|
||||
@@ -4,4 +4,4 @@ from . import no_reply_action # noqa
|
||||
from . import exit_focus_chat_action # noqa
|
||||
from . import emoji_action # noqa
|
||||
|
||||
# 在此处添加更多动作模块导入
|
||||
# 在此处添加更多动作模块导入
|
||||
|
||||
@@ -22,29 +22,26 @@ class EmojiAction(BaseAction):
|
||||
action_parameters: dict[str:str] = {
|
||||
"description": "文字描述你想要发送的表情包内容",
|
||||
}
|
||||
action_require: list[str] = [
|
||||
"表达情绪时可以选择使用",
|
||||
"重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"]
|
||||
action_require: list[str] = ["表达情绪时可以选择使用", "重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"]
|
||||
|
||||
associated_types: list[str] = ["emoji"]
|
||||
|
||||
enable_plugin = True
|
||||
|
||||
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
normal_activation_type = ActionActivationType.RANDOM
|
||||
|
||||
|
||||
random_activation_probability = global_config.normal_chat.emoji_chance
|
||||
|
||||
|
||||
parallel_action = True
|
||||
|
||||
|
||||
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用表情动作的条件:
|
||||
1. 用户明确要求使用表情包
|
||||
2. 这是一个适合表达强烈情绪的场合
|
||||
3. 不要发送太多表情包,如果你已经发送过多个表情包
|
||||
"""
|
||||
|
||||
|
||||
# 模式启用设置 - 表情动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.ALL
|
||||
|
||||
@@ -147,4 +144,4 @@ class EmojiAction(BaseAction):
|
||||
elif type == "emoji":
|
||||
reply_text += data
|
||||
|
||||
return success, reply_text
|
||||
return success, reply_text
|
||||
|
||||
@@ -27,7 +27,7 @@ class ExitFocusChatAction(BaseAction):
|
||||
]
|
||||
# 退出专注聊天是系统核心功能,不是插件,但默认不启用(需要特定条件触发)
|
||||
enable_plugin = False
|
||||
|
||||
|
||||
# 模式启用设置 - 退出专注聊天动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ class NoReplyAction(BaseAction):
|
||||
"想要休息一下",
|
||||
]
|
||||
enable_plugin = True
|
||||
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
|
||||
|
||||
# 模式启用设置 - no_reply动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
|
||||
@@ -31,16 +31,16 @@ class ReplyAction(BaseAction):
|
||||
action_require: list[str] = [
|
||||
"你想要闲聊或者随便附和",
|
||||
"有人提到你",
|
||||
"如果你刚刚进行了回复,不要对同一个话题重复回应"
|
||||
"如果你刚刚进行了回复,不要对同一个话题重复回应",
|
||||
]
|
||||
|
||||
associated_types: list[str] = ["text"]
|
||||
|
||||
enable_plugin = True
|
||||
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
|
||||
|
||||
# 模式启用设置 - 回复动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
@@ -89,12 +89,12 @@ class ReplyAction(BaseAction):
|
||||
cycle_timers=self.cycle_timers,
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=f"{reply_text}",
|
||||
)
|
||||
|
||||
|
||||
return success, reply_text
|
||||
|
||||
async def _handle_reply(
|
||||
@@ -115,22 +115,22 @@ class ReplyAction(BaseAction):
|
||||
chatting_observation: ChattingObservation = next(
|
||||
obs for obs in self.observations if isinstance(obs, ChattingObservation)
|
||||
)
|
||||
|
||||
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
|
||||
|
||||
# sender = ""
|
||||
target = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1)
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
# sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
anchor_message = chatting_observation.search_message_by_text(target)
|
||||
else:
|
||||
anchor_message = None
|
||||
|
||||
if anchor_message:
|
||||
|
||||
if anchor_message:
|
||||
anchor_message.update_chat_stream(self.chat_stream)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
@@ -138,7 +138,6 @@ class ReplyAction(BaseAction):
|
||||
self.chat_stream.platform, self.chat_stream.group_info, self.chat_stream
|
||||
)
|
||||
|
||||
|
||||
success, reply_set = await self.replyer.deal_reply(
|
||||
cycle_timers=cycle_timers,
|
||||
action_data=reply_data,
|
||||
@@ -158,8 +157,9 @@ class ReplyAction(BaseAction):
|
||||
|
||||
return success, reply_text
|
||||
|
||||
|
||||
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
|
||||
async def store_action_info(
|
||||
self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True
|
||||
) -> None:
|
||||
"""存储action执行信息到数据库
|
||||
|
||||
Args:
|
||||
@@ -188,9 +188,9 @@ class ReplyAction(BaseAction):
|
||||
chat_info_platform=chat_stream.platform,
|
||||
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
|
||||
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else ""
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import traceback
|
||||
from typing import Tuple, Dict, List, Any, Optional, Union, Type
|
||||
from typing import Tuple, Dict, Any, Optional
|
||||
from src.chat.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode # noqa F401
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
import inspect
|
||||
import toml # 导入 toml 库
|
||||
@@ -20,10 +16,10 @@ from src.chat.actions.plugin_api.stream_api import StreamAPI
|
||||
from src.chat.actions.plugin_api.hearflow_api import HearflowAPI
|
||||
|
||||
# 以下为类型注解需要
|
||||
from src.chat.message_receive.chat_stream import ChatStream # noqa
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor # noqa
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer # noqa
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo # noqa
|
||||
from src.chat.message_receive.chat_stream import ChatStream # noqa
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor # noqa
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer # noqa
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo # noqa
|
||||
|
||||
logger = get_logger("plugin_action")
|
||||
|
||||
@@ -35,7 +31,7 @@ class PluginAction(BaseAction, MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, Utils
|
||||
"""
|
||||
|
||||
action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
|
||||
|
||||
|
||||
# 默认激活类型设置,插件可以覆盖
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
normal_activation_type = ActionActivationType.ALWAYS
|
||||
@@ -43,7 +39,7 @@ class PluginAction(BaseAction, MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, Utils
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: list[str] = []
|
||||
keyword_case_sensitive: bool = False
|
||||
|
||||
|
||||
# 默认模式启用设置 - 插件动作默认在所有模式下可用,插件可以覆盖
|
||||
mode_enable = ChatMode.ALL
|
||||
|
||||
|
||||
@@ -7,11 +7,11 @@ from src.chat.actions.plugin_api.stream_api import StreamAPI
|
||||
from src.chat.actions.plugin_api.hearflow_api import HearflowAPI
|
||||
|
||||
__all__ = [
|
||||
'MessageAPI',
|
||||
'LLMAPI',
|
||||
'DatabaseAPI',
|
||||
'ConfigAPI',
|
||||
'UtilsAPI',
|
||||
'StreamAPI',
|
||||
'HearflowAPI',
|
||||
]
|
||||
"MessageAPI",
|
||||
"LLMAPI",
|
||||
"DatabaseAPI",
|
||||
"ConfigAPI",
|
||||
"UtilsAPI",
|
||||
"StreamAPI",
|
||||
"HearflowAPI",
|
||||
]
|
||||
|
||||
@@ -5,32 +5,33 @@ from src.person_info.person_info import person_info_manager
|
||||
|
||||
logger = get_logger("config_api")
|
||||
|
||||
|
||||
class ConfigAPI:
|
||||
"""配置API模块
|
||||
|
||||
|
||||
提供了配置读取和用户信息获取等功能
|
||||
"""
|
||||
|
||||
|
||||
def get_global_config(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从全局配置中获取一个值。
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
|
||||
|
||||
Args:
|
||||
key: 配置键名
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
return global_config.get(key, default)
|
||||
|
||||
|
||||
async def get_user_id_by_person_name(self, person_name: str) -> tuple[str, str]:
|
||||
"""根据用户名获取用户ID
|
||||
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (平台, 用户ID)
|
||||
"""
|
||||
@@ -38,16 +39,16 @@ class ConfigAPI:
|
||||
user_id = await person_info_manager.get_value(person_id, "user_id")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
return platform, user_id
|
||||
|
||||
|
||||
async def get_person_info(self, person_id: str, key: str, default: Any = None) -> Any:
|
||||
"""获取用户信息
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 用户ID
|
||||
key: 信息键名
|
||||
default: 默认值
|
||||
|
||||
|
||||
Returns:
|
||||
Any: 用户信息值或默认值
|
||||
"""
|
||||
return await person_info_manager.get_value(person_id, key, default)
|
||||
return await person_info_manager.get_value(person_id, key, default)
|
||||
|
||||
@@ -8,13 +8,16 @@ from peewee import Model, DoesNotExist
|
||||
|
||||
logger = get_logger("database_api")
|
||||
|
||||
|
||||
class DatabaseAPI:
|
||||
"""数据库API模块
|
||||
|
||||
|
||||
提供了数据库操作相关的功能
|
||||
"""
|
||||
|
||||
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
|
||||
|
||||
async def store_action_info(
|
||||
self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True
|
||||
) -> None:
|
||||
"""存储action执行信息到数据库
|
||||
|
||||
Args:
|
||||
@@ -44,13 +47,13 @@ class DatabaseAPI:
|
||||
chat_info_platform=chat_stream.platform,
|
||||
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
|
||||
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else ""
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def db_query(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
@@ -59,12 +62,12 @@ class DatabaseAPI:
|
||||
data: Dict[str, Any] = None,
|
||||
limit: int = None,
|
||||
order_by: List[str] = None,
|
||||
single_result: bool = False
|
||||
single_result: bool = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
|
||||
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
|
||||
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
@@ -73,7 +76,7 @@ class DatabaseAPI:
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
|
||||
Returns:
|
||||
根据查询类型返回不同的结果:
|
||||
- "get": 返回查询结果列表或单个结果(如果 single_result=True)
|
||||
@@ -81,24 +84,24 @@ class DatabaseAPI:
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
|
||||
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await self.db_query(
|
||||
Messages,
|
||||
Messages,
|
||||
query_type="get",
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by=["-time"]
|
||||
)
|
||||
|
||||
|
||||
# 创建一条记录
|
||||
new_record = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="create",
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
|
||||
)
|
||||
|
||||
|
||||
# 更新记录
|
||||
updated_count = await self.db_query(
|
||||
ActionRecords,
|
||||
@@ -106,14 +109,14 @@ class DatabaseAPI:
|
||||
filters={"action_id": "123"},
|
||||
data={"action_done": True}
|
||||
)
|
||||
|
||||
|
||||
# 删除记录
|
||||
deleted_count = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="delete",
|
||||
filters={"action_id": "123"}
|
||||
)
|
||||
|
||||
|
||||
# 计数
|
||||
count = await self.db_query(
|
||||
Messages,
|
||||
@@ -125,12 +128,12 @@ class DatabaseAPI:
|
||||
# 构建基本查询
|
||||
if query_type in ["get", "update", "delete", "count"]:
|
||||
query = model_class.select()
|
||||
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
|
||||
# 执行查询
|
||||
if query_type == "get":
|
||||
# 应用排序
|
||||
@@ -140,56 +143,56 @@ class DatabaseAPI:
|
||||
query = query.order_by(getattr(model_class, field[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, field))
|
||||
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get()
|
||||
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
|
||||
# 更新记录
|
||||
return query.update(**data).execute()
|
||||
|
||||
|
||||
elif query_type == "delete":
|
||||
# 删除记录
|
||||
return query.delete().execute()
|
||||
|
||||
|
||||
elif query_type == "count":
|
||||
# 计数
|
||||
return query.count()
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的查询类型: {query_type}")
|
||||
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
if query_type == "get" and single_result:
|
||||
return None
|
||||
return []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
@@ -198,21 +201,18 @@ class DatabaseAPI:
|
||||
raise "unknown query type"
|
||||
|
||||
async def db_raw_query(
|
||||
self,
|
||||
sql: str,
|
||||
params: List[Any] = None,
|
||||
fetch_results: bool = True
|
||||
self, sql: str, params: List[Any] = None, fetch_results: bool = True
|
||||
) -> Union[List[Dict[str, Any]], int, None]:
|
||||
"""执行原始SQL查询
|
||||
|
||||
|
||||
警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。
|
||||
|
||||
|
||||
Args:
|
||||
sql: 原始SQL查询字符串
|
||||
params: 查询参数列表,用于替换SQL中的占位符
|
||||
fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于
|
||||
UPDATE/INSERT/DELETE等操作设为False
|
||||
|
||||
|
||||
Returns:
|
||||
如果fetch_results为True,返回查询结果列表;
|
||||
如果fetch_results为False,返回受影响的行数;
|
||||
@@ -220,55 +220,51 @@ class DatabaseAPI:
|
||||
"""
|
||||
try:
|
||||
cursor = db.execute_sql(sql, params or [])
|
||||
|
||||
|
||||
if fetch_results:
|
||||
# 获取列名
|
||||
columns = [col[0] for col in cursor.description]
|
||||
|
||||
|
||||
# 构建结果字典列表
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
results.append(dict(zip(columns, row)))
|
||||
|
||||
|
||||
return results
|
||||
else:
|
||||
# 返回受影响的行数
|
||||
return cursor.rowcount
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
data: Dict[str, Any],
|
||||
key_field: str = None,
|
||||
key_value: Any = None
|
||||
self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类,如ActionRecords, Messages等
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名,例如"action_id"
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存后的记录数据
|
||||
None: 如果操作失败
|
||||
|
||||
|
||||
示例:
|
||||
# 创建或更新一条记录
|
||||
record = await self.db_save(
|
||||
ActionRecords,
|
||||
{
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_name": "TestAction",
|
||||
"action_done": True
|
||||
},
|
||||
@@ -280,58 +276,50 @@ class DatabaseAPI:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
# 查找现有记录
|
||||
existing_records = list(model_class.select().where(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).limit(1))
|
||||
|
||||
existing_records = list(
|
||||
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
|
||||
)
|
||||
|
||||
if existing_records:
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
for field, value in data.items():
|
||||
setattr(existing_record, field, value)
|
||||
existing_record.save()
|
||||
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(
|
||||
model_class.id == existing_record.id
|
||||
).dicts().get()
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
|
||||
return updated_record
|
||||
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(
|
||||
model_class.id == new_record.id
|
||||
).dicts().get()
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
|
||||
return created_record
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
filters: Dict[str, Any] = None,
|
||||
order_by: str = None,
|
||||
limit: int = None
|
||||
self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
|
||||
这是db_query方法的简化版本,专注于数据检索操作。
|
||||
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
|
||||
limit: 结果数量限制,如果为1则返回单个记录而不是列表
|
||||
|
||||
|
||||
Returns:
|
||||
如果limit=1,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
|
||||
示例:
|
||||
# 获取单个记录
|
||||
record = await self.db_get(
|
||||
@@ -339,7 +327,7 @@ class DatabaseAPI:
|
||||
filters={"action_id": "123"},
|
||||
limit=1
|
||||
)
|
||||
|
||||
|
||||
# 获取最近10条记录
|
||||
records = await self.db_get(
|
||||
Messages,
|
||||
@@ -351,32 +339,32 @@ class DatabaseAPI:
|
||||
try:
|
||||
# 构建查询
|
||||
query = model_class.select()
|
||||
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
if order_by.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, order_by[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, order_by))
|
||||
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
|
||||
# 返回结果
|
||||
if limit == 1:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if limit == 1 else []
|
||||
return None if limit == 1 else []
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, List, Any, Tuple
|
||||
from typing import Optional, List, Any
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||
@@ -8,16 +8,16 @@ logger = get_logger("hearflow_api")
|
||||
|
||||
class HearflowAPI:
|
||||
"""心流API模块
|
||||
|
||||
|
||||
提供与心流和子心流相关的操作接口
|
||||
"""
|
||||
|
||||
|
||||
async def get_sub_hearflow_by_chat_id(self, chat_id: str) -> Optional[SubHeartflow]:
|
||||
"""根据chat_id获取指定的sub_hearflow实例
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID,与sub_hearflow的subheartflow_id相同
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[SubHeartflow]: sub_hearflow实例,如果不存在则返回None
|
||||
"""
|
||||
@@ -35,11 +35,10 @@ class HearflowAPI:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取子心流实例时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def get_all_sub_hearflow_ids(self) -> List[str]:
|
||||
"""获取所有子心流的ID列表
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 所有子心流的ID列表
|
||||
"""
|
||||
@@ -51,10 +50,10 @@ class HearflowAPI:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取子心流ID列表时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_all_sub_hearflows(self) -> List[SubHeartflow]:
|
||||
"""获取所有子心流实例
|
||||
|
||||
|
||||
Returns:
|
||||
List[SubHeartflow]: 所有活跃的子心流实例列表
|
||||
"""
|
||||
@@ -66,13 +65,13 @@ class HearflowAPI:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取子心流实例列表时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_sub_hearflow_chat_state(self, chat_id: str) -> Optional[ChatState]:
|
||||
"""获取指定子心流的聊天状态
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[ChatState]: 聊天状态,如果子心流不存在则返回None
|
||||
"""
|
||||
@@ -84,14 +83,14 @@ class HearflowAPI:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取子心流聊天状态时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def set_sub_hearflow_chat_state(self, chat_id: str, target_state: ChatState) -> bool:
|
||||
"""设置指定子心流的聊天状态
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
target_state: 目标状态
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否设置成功
|
||||
"""
|
||||
@@ -100,13 +99,13 @@ class HearflowAPI:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 设置子心流聊天状态时出错: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_sub_hearflow_replyer(self, chat_id: str) -> Optional[Any]:
|
||||
"""根据chat_id获取指定子心流的replyer实例
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Any]: replyer实例,如果不存在则返回None
|
||||
"""
|
||||
@@ -116,13 +115,13 @@ class HearflowAPI:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取子心流replyer时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_sub_hearflow_expressor(self, chat_id: str) -> Optional[Any]:
|
||||
"""根据chat_id获取指定子心流的expressor实例
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Any]: expressor实例,如果不存在则返回None
|
||||
"""
|
||||
@@ -131,4 +130,4 @@ class HearflowAPI:
|
||||
return expressor
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取子心流expressor时出错: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -5,12 +5,13 @@ from src.config.config import global_config
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
|
||||
class LLMAPI:
|
||||
"""LLM API模块
|
||||
|
||||
|
||||
提供了与LLM模型交互的功能
|
||||
"""
|
||||
|
||||
|
||||
def get_available_models(self) -> Dict[str, Any]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
@@ -20,17 +21,13 @@ class LLMAPI:
|
||||
if not hasattr(global_config, "model"):
|
||||
logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置")
|
||||
return {}
|
||||
|
||||
|
||||
models = global_config.model
|
||||
|
||||
|
||||
return models
|
||||
|
||||
async def generate_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
model_config: Dict[str, Any],
|
||||
request_type: str = "plugin.generate",
|
||||
**kwargs
|
||||
self, prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
@@ -45,17 +42,13 @@ class LLMAPI:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"{self.log_prefix} 使用模型生成内容,提示词: {prompt[:100]}...")
|
||||
|
||||
llm_request = LLMRequest(
|
||||
model=model_config,
|
||||
request_type=request_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
|
||||
|
||||
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
|
||||
return True, response, reasoning, model_name
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
return False, error_msg, "", ""
|
||||
|
||||
@@ -14,17 +14,18 @@ from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
# 新增导入
|
||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from maim_message import Seg, UserInfo, GroupInfo
|
||||
from maim_message import Seg, UserInfo
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("message_api")
|
||||
|
||||
|
||||
class MessageAPI:
|
||||
"""消息API模块
|
||||
|
||||
|
||||
提供了发送消息、获取消息历史等功能
|
||||
"""
|
||||
|
||||
|
||||
async def send_message_to_target(
|
||||
self,
|
||||
message_type: str,
|
||||
@@ -35,7 +36,7 @@ class MessageAPI:
|
||||
display_message: str = "",
|
||||
) -> bool:
|
||||
"""直接向指定目标发送消息
|
||||
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"等
|
||||
content: 消息内容
|
||||
@@ -43,7 +44,7 @@ class MessageAPI:
|
||||
target_id: 目标ID(群ID或用户ID)
|
||||
is_group: 是否为群聊,True为群聊,False为私聊
|
||||
display_message: 显示消息(可选)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
@@ -53,12 +54,14 @@ class MessageAPI:
|
||||
# 群聊:从数据库查找对应的聊天流
|
||||
target_stream = None
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (stream.group_info and
|
||||
str(stream.group_info.group_id) == str(target_id) and
|
||||
stream.platform == platform):
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(target_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
target_stream = stream
|
||||
break
|
||||
|
||||
|
||||
if not target_stream:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到群ID为 {target_id} 的聊天流")
|
||||
return False
|
||||
@@ -66,39 +69,39 @@ class MessageAPI:
|
||||
# 私聊:从数据库查找对应的聊天流
|
||||
target_stream = None
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (not stream.group_info and
|
||||
str(stream.user_info.user_id) == str(target_id) and
|
||||
stream.platform == platform):
|
||||
if (
|
||||
not stream.group_info
|
||||
and str(stream.user_info.user_id) == str(target_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
target_stream = stream
|
||||
break
|
||||
|
||||
|
||||
if not target_stream:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到用户ID为 {target_id} 的私聊流")
|
||||
return False
|
||||
|
||||
|
||||
# 创建HeartFCSender实例
|
||||
heart_fc_sender = HeartFCSender()
|
||||
|
||||
|
||||
# 生成消息ID和thinking_id
|
||||
current_time = time.time()
|
||||
message_id = f"plugin_msg_{int(current_time * 1000)}"
|
||||
thinking_id = f"plugin_thinking_{int(current_time * 1000)}"
|
||||
|
||||
|
||||
# 构建机器人用户信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=platform,
|
||||
)
|
||||
|
||||
|
||||
# 创建消息段
|
||||
message_segment = Seg(type=message_type, data=content)
|
||||
|
||||
|
||||
# 创建空锚点消息(用于回复)
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
platform, target_stream.group_info, target_stream
|
||||
)
|
||||
|
||||
anchor_message = await create_empty_anchor_message(platform, target_stream.group_info, target_stream)
|
||||
|
||||
# 构建发送消息对象
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id,
|
||||
@@ -112,22 +115,17 @@ class MessageAPI:
|
||||
is_emoji=(message_type == "emoji"),
|
||||
thinking_start_time=current_time,
|
||||
)
|
||||
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
bot_message,
|
||||
has_thinking=True,
|
||||
typing=False,
|
||||
set_reply=False
|
||||
)
|
||||
|
||||
sent_msg = await heart_fc_sender.send_message(bot_message, has_thinking=True, typing=False, set_reply=False)
|
||||
|
||||
if sent_msg:
|
||||
logger.info(f"{getattr(self, 'log_prefix', '')} 成功发送消息到 {platform}:{target_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')} 发送消息失败")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{getattr(self, 'log_prefix', '')} 向目标发送消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
@@ -135,42 +133,34 @@ class MessageAPI:
|
||||
|
||||
async def send_text_to_group(self, text: str, group_id: str, platform: str = "qq") -> bool:
|
||||
"""便捷方法:向指定群聊发送文本消息
|
||||
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await self.send_message_to_target(
|
||||
message_type="text",
|
||||
content=text,
|
||||
platform=platform,
|
||||
target_id=group_id,
|
||||
is_group=True
|
||||
message_type="text", content=text, platform=platform, target_id=group_id, is_group=True
|
||||
)
|
||||
|
||||
async def send_text_to_user(self, text: str, user_id: str, platform: str = "qq") -> bool:
|
||||
"""便捷方法:向指定用户发送私聊文本消息
|
||||
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await self.send_message_to_target(
|
||||
message_type="text",
|
||||
content=text,
|
||||
platform=platform,
|
||||
target_id=user_id,
|
||||
is_group=False
|
||||
message_type="text", content=text, platform=platform, target_id=user_id, is_group=False
|
||||
)
|
||||
|
||||
|
||||
async def send_message(self, type: str, data: str, target: Optional[str] = "", display_message: str = "") -> bool:
|
||||
"""发送消息的简化方法
|
||||
|
||||
@@ -288,7 +278,9 @@ class MessageAPI:
|
||||
|
||||
return success
|
||||
|
||||
async def send_message_by_replyer(self, target: Optional[str] = None, extra_info_block: Optional[str] = None) -> bool:
|
||||
async def send_message_by_replyer(
|
||||
self, target: Optional[str] = None, extra_info_block: Optional[str] = None
|
||||
) -> bool:
|
||||
"""通过replyer发送消息的简化方法
|
||||
|
||||
Args:
|
||||
@@ -381,4 +373,4 @@ class MessageAPI:
|
||||
}
|
||||
messages.append(simple_msg)
|
||||
|
||||
return messages
|
||||
return messages
|
||||
|
||||
@@ -1,147 +1,142 @@
|
||||
import hashlib
|
||||
from typing import Optional, List, Dict, Any
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatManager, ChatStream
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
logger = get_logger("stream_api")
|
||||
|
||||
|
||||
class StreamAPI:
|
||||
"""聊天流API模块
|
||||
|
||||
|
||||
提供了获取聊天流、通过群ID查找聊天流等功能
|
||||
"""
|
||||
|
||||
|
||||
def get_chat_stream_by_group_id(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""通过QQ群ID获取聊天流
|
||||
|
||||
|
||||
Args:
|
||||
group_id: QQ群ID
|
||||
platform: 平台标识,默认为"qq"
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 找到的聊天流对象,如果未找到则返回None
|
||||
"""
|
||||
try:
|
||||
chat_manager = ChatManager()
|
||||
|
||||
|
||||
# 遍历所有已加载的聊天流,查找匹配的群ID
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (stream.group_info and
|
||||
str(stream.group_info.group_id) == str(group_id) and
|
||||
stream.platform == platform):
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(group_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.info(f"{self.log_prefix} 通过群ID {group_id} 找到聊天流: {stream_id}")
|
||||
return stream
|
||||
|
||||
|
||||
logger.warning(f"{self.log_prefix} 未找到群ID为 {group_id} 的聊天流")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 通过群ID获取聊天流时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_all_group_chat_streams(self, platform: str = "qq") -> List[ChatStream]:
|
||||
"""获取所有群聊的聊天流
|
||||
|
||||
|
||||
Args:
|
||||
platform: 平台标识,默认为"qq"
|
||||
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 所有群聊的聊天流列表
|
||||
"""
|
||||
try:
|
||||
chat_manager = ChatManager()
|
||||
group_streams = []
|
||||
|
||||
|
||||
for stream in chat_manager.streams.values():
|
||||
if (stream.group_info and
|
||||
stream.platform == platform):
|
||||
if stream.group_info and stream.platform == platform:
|
||||
group_streams.append(stream)
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 找到 {len(group_streams)} 个群聊聊天流")
|
||||
return group_streams
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取所有群聊聊天流时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_chat_stream_by_user_id(self, user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""通过用户ID获取私聊聊天流
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 平台标识,默认为"qq"
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 找到的私聊聊天流对象,如果未找到则返回None
|
||||
"""
|
||||
try:
|
||||
chat_manager = ChatManager()
|
||||
|
||||
|
||||
# 遍历所有已加载的聊天流,查找匹配的用户ID(私聊)
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (not stream.group_info and # 私聊没有群信息
|
||||
stream.user_info and
|
||||
str(stream.user_info.user_id) == str(user_id) and
|
||||
stream.platform == platform):
|
||||
if (
|
||||
not stream.group_info # 私聊没有群信息
|
||||
and stream.user_info
|
||||
and str(stream.user_info.user_id) == str(user_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.info(f"{self.log_prefix} 通过用户ID {user_id} 找到私聊聊天流: {stream_id}")
|
||||
return stream
|
||||
|
||||
|
||||
logger.warning(f"{self.log_prefix} 未找到用户ID为 {user_id} 的私聊聊天流")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 通过用户ID获取私聊聊天流时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_chat_streams_info(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有聊天流的基本信息
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 包含聊天流基本信息的字典列表
|
||||
"""
|
||||
try:
|
||||
chat_manager = ChatManager()
|
||||
streams_info = []
|
||||
|
||||
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
info = {
|
||||
"stream_id": stream_id,
|
||||
"platform": stream.platform,
|
||||
"chat_type": "group" if stream.group_info else "private",
|
||||
"create_time": stream.create_time,
|
||||
"last_active_time": stream.last_active_time
|
||||
"last_active_time": stream.last_active_time,
|
||||
}
|
||||
|
||||
|
||||
if stream.group_info:
|
||||
info.update({
|
||||
"group_id": stream.group_info.group_id,
|
||||
"group_name": stream.group_info.group_name
|
||||
})
|
||||
|
||||
info.update({"group_id": stream.group_info.group_id, "group_name": stream.group_info.group_name})
|
||||
|
||||
if stream.user_info:
|
||||
info.update({
|
||||
"user_id": stream.user_info.user_id,
|
||||
"user_nickname": stream.user_info.user_nickname
|
||||
})
|
||||
|
||||
info.update({"user_id": stream.user_info.user_id, "user_nickname": stream.user_info.user_nickname})
|
||||
|
||||
streams_info.append(info)
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 获取到 {len(streams_info)} 个聊天流信息")
|
||||
return streams_info
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_chat_stream_by_group_id_async(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
"""异步通过QQ群ID获取聊天流(包括从数据库搜索)
|
||||
|
||||
|
||||
Args:
|
||||
group_id: QQ群ID
|
||||
platform: 平台标识,默认为"qq"
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 找到的聊天流对象,如果未找到则返回None
|
||||
"""
|
||||
@@ -150,15 +145,15 @@ class StreamAPI:
|
||||
stream = self.get_chat_stream_by_group_id(group_id, platform)
|
||||
if stream:
|
||||
return stream
|
||||
|
||||
|
||||
# 如果内存中没有,尝试从数据库加载所有聊天流后再查找
|
||||
chat_manager = ChatManager()
|
||||
await chat_manager.load_all_streams()
|
||||
|
||||
|
||||
# 再次尝试从内存中查找
|
||||
stream = self.get_chat_stream_by_group_id(group_id, platform)
|
||||
return stream
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 异步通过群ID获取聊天流时出错: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -1,35 +1,37 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("utils_api")
|
||||
|
||||
|
||||
class UtilsAPI:
|
||||
"""工具类API模块
|
||||
|
||||
|
||||
提供了各种辅助功能
|
||||
"""
|
||||
|
||||
|
||||
def get_plugin_path(self) -> str:
|
||||
"""获取当前插件的路径
|
||||
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径
|
||||
"""
|
||||
import inspect
|
||||
|
||||
plugin_module_path = inspect.getfile(self.__class__)
|
||||
plugin_dir = os.path.dirname(plugin_module_path)
|
||||
return plugin_dir
|
||||
|
||||
|
||||
def read_json_file(self, file_path: str, default: Any = None) -> Any:
|
||||
"""读取JSON文件
|
||||
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,可以是相对于插件目录的路径
|
||||
default: 如果文件不存在或读取失败时返回的默认值
|
||||
|
||||
|
||||
Returns:
|
||||
Any: JSON数据或默认值
|
||||
"""
|
||||
@@ -37,25 +39,25 @@ class UtilsAPI:
|
||||
# 如果是相对路径,则相对于插件目录
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.join(self.get_plugin_path(), file_path)
|
||||
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f"{self.log_prefix} 文件不存在: {file_path}")
|
||||
return default
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 读取JSON文件出错: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def write_json_file(self, file_path: str, data: Any, indent: int = 2) -> bool:
|
||||
"""写入JSON文件
|
||||
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,可以是相对于插件目录的路径
|
||||
data: 要写入的数据
|
||||
indent: JSON缩进
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否写入成功
|
||||
"""
|
||||
@@ -63,59 +65,62 @@ class UtilsAPI:
|
||||
# 如果是相对路径,则相对于插件目录
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.join(self.get_plugin_path(), file_path)
|
||||
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=indent)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 写入JSON文件出错: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_timestamp(self) -> int:
|
||||
"""获取当前时间戳
|
||||
|
||||
|
||||
Returns:
|
||||
int: 当前时间戳(秒)
|
||||
"""
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def format_time(self, timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""格式化时间
|
||||
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳,如果为None则使用当前时间
|
||||
format_str: 时间格式字符串
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化后的时间字符串
|
||||
"""
|
||||
import datetime
|
||||
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
|
||||
|
||||
|
||||
def parse_time(self, time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
|
||||
"""解析时间字符串为时间戳
|
||||
|
||||
|
||||
Args:
|
||||
time_str: 时间字符串
|
||||
format_str: 时间格式字符串
|
||||
|
||||
|
||||
Returns:
|
||||
int: 时间戳(秒)
|
||||
"""
|
||||
import datetime
|
||||
|
||||
dt = datetime.datetime.strptime(time_str, format_str)
|
||||
return int(dt.timestamp())
|
||||
|
||||
|
||||
def generate_unique_id(self) -> str:
|
||||
"""生成唯一ID
|
||||
|
||||
|
||||
Returns:
|
||||
str: 唯一ID
|
||||
"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Type, Optional, Tuple, Pattern
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.actions.plugin_api.message_api import MessageAPI
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
|
||||
@@ -13,9 +12,10 @@ logger = get_logger("command_handler")
|
||||
_COMMAND_REGISTRY: Dict[str, Type["BaseCommand"]] = {}
|
||||
_COMMAND_PATTERNS: Dict[Pattern, Type["BaseCommand"]] = {}
|
||||
|
||||
|
||||
class BaseCommand(ABC):
|
||||
"""命令基类,所有自定义命令都应该继承这个类"""
|
||||
|
||||
|
||||
# 命令的基本属性
|
||||
command_name: str = "" # 命令名称
|
||||
command_description: str = "" # 命令描述
|
||||
@@ -23,43 +23,43 @@ class BaseCommand(ABC):
|
||||
command_help: str = "" # 命令帮助信息
|
||||
command_examples: List[str] = [] # 命令使用示例
|
||||
enable_command: bool = True # 是否启用命令
|
||||
|
||||
|
||||
def __init__(self, message: MessageRecv):
|
||||
"""初始化命令处理器
|
||||
|
||||
|
||||
Args:
|
||||
message: 接收到的消息对象
|
||||
"""
|
||||
self.message = message
|
||||
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
|
||||
self._services = {} # 存储内部服务
|
||||
|
||||
|
||||
# 设置服务
|
||||
self._services["chat_stream"] = message.chat_stream
|
||||
|
||||
|
||||
# 日志前缀
|
||||
self.log_prefix = f"[Command:{self.command_name}]"
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行命令的抽象方法,需要被子类实现
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否执行成功, 可选的回复消息)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def set_matched_groups(self, groups: Dict[str, str]) -> None:
|
||||
"""设置正则表达式匹配的命名组
|
||||
|
||||
|
||||
Args:
|
||||
groups: 正则表达式匹配的命名组
|
||||
"""
|
||||
self.matched_groups = groups
|
||||
|
||||
|
||||
async def send_reply(self, content: str) -> None:
|
||||
"""发送回复消息
|
||||
|
||||
|
||||
Args:
|
||||
content: 回复内容
|
||||
"""
|
||||
@@ -69,43 +69,42 @@ class BaseCommand(ABC):
|
||||
if not chat_stream:
|
||||
logger.error(f"{self.log_prefix} 无法发送消息:缺少chat_stream")
|
||||
return
|
||||
|
||||
|
||||
# 创建空的锚定消息
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform,
|
||||
chat_stream.group_info,
|
||||
chat_stream
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
|
||||
|
||||
# 创建表达器,传入chat_stream参数
|
||||
expressor = DefaultExpressor(chat_stream)
|
||||
|
||||
|
||||
# 设置服务
|
||||
self._services["expressor"] = expressor
|
||||
|
||||
|
||||
# 发送消息
|
||||
response_set = [
|
||||
("text", content),
|
||||
]
|
||||
|
||||
|
||||
# 调用表达器发送消息
|
||||
await expressor.send_response_messages(
|
||||
anchor_message=anchor_message,
|
||||
response_set=response_set,
|
||||
display_message="",
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 命令回复消息发送成功: {content[:30]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令回复消息失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def register_command(cls):
|
||||
"""
|
||||
命令注册装饰器
|
||||
|
||||
|
||||
用法:
|
||||
@register_command
|
||||
class MyCommand(BaseCommand):
|
||||
@@ -115,21 +114,25 @@ def register_command(cls):
|
||||
...
|
||||
"""
|
||||
# 检查类是否有必要的属性
|
||||
if not hasattr(cls, "command_name") or not hasattr(cls, "command_description") or not hasattr(cls, "command_pattern"):
|
||||
if (
|
||||
not hasattr(cls, "command_name")
|
||||
or not hasattr(cls, "command_description")
|
||||
or not hasattr(cls, "command_pattern")
|
||||
):
|
||||
logger.error(f"命令类 {cls.__name__} 缺少必要的属性: command_name, command_description 或 command_pattern")
|
||||
return cls
|
||||
|
||||
|
||||
command_name = cls.command_name
|
||||
command_pattern = cls.command_pattern
|
||||
is_enabled = getattr(cls, "enable_command", True) # 默认启用命令
|
||||
|
||||
|
||||
if not command_name or not command_pattern:
|
||||
logger.error(f"命令类 {cls.__name__} 的 command_name 或 command_pattern 为空")
|
||||
return cls
|
||||
|
||||
|
||||
# 将命令类注册到全局注册表
|
||||
_COMMAND_REGISTRY[command_name] = cls
|
||||
|
||||
|
||||
# 编译正则表达式并注册
|
||||
try:
|
||||
pattern = re.compile(command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
@@ -137,47 +140,47 @@ def register_command(cls):
|
||||
logger.info(f"已注册命令: {command_name} -> {cls.__name__},命令启用: {is_enabled}")
|
||||
except re.error as e:
|
||||
logger.error(f"命令 {command_name} 的正则表达式编译失败: {e}")
|
||||
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class CommandManager:
|
||||
"""命令管理器,负责处理命令(不再负责加载,加载由统一的插件加载器处理)"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""初始化命令管理器"""
|
||||
# 命令加载现在由统一的插件加载器处理,这里只需要初始化
|
||||
logger.info("命令管理器初始化完成")
|
||||
|
||||
|
||||
async def process_command(self, message: MessageRecv) -> Tuple[bool, Optional[str], bool]:
|
||||
"""处理消息中的命令
|
||||
|
||||
|
||||
Args:
|
||||
message: 接收到的消息对象
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str], bool]: (是否找到并执行了命令, 命令执行结果, 是否继续处理消息)
|
||||
"""
|
||||
if not message.processed_plain_text:
|
||||
await message.process()
|
||||
|
||||
|
||||
text = message.processed_plain_text
|
||||
|
||||
|
||||
# 检查是否匹配任何命令模式
|
||||
for pattern, command_cls in _COMMAND_PATTERNS.items():
|
||||
match = pattern.match(text)
|
||||
if match and getattr(command_cls, "enable_command", True):
|
||||
# 创建命令实例
|
||||
command_instance = command_cls(message)
|
||||
|
||||
|
||||
# 提取命名组并设置
|
||||
groups = match.groupdict()
|
||||
command_instance.set_matched_groups(groups)
|
||||
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
success, response = await command_instance.execute()
|
||||
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令 {command_cls.command_name} 执行成功")
|
||||
@@ -189,27 +192,28 @@ class CommandManager:
|
||||
if response:
|
||||
# 使用命令实例的send_reply方法发送错误信息
|
||||
await command_instance.send_reply(f"命令执行失败: {response}")
|
||||
|
||||
|
||||
# 命令执行后不再继续处理消息
|
||||
return True, response, False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令 {command_cls.command_name} 时出错: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
try:
|
||||
# 使用命令实例的send_reply方法发送错误信息
|
||||
await command_instance.send_reply(f"命令执行出错: {str(e)}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
|
||||
# 命令执行出错后不再继续处理消息
|
||||
return True, str(e), False
|
||||
|
||||
|
||||
# 没有匹配到任何命令,继续处理消息
|
||||
return False, None, True
|
||||
|
||||
|
||||
# 创建全局命令管理器实例
|
||||
command_manager = CommandManager()
|
||||
command_manager = CommandManager()
|
||||
|
||||
@@ -227,8 +227,6 @@ class DefaultExpressor:
|
||||
logger.info(f"想要表达:{in_mind_reply}||理由:{reason}")
|
||||
logger.info(f"最终回复: {content}\n")
|
||||
|
||||
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
|
||||
@@ -113,25 +113,25 @@ class ExpressionLearner:
|
||||
同时对所有已存储的表达方式进行全局衰减
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 全局衰减所有已存储的表达方式
|
||||
for type in ["style", "grammar"]:
|
||||
base_dir = os.path.join("data", "expression", f"learnt_{type}")
|
||||
if not os.path.exists(base_dir):
|
||||
continue
|
||||
|
||||
|
||||
for chat_id in os.listdir(base_dir):
|
||||
file_path = os.path.join(base_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(file_path):
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
|
||||
|
||||
# 应用全局衰减
|
||||
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
|
||||
|
||||
|
||||
# 保存衰减后的结果
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
|
||||
@@ -162,23 +162,25 @@ class ExpressionLearner:
|
||||
"""
|
||||
if time_diff_days <= 0 or time_diff_days >= DECAY_DAYS:
|
||||
return 0.001
|
||||
|
||||
|
||||
# 使用二次函数进行插值
|
||||
# 将7天作为顶点,0天和30天作为两个端点
|
||||
# 使用顶点式:y = a(x-h)^2 + k,其中(h,k)为顶点
|
||||
h = 7.0 # 顶点x坐标
|
||||
k = 0.001 # 顶点y坐标
|
||||
|
||||
|
||||
# 计算a值,使得x=0和x=30时y=0.001
|
||||
# 0.001 = a(0-7)^2 + 0.001
|
||||
# 解得a = 0
|
||||
a = 0
|
||||
|
||||
|
||||
# 计算衰减值
|
||||
decay = a * (time_diff_days - h) ** 2 + k
|
||||
return min(0.001, decay)
|
||||
|
||||
def apply_decay_to_expressions(self, expressions: List[Dict[str, Any]], current_time: float) -> List[Dict[str, Any]]:
|
||||
def apply_decay_to_expressions(
|
||||
self, expressions: List[Dict[str, Any]], current_time: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
对表达式列表应用衰减
|
||||
返回衰减后的表达式列表,移除count小于0的项
|
||||
@@ -188,16 +190,16 @@ class ExpressionLearner:
|
||||
# 确保last_active_time存在,如果不存在则使用current_time
|
||||
if "last_active_time" not in expr:
|
||||
expr["last_active_time"] = current_time
|
||||
|
||||
|
||||
last_active = expr["last_active_time"]
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
expr["count"] = max(0.01, expr.get("count", 1) - decay_value)
|
||||
|
||||
|
||||
if expr["count"] > 0:
|
||||
result.append(expr)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
@@ -211,7 +213,7 @@ class ExpressionLearner:
|
||||
type_str = "句法特点"
|
||||
else:
|
||||
raise ValueError(f"Invalid type: {type}")
|
||||
|
||||
|
||||
res = await self.learn_expression(type, num)
|
||||
|
||||
if res is None:
|
||||
@@ -238,15 +240,15 @@ class ExpressionLearner:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append({"situation": situation, "style": style})
|
||||
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 存储到/data/expression/对应chat_id/expressions.json
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
dir_path = os.path.join("data", "expression", f"learnt_{type}", str(chat_id))
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, "expressions.json")
|
||||
|
||||
|
||||
# 若已存在,先读出合并
|
||||
old_data: List[Dict[str, Any]] = []
|
||||
if os.path.exists(file_path):
|
||||
@@ -255,10 +257,10 @@ class ExpressionLearner:
|
||||
old_data = json.load(f)
|
||||
except Exception:
|
||||
old_data = []
|
||||
|
||||
|
||||
# 应用衰减
|
||||
# old_data = self.apply_decay_to_expressions(old_data, current_time)
|
||||
|
||||
|
||||
# 合并逻辑
|
||||
for new_expr in expr_list:
|
||||
found = False
|
||||
@@ -278,43 +280,43 @@ class ExpressionLearner:
|
||||
new_expr["count"] = 1
|
||||
new_expr["last_active_time"] = current_time
|
||||
old_data.append(new_expr)
|
||||
|
||||
|
||||
# 处理超限问题
|
||||
if len(old_data) > MAX_EXPRESSION_COUNT:
|
||||
# 计算每个表达方式的权重(count的倒数,这样count越小的越容易被选中)
|
||||
weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data]
|
||||
|
||||
|
||||
# 随机选择要移除的表达方式,避免重复索引
|
||||
remove_count = len(old_data) - MAX_EXPRESSION_COUNT
|
||||
|
||||
|
||||
# 使用一种不会选到重复索引的方法
|
||||
indices = list(range(len(old_data)))
|
||||
|
||||
|
||||
# 方法1:使用numpy.random.choice
|
||||
# 把列表转成一个映射字典,保证不会有重复
|
||||
remove_set = set()
|
||||
total_attempts = 0
|
||||
|
||||
|
||||
# 尝试按权重随机选择,直到选够数量
|
||||
while len(remove_set) < remove_count and total_attempts < len(old_data) * 2:
|
||||
idx = random.choices(indices, weights=weights, k=1)[0]
|
||||
remove_set.add(idx)
|
||||
total_attempts += 1
|
||||
|
||||
|
||||
# 如果没选够,随机补充
|
||||
if len(remove_set) < remove_count:
|
||||
remaining = set(indices) - remove_set
|
||||
remove_set.update(random.sample(list(remaining), remove_count - len(remove_set)))
|
||||
|
||||
|
||||
remove_indices = list(remove_set)
|
||||
|
||||
|
||||
# 从后往前删除,避免索引变化
|
||||
for idx in sorted(remove_indices, reverse=True):
|
||||
old_data.pop(idx)
|
||||
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(old_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
|
||||
@@ -97,7 +97,7 @@ class CycleDetail:
|
||||
)
|
||||
|
||||
# current_time_minute = time.strftime("%Y%m%d_%H%M", time.localtime())
|
||||
|
||||
|
||||
# try:
|
||||
# self.log_cycle_to_file(
|
||||
# log_dir + self.prefix + f"/{current_time_minute}_cycle_" + str(self.cycle_id) + ".json"
|
||||
@@ -117,7 +117,6 @@ class CycleDetail:
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
# 写入文件
|
||||
|
||||
|
||||
file_path = os.path.join(dir_name, os.path.basename(file_path))
|
||||
# print("file_path:", file_path)
|
||||
|
||||
@@ -99,22 +99,23 @@ class HeartFChatting:
|
||||
self.stream_id: str = chat_id # 聊天流ID
|
||||
self.chat_stream = chat_manager.get_stream(self.stream_id)
|
||||
self.log_prefix = f"[{chat_manager.get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
|
||||
self.memory_activator = MemoryActivator()
|
||||
|
||||
|
||||
# 初始化观察器
|
||||
self.observations: List[Observation] = []
|
||||
self._register_observations()
|
||||
|
||||
|
||||
# 根据配置文件和默认规则确定启用的处理器
|
||||
config_processor_settings = global_config.focus_chat_processor
|
||||
self.enabled_processor_names = []
|
||||
|
||||
|
||||
for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items():
|
||||
# 对于关系处理器,需要同时检查两个配置项
|
||||
if proc_name == "RelationshipProcessor":
|
||||
if (global_config.relationship.enable_relationship and
|
||||
getattr(config_processor_settings, config_key, True)):
|
||||
if global_config.relationship.enable_relationship and getattr(
|
||||
config_processor_settings, config_key, True
|
||||
):
|
||||
self.enabled_processor_names.append(proc_name)
|
||||
else:
|
||||
# 其他处理器的原有逻辑
|
||||
@@ -122,14 +123,13 @@ class HeartFChatting:
|
||||
self.enabled_processor_names.append(proc_name)
|
||||
|
||||
# logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}")
|
||||
|
||||
|
||||
self.processors: List[BaseProcessor] = []
|
||||
self._register_default_processors()
|
||||
|
||||
self.expressor = DefaultExpressor(chat_stream=self.chat_stream)
|
||||
self.replyer = DefaultReplyer(chat_stream=self.chat_stream)
|
||||
|
||||
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = PlannerFactory.create_planner(
|
||||
log_prefix=self.log_prefix, action_manager=self.action_manager
|
||||
@@ -138,7 +138,6 @@ class HeartFChatting:
|
||||
self.action_observation = ActionObservation(observe_id=self.stream_id)
|
||||
self.action_observation.set_action_manager(self.action_manager)
|
||||
|
||||
|
||||
self._processing_lock = asyncio.Lock()
|
||||
|
||||
# 循环控制内部状态
|
||||
@@ -182,7 +181,13 @@ class HeartFChatting:
|
||||
if processor_info:
|
||||
processor_actual_class = processor_info[0] # 获取实际的类定义
|
||||
# 根据处理器类名判断是否需要 subheartflow_id
|
||||
if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor", "RelationshipProcessor"]:
|
||||
if name in [
|
||||
"MindProcessor",
|
||||
"ToolProcessor",
|
||||
"WorkingMemoryProcessor",
|
||||
"SelfProcessor",
|
||||
"RelationshipProcessor",
|
||||
]:
|
||||
self.processors.append(processor_actual_class(subheartflow_id=self.stream_id))
|
||||
elif name == "ChattingInfoProcessor":
|
||||
self.processors.append(processor_actual_class())
|
||||
@@ -203,9 +208,7 @@ class HeartFChatting:
|
||||
)
|
||||
|
||||
if self.processors:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}"
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 没有注册任何处理器。这可能是由于配置错误或所有处理器都被禁用了。")
|
||||
|
||||
@@ -292,7 +295,9 @@ class HeartFChatting:
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
|
||||
# 从observations列表中获取HFCloopObservation
|
||||
hfcloop_observation = next((obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None)
|
||||
hfcloop_observation = next(
|
||||
(obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None
|
||||
)
|
||||
if hfcloop_observation:
|
||||
hfcloop_observation.add_loop_info(self._current_cycle_detail)
|
||||
else:
|
||||
@@ -451,19 +456,19 @@ class HeartFChatting:
|
||||
|
||||
# 根据配置决定是否并行执行调整动作、回忆和处理器阶段
|
||||
|
||||
# 并行执行调整动作、回忆和处理器阶段
|
||||
# 并行执行调整动作、回忆和处理器阶段
|
||||
with Timer("并行调整动作、处理", cycle_timers):
|
||||
# 创建并行任务
|
||||
async def modify_actions_task():
|
||||
async def modify_actions_task():
|
||||
# 调用完整的动作修改流程
|
||||
await self.action_modifier.modify_actions(
|
||||
observations=self.observations,
|
||||
)
|
||||
|
||||
|
||||
await self.action_observation.observe()
|
||||
self.observations.append(self.action_observation)
|
||||
return True
|
||||
|
||||
|
||||
# 创建三个并行任务
|
||||
action_modify_task = asyncio.create_task(modify_actions_task())
|
||||
memory_task = asyncio.create_task(self.memory_activator.activate_memory(self.observations))
|
||||
@@ -474,9 +479,6 @@ class HeartFChatting:
|
||||
action_modify_task, memory_task, processor_task
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
loop_processor_info = {
|
||||
"all_plan_info": all_plan_info,
|
||||
"processor_time_costs": processor_time_costs,
|
||||
@@ -594,9 +596,7 @@ class HeartFChatting:
|
||||
else:
|
||||
success, reply_text = result
|
||||
command = ""
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'")
|
||||
|
||||
return success, reply_text, command
|
||||
|
||||
|
||||
@@ -51,8 +51,8 @@ async def _process_relationship(message: MessageRecv) -> None:
|
||||
logger.info(f"首次认识用户: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
|
||||
# elif not await relationship_manager.is_qved_name(platform, user_id):
|
||||
# logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
|
||||
# await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||
# logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
|
||||
# await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
@@ -74,7 +74,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
fast_retrieval=True,
|
||||
)
|
||||
logger.trace(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
|
||||
# 采用对数函数实现递减增长
|
||||
@@ -181,7 +181,6 @@ class HeartFCMessageReceiver:
|
||||
userinfo = message.message_info.user_info
|
||||
messageinfo = message.message_info
|
||||
|
||||
|
||||
chat = await chat_manager.get_or_create_stream(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
|
||||
@@ -11,7 +11,6 @@ from datetime import datetime
|
||||
from typing import Dict
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import asyncio
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_ch
|
||||
# 配置常量:是否启用小模型即时信息提取
|
||||
# 开启时:使用小模型并行即时提取,速度更快,但精度可能略低
|
||||
# 关闭时:使用原来的异步模式,精度更高但速度较慢
|
||||
ENABLE_INSTANT_INFO_EXTRACTION = True
|
||||
ENABLE_INSTANT_INFO_EXTRACTION = True
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
@@ -63,7 +63,7 @@ def init_prompt():
|
||||
|
||||
"""
|
||||
Prompt(relationship_prompt, "relationship_prompt")
|
||||
|
||||
|
||||
fetch_info_prompt = """
|
||||
|
||||
{name_block}
|
||||
@@ -84,7 +84,6 @@ def init_prompt():
|
||||
Prompt(fetch_info_prompt, "fetch_info_prompt")
|
||||
|
||||
|
||||
|
||||
class RelationshipProcessor(BaseProcessor):
|
||||
log_prefix = "关系"
|
||||
|
||||
@@ -92,8 +91,10 @@ class RelationshipProcessor(BaseProcessor):
|
||||
super().__init__()
|
||||
|
||||
self.subheartflow_id = subheartflow_id
|
||||
self.info_fetching_cache: List[Dict[str, any]] = []
|
||||
self.info_fetched_cache: Dict[str, Dict[str, any]] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
|
||||
self.info_fetching_cache: List[Dict[str, any]] = []
|
||||
self.info_fetched_cache: Dict[
|
||||
str, Dict[str, any]
|
||||
] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
|
||||
self.person_engaged_cache: List[Dict[str, any]] = [] # [{person_id: str, start_time: float, rounds: int}]
|
||||
self.grace_period_rounds = 5
|
||||
|
||||
@@ -101,7 +102,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
model=global_config.model.relation,
|
||||
request_type="focus.relationship",
|
||||
)
|
||||
|
||||
|
||||
# 小模型用于即时信息提取
|
||||
if ENABLE_INSTANT_INFO_EXTRACTION:
|
||||
self.instant_llm_model = LLMRequest(
|
||||
@@ -156,26 +157,27 @@ class RelationshipProcessor(BaseProcessor):
|
||||
for record in list(self.person_engaged_cache):
|
||||
record["rounds"] += 1
|
||||
time_elapsed = current_time - record["start_time"]
|
||||
message_count = len(get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time))
|
||||
|
||||
message_count = len(
|
||||
get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time)
|
||||
)
|
||||
|
||||
print(record)
|
||||
|
||||
|
||||
# 根据消息数量和时间设置不同的触发条件
|
||||
should_trigger = (
|
||||
message_count >= 50 or # 50条消息必定满足
|
||||
(message_count >= 35 and time_elapsed >= 300) or # 35条且10分钟
|
||||
(message_count >= 25 and time_elapsed >= 900) or # 25条且30分钟
|
||||
(message_count >= 10 and time_elapsed >= 2000) # 10条且1小时
|
||||
message_count >= 50 # 50条消息必定满足
|
||||
or (message_count >= 35 and time_elapsed >= 300) # 35条且10分钟
|
||||
or (message_count >= 25 and time_elapsed >= 900) # 25条且30分钟
|
||||
or (message_count >= 10 and time_elapsed >= 2000) # 10条且1小时
|
||||
)
|
||||
|
||||
|
||||
if should_trigger:
|
||||
logger.info(f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}秒")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}秒"
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.update_impression_on_cache_expiry(
|
||||
record["person_id"],
|
||||
self.subheartflow_id,
|
||||
record["start_time"],
|
||||
current_time
|
||||
record["person_id"], self.subheartflow_id, record["start_time"], current_time
|
||||
)
|
||||
)
|
||||
self.person_engaged_cache.remove(record)
|
||||
@@ -187,20 +189,24 @@ class RelationshipProcessor(BaseProcessor):
|
||||
if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0:
|
||||
# 在删除前查找匹配的info_fetching_cache记录
|
||||
matched_record = None
|
||||
min_time_diff = float('inf')
|
||||
min_time_diff = float("inf")
|
||||
for record in self.info_fetching_cache:
|
||||
if (record["person_id"] == person_id and
|
||||
record["info_type"] == info_type and
|
||||
not record["forget"]):
|
||||
time_diff = abs(record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"])
|
||||
if (
|
||||
record["person_id"] == person_id
|
||||
and record["info_type"] == info_type
|
||||
and not record["forget"]
|
||||
):
|
||||
time_diff = abs(
|
||||
record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"]
|
||||
)
|
||||
if time_diff < min_time_diff:
|
||||
min_time_diff = time_diff
|
||||
matched_record = record
|
||||
|
||||
|
||||
if matched_record:
|
||||
matched_record["forget"] = True
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已过期,标记为遗忘。")
|
||||
|
||||
|
||||
del self.info_fetched_cache[person_id][info_type]
|
||||
if not self.info_fetched_cache[person_id]:
|
||||
del self.info_fetched_cache[person_id]
|
||||
@@ -208,7 +214,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
# 5. 为需要处理的人员准备LLM prompt
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
|
||||
info_cache_block = ""
|
||||
if self.info_fetching_cache:
|
||||
for info_fetching in self.info_fetching_cache:
|
||||
@@ -223,7 +229,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
chat_observe_info=chat_observe_info,
|
||||
info_cache_block=info_cache_block,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
logger.debug(f"{self.log_prefix} 人物信息prompt: \n{prompt}\n")
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
@@ -234,45 +240,47 @@ class RelationshipProcessor(BaseProcessor):
|
||||
# 收集即时提取任务
|
||||
instant_tasks = []
|
||||
async_tasks = []
|
||||
|
||||
|
||||
for person_name, info_type in content_json.items():
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
if person_id:
|
||||
self.info_fetching_cache.append({
|
||||
"person_id": person_id,
|
||||
"person_name": person_name,
|
||||
"info_type": info_type,
|
||||
"start_time": time.time(),
|
||||
"forget": False,
|
||||
})
|
||||
self.info_fetching_cache.append(
|
||||
{
|
||||
"person_id": person_id,
|
||||
"person_name": person_name,
|
||||
"info_type": info_type,
|
||||
"start_time": time.time(),
|
||||
"forget": False,
|
||||
}
|
||||
)
|
||||
if len(self.info_fetching_cache) > 20:
|
||||
self.info_fetching_cache.pop(0)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 未找到用户 {person_name} 的ID,跳过调取信息。")
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 调取用户 {person_name} 的 {info_type} 信息。")
|
||||
|
||||
|
||||
# 检查person_engaged_cache中是否已存在该person_id
|
||||
person_exists = any(record["person_id"] == person_id for record in self.person_engaged_cache)
|
||||
if not person_exists:
|
||||
self.person_engaged_cache.append({
|
||||
"person_id": person_id,
|
||||
"start_time": time.time(),
|
||||
"rounds": 0
|
||||
})
|
||||
|
||||
self.person_engaged_cache.append(
|
||||
{"person_id": person_id, "start_time": time.time(), "rounds": 0}
|
||||
)
|
||||
|
||||
if ENABLE_INSTANT_INFO_EXTRACTION:
|
||||
# 收集即时提取任务
|
||||
instant_tasks.append((person_id, info_type, time.time()))
|
||||
else:
|
||||
# 使用原来的异步模式
|
||||
async_tasks.append(asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time())))
|
||||
async_tasks.append(
|
||||
asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time()))
|
||||
)
|
||||
|
||||
# 执行即时提取任务
|
||||
if ENABLE_INSTANT_INFO_EXTRACTION and instant_tasks:
|
||||
await self._execute_instant_extraction_batch(instant_tasks)
|
||||
|
||||
|
||||
# 启动异步任务(如果不是即时模式)
|
||||
if async_tasks:
|
||||
# 异步任务不需要等待完成
|
||||
@@ -300,7 +308,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
person_infos_str += f"你不了解{person_name}有关[{info_type}]的信息,不要胡乱回答,你可以直接说你不知道,或者你忘记了;"
|
||||
if person_infos_str:
|
||||
persons_infos_str += f"你对 {person_name} 的了解:{person_infos_str}\n"
|
||||
|
||||
|
||||
# 处理正在调取但还没有结果的项目(只在非即时提取模式下显示)
|
||||
if not ENABLE_INSTANT_INFO_EXTRACTION:
|
||||
pending_info_dict = {}
|
||||
@@ -312,50 +320,47 @@ class RelationshipProcessor(BaseProcessor):
|
||||
person_id = record["person_id"]
|
||||
person_name = record["person_name"]
|
||||
info_type = record["info_type"]
|
||||
|
||||
|
||||
# 检查是否已经在info_fetched_cache中有结果
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
continue
|
||||
|
||||
|
||||
# 按人物组织正在调取的信息
|
||||
if person_name not in pending_info_dict:
|
||||
pending_info_dict[person_name] = []
|
||||
pending_info_dict[person_name].append(info_type)
|
||||
|
||||
|
||||
# 添加正在调取的信息到返回字符串
|
||||
for person_name, info_types in pending_info_dict.items():
|
||||
info_types_str = "、".join(info_types)
|
||||
persons_infos_str += f"你正在识图回忆有关 {person_name} 的 {info_types_str} 信息,稍等一下再回答...\n"
|
||||
|
||||
return persons_infos_str
|
||||
|
||||
|
||||
async def _execute_instant_extraction_batch(self, instant_tasks: list):
|
||||
"""
|
||||
批量执行即时提取任务
|
||||
"""
|
||||
if not instant_tasks:
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} [即时提取] 开始批量提取 {len(instant_tasks)} 个信息")
|
||||
|
||||
|
||||
# 创建所有提取任务
|
||||
extraction_tasks = []
|
||||
for person_id, info_type, start_time in instant_tasks:
|
||||
# 检查缓存中是否已存在且未过期的信息
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。")
|
||||
continue
|
||||
|
||||
|
||||
task = asyncio.create_task(self._fetch_single_info_instant(person_id, info_type, start_time))
|
||||
extraction_tasks.append(task)
|
||||
|
||||
|
||||
# 并行执行所有提取任务并等待完成
|
||||
if extraction_tasks:
|
||||
await asyncio.gather(*extraction_tasks, return_exceptions=True)
|
||||
logger.info(f"{self.log_prefix} [即时提取] 批量提取完成")
|
||||
|
||||
|
||||
async def _fetch_single_info_instant(self, person_id: str, info_type: str, start_time: float):
|
||||
"""
|
||||
@@ -363,24 +368,21 @@ class RelationshipProcessor(BaseProcessor):
|
||||
"""
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
|
||||
person_impression = await person_info_manager.get_value(person_id, "impression")
|
||||
if not person_impression:
|
||||
impression_block = "你对ta没有什么深刻的印象"
|
||||
else:
|
||||
impression_block = f"{person_impression}"
|
||||
|
||||
|
||||
points = await person_info_manager.get_value(person_id, "points")
|
||||
if points:
|
||||
points_text = "\n".join([
|
||||
f"{point[2]}:{point[0]}"
|
||||
for point in points
|
||||
])
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
else:
|
||||
points_text = "你不记得ta最近发生了什么"
|
||||
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format(
|
||||
name_block=name_block,
|
||||
info_type=info_type,
|
||||
@@ -393,9 +395,9 @@ class RelationshipProcessor(BaseProcessor):
|
||||
try:
|
||||
# 使用小模型进行即时提取
|
||||
content, _ = await self.instant_llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} [即时提取] {person_name} 的 {info_type} 结果: {content}")
|
||||
|
||||
|
||||
if content:
|
||||
content_json = json.loads(repair_json(content))
|
||||
if info_type in content_json:
|
||||
@@ -410,7 +412,9 @@ class RelationshipProcessor(BaseProcessor):
|
||||
"person_name": person_name,
|
||||
"unknow": False,
|
||||
}
|
||||
logger.info(f"{self.log_prefix} [即时提取] 成功获取 {person_name} 的 {info_type}: {info_content}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} [即时提取] 成功获取 {person_name} 的 {info_type}: {info_content}"
|
||||
)
|
||||
else:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
@@ -423,59 +427,55 @@ class RelationshipProcessor(BaseProcessor):
|
||||
}
|
||||
logger.info(f"{self.log_prefix} [即时提取] {person_name} 的 {info_type} 信息不明确")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。")
|
||||
logger.warning(
|
||||
f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} [即时提取] 执行小模型请求获取用户信息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def fetch_person_info(self, person_id: str, info_types: list[str], start_time: float):
|
||||
"""
|
||||
获取某个人的信息
|
||||
"""
|
||||
# 检查缓存中是否已存在且未过期的信息
|
||||
info_types_to_fetch = []
|
||||
|
||||
|
||||
for info_type in info_types:
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。")
|
||||
continue
|
||||
info_types_to_fetch.append(info_type)
|
||||
|
||||
|
||||
if not info_types_to_fetch:
|
||||
return
|
||||
|
||||
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
|
||||
info_type_str = ""
|
||||
info_json_str = ""
|
||||
for info_type in info_types_to_fetch:
|
||||
info_type_str += f"{info_type},"
|
||||
info_json_str += f"\"{info_type}\": \"信息内容\","
|
||||
info_json_str += f'"{info_type}": "信息内容",'
|
||||
info_type_str = info_type_str[:-1]
|
||||
info_json_str = info_json_str[:-1]
|
||||
|
||||
|
||||
person_impression = await person_info_manager.get_value(person_id, "impression")
|
||||
if not person_impression:
|
||||
impression_block = "你对ta没有什么深刻的印象"
|
||||
else:
|
||||
impression_block = f"{person_impression}"
|
||||
|
||||
|
||||
|
||||
points = await person_info_manager.get_value(person_id, "points")
|
||||
|
||||
if points:
|
||||
points_text = "\n".join([
|
||||
f"{point[2]}:{point[0]}"
|
||||
for point in points
|
||||
])
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
else:
|
||||
points_text = "你不记得ta最近发生了什么"
|
||||
|
||||
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format(
|
||||
name_block=name_block,
|
||||
info_type=info_type_str,
|
||||
@@ -487,10 +487,10 @@ class RelationshipProcessor(BaseProcessor):
|
||||
|
||||
try:
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
# logger.info(f"{self.log_prefix} fetch_person_info prompt: \n{prompt}\n")
|
||||
logger.info(f"{self.log_prefix} fetch_person_info 结果: {content}")
|
||||
|
||||
|
||||
if content:
|
||||
try:
|
||||
content_json = json.loads(repair_json(content))
|
||||
@@ -508,9 +508,9 @@ class RelationshipProcessor(BaseProcessor):
|
||||
else:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
|
||||
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info":"unknow",
|
||||
"info": "unknow",
|
||||
"ttl": 10,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
@@ -525,16 +525,12 @@ class RelationshipProcessor(BaseProcessor):
|
||||
logger.error(f"{self.log_prefix} 执行LLM请求获取用户信息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def update_impression_on_cache_expiry(
|
||||
self, person_id: str, chat_id: str, start_time: float, end_time: float
|
||||
):
|
||||
async def update_impression_on_cache_expiry(self, person_id: str, chat_id: str, start_time: float, end_time: float):
|
||||
"""
|
||||
在缓存过期时,获取聊天记录并更新用户印象
|
||||
"""
|
||||
logger.info(f"缓存过期,开始为 {person_id} 更新印象。时间范围:{start_time} -> {end_time}")
|
||||
try:
|
||||
|
||||
|
||||
impression_messages = get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time)
|
||||
if impression_messages:
|
||||
logger.info(f"为 {person_id} 获取到 {len(impression_messages)} 条消息用于印象更新。")
|
||||
|
||||
@@ -122,9 +122,7 @@ class SelfProcessor(BaseProcessor):
|
||||
)
|
||||
# 获取聊天内容
|
||||
chat_observe_info = observation.get_observe_info()
|
||||
person_list = observation.person_list
|
||||
if isinstance(observation, HFCloopObservation):
|
||||
# hfcloop_observe_info = observation.get_observe_info()
|
||||
pass
|
||||
|
||||
nickname_str = ""
|
||||
@@ -133,9 +131,7 @@ class SelfProcessor(BaseProcessor):
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
|
||||
|
||||
|
||||
|
||||
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
|
||||
|
||||
@@ -118,7 +118,7 @@ class ToolProcessor(BaseProcessor):
|
||||
is_group_chat = observation.is_group_chat
|
||||
|
||||
chat_observe_info = observation.get_observe_info()
|
||||
person_list = observation.person_list
|
||||
# person_list = observation.person_list
|
||||
|
||||
memory_str = ""
|
||||
if running_memorys:
|
||||
@@ -141,9 +141,7 @@ class ToolProcessor(BaseProcessor):
|
||||
|
||||
# 调用LLM,专注于工具使用
|
||||
# logger.info(f"开始执行工具调用{prompt}")
|
||||
response, other_info = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools
|
||||
)
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
|
||||
@@ -118,9 +118,7 @@ class WorkingMemoryProcessor(BaseProcessor):
|
||||
memory_str=memory_choose_str,
|
||||
)
|
||||
|
||||
|
||||
# print(f"prompt: {prompt}")
|
||||
|
||||
|
||||
# 调用LLM处理记忆
|
||||
content = ""
|
||||
|
||||
@@ -90,7 +90,7 @@ class MemoryActivator:
|
||||
# 如果记忆系统被禁用,直接返回空列表
|
||||
if not global_config.memory.enable_memory:
|
||||
return []
|
||||
|
||||
|
||||
obs_info_text = ""
|
||||
for observation in observations:
|
||||
if isinstance(observation, ChattingObservation):
|
||||
|
||||
@@ -5,9 +5,6 @@ from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger_manager import get_logger
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
|
||||
# 不再需要导入动作类,因为已经在main.py中导入
|
||||
# import src.chat.actions.default_actions # noqa
|
||||
@@ -41,7 +38,7 @@ class ActionManager:
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = self._default_actions.copy()
|
||||
|
||||
|
||||
# 添加系统核心动作
|
||||
self._add_system_core_actions()
|
||||
|
||||
@@ -63,19 +60,19 @@ class ActionManager:
|
||||
action_require: list[str] = getattr(action_class, "action_require", [])
|
||||
associated_types: list[str] = getattr(action_class, "associated_types", [])
|
||||
is_enabled: bool = getattr(action_class, "enable_plugin", True)
|
||||
|
||||
|
||||
# 获取激活类型相关属性
|
||||
focus_activation_type: str = getattr(action_class, "focus_activation_type", "always")
|
||||
normal_activation_type: str = getattr(action_class, "normal_activation_type", "always")
|
||||
|
||||
|
||||
random_probability: float = getattr(action_class, "random_activation_probability", 0.3)
|
||||
llm_judge_prompt: str = getattr(action_class, "llm_judge_prompt", "")
|
||||
activation_keywords: list[str] = getattr(action_class, "activation_keywords", [])
|
||||
keyword_case_sensitive: bool = getattr(action_class, "keyword_case_sensitive", False)
|
||||
|
||||
|
||||
# 获取模式启用属性
|
||||
mode_enable: str = getattr(action_class, "mode_enable", "all")
|
||||
|
||||
|
||||
# 获取并行执行属性
|
||||
parallel_action: bool = getattr(action_class, "parallel_action", False)
|
||||
|
||||
@@ -114,13 +111,13 @@ class ActionManager:
|
||||
def _load_plugin_actions(self) -> None:
|
||||
"""
|
||||
加载所有插件目录中的动作
|
||||
|
||||
|
||||
注意:插件动作的实际导入已经在main.py中完成,这里只需要从_ACTION_REGISTRY获取
|
||||
"""
|
||||
try:
|
||||
# 插件动作已在main.py中加载,这里只需要从_ACTION_REGISTRY获取
|
||||
self._load_registered_actions()
|
||||
logger.info(f"从注册表加载插件动作成功")
|
||||
logger.info("从注册表加载插件动作成功")
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件动作失败: {e}")
|
||||
|
||||
@@ -203,25 +200,25 @@ class ActionManager:
|
||||
def get_using_actions_for_mode(self, mode: str) -> Dict[str, ActionInfo]:
|
||||
"""
|
||||
根据聊天模式获取可用的动作集合
|
||||
|
||||
|
||||
Args:
|
||||
mode: 聊天模式 ("focus", "normal", "all")
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, ActionInfo]: 在指定模式下可用的动作集合
|
||||
"""
|
||||
filtered_actions = {}
|
||||
|
||||
|
||||
for action_name, action_info in self._using_actions.items():
|
||||
action_mode = action_info.get("mode_enable", "all")
|
||||
|
||||
|
||||
# 检查动作是否在当前模式下启用
|
||||
if action_mode == "all" or action_mode == mode:
|
||||
filtered_actions[action_name] = action_info
|
||||
logger.debug(f"动作 {action_name} 在模式 {mode} 下可用 (mode_enable: {action_mode})")
|
||||
else:
|
||||
logger.debug(f"动作 {action_name} 在模式 {mode} 下不可用 (mode_enable: {action_mode})")
|
||||
|
||||
|
||||
logger.debug(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}")
|
||||
return filtered_actions
|
||||
|
||||
@@ -325,7 +322,7 @@ class ActionManager:
|
||||
系统核心动作是那些enable_plugin为False但是系统必需的动作
|
||||
"""
|
||||
system_core_actions = ["exit_focus_chat"] # 可以根据需要扩展
|
||||
|
||||
|
||||
for action_name in system_core_actions:
|
||||
if action_name in self._registered_actions and action_name not in self._using_actions:
|
||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
@@ -334,10 +331,10 @@ class ActionManager:
|
||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||
"""
|
||||
根据需要添加系统动作到使用集
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
|
||||
@@ -30,13 +30,13 @@ class ActionModifier:
|
||||
"""初始化动作处理器"""
|
||||
self.action_manager = action_manager
|
||||
self.all_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
|
||||
# 用于LLM判定的小模型
|
||||
self.llm_judge = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="action.judge",
|
||||
)
|
||||
|
||||
|
||||
# 缓存相关属性
|
||||
self._llm_judge_cache = {} # 缓存LLM判定结果
|
||||
self._cache_expiry_time = 30 # 缓存过期时间(秒)
|
||||
@@ -49,15 +49,15 @@ class ActionModifier:
|
||||
):
|
||||
"""
|
||||
完整的动作修改流程,整合传统观察处理和新的激活类型判定
|
||||
|
||||
|
||||
这个方法处理完整的动作管理流程:
|
||||
1. 基于观察的传统动作修改(循环历史分析、类型匹配等)
|
||||
2. 基于激活类型的智能动作判定,最终确定可用动作集
|
||||
|
||||
|
||||
处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
|
||||
# === 第一阶段:传统观察处理 ===
|
||||
if observations:
|
||||
hfc_obs = None
|
||||
@@ -86,7 +86,7 @@ class ActionModifier:
|
||||
merged_action_changes["add"].extend(action_changes["add"])
|
||||
merged_action_changes["remove"].extend(action_changes["remove"])
|
||||
reasons.append("基于循环历史分析")
|
||||
|
||||
|
||||
# 详细记录循环历史分析的变更原因
|
||||
for action_name in action_changes["add"]:
|
||||
logger.info(f"{self.log_prefix}添加动作: {action_name},原因: 循环历史分析建议添加")
|
||||
@@ -106,7 +106,9 @@ class ActionModifier:
|
||||
if not chat_context.check_types(data["associated_types"]):
|
||||
type_mismatched_actions.append(action_name)
|
||||
associated_types_str = ", ".join(data["associated_types"])
|
||||
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})")
|
||||
logger.info(
|
||||
f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})"
|
||||
)
|
||||
|
||||
if type_mismatched_actions:
|
||||
# 合并到移除列表中
|
||||
@@ -123,17 +125,19 @@ class ActionModifier:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}应用移除动作: {action_name},原因集合: {reasons}")
|
||||
|
||||
logger.info(f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}")
|
||||
logger.info(
|
||||
f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}"
|
||||
)
|
||||
|
||||
# === 第二阶段:激活类型判定 ===
|
||||
# 如果提供了聊天上下文,则进行激活类型判定
|
||||
if chat_content is not None:
|
||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理,且适用于FOCUS模式)
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
all_registered_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
|
||||
# 构建完整的动作信息
|
||||
current_actions_with_info = {}
|
||||
for action_name in current_using_actions.keys():
|
||||
@@ -141,17 +145,17 @@ class ActionModifier:
|
||||
current_actions_with_info[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
|
||||
# 应用激活类型判定
|
||||
final_activated_actions = await self._apply_activation_type_filtering(
|
||||
current_actions_with_info,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
|
||||
# 更新ActionManager,移除未激活的动作
|
||||
actions_to_remove = []
|
||||
removal_reasons = {}
|
||||
|
||||
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name not in final_activated_actions:
|
||||
actions_to_remove.append(action_name)
|
||||
@@ -159,7 +163,7 @@ class ActionModifier:
|
||||
if action_name in all_registered_actions:
|
||||
action_info = all_registered_actions[action_name]
|
||||
activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
|
||||
if activation_type == ActionActivationType.RANDOM:
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
removal_reasons[action_name] = f"RANDOM类型未触发(概率{probability})"
|
||||
@@ -172,15 +176,17 @@ class ActionModifier:
|
||||
removal_reasons[action_name] = "激活判定未通过"
|
||||
else:
|
||||
removal_reasons[action_name] = "动作信息不完整"
|
||||
|
||||
|
||||
for action_name in actions_to_remove:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
reason = removal_reasons.get(action_name, "未知原因")
|
||||
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: {reason}")
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix}激活类型判定完成,最终可用动作: {list(final_activated_actions.keys())}")
|
||||
|
||||
logger.info(f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}"
|
||||
)
|
||||
|
||||
async def _apply_activation_type_filtering(
|
||||
self,
|
||||
@@ -189,27 +195,27 @@ class ActionModifier:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
应用激活类型过滤逻辑,支持四种激活类型的并行处理
|
||||
|
||||
|
||||
Args:
|
||||
actions_with_info: 带完整信息的动作字典
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文信息
|
||||
extra_context: 额外的上下文信息
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 过滤后激活的actions字典
|
||||
"""
|
||||
activated_actions = {}
|
||||
|
||||
|
||||
# 分类处理不同激活类型的actions
|
||||
always_actions = {}
|
||||
random_actions = {}
|
||||
llm_judge_actions = {}
|
||||
keyword_actions = {}
|
||||
|
||||
|
||||
for action_name, action_info in actions_with_info.items():
|
||||
activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
always_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.RANDOM:
|
||||
@@ -220,12 +226,12 @@ class ActionModifier:
|
||||
keyword_actions[action_name] = action_info
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
|
||||
# 1. 处理ALWAYS类型(直接激活)
|
||||
for action_name, action_info in always_actions.items():
|
||||
activated_actions[action_name] = action_info
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活")
|
||||
|
||||
|
||||
# 2. 处理RANDOM类型
|
||||
for action_name, action_info in random_actions.items():
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
@@ -235,7 +241,7 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发(概率{probability})")
|
||||
|
||||
|
||||
# 3. 处理KEYWORD类型(快速判定)
|
||||
for action_name, action_info in keyword_actions.items():
|
||||
should_activate = self._check_keyword_activation(
|
||||
@@ -250,7 +256,7 @@ class ActionModifier:
|
||||
else:
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})")
|
||||
|
||||
|
||||
# 4. 处理LLM_JUDGE类型(并行判定)
|
||||
if llm_judge_actions:
|
||||
# 直接并行处理所有LLM判定actions
|
||||
@@ -258,7 +264,7 @@ class ActionModifier:
|
||||
llm_judge_actions,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
|
||||
# 添加激活的LLM判定actions
|
||||
for action_name, should_activate in llm_results.items():
|
||||
if should_activate:
|
||||
@@ -266,46 +272,43 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: LLM_JUDGE类型判定通过")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: LLM_JUDGE类型判定未通过")
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix}激活类型过滤完成: {list(activated_actions.keys())}")
|
||||
return activated_actions
|
||||
|
||||
async def process_actions_for_planner(
|
||||
self,
|
||||
observed_messages_str: str = "",
|
||||
chat_context: Optional[str] = None,
|
||||
extra_context: Optional[str] = None
|
||||
self, observed_messages_str: str = "", chat_context: Optional[str] = None, extra_context: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
[已废弃] 此方法现在已被整合到 modify_actions() 中
|
||||
|
||||
|
||||
为了保持向后兼容性而保留,但建议直接使用 ActionManager.get_using_actions()
|
||||
规划器应该直接从 ActionManager 获取最终的可用动作集,而不是调用此方法
|
||||
|
||||
|
||||
新的架构:
|
||||
1. 主循环调用 modify_actions() 处理完整的动作管理流程
|
||||
2. 规划器直接使用 ActionManager.get_using_actions() 获取最终动作集
|
||||
"""
|
||||
logger.warning(f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()")
|
||||
|
||||
logger.warning(
|
||||
f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()"
|
||||
)
|
||||
|
||||
# 为了向后兼容,仍然返回当前使用的动作集
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
|
||||
|
||||
# 构建完整的动作信息
|
||||
result = {}
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name in all_registered_actions:
|
||||
result[action_name] = all_registered_actions[action_name]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def _generate_context_hash(self, chat_content: str) -> str:
|
||||
"""生成上下文的哈希值用于缓存"""
|
||||
context_content = f"{chat_content}"
|
||||
return hashlib.md5(context_content.encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
return hashlib.md5(context_content.encode("utf-8")).hexdigest()
|
||||
|
||||
async def _process_llm_judge_actions_parallel(
|
||||
self,
|
||||
@@ -314,85 +317,85 @@ class ActionModifier:
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
并行处理LLM判定actions,支持智能缓存
|
||||
|
||||
|
||||
Args:
|
||||
llm_judge_actions: 需要LLM判定的actions
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: action名称到激活结果的映射
|
||||
"""
|
||||
|
||||
|
||||
# 生成当前上下文的哈希值
|
||||
current_context_hash = self._generate_context_hash(chat_content)
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
results = {}
|
||||
tasks_to_run = {}
|
||||
|
||||
|
||||
# 检查缓存
|
||||
for action_name, action_info in llm_judge_actions.items():
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
|
||||
|
||||
# 检查是否有有效的缓存
|
||||
if (cache_key in self._llm_judge_cache and
|
||||
current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time):
|
||||
|
||||
if (
|
||||
cache_key in self._llm_judge_cache
|
||||
and current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time
|
||||
):
|
||||
results[action_name] = self._llm_judge_cache[cache_key]["result"]
|
||||
logger.debug(f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}")
|
||||
logger.debug(
|
||||
f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}"
|
||||
)
|
||||
else:
|
||||
# 需要进行LLM判定
|
||||
tasks_to_run[action_name] = action_info
|
||||
|
||||
|
||||
# 如果有需要运行的任务,并行执行
|
||||
if tasks_to_run:
|
||||
logger.debug(f"{self.log_prefix}并行执行LLM判定,任务数: {len(tasks_to_run)}")
|
||||
|
||||
|
||||
# 创建并行任务
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
|
||||
for action_name, action_info in tasks_to_run.items():
|
||||
task = self._llm_judge_action(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
)
|
||||
tasks.append(task)
|
||||
task_names.append(action_name)
|
||||
|
||||
|
||||
# 并行执行所有任务
|
||||
try:
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理结果并更新缓存
|
||||
for i, (action_name, result) in enumerate(zip(task_names, task_results)):
|
||||
for _, (action_name, result) in enumerate(zip(task_names, task_results)):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||
results[action_name] = False
|
||||
else:
|
||||
results[action_name] = result
|
||||
|
||||
|
||||
# 更新缓存
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
self._llm_judge_cache[cache_key] = {
|
||||
"result": result,
|
||||
"timestamp": current_time
|
||||
}
|
||||
|
||||
self._llm_judge_cache[cache_key] = {"result": result, "timestamp": current_time}
|
||||
|
||||
logger.debug(f"{self.log_prefix}并行LLM判定完成,耗时: {time.time() - current_time:.2f}s")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
||||
# 如果并行执行失败,为所有任务返回False
|
||||
for action_name in tasks_to_run.keys():
|
||||
results[action_name] = False
|
||||
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_expired_cache(current_time)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def _cleanup_expired_cache(self, current_time: float):
|
||||
@@ -401,40 +404,39 @@ class ActionModifier:
|
||||
for cache_key, cache_data in self._llm_judge_cache.items():
|
||||
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
|
||||
expired_keys.append(cache_key)
|
||||
|
||||
|
||||
for key in expired_keys:
|
||||
del self._llm_judge_cache[key]
|
||||
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了 {len(expired_keys)} 个过期缓存条目")
|
||||
|
||||
async def _llm_judge_action(
|
||||
self,
|
||||
action_name: str,
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: Dict[str, Any],
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
使用LLM判定是否应该激活某个action
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
# 构建判定提示词
|
||||
action_description = action_info.get("description", "")
|
||||
action_require = action_info.get("require", [])
|
||||
custom_prompt = action_info.get("llm_judge_prompt", "")
|
||||
|
||||
|
||||
|
||||
# 构建基础判定提示词
|
||||
base_prompt = f"""
|
||||
你需要判断在当前聊天情况下,是否应该激活名为"{action_name}"的动作。
|
||||
@@ -445,34 +447,34 @@ class ActionModifier:
|
||||
"""
|
||||
for req in action_require:
|
||||
base_prompt += f"- {req}\n"
|
||||
|
||||
|
||||
if custom_prompt:
|
||||
base_prompt += f"\n额外判定条件:\n{custom_prompt}\n"
|
||||
|
||||
|
||||
if chat_content:
|
||||
base_prompt += f"\n当前聊天记录:\n{chat_content}\n"
|
||||
|
||||
|
||||
|
||||
base_prompt += """
|
||||
请根据以上信息判断是否应该激活这个动作。
|
||||
只需要回答"是"或"否",不要有其他内容。
|
||||
"""
|
||||
|
||||
|
||||
# 调用LLM进行判定
|
||||
response, _ = await self.llm_judge.generate_response_async(prompt=base_prompt)
|
||||
|
||||
|
||||
# 解析响应
|
||||
response = response.strip().lower()
|
||||
|
||||
|
||||
# print(base_prompt)
|
||||
print(f"LLM判定动作 {action_name}:响应='{response}'")
|
||||
|
||||
|
||||
|
||||
should_activate = "是" in response or "yes" in response or "true" in response
|
||||
|
||||
logger.debug(f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}")
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}"
|
||||
)
|
||||
return should_activate
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}LLM判定动作 {action_name} 时出错: {e}")
|
||||
# 出错时默认不激活
|
||||
@@ -486,45 +488,45 @@ class ActionModifier:
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否匹配关键词触发条件
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
|
||||
activation_keywords = action_info.get("activation_keywords", [])
|
||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
||||
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
return False
|
||||
|
||||
|
||||
# 构建检索文本
|
||||
search_text = ""
|
||||
if chat_content:
|
||||
search_text += chat_content
|
||||
# if chat_context:
|
||||
# search_text += f" {chat_context}"
|
||||
# search_text += f" {chat_context}"
|
||||
# if extra_context:
|
||||
# search_text += f" {extra_context}"
|
||||
|
||||
# search_text += f" {extra_context}"
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
|
||||
# 检查每个关键词
|
||||
matched_keywords = []
|
||||
for keyword in activation_keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
|
||||
if matched_keywords:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
return True
|
||||
@@ -568,7 +570,9 @@ class ActionModifier:
|
||||
result["remove"].append("no_reply")
|
||||
result["remove"].append("reply")
|
||||
no_reply_ratio = no_reply_count / len(recent_cycles)
|
||||
logger.info(f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作")
|
||||
logger.info(
|
||||
f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作"
|
||||
)
|
||||
|
||||
# 计算连续回复的相关阈值
|
||||
|
||||
@@ -593,7 +597,7 @@ class ActionModifier:
|
||||
if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
|
||||
# 如果最近max_reply_num次都是reply,直接移除
|
||||
result["remove"].append("reply")
|
||||
reply_count = len(last_max_reply_num) - no_reply_count
|
||||
# reply_count = len(last_max_reply_num) - no_reply_count
|
||||
logger.info(
|
||||
f"{self.log_prefix}移除reply动作,原因: 连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})"
|
||||
)
|
||||
@@ -622,8 +626,6 @@ class ActionModifier:
|
||||
f"{self.log_prefix}连续回复检测:最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,未触发"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常")
|
||||
|
||||
return result
|
||||
|
||||
@@ -146,7 +146,7 @@ class ActionPlanner(BasePlanner):
|
||||
# 注意:动作的激活判定现在在主循环的modify_actions中完成
|
||||
# 使用Focus模式过滤动作
|
||||
current_available_actions_dict = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
current_available_actions = {}
|
||||
@@ -192,12 +192,11 @@ class ActionPlanner(BasePlanner):
|
||||
try:
|
||||
prompt = f"{prompt}"
|
||||
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}"
|
||||
@@ -237,10 +236,10 @@ class ActionPlanner(BasePlanner):
|
||||
extra_info_block = ""
|
||||
|
||||
action_data["extra_info_block"] = extra_info_block
|
||||
|
||||
|
||||
if relation_info:
|
||||
action_data["relation_info_block"] = relation_info
|
||||
|
||||
|
||||
# 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data
|
||||
|
||||
if extracted_action not in current_available_actions:
|
||||
@@ -303,12 +302,11 @@ class ActionPlanner(BasePlanner):
|
||||
) -> str:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
|
||||
if relation_info_block:
|
||||
relation_info_block = f"以下是你和别人的关系描述:\n{relation_info_block}"
|
||||
else:
|
||||
relation_info_block = ""
|
||||
|
||||
|
||||
memory_str = ""
|
||||
if running_memorys:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
@@ -331,9 +329,9 @@ class ActionPlanner(BasePlanner):
|
||||
|
||||
# mind_info_block = ""
|
||||
# if current_mind:
|
||||
# mind_info_block = f"对聊天的规划:{current_mind}"
|
||||
# mind_info_block = f"对聊天的规划:{current_mind}"
|
||||
# else:
|
||||
# mind_info_block = "你刚参与聊天"
|
||||
# mind_info_block = "你刚参与聊天"
|
||||
|
||||
personality_block = individuality.get_prompt(x_person=2, level=2)
|
||||
|
||||
@@ -351,16 +349,14 @@ class ActionPlanner(BasePlanner):
|
||||
param_text = "\n"
|
||||
for param_name, param_description in using_actions_info["parameters"].items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip('\n')
|
||||
param_text = param_text.rstrip("\n")
|
||||
else:
|
||||
param_text = ""
|
||||
|
||||
|
||||
require_text = ""
|
||||
for require_item in using_actions_info["require"]:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip('\n')
|
||||
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
using_action_prompt = using_action_prompt.format(
|
||||
action_name=using_actions_name,
|
||||
|
||||
@@ -93,7 +93,7 @@ class DefaultReplyer:
|
||||
|
||||
self.chat_id = chat_stream.stream_id
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
|
||||
"""创建思考消息 (尝试锚定到 anchor_message)"""
|
||||
@@ -141,7 +141,7 @@ class DefaultReplyer:
|
||||
# text_part = action_data.get("text", [])
|
||||
# if text_part:
|
||||
sent_msg_list = []
|
||||
|
||||
|
||||
with Timer("生成回复", cycle_timers):
|
||||
# 可以保留原有的文本处理逻辑或进行适当调整
|
||||
reply = await self.reply(
|
||||
@@ -240,22 +240,21 @@ class DefaultReplyer:
|
||||
# current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
|
||||
# self.express_model.params["temperature"] = current_temp # 动态调整温度
|
||||
|
||||
|
||||
reply_to = action_data.get("reply_to", "none")
|
||||
|
||||
|
||||
sender = ""
|
||||
targer = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1)
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
targer = parts[1].strip()
|
||||
|
||||
|
||||
identity = action_data.get("identity", "")
|
||||
extra_info_block = action_data.get("extra_info_block", "")
|
||||
relation_info_block = action_data.get("relation_info_block", "")
|
||||
|
||||
|
||||
# 3. 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_focus(
|
||||
@@ -374,8 +373,6 @@ class DefaultReplyer:
|
||||
|
||||
style_habbits_str = "\n".join(style_habbits)
|
||||
grammar_habbits_str = "\n".join(grammar_habbits)
|
||||
|
||||
|
||||
|
||||
# 关键词检测与反应
|
||||
keywords_reaction_prompt = ""
|
||||
@@ -407,16 +404,15 @@ class DefaultReplyer:
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
# logger.debug("开始构建 focus prompt")
|
||||
|
||||
|
||||
if sender_name:
|
||||
reply_target_block = f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
reply_target_block = (
|
||||
f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target_message:
|
||||
reply_target_block = f"现在{target_message}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
|
||||
|
||||
|
||||
|
||||
# --- Choose template based on chat type ---
|
||||
if is_group_chat:
|
||||
@@ -665,30 +661,30 @@ def find_similar_expressions(input_text: str, expressions: List[Dict], top_k: in
|
||||
"""使用TF-IDF和余弦相似度找出与输入文本最相似的top_k个表达方式"""
|
||||
if not expressions:
|
||||
return []
|
||||
|
||||
|
||||
# 准备文本数据
|
||||
texts = [expr['situation'] for expr in expressions]
|
||||
texts = [expr["situation"] for expr in expressions]
|
||||
texts.append(input_text) # 添加输入文本
|
||||
|
||||
|
||||
# 使用TF-IDF向量化
|
||||
vectorizer = TfidfVectorizer()
|
||||
tfidf_matrix = vectorizer.fit_transform(texts)
|
||||
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||
|
||||
|
||||
# 获取输入文本的相似度分数(最后一行)
|
||||
scores = similarity_matrix[-1][:-1] # 排除与自身的相似度
|
||||
|
||||
|
||||
# 获取top_k的索引
|
||||
top_indices = np.argsort(scores)[::-1][:top_k]
|
||||
|
||||
|
||||
# 获取相似表达
|
||||
similar_exprs = []
|
||||
for idx in top_indices:
|
||||
if scores[idx] > 0: # 只保留有相似度的
|
||||
similar_exprs.append(expressions[idx])
|
||||
|
||||
|
||||
return similar_exprs
|
||||
|
||||
|
||||
|
||||
@@ -62,13 +62,12 @@ class ChattingObservation(Observation):
|
||||
self.oldest_messages = []
|
||||
self.oldest_messages_str = ""
|
||||
self.compressor_prompt = ""
|
||||
|
||||
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
|
||||
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
|
||||
self.talking_message = initial_messages
|
||||
self.talking_message_str = build_readable_messages(self.talking_message, show_actions=True)
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将观察对象转换为可序列化的字典"""
|
||||
return {
|
||||
@@ -283,7 +282,7 @@ class ChattingObservation(Observation):
|
||||
show_actions=True,
|
||||
)
|
||||
# print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}")
|
||||
|
||||
|
||||
self.person_list = await get_person_id_list(self.talking_message)
|
||||
|
||||
# print(f"构建中:self.person_list: {self.person_list}")
|
||||
|
||||
@@ -42,9 +42,7 @@ class SubHeartflow:
|
||||
self.history_chat_state: List[Tuple[ChatState, float]] = []
|
||||
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
self.log_prefix = (
|
||||
chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
)
|
||||
self.log_prefix = chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
# 兴趣消息集合
|
||||
self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {}
|
||||
|
||||
@@ -199,7 +197,6 @@ class SubHeartflow:
|
||||
# 如果实例不存在,则创建并启动
|
||||
logger.info(f"{log_prefix} 麦麦准备开始专注聊天...")
|
||||
try:
|
||||
|
||||
self.heart_fc_instance = HeartFChatting(
|
||||
chat_id=self.subheartflow_id,
|
||||
# observations=self.observations,
|
||||
|
||||
@@ -23,7 +23,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
chat_target_info = None
|
||||
|
||||
try:
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
|
||||
if chat_stream:
|
||||
if chat_stream.group_info:
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
|
||||
from .global_logger import logger
|
||||
from .lpmmconfig import global_config
|
||||
from src.chat.knowledge.utils import get_sha256
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
|
||||
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
||||
|
||||
@@ -346,7 +346,9 @@ class Hippocampus:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
# logger.info(f"提取关键词数量: {topic_num}")
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
@@ -701,7 +703,9 @@ class Hippocampus:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
# logger.info(f"提取关键词数量: {topic_num}")
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
@@ -893,7 +897,7 @@ class EntorhinalCortex:
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
|
||||
# 批量准备节点数据
|
||||
nodes_to_create = []
|
||||
nodes_to_update = []
|
||||
@@ -929,22 +933,26 @@ class EntorhinalCortex:
|
||||
continue
|
||||
|
||||
if concept not in db_nodes:
|
||||
nodes_to_create.append({
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
else:
|
||||
db_node = db_nodes[concept]
|
||||
if db_node.hash != memory_hash:
|
||||
nodes_to_update.append({
|
||||
nodes_to_create.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
}
|
||||
)
|
||||
else:
|
||||
db_node = db_nodes[concept]
|
||||
if db_node.hash != memory_hash:
|
||||
nodes_to_update.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算需要删除的节点
|
||||
memory_concepts = {concept for concept, _ in memory_nodes}
|
||||
@@ -954,13 +962,13 @@ class EntorhinalCortex:
|
||||
if nodes_to_create:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i:i + batch_size]
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_update), batch_size):
|
||||
batch = nodes_to_update[i:i + batch_size]
|
||||
batch = nodes_to_update[i : i + batch_size]
|
||||
for node_data in batch:
|
||||
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
|
||||
GraphNodes.concept == node_data["concept"]
|
||||
@@ -992,22 +1000,26 @@ class EntorhinalCortex:
|
||||
last_modified = data.get("last_modified", current_time)
|
||||
|
||||
if edge_key not in db_edge_dict:
|
||||
edges_to_create.append({
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
edges_to_create.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
elif db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||
edges_to_update.append({
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
edges_to_update.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算需要删除的边
|
||||
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
|
||||
@@ -1017,13 +1029,13 @@ class EntorhinalCortex:
|
||||
if edges_to_create:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i:i + batch_size]
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_update), batch_size):
|
||||
batch = edges_to_update[i:i + batch_size]
|
||||
batch = edges_to_update[i : i + batch_size]
|
||||
for edge_data in batch:
|
||||
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
@@ -1031,9 +1043,7 @@ class EntorhinalCortex:
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
GraphEdges.delete().where(
|
||||
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||
).execute()
|
||||
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
|
||||
|
||||
end_time = time.time()
|
||||
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||
@@ -1069,13 +1079,15 @@ class EntorhinalCortex:
|
||||
if not memory_items_json:
|
||||
continue
|
||||
|
||||
nodes_data.append({
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
})
|
||||
nodes_data.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
|
||||
continue
|
||||
@@ -1084,14 +1096,16 @@ class EntorhinalCortex:
|
||||
edges_data = []
|
||||
for source, target, data in memory_edges:
|
||||
try:
|
||||
edges_data.append({
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": data.get("strength", 1),
|
||||
"hash": self.hippocampus.calculate_edge_hash(source, target),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
})
|
||||
edges_data.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": data.get("strength", 1),
|
||||
"hash": self.hippocampus.calculate_edge_hash(source, target),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
|
||||
continue
|
||||
@@ -1102,7 +1116,7 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
with GraphNodes._meta.database.atomic():
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i:i + batch_size]
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
@@ -1113,7 +1127,7 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
with GraphEdges._meta.database.atomic():
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i:i + batch_size]
|
||||
batch = edges_data[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
|
||||
@@ -79,7 +79,7 @@ class ChatBot:
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
chat_manager.register_message(message)
|
||||
|
||||
|
||||
# 创建聊天流
|
||||
chat = await chat_manager.get_or_create_stream(
|
||||
platform=message.message_info.platform,
|
||||
@@ -87,13 +87,13 @@ class ChatBot:
|
||||
group_info=group_info,
|
||||
)
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
|
||||
# 命令处理 - 在消息处理的早期阶段检查并处理命令
|
||||
is_command, cmd_result, continue_process = await command_manager.process_command(message)
|
||||
|
||||
|
||||
# 如果是命令且不需要继续处理,则直接返回
|
||||
if is_command and not continue_process:
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
|
||||
@@ -24,7 +24,6 @@ class MessageStorage:
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
|
||||
if isinstance(message, MessageSending):
|
||||
display_message = message.display_message
|
||||
if display_message:
|
||||
|
||||
@@ -13,8 +13,6 @@ from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from .normal_chat_generator import NormalChatGenerator
|
||||
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import willing_manager
|
||||
from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
|
||||
from src.config.config import global_config
|
||||
@@ -69,7 +67,7 @@ class NormalChat:
|
||||
self.on_switch_to_focus_callback = on_switch_to_focus_callback
|
||||
|
||||
self._disabled = False # 增加停用标志
|
||||
|
||||
|
||||
logger.debug(f"[{self.stream_name}] NormalChat 初始化完成 (异步部分)。")
|
||||
|
||||
# 改为实例方法
|
||||
@@ -193,7 +191,9 @@ class NormalChat:
|
||||
return
|
||||
|
||||
timing_results = {}
|
||||
reply_probability = 1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0 # 如果被提及,且开启了提及必回复,则基础概率为1,否则需要意愿判断
|
||||
reply_probability = (
|
||||
1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0
|
||||
) # 如果被提及,且开启了提及必回复,则基础概率为1,否则需要意愿判断
|
||||
|
||||
# 意愿管理器:设置当前message信息
|
||||
willing_manager.setup(message, self.chat_stream, is_mentioned, interested_rate)
|
||||
@@ -267,13 +267,17 @@ class NormalChat:
|
||||
try:
|
||||
# 获取发送者名称(动作修改已在并行执行前完成)
|
||||
sender_name = self._get_sender_name(message)
|
||||
|
||||
|
||||
no_action = {
|
||||
"action_result": {"action_type": "no_action", "action_data": {}, "reasoning": "规划器初始化默认", "is_parallel": True},
|
||||
"action_result": {
|
||||
"action_type": "no_action",
|
||||
"action_data": {},
|
||||
"reasoning": "规划器初始化默认",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
|
||||
|
||||
# 检查是否应该跳过规划
|
||||
if self.action_modifier.should_skip_planning():
|
||||
@@ -288,7 +292,9 @@ class NormalChat:
|
||||
reasoning = plan_result["action_result"]["reasoning"]
|
||||
is_parallel = plan_result["action_result"].get("is_parallel", False)
|
||||
|
||||
logger.info(f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}")
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}"
|
||||
)
|
||||
self.action_type = action_type # 更新实例属性
|
||||
self.is_parallel_action = is_parallel # 新增:保存并行执行标志
|
||||
|
||||
@@ -307,7 +313,12 @@ class NormalChat:
|
||||
else:
|
||||
logger.warning(f"[{self.stream_name}] 额外动作 {action_type} 执行失败")
|
||||
|
||||
return {"action_type": action_type, "action_data": action_data, "reasoning": reasoning, "is_parallel": is_parallel}
|
||||
return {
|
||||
"action_type": action_type,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": is_parallel,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Planner执行失败: {e}")
|
||||
@@ -331,13 +342,19 @@ class NormalChat:
|
||||
logger.error(f"[{self.stream_name}] 动作规划异常: {plan_result}")
|
||||
elif plan_result:
|
||||
logger.debug(f"[{self.stream_name}] 额外动作处理完成: {self.action_type}")
|
||||
|
||||
|
||||
if not response_set or (
|
||||
self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action
|
||||
self.enable_planner
|
||||
and self.action_type not in ["no_action", "change_to_focus_chat"]
|
||||
and not self.is_parallel_action
|
||||
):
|
||||
if not response_set:
|
||||
logger.info(f"[{self.stream_name}] 模型未生成回复内容")
|
||||
elif self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action:
|
||||
elif (
|
||||
self.enable_planner
|
||||
and self.action_type not in ["no_action", "change_to_focus_chat"]
|
||||
and not self.is_parallel_action
|
||||
):
|
||||
logger.info(f"[{self.stream_name}] 模型选择其他动作(非并行动作)")
|
||||
# 如果模型未生成回复,移除思考消息
|
||||
container = await message_manager.get_container(self.stream_id) # 使用 self.stream_id
|
||||
@@ -364,7 +381,6 @@ class NormalChat:
|
||||
|
||||
# 检查 first_bot_msg 是否为 None (例如思考消息已被移除的情况)
|
||||
if first_bot_msg:
|
||||
|
||||
# 记录回复信息到最近回复列表中
|
||||
reply_info = {
|
||||
"time": time.time(),
|
||||
@@ -396,7 +412,6 @@ class NormalChat:
|
||||
await self._check_switch_to_focus()
|
||||
pass
|
||||
|
||||
|
||||
# with Timer("关系更新", timing_results):
|
||||
# await self._update_relationship(message, response_set)
|
||||
|
||||
@@ -605,7 +620,7 @@ class NormalChat:
|
||||
# 执行动作
|
||||
result = await action_handler.handle_action()
|
||||
success = False
|
||||
|
||||
|
||||
if result and isinstance(result, tuple) and len(result) >= 2:
|
||||
# handle_action返回 (success: bool, message: str)
|
||||
success = result[0]
|
||||
|
||||
@@ -35,7 +35,7 @@ class NormalChatActionModifier:
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""为Normal Chat修改可用动作集合
|
||||
|
||||
|
||||
实现动作激活策略:
|
||||
1. 基于关联类型的动态过滤
|
||||
2. 基于激活类型的智能判定(LLM_JUDGE转为概率激活)
|
||||
@@ -49,7 +49,7 @@ class NormalChatActionModifier:
|
||||
reasons = []
|
||||
merged_action_changes = {"add": [], "remove": []}
|
||||
type_mismatched_actions = [] # 在外层定义避免作用域问题
|
||||
|
||||
|
||||
self.action_manager.restore_default_actions()
|
||||
|
||||
# 第一阶段:基于关联类型的动态过滤
|
||||
@@ -74,7 +74,7 @@ class NormalChatActionModifier:
|
||||
# 第二阶段:应用激活类型判定
|
||||
# 构建聊天内容 - 使用与planner一致的方式
|
||||
chat_content = ""
|
||||
if chat_stream and hasattr(chat_stream, 'stream_id'):
|
||||
if chat_stream and hasattr(chat_stream, "stream_id"):
|
||||
try:
|
||||
# 获取消息历史,使用与normal_chat_planner相同的方法
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
@@ -82,7 +82,7 @@ class NormalChatActionModifier:
|
||||
timestamp=time.time(),
|
||||
limit=global_config.focus_chat.observation_context_size, # 使用相同的配置
|
||||
)
|
||||
|
||||
|
||||
# 构建可读的聊天上下文
|
||||
chat_content = build_readable_messages(
|
||||
message_list_before_now,
|
||||
@@ -92,39 +92,41 @@ class NormalChatActionModifier:
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix} 成功构建聊天内容,长度: {len(chat_content)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.log_prefix} 构建聊天内容失败: {e}")
|
||||
chat_content = ""
|
||||
|
||||
|
||||
# 获取当前Normal模式下的动作集进行激活判定
|
||||
current_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
|
||||
# print(f"current_actions: {current_actions}")
|
||||
# print(f"chat_content: {chat_content}")
|
||||
final_activated_actions = await self._apply_normal_activation_filtering(
|
||||
current_actions,
|
||||
chat_content,
|
||||
message_content
|
||||
current_actions, chat_content, message_content
|
||||
)
|
||||
# print(f"final_activated_actions: {final_activated_actions}")
|
||||
|
||||
|
||||
# 统一处理所有需要移除的动作,避免重复移除
|
||||
all_actions_to_remove = set() # 使用set避免重复
|
||||
|
||||
|
||||
# 添加关联类型不匹配的动作
|
||||
if type_mismatched_actions:
|
||||
all_actions_to_remove.update(type_mismatched_actions)
|
||||
|
||||
|
||||
# 添加激活类型判定未通过的动作
|
||||
for action_name in current_actions.keys():
|
||||
if action_name not in final_activated_actions:
|
||||
all_actions_to_remove.add(action_name)
|
||||
|
||||
|
||||
# 统计移除原因(避免重复)
|
||||
activation_failed_actions = [name for name in current_actions.keys() if name not in final_activated_actions and name not in type_mismatched_actions]
|
||||
activation_failed_actions = [
|
||||
name
|
||||
for name in current_actions.keys()
|
||||
if name not in final_activated_actions and name not in type_mismatched_actions
|
||||
]
|
||||
if activation_failed_actions:
|
||||
reasons.append(f"移除{activation_failed_actions}(激活类型判定未通过)")
|
||||
|
||||
@@ -146,7 +148,7 @@ class NormalChatActionModifier:
|
||||
# 记录变更原因
|
||||
if reasons:
|
||||
logger.info(f"{self.log_prefix} 动作调整完成: {' | '.join(reasons)}")
|
||||
|
||||
|
||||
# 获取最终的Normal模式可用动作并记录
|
||||
final_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
logger.debug(f"{self.log_prefix} 当前Normal模式可用动作: {list(final_actions.keys())}")
|
||||
@@ -159,31 +161,31 @@ class NormalChatActionModifier:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
应用Normal模式的激活类型过滤逻辑
|
||||
|
||||
|
||||
与Focus模式的区别:
|
||||
1. LLM_JUDGE类型转换为概率激活(避免LLM调用)
|
||||
2. RANDOM类型保持概率激活
|
||||
3. KEYWORD类型保持关键词匹配
|
||||
4. ALWAYS类型直接激活
|
||||
|
||||
|
||||
Args:
|
||||
actions_with_info: 带完整信息的动作字典
|
||||
chat_content: 聊天内容
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 过滤后激活的actions字典
|
||||
"""
|
||||
activated_actions = {}
|
||||
|
||||
|
||||
# 分类处理不同激活类型的actions
|
||||
always_actions = {}
|
||||
random_actions = {}
|
||||
keyword_actions = {}
|
||||
|
||||
|
||||
for action_name, action_info in actions_with_info.items():
|
||||
# 使用normal_activation_type
|
||||
activation_type = action_info.get("normal_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
always_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.RANDOM or activation_type == ActionActivationType.LLM_JUDGE:
|
||||
@@ -192,12 +194,12 @@ class NormalChatActionModifier:
|
||||
keyword_actions[action_name] = action_info
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
|
||||
# 1. 处理ALWAYS类型(直接激活)
|
||||
for action_name, action_info in always_actions.items():
|
||||
activated_actions[action_name] = action_info
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活")
|
||||
|
||||
|
||||
# 2. 处理RANDOM类型(概率激活)
|
||||
for action_name, action_info in random_actions.items():
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
@@ -207,15 +209,10 @@ class NormalChatActionModifier:
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发(概率{probability})")
|
||||
|
||||
|
||||
# 3. 处理KEYWORD类型(关键词匹配)
|
||||
for action_name, action_info in keyword_actions.items():
|
||||
should_activate = self._check_keyword_activation(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
message_content
|
||||
)
|
||||
should_activate = self._check_keyword_activation(action_name, action_info, chat_content, message_content)
|
||||
if should_activate:
|
||||
activated_actions[action_name] = action_info
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
@@ -225,7 +222,7 @@ class NormalChatActionModifier:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})")
|
||||
# print(f"keywords: {keywords}")
|
||||
# print(f"chat_content: {chat_content}")
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix}Normal模式激活类型过滤完成: {list(activated_actions.keys())}")
|
||||
return activated_actions
|
||||
|
||||
@@ -238,41 +235,40 @@ class NormalChatActionModifier:
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否匹配关键词触发条件
|
||||
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
chat_content: 聊天内容(已经是格式化后的可读消息)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
|
||||
activation_keywords = action_info.get("activation_keywords", [])
|
||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
||||
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
return False
|
||||
|
||||
|
||||
# 使用构建好的聊天内容作为检索文本
|
||||
search_text = chat_content +message_content
|
||||
|
||||
search_text = chat_content + message_content
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
|
||||
# 检查每个关键词
|
||||
matched_keywords = []
|
||||
for keyword in activation_keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
|
||||
|
||||
# print(f"search_text: {search_text}")
|
||||
# print(f"activation_keywords: {activation_keywords}")
|
||||
|
||||
|
||||
if matched_keywords:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
return True
|
||||
|
||||
@@ -9,7 +9,7 @@ import time
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, MessageThinking, Seg
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream,chat_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
@@ -37,7 +37,7 @@ class NormalChatExpressor:
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_name = chat_manager.get_stream_name(self.chat_stream.stream_id) or self.chat_stream.stream_id
|
||||
self.log_prefix = f"[{self.stream_name}]Normal表达器"
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix} 初始化完成")
|
||||
|
||||
async def create_thinking_message(
|
||||
|
||||
@@ -25,9 +25,7 @@ class NormalChatGenerator:
|
||||
request_type="normal.chat_2",
|
||||
)
|
||||
|
||||
self.model_sum = LLMRequest(
|
||||
model=global_config.model.memory_summary, temperature=0.7, request_type="relation"
|
||||
)
|
||||
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
|
||||
self.current_model_type = "r1" # 默认使用 R1
|
||||
self.current_model_name = "unknown model"
|
||||
|
||||
@@ -68,7 +66,6 @@ class NormalChatGenerator:
|
||||
enable_planner: bool = False,
|
||||
available_actions=None,
|
||||
):
|
||||
|
||||
person_id = person_info_manager.get_person_id(
|
||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
)
|
||||
@@ -103,7 +100,6 @@ class NormalChatGenerator:
|
||||
|
||||
logger.info(f"对 {message.processed_plain_text} 的回复:{content}")
|
||||
|
||||
|
||||
except Exception:
|
||||
logger.exception("生成回复时出错")
|
||||
return None
|
||||
|
||||
@@ -101,7 +101,7 @@ class NormalChatPlanner:
|
||||
|
||||
# 获取当前可用的动作,使用Normal模式过滤
|
||||
current_available_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
|
||||
|
||||
# 注意:动作的激活判定现在在 normal_chat_action_modifier 中完成
|
||||
# 这里直接使用经过 action_modifier 处理后的最终动作集
|
||||
# 符合职责分离原则:ActionModifier负责动作管理,Planner专注于决策
|
||||
@@ -110,7 +110,12 @@ class NormalChatPlanner:
|
||||
if not current_available_actions:
|
||||
logger.debug(f"{self.log_prefix}规划器: 没有可用动作,返回no_action")
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": True},
|
||||
"action_result": {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": True,
|
||||
},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
@@ -121,7 +126,7 @@ class NormalChatPlanner:
|
||||
timestamp=time.time(),
|
||||
limit=global_config.focus_chat.observation_context_size,
|
||||
)
|
||||
|
||||
|
||||
chat_context = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
@@ -130,7 +135,7 @@ class NormalChatPlanner:
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
# 构建planner的prompt
|
||||
prompt = await self.build_planner_prompt(
|
||||
self_info_block=self_info,
|
||||
@@ -141,7 +146,12 @@ class NormalChatPlanner:
|
||||
if not prompt:
|
||||
logger.warning(f"{self.log_prefix}规划器: 构建提示词失败")
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": False},
|
||||
"action_result": {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": False,
|
||||
},
|
||||
"chat_context": chat_context,
|
||||
"action_prompt": "",
|
||||
}
|
||||
@@ -149,7 +159,7 @@ class NormalChatPlanner:
|
||||
# 使用LLM生成动作决策
|
||||
try:
|
||||
content, (reasoning_content, model_name) = await self.planner_llm.generate_response_async(prompt)
|
||||
|
||||
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {content}")
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
@@ -201,8 +211,10 @@ class NormalChatPlanner:
|
||||
if action in current_available_actions:
|
||||
action_info = current_available_actions[action]
|
||||
is_parallel = action_info.get("parallel_action", False)
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}")
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}"
|
||||
)
|
||||
|
||||
# 恢复到默认动作集
|
||||
self.action_manager.restore_actions()
|
||||
@@ -216,15 +228,15 @@ class NormalChatPlanner:
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"timestamp": time.time(),
|
||||
"model_name": model_name if 'model_name' in locals() else None
|
||||
"model_name": model_name if "model_name" in locals() else None,
|
||||
}
|
||||
|
||||
action_result = {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": is_parallel,
|
||||
"action_record": json.dumps(action_record, ensure_ascii=False)
|
||||
"action_record": json.dumps(action_record, ensure_ascii=False),
|
||||
}
|
||||
|
||||
plan_result = {
|
||||
@@ -248,24 +260,19 @@ class NormalChatPlanner:
|
||||
|
||||
# 添加特殊的change_to_focus_chat动作
|
||||
action_options_text += "动作:change_to_focus_chat\n"
|
||||
action_options_text += (
|
||||
"该动作的描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n"
|
||||
)
|
||||
action_options_text += "该动作的描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n"
|
||||
|
||||
action_options_text += "使用该动作的场景:\n"
|
||||
action_options_text += "- 聊天上下文中自己的回复条数较多(超过3-4条)\n"
|
||||
action_options_text += "- 对话进行得非常热烈活跃\n"
|
||||
action_options_text += "- 用户表现出深入交流的意图\n"
|
||||
action_options_text += "- 话题需要更专注和深入的讨论\n\n"
|
||||
|
||||
|
||||
action_options_text += "输出要求:\n"
|
||||
action_options_text += "{{"
|
||||
action_options_text += " \"action\": \"change_to_focus_chat\""
|
||||
action_options_text += ' "action": "change_to_focus_chat"'
|
||||
action_options_text += "}}\n\n"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
action_description = action_info.get("description", "")
|
||||
action_parameters = action_info.get("parameters", {})
|
||||
@@ -276,15 +283,14 @@ class NormalChatPlanner:
|
||||
print(action_parameters)
|
||||
for param_name, param_description in action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip('\n')
|
||||
param_text = param_text.rstrip("\n")
|
||||
else:
|
||||
param_text = ""
|
||||
|
||||
|
||||
require_text = ""
|
||||
for require_item in action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip('\n')
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
# 构建单个动作的提示
|
||||
action_prompt = await global_prompt_manager.format_prompt(
|
||||
@@ -316,6 +322,4 @@ class NormalChatPlanner:
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -214,7 +214,6 @@ class PromptBuilder:
|
||||
except Exception as e:
|
||||
logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
# 构建action描述 (如果启用planner)
|
||||
|
||||
@@ -42,9 +42,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
|
||||
reply_probability = min(
|
||||
max((current_willing - 0.5), 0.01) * 2, 1
|
||||
)
|
||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
|
||||
|
||||
# 检查群组权限(如果是群聊)
|
||||
if (
|
||||
|
||||
@@ -286,7 +286,7 @@ def _build_readable_messages_internal(
|
||||
message_details_with_flags.append((timestamp, name, content, is_action))
|
||||
# print(f"content:{content}")
|
||||
# print(f"is_action:{is_action}")
|
||||
|
||||
|
||||
# print(f"message_details_with_flags:{message_details_with_flags}")
|
||||
|
||||
# 应用截断逻辑 (如果 truncate 为 True)
|
||||
@@ -324,7 +324,7 @@ def _build_readable_messages_internal(
|
||||
else:
|
||||
# 如果不截断,直接使用原始列表
|
||||
message_details = message_details_with_flags
|
||||
|
||||
|
||||
# print(f"message_details:{message_details}")
|
||||
|
||||
# 3: 合并连续消息 (如果 merge_messages 为 True)
|
||||
@@ -336,12 +336,12 @@ def _build_readable_messages_internal(
|
||||
"start_time": message_details[0][0],
|
||||
"end_time": message_details[0][0],
|
||||
"content": [message_details[0][2]],
|
||||
"is_action": message_details[0][3]
|
||||
"is_action": message_details[0][3],
|
||||
}
|
||||
|
||||
for i in range(1, len(message_details)):
|
||||
timestamp, name, content, is_action = message_details[i]
|
||||
|
||||
|
||||
# 对于动作记录,不进行合并
|
||||
if is_action or current_merge["is_action"]:
|
||||
# 保存当前的合并块
|
||||
@@ -352,7 +352,7 @@ def _build_readable_messages_internal(
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"content": [content],
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
continue
|
||||
|
||||
@@ -365,11 +365,11 @@ def _build_readable_messages_internal(
|
||||
merged_messages.append(current_merge)
|
||||
# 开始新的合并块
|
||||
current_merge = {
|
||||
"name": name,
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"name": name,
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"content": [content],
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
# 添加最后一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
@@ -381,10 +381,9 @@ def _build_readable_messages_internal(
|
||||
"start_time": timestamp, # 起始和结束时间相同
|
||||
"end_time": timestamp,
|
||||
"content": [content], # 内容只有一个元素
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 4 & 5: 格式化为字符串
|
||||
output_lines = []
|
||||
@@ -451,7 +450,7 @@ def build_readable_messages(
|
||||
将消息列表转换为可读的文本格式。
|
||||
如果提供了 read_mark,则在相应位置插入已读标记。
|
||||
允许通过参数控制格式化行为。
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否替换机器人名称为"你"
|
||||
@@ -463,22 +462,24 @@ def build_readable_messages(
|
||||
"""
|
||||
# 创建messages的深拷贝,避免修改原始列表
|
||||
copy_messages = [msg.copy() for msg in messages]
|
||||
|
||||
|
||||
if show_actions and copy_messages:
|
||||
# 获取所有消息的时间范围
|
||||
min_time = min(msg.get("time", 0) for msg in copy_messages)
|
||||
max_time = max(msg.get("time", 0) for msg in copy_messages)
|
||||
|
||||
|
||||
# 从第一条消息中获取chat_id
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions = ActionRecords.select().where(
|
||||
(ActionRecords.time >= min_time) &
|
||||
(ActionRecords.time <= max_time) &
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
).order_by(ActionRecords.time)
|
||||
|
||||
actions = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
|
||||
# 将动作记录转换为消息格式
|
||||
for action in actions:
|
||||
# 只有当build_into_prompt为True时才添加动作记录
|
||||
@@ -495,25 +496,22 @@ def build_readable_messages(
|
||||
"action_name": action.action_name, # 保存动作名称
|
||||
}
|
||||
copy_messages.append(action_msg)
|
||||
|
||||
|
||||
# 重新按时间排序
|
||||
copy_messages.sort(key=lambda x: x.get("time", 0))
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
|
||||
|
||||
# for message in messages:
|
||||
# print(f"message:{message}")
|
||||
|
||||
|
||||
# print(f"message:{message}")
|
||||
|
||||
formatted_string, _ = _build_readable_messages_internal(
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
|
||||
# print(f"formatted_string:{formatted_string}")
|
||||
|
||||
|
||||
|
||||
|
||||
return formatted_string
|
||||
else:
|
||||
# 按 read_mark 分割消息
|
||||
@@ -521,10 +519,10 @@ def build_readable_messages(
|
||||
messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark]
|
||||
|
||||
# for message in messages_before_mark:
|
||||
# print(f"message:{message}")
|
||||
|
||||
# print(f"message:{message}")
|
||||
|
||||
# for message in messages_after_mark:
|
||||
# print(f"message:{message}")
|
||||
# print(f"message:{message}")
|
||||
|
||||
# 分别格式化
|
||||
formatted_before, _ = _build_readable_messages_internal(
|
||||
@@ -536,7 +534,7 @@ def build_readable_messages(
|
||||
merge_messages,
|
||||
timestamp_mode,
|
||||
)
|
||||
|
||||
|
||||
# print(f"formatted_before:{formatted_before}")
|
||||
# print(f"formatted_after:{formatted_after}")
|
||||
|
||||
|
||||
@@ -154,7 +154,8 @@ class Messages(BaseModel):
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "messages"
|
||||
|
||||
|
||||
|
||||
class ActionRecords(BaseModel):
|
||||
"""
|
||||
用于存储动作记录数据的模型。
|
||||
@@ -162,11 +163,11 @@ class ActionRecords(BaseModel):
|
||||
|
||||
action_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
|
||||
time = DoubleField() # 消息时间戳
|
||||
|
||||
|
||||
action_name = TextField()
|
||||
action_data = TextField()
|
||||
action_done = BooleanField(default=False)
|
||||
|
||||
|
||||
action_build_into_prompt = BooleanField(default=False)
|
||||
action_prompt_display = TextField()
|
||||
|
||||
@@ -241,11 +242,10 @@ class PersonInfo(BaseModel):
|
||||
points = TextField(null=True) # 个人印象的点
|
||||
forgotten_points = TextField(null=True) # 被遗忘的点
|
||||
info_list = TextField(null=True) # 与Bot的互动
|
||||
|
||||
|
||||
know_times = FloatField(null=True) # 认识时间 (时间戳)
|
||||
know_since = FloatField(null=True) # 首次印象总结时间
|
||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
@@ -403,20 +403,20 @@ def initialize_database():
|
||||
logger.info(f"表 '{table_name}' 缺失字段 '{field_name}',正在添加...")
|
||||
field_type = field_obj.__class__.__name__
|
||||
sql_type = {
|
||||
'TextField': 'TEXT',
|
||||
'IntegerField': 'INTEGER',
|
||||
'FloatField': 'FLOAT',
|
||||
'DoubleField': 'DOUBLE',
|
||||
'BooleanField': 'INTEGER',
|
||||
'DateTimeField': 'DATETIME'
|
||||
}.get(field_type, 'TEXT')
|
||||
alter_sql = f'ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}'
|
||||
"TextField": "TEXT",
|
||||
"IntegerField": "INTEGER",
|
||||
"FloatField": "FLOAT",
|
||||
"DoubleField": "DOUBLE",
|
||||
"BooleanField": "INTEGER",
|
||||
"DateTimeField": "DATETIME",
|
||||
}.get(field_type, "TEXT")
|
||||
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
||||
if field_obj.null:
|
||||
alter_sql += ' NULL'
|
||||
alter_sql += " NULL"
|
||||
else:
|
||||
alter_sql += ' NOT NULL'
|
||||
if hasattr(field_obj, 'default') and field_obj.default is not None:
|
||||
alter_sql += f' DEFAULT {field_obj.default}'
|
||||
alter_sql += " NOT NULL"
|
||||
if hasattr(field_obj, "default") and field_obj.default is not None:
|
||||
alter_sql += f" DEFAULT {field_obj.default}"
|
||||
db.execute_sql(alter_sql)
|
||||
logger.info(f"字段 '{field_name}' 添加成功")
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def update_config():
|
||||
contains_regex = False
|
||||
if value and isinstance(value[0], dict) and "regex" in value[0]:
|
||||
contains_regex = True
|
||||
|
||||
|
||||
if contains_regex:
|
||||
target[key] = value
|
||||
else:
|
||||
|
||||
@@ -49,7 +49,7 @@ class IdentityConfig(ConfigBase):
|
||||
@dataclass
|
||||
class RelationshipConfig(ConfigBase):
|
||||
"""关系配置类"""
|
||||
|
||||
|
||||
enable_relationship: bool = True
|
||||
|
||||
give_name: bool = False
|
||||
@@ -58,6 +58,7 @@ class RelationshipConfig(ConfigBase):
|
||||
build_relationship_interval: int = 600
|
||||
"""构建关系间隔 单位秒,如果为0则不构建关系"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatConfig(ConfigBase):
|
||||
"""聊天配置类"""
|
||||
@@ -222,7 +223,7 @@ class EmojiConfig(ConfigBase):
|
||||
@dataclass
|
||||
class MemoryConfig(ConfigBase):
|
||||
"""记忆配置类"""
|
||||
|
||||
|
||||
enable_memory: bool = True
|
||||
|
||||
memory_build_interval: int = 600
|
||||
@@ -329,6 +330,7 @@ class KeywordReactionConfig(ConfigBase):
|
||||
if not isinstance(rule, KeywordRuleConfig):
|
||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponsePostProcessConfig(ConfigBase):
|
||||
"""回复后处理配置类"""
|
||||
@@ -461,7 +463,7 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
|
||||
qa_res_top_k: int = 10
|
||||
"""QA最终结果的Top K数量"""
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(ConfigBase):
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Callable
|
||||
|
||||
# from src.common.database.database import db # Peewee db 导入
|
||||
from playhouse import shortcuts
|
||||
|
||||
# from src.common.database.database import db # Peewee db 导入
|
||||
from src.common.database.database_model import Messages # Peewee Messages 模型导入
|
||||
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
|
||||
|
||||
model_to_dict: Callable[..., dict] = shortcuts.model_to_dict # Peewee 模型转换为字典的快捷函数
|
||||
|
||||
|
||||
class MessageStorage(ABC):
|
||||
|
||||
@@ -90,7 +90,7 @@ class PersonalityExpression:
|
||||
|
||||
current_style_text = global_config.expression.expression_style
|
||||
current_personality = global_config.personality.personality_core
|
||||
|
||||
|
||||
meta_data = self._read_meta_data()
|
||||
|
||||
last_style_text = meta_data.get("last_style_text")
|
||||
@@ -98,9 +98,10 @@ class PersonalityExpression:
|
||||
count = meta_data.get("count", 0)
|
||||
|
||||
# 检查是否有任何变化
|
||||
if (current_style_text != last_style_text or
|
||||
current_personality != last_personality):
|
||||
logger.info(f"检测到变化:\n风格: '{last_style_text}' -> '{current_style_text}'\n人格: '{last_personality}' -> '{current_personality}'")
|
||||
if current_style_text != last_style_text or current_personality != last_personality:
|
||||
logger.info(
|
||||
f"检测到变化:\n风格: '{last_style_text}' -> '{current_style_text}'\n人格: '{last_personality}' -> '{current_personality}'"
|
||||
)
|
||||
count = 0
|
||||
if os.path.exists(self.expressions_file_path):
|
||||
try:
|
||||
@@ -196,7 +197,7 @@ class PersonalityExpression:
|
||||
"last_style_text": current_style_text,
|
||||
"last_personality": current_personality,
|
||||
"count": count,
|
||||
"last_update_time": current_time
|
||||
"last_update_time": current_time,
|
||||
}
|
||||
)
|
||||
logger.info(f"成功处理。当前配置的计数现在是 {count},最后更新时间:{current_time}。")
|
||||
|
||||
39
src/main.py
39
src/main.py
@@ -19,6 +19,7 @@ from .common.server import global_server, Server
|
||||
from rich.traceback import install
|
||||
from .chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||
from .api.main import start_api_server
|
||||
|
||||
# 导入actions模块,确保装饰器被执行
|
||||
import src.chat.actions.default_actions # noqa
|
||||
|
||||
@@ -40,7 +41,7 @@ class MainSystem:
|
||||
self.hippocampus_manager = hippocampus_manager
|
||||
else:
|
||||
self.hippocampus_manager = None
|
||||
|
||||
|
||||
self.individuality: Individuality = individuality
|
||||
|
||||
# 使用消息API替代直接的FastAPI实例
|
||||
@@ -74,11 +75,11 @@ class MainSystem:
|
||||
# 启动API服务器
|
||||
start_api_server()
|
||||
logger.success("API服务器启动成功")
|
||||
|
||||
|
||||
# 加载所有actions,包括默认的和插件的
|
||||
self._load_all_actions()
|
||||
logger.success("动作系统加载成功")
|
||||
|
||||
|
||||
# 初始化表情管理器
|
||||
emoji_manager.initialize()
|
||||
logger.success("表情包管理器初始化成功")
|
||||
@@ -137,23 +138,25 @@ class MainSystem:
|
||||
try:
|
||||
# 导入统一的插件加载器
|
||||
from src.plugins.plugin_loader import plugin_loader
|
||||
|
||||
|
||||
# 使用统一的插件加载器加载所有插件组件
|
||||
loaded_actions, loaded_commands = plugin_loader.load_all_plugins()
|
||||
|
||||
|
||||
# 加载命令处理系统
|
||||
try:
|
||||
# 导入命令处理系统
|
||||
from src.chat.command.command_handler import command_manager
|
||||
|
||||
logger.success("命令处理系统加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"加载命令处理系统失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def schedule_tasks(self):
|
||||
@@ -165,17 +168,19 @@ class MainSystem:
|
||||
self.app.run(),
|
||||
self.server.run(),
|
||||
]
|
||||
|
||||
|
||||
# 根据配置条件性地添加记忆系统相关任务
|
||||
if global_config.memory.enable_memory and self.hippocampus_manager:
|
||||
tasks.extend([
|
||||
self.build_memory_task(),
|
||||
self.forget_memory_task(),
|
||||
self.consolidate_memory_task(),
|
||||
])
|
||||
|
||||
tasks.extend(
|
||||
[
|
||||
self.build_memory_task(),
|
||||
self.forget_memory_task(),
|
||||
self.consolidate_memory_task(),
|
||||
]
|
||||
)
|
||||
|
||||
tasks.append(self.learn_and_store_expression_task())
|
||||
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def build_memory_task(self):
|
||||
@@ -190,9 +195,7 @@ class MainSystem:
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.forget_memory_interval)
|
||||
logger.info("[记忆遗忘] 开始遗忘记忆...")
|
||||
await self.hippocampus_manager.forget_memory(
|
||||
percentage=global_config.memory.memory_forget_percentage
|
||||
)
|
||||
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage)
|
||||
logger.info("[记忆遗忘] 记忆遗忘完成")
|
||||
|
||||
async def consolidate_memory_task(self):
|
||||
|
||||
@@ -11,6 +11,7 @@ from collections import defaultdict
|
||||
|
||||
logger = get_logger("relation")
|
||||
|
||||
|
||||
# 暂时弃用,改为实时更新
|
||||
class ImpressionUpdateTask(AsyncTask):
|
||||
def __init__(self):
|
||||
@@ -25,10 +26,10 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
# 获取最近的消息
|
||||
current_time = int(time.time())
|
||||
start_time = current_time - global_config.relationship.build_relationship_interval # 100分钟前
|
||||
|
||||
|
||||
# 获取所有消息
|
||||
messages = get_raw_msg_by_timestamp(timestamp_start=start_time, timestamp_end=current_time)
|
||||
|
||||
|
||||
if not messages:
|
||||
logger.info("没有找到需要处理的消息")
|
||||
return
|
||||
@@ -48,7 +49,7 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
if len(msgs) < 30:
|
||||
logger.info(f"聊天组 {chat_id} 消息数小于30,跳过处理")
|
||||
continue
|
||||
|
||||
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"未找到聊天组 {chat_id} 的chat_stream,跳过处理")
|
||||
@@ -56,26 +57,26 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
|
||||
# 找到bot的消息
|
||||
bot_messages = [msg for msg in msgs if msg["user_nickname"] == global_config.bot.nickname]
|
||||
|
||||
|
||||
if not bot_messages:
|
||||
logger.info(f"聊天组 {chat_id} 没有bot消息,跳过处理")
|
||||
continue
|
||||
|
||||
# 按时间排序所有消息
|
||||
sorted_messages = sorted(msgs, key=lambda x: x["time"])
|
||||
|
||||
|
||||
# 找到第一条和最后一条bot消息
|
||||
first_bot_msg = bot_messages[0]
|
||||
last_bot_msg = bot_messages[-1]
|
||||
|
||||
|
||||
# 获取第一条bot消息前15条消息
|
||||
first_bot_index = sorted_messages.index(first_bot_msg)
|
||||
start_index = max(0, first_bot_index - 25)
|
||||
|
||||
|
||||
# 获取最后一条bot消息后15条消息
|
||||
last_bot_index = sorted_messages.index(last_bot_msg)
|
||||
end_index = min(len(sorted_messages), last_bot_index + 26)
|
||||
|
||||
|
||||
# 获取相关消息
|
||||
relevant_messages = sorted_messages[start_index:end_index]
|
||||
|
||||
@@ -85,7 +86,9 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
# 计算权重
|
||||
for bot_msg in bot_messages:
|
||||
bot_time = bot_msg["time"]
|
||||
context_messages = [msg for msg in relevant_messages if abs(msg["time"] - bot_time) <= 600] # 前后10分钟
|
||||
context_messages = [
|
||||
msg for msg in relevant_messages if abs(msg["time"] - bot_time) <= 600
|
||||
] # 前后10分钟
|
||||
logger.debug(f"Bot消息 {bot_time} 的上下文消息数: {len(context_messages)}")
|
||||
|
||||
for msg in context_messages:
|
||||
@@ -121,7 +124,7 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
weights = [user[1]["weight"] for user in sorted_users]
|
||||
total_weight = sum(weights)
|
||||
# 计算每个用户的概率
|
||||
probabilities = [w/total_weight for w in weights]
|
||||
probabilities = [w / total_weight for w in weights]
|
||||
# 使用累积概率进行选择
|
||||
selected_indices = []
|
||||
remaining_indices = list(range(len(sorted_users)))
|
||||
@@ -131,12 +134,12 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
# 计算剩余索引的累积概率
|
||||
remaining_probs = [probabilities[i] for i in remaining_indices]
|
||||
# 归一化概率
|
||||
remaining_probs = [p/sum(remaining_probs) for p in remaining_probs]
|
||||
remaining_probs = [p / sum(remaining_probs) for p in remaining_probs]
|
||||
# 选择索引
|
||||
chosen_idx = random.choices(remaining_indices, weights=remaining_probs, k=1)[0]
|
||||
selected_indices.append(chosen_idx)
|
||||
remaining_indices.remove(chosen_idx)
|
||||
|
||||
|
||||
selected_users = [sorted_users[i] for i in selected_indices]
|
||||
logger.info(
|
||||
f"开始进一步了解这些用户: {[msg[1]['messages'][0]['user_nickname'] for msg in selected_users]}"
|
||||
@@ -153,19 +156,16 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
platform = data["messages"][0]["chat_info_platform"]
|
||||
user_id = data["messages"][0]["user_id"]
|
||||
cardname = data["messages"][0]["user_cardname"]
|
||||
|
||||
|
||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||
|
||||
if not is_known:
|
||||
logger.info(f"首次认识用户: {user_nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, user_nickname, cardname)
|
||||
|
||||
|
||||
|
||||
logger.info(f"开始更新用户 {user_nickname} 的印象")
|
||||
await relationship_manager.update_person_impression(
|
||||
person_id=person_id,
|
||||
timestamp=last_bot_msg["time"],
|
||||
bot_engaged_messages=relevant_messages
|
||||
person_id=person_id, timestamp=last_bot_msg["time"], bot_engaged_messages=relevant_messages
|
||||
)
|
||||
|
||||
logger.debug("印象更新任务执行完成")
|
||||
|
||||
@@ -33,7 +33,7 @@ JSON_SERIALIZED_FIELDS = ["points", "forgotten_points", "info_list"]
|
||||
person_info_default = {
|
||||
"person_id": None,
|
||||
"person_name": None,
|
||||
"name_reason": None, # Corrected from person_name_reason to match common usage if intended
|
||||
"name_reason": None, # Corrected from person_name_reason to match common usage if intended
|
||||
"platform": "unknown",
|
||||
"user_id": "unknown",
|
||||
"nickname": "Unknown",
|
||||
@@ -42,11 +42,10 @@ person_info_default = {
|
||||
"last_know": None,
|
||||
# "user_cardname": None, # This field is not in Peewee model PersonInfo
|
||||
# "user_avatar": None, # This field is not in Peewee model PersonInfo
|
||||
"impression": None, # Corrected from persion_impression
|
||||
"impression": None, # Corrected from persion_impression
|
||||
"info_list": None,
|
||||
"points": None,
|
||||
"forgotten_points": None,
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +125,7 @@ class PersonInfoManager:
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
@@ -141,7 +140,7 @@ class PersonInfoManager:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = json.dumps([], ensure_ascii=False)
|
||||
# If it's already a string, assume it's valid JSON or a non-JSON string field
|
||||
|
||||
@@ -165,12 +164,12 @@ class PersonInfoManager:
|
||||
return
|
||||
|
||||
print(f"更新字段: {field_name},值: {value}")
|
||||
|
||||
|
||||
processed_value = value
|
||||
if field_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(value, (list, dict)):
|
||||
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
|
||||
elif value is None: # Store None as "[]" for JSON list fields
|
||||
elif value is None: # Store None as "[]" for JSON list fields
|
||||
processed_value = json.dumps([], ensure_ascii=False, indent=None)
|
||||
# If value is already a string, assume it's pre-serialized or a non-JSON string.
|
||||
|
||||
@@ -180,7 +179,7 @@ class PersonInfoManager:
|
||||
setattr(record, f_name, val_to_set)
|
||||
record.save()
|
||||
return True, False # Found and updated, no creation needed
|
||||
return False, True # Not found, needs creation
|
||||
return False, True # Not found, needs creation
|
||||
|
||||
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
|
||||
|
||||
@@ -190,15 +189,14 @@ class PersonInfoManager:
|
||||
# Ensure platform and user_id are present for context if available from 'data'
|
||||
# but primarily, set the field that triggered the update.
|
||||
# The create_person_info will handle defaults and serialization.
|
||||
creation_data[field_name] = value # Pass original value to create_person_info
|
||||
|
||||
creation_data[field_name] = value # Pass original value to create_person_info
|
||||
|
||||
# Ensure platform and user_id are in creation_data if available,
|
||||
# otherwise create_person_info will use defaults.
|
||||
if data and "platform" in data:
|
||||
creation_data["platform"] = data["platform"]
|
||||
creation_data["platform"] = data["platform"]
|
||||
if data and "user_id" in data:
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
await self.create_person_info(person_id, creation_data)
|
||||
|
||||
@@ -233,7 +231,7 @@ class PersonInfoManager:
|
||||
|
||||
if isinstance(parsed_json, list) and parsed_json:
|
||||
parsed_json = parsed_json[0]
|
||||
|
||||
|
||||
if isinstance(parsed_json, dict):
|
||||
return parsed_json
|
||||
|
||||
@@ -249,11 +247,11 @@ class PersonInfoManager:
|
||||
# 处理空昵称的情况
|
||||
if not base_name or base_name.isspace():
|
||||
base_name = "空格"
|
||||
|
||||
|
||||
# 检查基础名称是否已存在
|
||||
if base_name not in self.person_name_list.values():
|
||||
return base_name
|
||||
|
||||
|
||||
# 如果存在,添加数字后缀
|
||||
counter = 1
|
||||
while True:
|
||||
@@ -331,9 +329,11 @@ class PersonInfoManager:
|
||||
if not is_duplicate:
|
||||
await self.update_one_field(person_id, "person_name", generated_nickname)
|
||||
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
|
||||
|
||||
logger.info(f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}"
|
||||
)
|
||||
|
||||
self.person_name_list[person_id] = generated_nickname
|
||||
return result
|
||||
else:
|
||||
@@ -379,7 +379,7 @@ class PersonInfoManager:
|
||||
"""获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
|
||||
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
|
||||
|
||||
def _db_get_value_sync(p_id: str, f_name: str):
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
@@ -391,32 +391,32 @@ class PersonInfoManager:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return [] # Default for JSON fields on error
|
||||
elif val is None: # Field exists in DB but is None
|
||||
return [] # Default for JSON fields
|
||||
return [] # Default for JSON fields on error
|
||||
elif val is None: # Field exists in DB but is None
|
||||
return [] # Default for JSON fields
|
||||
# If val is already a list/dict (e.g. if somehow set without serialization)
|
||||
return val # Should ideally not happen if update_one_field is always used
|
||||
return val # Should ideally not happen if update_one_field is always used
|
||||
return val
|
||||
return None # Record not found
|
||||
return None # Record not found
|
||||
|
||||
try:
|
||||
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
|
||||
if value_from_db is not None:
|
||||
return value_from_db
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
|
||||
return None # Ultimate fallback
|
||||
return None # Ultimate fallback
|
||||
except Exception as e:
|
||||
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
||||
# Fallback to default in case of any error during DB access
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_value_sync(person_id: str, field_name: str):
|
||||
""" 同步获取指定用户指定字段的值 """
|
||||
"""同步获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = []
|
||||
@@ -430,12 +430,12 @@ class PersonInfoManager:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return []
|
||||
return []
|
||||
elif val is None:
|
||||
return []
|
||||
return val
|
||||
return val
|
||||
return val
|
||||
|
||||
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
|
||||
@@ -534,7 +534,7 @@ class PersonInfoManager:
|
||||
"last_know": int(datetime.datetime.now().timestamp()),
|
||||
"impression": None,
|
||||
"points": [],
|
||||
"forgotten_points": []
|
||||
"forgotten_points": [],
|
||||
}
|
||||
model_fields = PersonInfo._meta.fields.keys()
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
@@ -7,7 +7,6 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.individuality.individuality import individuality
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime
|
||||
@@ -90,9 +89,7 @@ class RelationshipManager:
|
||||
return is_known
|
||||
|
||||
@staticmethod
|
||||
async def first_knowing_some_one(
|
||||
platform: str, user_id: str, user_nickname: str, user_cardname: str
|
||||
):
|
||||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||
"""判断是否认识某人"""
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
# 生成唯一的 person_name
|
||||
@@ -112,7 +109,7 @@ class RelationshipManager:
|
||||
)
|
||||
# 尝试生成更好的名字
|
||||
# await person_info_manager.qv_person_name(
|
||||
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
||||
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
||||
# )
|
||||
|
||||
async def build_relationship_info(self, person, is_id: bool = False) -> str:
|
||||
@@ -124,26 +121,24 @@ class RelationshipManager:
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
if not person_name or person_name == "none":
|
||||
return ""
|
||||
impression = await person_info_manager.get_value(person_id, "impression")
|
||||
# impression = await person_info_manager.get_value(person_id, "impression")
|
||||
points = await person_info_manager.get_value(person_id, "points") or []
|
||||
|
||||
|
||||
if isinstance(points, str):
|
||||
try:
|
||||
points = ast.literal_eval(points)
|
||||
except (SyntaxError, ValueError):
|
||||
points = []
|
||||
|
||||
|
||||
random_points = random.sample(points, min(5, len(points))) if points else []
|
||||
|
||||
|
||||
nickname_str = await person_info_manager.get_value(person_id, "nickname")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
relation_prompt = f"'{person_name}' ,ta在{platform}上的昵称是{nickname_str}。"
|
||||
|
||||
|
||||
# if impression:
|
||||
# relation_prompt += f"你对ta的印象是:{impression}。"
|
||||
# relation_prompt += f"你对ta的印象是:{impression}。"
|
||||
|
||||
|
||||
if random_points:
|
||||
for point in random_points:
|
||||
# print(f"point: {point}")
|
||||
@@ -151,13 +146,12 @@ class RelationshipManager:
|
||||
# print(f"point[0]: {point[0]}")
|
||||
point_str = f"时间:{point[2]}。内容:{point[0]}"
|
||||
relation_prompt += f"你记得{person_name}最近的点是:{point_str}。"
|
||||
|
||||
|
||||
|
||||
return relation_prompt
|
||||
|
||||
async def _update_list_field(self, person_id: str, field_name: str, new_items: list) -> None:
|
||||
"""更新列表类型的字段,将新项目添加到现有列表中
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 用户ID
|
||||
field_name: 字段名称
|
||||
@@ -179,21 +173,21 @@ class RelationshipManager:
|
||||
"""
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
nickname = await person_info_manager.get_value(person_id, "nickname")
|
||||
|
||||
|
||||
alias_str = ", ".join(global_config.bot.alias_names)
|
||||
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
# personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
# identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
|
||||
user_messages = bot_engaged_messages
|
||||
|
||||
|
||||
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
# 匿名化消息
|
||||
# 创建用户名称映射
|
||||
name_mapping = {}
|
||||
current_user = "A"
|
||||
user_count = 1
|
||||
|
||||
|
||||
# 遍历消息,构建映射
|
||||
for msg in user_messages:
|
||||
await person_info_manager.get_or_create_person(
|
||||
@@ -206,37 +200,31 @@ class RelationshipManager:
|
||||
replace_platform = msg.get("chat_info_platform")
|
||||
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
|
||||
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
||||
|
||||
|
||||
# 跳过机器人自己
|
||||
if replace_user_id == global_config.bot.qq_account:
|
||||
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
|
||||
continue
|
||||
|
||||
|
||||
# 跳过目标用户
|
||||
if replace_person_name == person_name:
|
||||
name_mapping[replace_person_name] = f"{person_name}"
|
||||
continue
|
||||
|
||||
|
||||
# 其他用户映射
|
||||
if replace_person_name not in name_mapping:
|
||||
if current_user > 'Z':
|
||||
current_user = 'A'
|
||||
if current_user > "Z":
|
||||
current_user = "A"
|
||||
user_count += 1
|
||||
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
|
||||
current_user = chr(ord(current_user) + 1)
|
||||
|
||||
|
||||
|
||||
|
||||
readable_messages = self.build_focus_readable_messages(
|
||||
messages=user_messages,
|
||||
target_person_id=person_id
|
||||
)
|
||||
|
||||
readable_messages = self.build_focus_readable_messages(messages=user_messages, target_person_id=person_id)
|
||||
|
||||
for original_name, mapped_name in name_mapping.items():
|
||||
# print(f"original_name: {original_name}, mapped_name: {mapped_name}")
|
||||
readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}")
|
||||
|
||||
|
||||
prompt = f"""
|
||||
你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。
|
||||
请不要混淆你自己和{global_config.bot.nickname}和{person_name}。
|
||||
@@ -271,22 +259,22 @@ class RelationshipManager:
|
||||
"weight": 0
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
# 调用LLM生成印象
|
||||
points, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
points = points.strip()
|
||||
|
||||
|
||||
# 还原用户名称
|
||||
for original_name, mapped_name in name_mapping.items():
|
||||
points = points.replace(mapped_name, original_name)
|
||||
|
||||
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
# logger.info(f"points: {points}")
|
||||
|
||||
|
||||
if not points:
|
||||
logger.warning(f"未能从LLM获取 {person_name} 的新印象")
|
||||
return
|
||||
|
||||
|
||||
# 解析JSON并转换为元组列表
|
||||
try:
|
||||
points = repair_json(points)
|
||||
@@ -307,7 +295,7 @@ class RelationshipManager:
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.error(f"处理points数据失败: {e}, points: {points}")
|
||||
return
|
||||
|
||||
|
||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||
if isinstance(current_points, str):
|
||||
try:
|
||||
@@ -318,7 +306,9 @@ class RelationshipManager:
|
||||
elif not isinstance(current_points, list):
|
||||
current_points = []
|
||||
current_points.extend(points_list)
|
||||
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
|
||||
# 将新记录添加到现有记录中
|
||||
if isinstance(current_points, list):
|
||||
@@ -326,14 +316,14 @@ class RelationshipManager:
|
||||
for new_point in points_list:
|
||||
similar_points = []
|
||||
similar_indices = []
|
||||
|
||||
|
||||
# 在现有points中查找相似的点
|
||||
for i, existing_point in enumerate(current_points):
|
||||
# 使用组合的相似度检查方法
|
||||
if self.check_similarity(new_point[0], existing_point[0]):
|
||||
similar_points.append(existing_point)
|
||||
similar_indices.append(i)
|
||||
|
||||
|
||||
if similar_points:
|
||||
# 合并相似的点
|
||||
all_points = [new_point] + similar_points
|
||||
@@ -343,14 +333,14 @@ class RelationshipManager:
|
||||
total_weight = sum(p[1] for p in all_points)
|
||||
# 使用最长的描述
|
||||
longest_desc = max(all_points, key=lambda x: len(x[0]))[0]
|
||||
|
||||
|
||||
# 创建合并后的点
|
||||
merged_point = (longest_desc, total_weight, latest_time)
|
||||
|
||||
|
||||
# 从现有points中移除已合并的点
|
||||
for idx in sorted(similar_indices, reverse=True):
|
||||
current_points.pop(idx)
|
||||
|
||||
|
||||
# 添加合并后的点
|
||||
current_points.append(merged_point)
|
||||
else:
|
||||
@@ -359,7 +349,7 @@ class RelationshipManager:
|
||||
else:
|
||||
current_points = points_list
|
||||
|
||||
# 如果points超过10条,按权重随机选择多余的条目移动到forgotten_points
|
||||
# 如果points超过10条,按权重随机选择多余的条目移动到forgotten_points
|
||||
if len(current_points) > 10:
|
||||
# 获取现有forgotten_points
|
||||
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
|
||||
@@ -371,29 +361,29 @@ class RelationshipManager:
|
||||
forgotten_points = []
|
||||
elif not isinstance(forgotten_points, list):
|
||||
forgotten_points = []
|
||||
|
||||
|
||||
# 计算当前时间
|
||||
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
# 计算每个点的最终权重(原始权重 * 时间权重)
|
||||
weighted_points = []
|
||||
for point in current_points:
|
||||
time_weight = self.calculate_time_weight(point[2], current_time)
|
||||
final_weight = point[1] * time_weight
|
||||
weighted_points.append((point, final_weight))
|
||||
|
||||
|
||||
# 计算总权重
|
||||
total_weight = sum(w for _, w in weighted_points)
|
||||
|
||||
|
||||
# 按权重随机选择要保留的点
|
||||
remaining_points = []
|
||||
points_to_move = []
|
||||
|
||||
|
||||
# 对每个点进行随机选择
|
||||
for point, weight in weighted_points:
|
||||
# 计算保留概率(权重越高越可能保留)
|
||||
keep_probability = weight / total_weight
|
||||
|
||||
|
||||
if len(remaining_points) < 10:
|
||||
# 如果还没达到30条,直接保留
|
||||
remaining_points.append(point)
|
||||
@@ -407,28 +397,26 @@ class RelationshipManager:
|
||||
else:
|
||||
# 不保留这个点
|
||||
points_to_move.append(point)
|
||||
|
||||
|
||||
# 更新points和forgotten_points
|
||||
current_points = remaining_points
|
||||
forgotten_points.extend(points_to_move)
|
||||
|
||||
|
||||
# 检查forgotten_points是否达到5条
|
||||
if len(forgotten_points) >= 10:
|
||||
# 构建压缩总结提示词
|
||||
alias_str = ", ".join(global_config.bot.alias_names)
|
||||
|
||||
|
||||
# 按时间排序forgotten_points
|
||||
forgotten_points.sort(key=lambda x: x[2])
|
||||
|
||||
|
||||
# 构建points文本
|
||||
points_text = "\n".join([
|
||||
f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}"
|
||||
for point in forgotten_points
|
||||
])
|
||||
|
||||
|
||||
points_text = "\n".join(
|
||||
[f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points]
|
||||
)
|
||||
|
||||
impression = await person_info_manager.get_value(person_id, "impression") or ""
|
||||
|
||||
|
||||
compress_prompt = f"""
|
||||
你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。
|
||||
请不要混淆你自己和{global_config.bot.nickname}和{person_name}。
|
||||
@@ -449,88 +437,85 @@ class RelationshipManager:
|
||||
"""
|
||||
# 调用LLM生成压缩总结
|
||||
compressed_summary, _ = await self.relationship_llm.generate_response_async(prompt=compress_prompt)
|
||||
|
||||
|
||||
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
compressed_summary = f"截至{current_time},你对{person_name}的了解:{compressed_summary}"
|
||||
|
||||
|
||||
await person_info_manager.update_one_field(person_id, "impression", compressed_summary)
|
||||
|
||||
|
||||
forgotten_points = []
|
||||
|
||||
|
||||
# 这句代码的作用是:将更新后的 forgotten_points(遗忘的记忆点)列表,序列化为 JSON 字符串后,写回到数据库中的 forgotten_points 字段
|
||||
await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None))
|
||||
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
||||
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
|
||||
await person_info_manager.update_one_field(person_id, "last_know", timestamp)
|
||||
|
||||
|
||||
logger.info(f"印象更新完成 for {person_name}")
|
||||
|
||||
|
||||
|
||||
|
||||
def build_focus_readable_messages(self, messages: list, target_person_id: str = None) -> str:
|
||||
"""格式化消息,只保留目标用户和bot消息附近的内容"""
|
||||
# 找到目标用户和bot的消息索引
|
||||
target_indices = []
|
||||
for i, msg in enumerate(messages):
|
||||
user_id = msg.get("user_id")
|
||||
platform = msg.get("chat_info_platform")
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
if person_id == target_person_id:
|
||||
target_indices.append(i)
|
||||
|
||||
if not target_indices:
|
||||
return ""
|
||||
|
||||
# 获取需要保留的消息索引
|
||||
keep_indices = set()
|
||||
for idx in target_indices:
|
||||
# 获取前后5条消息的索引
|
||||
start_idx = max(0, idx - 5)
|
||||
end_idx = min(len(messages), idx + 6)
|
||||
keep_indices.update(range(start_idx, end_idx))
|
||||
|
||||
print(keep_indices)
|
||||
|
||||
# 将索引排序
|
||||
keep_indices = sorted(list(keep_indices))
|
||||
|
||||
# 按顺序构建消息组
|
||||
message_groups = []
|
||||
current_group = []
|
||||
|
||||
for i in range(len(messages)):
|
||||
if i in keep_indices:
|
||||
current_group.append(messages[i])
|
||||
elif current_group:
|
||||
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
|
||||
if current_group:
|
||||
message_groups.append(current_group)
|
||||
current_group = []
|
||||
|
||||
# 添加最后一组
|
||||
if current_group:
|
||||
message_groups.append(current_group)
|
||||
|
||||
# 构建最终的消息文本
|
||||
result = []
|
||||
for i, group in enumerate(message_groups):
|
||||
if i > 0:
|
||||
result.append("...")
|
||||
group_text = build_readable_messages(
|
||||
messages=group,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False
|
||||
)
|
||||
result.append(group_text)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
"""格式化消息,只保留目标用户和bot消息附近的内容"""
|
||||
# 找到目标用户和bot的消息索引
|
||||
target_indices = []
|
||||
for i, msg in enumerate(messages):
|
||||
user_id = msg.get("user_id")
|
||||
platform = msg.get("chat_info_platform")
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
if person_id == target_person_id:
|
||||
target_indices.append(i)
|
||||
|
||||
if not target_indices:
|
||||
return ""
|
||||
|
||||
# 获取需要保留的消息索引
|
||||
keep_indices = set()
|
||||
for idx in target_indices:
|
||||
# 获取前后5条消息的索引
|
||||
start_idx = max(0, idx - 5)
|
||||
end_idx = min(len(messages), idx + 6)
|
||||
keep_indices.update(range(start_idx, end_idx))
|
||||
|
||||
print(keep_indices)
|
||||
|
||||
# 将索引排序
|
||||
keep_indices = sorted(list(keep_indices))
|
||||
|
||||
# 按顺序构建消息组
|
||||
message_groups = []
|
||||
current_group = []
|
||||
|
||||
for i in range(len(messages)):
|
||||
if i in keep_indices:
|
||||
current_group.append(messages[i])
|
||||
elif current_group:
|
||||
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
|
||||
if current_group:
|
||||
message_groups.append(current_group)
|
||||
current_group = []
|
||||
|
||||
# 添加最后一组
|
||||
if current_group:
|
||||
message_groups.append(current_group)
|
||||
|
||||
# 构建最终的消息文本
|
||||
result = []
|
||||
for i, group in enumerate(message_groups):
|
||||
if i > 0:
|
||||
result.append("...")
|
||||
group_text = build_readable_messages(
|
||||
messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False
|
||||
)
|
||||
result.append(group_text)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
|
||||
"""计算基于时间的权重系数"""
|
||||
try:
|
||||
@@ -538,7 +523,7 @@ class RelationshipManager:
|
||||
current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S")
|
||||
time_diff = current_timestamp - point_timestamp
|
||||
hours_diff = time_diff.total_seconds() / 3600
|
||||
|
||||
|
||||
if hours_diff <= 1: # 1小时内
|
||||
return 1.0
|
||||
elif hours_diff <= 24: # 1-24小时
|
||||
@@ -564,18 +549,18 @@ class RelationshipManager:
|
||||
s1 = " ".join(str(x) for x in s1)
|
||||
if isinstance(s2, list):
|
||||
s2 = " ".join(str(x) for x in s2)
|
||||
|
||||
|
||||
# 转换为字符串类型
|
||||
s1 = str(s1)
|
||||
s2 = str(s2)
|
||||
|
||||
|
||||
# 1. 使用 jieba 进行分词
|
||||
s1_words = " ".join(jieba.cut(s1))
|
||||
s2_words = " ".join(jieba.cut(s2))
|
||||
|
||||
|
||||
# 2. 将两句话放入一个列表中
|
||||
corpus = [s1_words, s2_words]
|
||||
|
||||
|
||||
# 3. 创建 TF-IDF 向量化器并进行计算
|
||||
try:
|
||||
vectorizer = TfidfVectorizer()
|
||||
@@ -586,7 +571,7 @@ class RelationshipManager:
|
||||
|
||||
# 4. 计算余弦相似度
|
||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||
|
||||
|
||||
# 返回 s1 和 s2 的相似度
|
||||
return similarity_matrix[0, 1]
|
||||
|
||||
@@ -599,20 +584,20 @@ class RelationshipManager:
|
||||
def check_similarity(self, text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
|
||||
"""
|
||||
使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的。
|
||||
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
tfidf_threshold: TF-IDF相似度阈值
|
||||
seq_threshold: SequenceMatcher相似度阈值
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果任一方法达到阈值则返回True
|
||||
"""
|
||||
# 计算两种相似度
|
||||
tfidf_sim = self.tfidf_similarity(text1, text2)
|
||||
seq_sim = self.sequence_similarity(text1, text2)
|
||||
|
||||
|
||||
# 只要其中一种方法达到阈值就认为是相似的
|
||||
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ DEFAULT_CONFIG = {
|
||||
"default_guidance_scale": 2.5,
|
||||
"default_seed": 42,
|
||||
"cache_enabled": True,
|
||||
"cache_max_size": 10
|
||||
"cache_max_size": 10,
|
||||
}
|
||||
|
||||
|
||||
@@ -49,37 +49,37 @@ def validate_and_fix_config(config_path: str) -> bool:
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = toml.load(f)
|
||||
|
||||
|
||||
# 检查缺失的配置项
|
||||
missing_keys = []
|
||||
fixed = False
|
||||
|
||||
|
||||
for key, default_value in DEFAULT_CONFIG.items():
|
||||
if key not in config:
|
||||
missing_keys.append(key)
|
||||
config[key] = default_value
|
||||
fixed = True
|
||||
logger.info(f"添加缺失的配置项: {key} = {default_value}")
|
||||
|
||||
|
||||
# 验证配置值的类型和范围
|
||||
if isinstance(config.get("default_guidance_scale"), (int, float)):
|
||||
if not 0.1 <= config["default_guidance_scale"] <= 20.0:
|
||||
config["default_guidance_scale"] = 2.5
|
||||
fixed = True
|
||||
logger.info("修复无效的 default_guidance_scale 值")
|
||||
|
||||
|
||||
if isinstance(config.get("default_seed"), (int, float)):
|
||||
config["default_seed"] = int(config["default_seed"])
|
||||
else:
|
||||
config["default_seed"] = 42
|
||||
fixed = True
|
||||
logger.info("修复无效的 default_seed 值")
|
||||
|
||||
|
||||
if config.get("cache_max_size") and not isinstance(config["cache_max_size"], int):
|
||||
config["cache_max_size"] = 10
|
||||
fixed = True
|
||||
logger.info("修复无效的 cache_max_size 值")
|
||||
|
||||
|
||||
# 如果有修复,写回文件
|
||||
if fixed:
|
||||
# 创建备份
|
||||
@@ -87,14 +87,14 @@ def validate_and_fix_config(config_path: str) -> bool:
|
||||
if os.path.exists(config_path):
|
||||
os.rename(config_path, backup_path)
|
||||
logger.info(f"已创建配置备份: {backup_path}")
|
||||
|
||||
|
||||
# 写入修复后的配置
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
toml.dump(config, f)
|
||||
logger.info(f"配置文件已修复: {config_path}")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证配置文件时出错: {e}")
|
||||
return False
|
||||
|
||||
@@ -37,15 +37,15 @@ class PicAction(PluginAction):
|
||||
]
|
||||
enable_plugin = False
|
||||
action_config_file_name = "pic_action_config.toml"
|
||||
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定,精确理解需求
|
||||
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应
|
||||
|
||||
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应
|
||||
|
||||
# 关键词设置(用于Normal模式)
|
||||
activation_keywords = ["画", "绘制", "生成图片", "画图", "draw", "paint", "图片生成"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
|
||||
# LLM判定提示词(用于Focus模式)
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用图片生成动作的条件:
|
||||
@@ -67,31 +67,31 @@ class PicAction(PluginAction):
|
||||
4. 技术讨论中提到绘图概念但无生成需求
|
||||
5. 用户明确表示不需要图片时
|
||||
"""
|
||||
|
||||
|
||||
# Random激活概率(备用)
|
||||
random_activation_probability = 0.15 # 适中概率,图片生成比较有趣
|
||||
|
||||
|
||||
# 简单的请求缓存,避免短时间内重复请求
|
||||
_request_cache = {}
|
||||
_cache_max_size = 10
|
||||
|
||||
|
||||
# 模式启用设置 - 图片生成在所有模式下可用
|
||||
mode_enable = ChatMode.ALL
|
||||
|
||||
|
||||
# 并行执行设置 - 图片生成可以与回复并行执行,不覆盖回复内容
|
||||
parallel_action = False
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_cache_key(cls, description: str, model: str, size: str) -> str:
|
||||
"""生成缓存键"""
|
||||
return f"{description[:100]}|{model}|{size}" # 限制描述长度避免键过长
|
||||
|
||||
|
||||
@classmethod
|
||||
def _cleanup_cache(cls):
|
||||
"""清理缓存,保持大小在限制内"""
|
||||
if len(cls._request_cache) > cls._cache_max_size:
|
||||
# 简单的FIFO策略,移除最旧的条目
|
||||
keys_to_remove = list(cls._request_cache.keys())[:-cls._cache_max_size//2]
|
||||
keys_to_remove = list(cls._request_cache.keys())[: -cls._cache_max_size // 2]
|
||||
for key in keys_to_remove:
|
||||
del cls._request_cache[key]
|
||||
|
||||
@@ -169,7 +169,7 @@ class PicAction(PluginAction):
|
||||
cached_result = self._request_cache[cache_key]
|
||||
logger.info(f"{self.log_prefix} 使用缓存的图片结果")
|
||||
await self.send_message_by_expressor("我之前画过类似的图片,用之前的结果~")
|
||||
|
||||
|
||||
# 直接发送缓存的结果
|
||||
send_success = await self.send_message(type="image", data=cached_result)
|
||||
if send_success:
|
||||
@@ -258,7 +258,7 @@ class PicAction(PluginAction):
|
||||
# 缓存成功的结果
|
||||
self._request_cache[cache_key] = base64_image_string
|
||||
self._cleanup_cache()
|
||||
|
||||
|
||||
await self.send_message_by_expressor("图片表情已发送!")
|
||||
return True, "图片表情已发送"
|
||||
else:
|
||||
@@ -370,7 +370,7 @@ class PicAction(PluginAction):
|
||||
def _validate_image_size(self, image_size: str) -> bool:
|
||||
"""验证图片尺寸格式"""
|
||||
try:
|
||||
width, height = map(int, image_size.split('x'))
|
||||
width, height = map(int, image_size.split("x"))
|
||||
return 100 <= width <= 10000 and 100 <= height <= 10000
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@@ -11,4 +11,4 @@
|
||||
- 用户输入特定格式的命令时触发
|
||||
- 通过命令前缀(如/)快速执行特定功能
|
||||
- 提供快速响应的交互方式
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""示例命令包
|
||||
|
||||
包含示例命令的实现
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -5,27 +5,28 @@ import random
|
||||
|
||||
logger = get_logger("custom_prefix_command")
|
||||
|
||||
|
||||
@register_command
|
||||
class DiceCommand(BaseCommand):
|
||||
"""骰子命令,使用!前缀而不是/前缀"""
|
||||
|
||||
|
||||
command_name = "dice"
|
||||
command_description = "骰子命令,随机生成1-6的数字"
|
||||
command_pattern = r"^[!!](?:dice|骰子)(?:\s+(?P<count>\d+))?$" # 匹配 !dice 或 !骰子,可选参数为骰子数量
|
||||
command_help = "使用方法: !dice [数量] 或 !骰子 [数量] - 掷骰子,默认掷1个"
|
||||
command_examples = ["!dice", "!骰子", "!dice 3", "!骰子 5"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行骰子命令
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否执行成功, 回复消息)
|
||||
"""
|
||||
try:
|
||||
# 获取骰子数量,默认为1
|
||||
count_str = self.matched_groups.get("count")
|
||||
|
||||
|
||||
# 确保count_str不为None
|
||||
if count_str is None:
|
||||
count = 1 # 默认值
|
||||
@@ -38,10 +39,10 @@ class DiceCommand(BaseCommand):
|
||||
return False, "一次最多只能掷10个骰子"
|
||||
except ValueError:
|
||||
return False, "骰子数量必须是整数"
|
||||
|
||||
|
||||
# 生成随机数
|
||||
results = [random.randint(1, 6) for _ in range(count)]
|
||||
|
||||
|
||||
# 构建回复消息
|
||||
if count == 1:
|
||||
message = f"🎲 掷出了 {results[0]} 点"
|
||||
@@ -49,10 +50,10 @@ class DiceCommand(BaseCommand):
|
||||
dice_results = ", ".join(map(str, results))
|
||||
total = sum(results)
|
||||
message = f"🎲 掷出了 {count} 个骰子: [{dice_results}],总点数: {total}"
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行骰子命令: {message}")
|
||||
return True, message
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行骰子命令时出错: {e}")
|
||||
return False, f"执行命令时出错: {str(e)}"
|
||||
return False, f"执行命令时出错: {str(e)}"
|
||||
|
||||
@@ -4,90 +4,86 @@ from typing import Tuple, Optional
|
||||
|
||||
logger = get_logger("help_command")
|
||||
|
||||
|
||||
@register_command
|
||||
class HelpCommand(BaseCommand):
|
||||
"""帮助命令,显示所有可用命令的帮助信息"""
|
||||
|
||||
|
||||
command_name = "help"
|
||||
command_description = "显示所有可用命令的帮助信息"
|
||||
command_pattern = r"^/help(?:\s+(?P<command>\w+))?$" # 匹配 /help 或 /help 命令名
|
||||
command_help = "使用方法: /help [命令名] - 显示所有命令或特定命令的帮助信息"
|
||||
command_examples = ["/help", "/help echo"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行帮助命令
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否执行成功, 回复消息)
|
||||
"""
|
||||
try:
|
||||
# 获取匹配到的命令名(如果有)
|
||||
command_name = self.matched_groups.get("command")
|
||||
|
||||
|
||||
# 如果指定了命令名,显示该命令的详细帮助
|
||||
if command_name:
|
||||
logger.info(f"{self.log_prefix} 查询命令帮助: {command_name}")
|
||||
return self._show_command_help(command_name)
|
||||
|
||||
|
||||
# 否则,显示所有命令的简要帮助
|
||||
logger.info(f"{self.log_prefix} 查询所有命令帮助")
|
||||
return self._show_all_commands()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行帮助命令时出错: {e}")
|
||||
return False, f"执行命令时出错: {str(e)}"
|
||||
|
||||
|
||||
def _show_command_help(self, command_name: str) -> Tuple[bool, str]:
|
||||
"""显示特定命令的详细帮助信息
|
||||
|
||||
|
||||
Args:
|
||||
command_name: 命令名称
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复消息)
|
||||
"""
|
||||
# 查找命令
|
||||
command_cls = _COMMAND_REGISTRY.get(command_name)
|
||||
|
||||
|
||||
if not command_cls:
|
||||
return False, f"未找到命令: {command_name}"
|
||||
|
||||
|
||||
# 获取命令信息
|
||||
description = getattr(command_cls, "command_description", "无描述")
|
||||
help_text = getattr(command_cls, "command_help", "无帮助信息")
|
||||
examples = getattr(command_cls, "command_examples", [])
|
||||
|
||||
|
||||
# 构建帮助信息
|
||||
help_info = [
|
||||
f"【命令】: {command_name}",
|
||||
f"【描述】: {description}",
|
||||
f"【用法】: {help_text}"
|
||||
]
|
||||
|
||||
help_info = [f"【命令】: {command_name}", f"【描述】: {description}", f"【用法】: {help_text}"]
|
||||
|
||||
# 添加示例
|
||||
if examples:
|
||||
help_info.append("【示例】:")
|
||||
for example in examples:
|
||||
help_info.append(f" {example}")
|
||||
|
||||
|
||||
return True, "\n".join(help_info)
|
||||
|
||||
|
||||
def _show_all_commands(self) -> Tuple[bool, str]:
|
||||
"""显示所有可用命令的简要帮助信息
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复消息)
|
||||
"""
|
||||
# 获取所有已启用的命令
|
||||
enabled_commands = {
|
||||
name: cls for name, cls in _COMMAND_REGISTRY.items()
|
||||
if getattr(cls, "enable_command", True)
|
||||
name: cls for name, cls in _COMMAND_REGISTRY.items() if getattr(cls, "enable_command", True)
|
||||
}
|
||||
|
||||
|
||||
if not enabled_commands:
|
||||
return True, "当前没有可用的命令"
|
||||
|
||||
|
||||
# 构建命令列表
|
||||
command_list = ["可用命令列表:"]
|
||||
for name, cls in sorted(enabled_commands.items()):
|
||||
@@ -107,9 +103,9 @@ class HelpCommand(BaseCommand):
|
||||
else:
|
||||
# 默认使用/name作为前缀
|
||||
prefix = f"/{name}"
|
||||
|
||||
|
||||
command_list.append(f"{prefix} - {description}")
|
||||
|
||||
|
||||
command_list.append("\n使用 /help <命令名> 获取特定命令的详细帮助")
|
||||
|
||||
return True, "\n".join(command_list)
|
||||
|
||||
return True, "\n".join(command_list)
|
||||
|
||||
@@ -1,43 +1,43 @@
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.command.command_handler import BaseCommand, register_command
|
||||
from typing import Tuple, Optional
|
||||
import json
|
||||
|
||||
logger = get_logger("message_info_command")
|
||||
|
||||
|
||||
@register_command
|
||||
class MessageInfoCommand(BaseCommand):
|
||||
"""消息信息查看命令,展示发送命令的原始消息和相关信息"""
|
||||
|
||||
|
||||
command_name = "msginfo"
|
||||
command_description = "查看发送命令的原始消息信息"
|
||||
command_pattern = r"^/msginfo(?:\s+(?P<detail>full|simple))?$"
|
||||
command_help = "使用方法: /msginfo [full|simple] - 查看当前消息的详细信息"
|
||||
command_examples = ["/msginfo", "/msginfo full", "/msginfo simple"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行消息信息查看命令"""
|
||||
try:
|
||||
detail_level = self.matched_groups.get("detail", "simple")
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 查看消息信息,详细级别: {detail_level}")
|
||||
|
||||
|
||||
if detail_level == "full":
|
||||
info_text = self._get_full_message_info()
|
||||
else:
|
||||
info_text = self._get_simple_message_info()
|
||||
|
||||
|
||||
return True, info_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取消息信息时出错: {e}")
|
||||
return False, f"获取消息信息失败: {str(e)}"
|
||||
|
||||
|
||||
def _get_simple_message_info(self) -> str:
|
||||
"""获取简化的消息信息"""
|
||||
message = self.message
|
||||
|
||||
|
||||
# 基础信息
|
||||
info_lines = [
|
||||
"📨 消息信息概览",
|
||||
@@ -45,157 +45,181 @@ class MessageInfoCommand(BaseCommand):
|
||||
f"⏰ 时间: {message.message_info.time}",
|
||||
f"🌐 平台: {message.message_info.platform}",
|
||||
]
|
||||
|
||||
|
||||
# 发送者信息
|
||||
user = message.message_info.user_info
|
||||
info_lines.extend([
|
||||
"",
|
||||
"👤 发送者信息:",
|
||||
f" 用户ID: {user.user_id}",
|
||||
f" 昵称: {user.user_nickname}",
|
||||
f" 群名片: {user.user_cardname or '无'}",
|
||||
])
|
||||
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👤 发送者信息:",
|
||||
f" 用户ID: {user.user_id}",
|
||||
f" 昵称: {user.user_nickname}",
|
||||
f" 群名片: {user.user_cardname or '无'}",
|
||||
]
|
||||
)
|
||||
|
||||
# 群聊信息(如果是群聊)
|
||||
if message.message_info.group_info:
|
||||
group = message.message_info.group_info
|
||||
info_lines.extend([
|
||||
"",
|
||||
"👥 群聊信息:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
])
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 群聊信息:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
info_lines.extend([
|
||||
"",
|
||||
"💬 消息类型: 私聊消息",
|
||||
])
|
||||
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"💬 消息类型: 私聊消息",
|
||||
]
|
||||
)
|
||||
|
||||
# 消息内容
|
||||
info_lines.extend([
|
||||
"",
|
||||
"📝 消息内容:",
|
||||
f" 原始文本: {message.processed_plain_text}",
|
||||
f" 是否表情: {'是' if getattr(message, 'is_emoji', False) else '否'}",
|
||||
])
|
||||
|
||||
# 聊天流信息
|
||||
if hasattr(message, 'chat_stream') and message.chat_stream:
|
||||
chat_stream = message.chat_stream
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔄 聊天流信息:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
f" 是否激活: {'是' if chat_stream.is_active else '否'}",
|
||||
])
|
||||
|
||||
"📝 消息内容:",
|
||||
f" 原始文本: {message.processed_plain_text}",
|
||||
f" 是否表情: {'是' if getattr(message, 'is_emoji', False) else '否'}",
|
||||
]
|
||||
)
|
||||
|
||||
# 聊天流信息
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
chat_stream = message.chat_stream
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔄 聊天流信息:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
f" 是否激活: {'是' if chat_stream.is_active else '否'}",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(info_lines)
|
||||
|
||||
|
||||
def _get_full_message_info(self) -> str:
|
||||
"""获取完整的消息信息(包含技术细节)"""
|
||||
message = self.message
|
||||
|
||||
|
||||
info_lines = [
|
||||
"📨 完整消息信息",
|
||||
"=" * 40,
|
||||
]
|
||||
|
||||
|
||||
# 消息基础信息
|
||||
info_lines.extend([
|
||||
"",
|
||||
"🔍 基础消息信息:",
|
||||
f" 消息ID: {message.message_info.message_id}",
|
||||
f" 时间戳: {message.message_info.time}",
|
||||
f" 平台: {message.message_info.platform}",
|
||||
f" 处理后文本: {message.processed_plain_text}",
|
||||
f" 详细文本: {message.detailed_plain_text[:100]}{'...' if len(message.detailed_plain_text) > 100 else ''}",
|
||||
])
|
||||
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔍 基础消息信息:",
|
||||
f" 消息ID: {message.message_info.message_id}",
|
||||
f" 时间戳: {message.message_info.time}",
|
||||
f" 平台: {message.message_info.platform}",
|
||||
f" 处理后文本: {message.processed_plain_text}",
|
||||
f" 详细文本: {message.detailed_plain_text[:100]}{'...' if len(message.detailed_plain_text) > 100 else ''}",
|
||||
]
|
||||
)
|
||||
|
||||
# 用户详细信息
|
||||
user = message.message_info.user_info
|
||||
info_lines.extend([
|
||||
"",
|
||||
"👤 发送者详细信息:",
|
||||
f" 用户ID: {user.user_id}",
|
||||
f" 昵称: {user.user_nickname}",
|
||||
f" 群名片: {user.user_cardname or '无'}",
|
||||
f" 平台: {user.platform}",
|
||||
])
|
||||
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👤 发送者详细信息:",
|
||||
f" 用户ID: {user.user_id}",
|
||||
f" 昵称: {user.user_nickname}",
|
||||
f" 群名片: {user.user_cardname or '无'}",
|
||||
f" 平台: {user.platform}",
|
||||
]
|
||||
)
|
||||
|
||||
# 群聊详细信息
|
||||
if message.message_info.group_info:
|
||||
group = message.message_info.group_info
|
||||
info_lines.extend([
|
||||
"",
|
||||
"👥 群聊详细信息:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
f" 平台: {group.platform}",
|
||||
])
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 群聊详细信息:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
f" 平台: {group.platform}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
info_lines.append("\n💬 消息类型: 私聊消息")
|
||||
|
||||
|
||||
# 消息段信息
|
||||
if message.message_segment:
|
||||
info_lines.extend([
|
||||
"",
|
||||
"📦 消息段信息:",
|
||||
f" 类型: {message.message_segment.type}",
|
||||
f" 数据类型: {type(message.message_segment.data).__name__}",
|
||||
f" 数据预览: {str(message.message_segment.data)[:200]}{'...' if len(str(message.message_segment.data)) > 200 else ''}",
|
||||
])
|
||||
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"📦 消息段信息:",
|
||||
f" 类型: {message.message_segment.type}",
|
||||
f" 数据类型: {type(message.message_segment.data).__name__}",
|
||||
f" 数据预览: {str(message.message_segment.data)[:200]}{'...' if len(str(message.message_segment.data)) > 200 else ''}",
|
||||
]
|
||||
)
|
||||
|
||||
# 聊天流详细信息
|
||||
if hasattr(message, 'chat_stream') and message.chat_stream:
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
chat_stream = message.chat_stream
|
||||
info_lines.extend([
|
||||
"",
|
||||
"🔄 聊天流详细信息:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
f" 平台: {chat_stream.platform}",
|
||||
f" 是否激活: {'是' if chat_stream.is_active else '否'}",
|
||||
f" 用户信息: {chat_stream.user_info.user_nickname} ({chat_stream.user_info.user_id})",
|
||||
f" 群信息: {getattr(chat_stream.group_info, 'group_name', '私聊') if chat_stream.group_info else '私聊'}",
|
||||
])
|
||||
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔄 聊天流详细信息:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
f" 平台: {chat_stream.platform}",
|
||||
f" 是否激活: {'是' if chat_stream.is_active else '否'}",
|
||||
f" 用户信息: {chat_stream.user_info.user_nickname} ({chat_stream.user_info.user_id})",
|
||||
f" 群信息: {getattr(chat_stream.group_info, 'group_name', '私聊') if chat_stream.group_info else '私聊'}",
|
||||
]
|
||||
)
|
||||
|
||||
# 回复信息
|
||||
if hasattr(message, 'reply') and message.reply:
|
||||
info_lines.extend([
|
||||
"",
|
||||
"↩️ 回复信息:",
|
||||
f" 回复消息ID: {message.reply.message_info.message_id}",
|
||||
f" 回复内容: {message.reply.processed_plain_text[:100]}{'...' if len(message.reply.processed_plain_text) > 100 else ''}",
|
||||
])
|
||||
|
||||
if hasattr(message, "reply") and message.reply:
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"↩️ 回复信息:",
|
||||
f" 回复消息ID: {message.reply.message_info.message_id}",
|
||||
f" 回复内容: {message.reply.processed_plain_text[:100]}{'...' if len(message.reply.processed_plain_text) > 100 else ''}",
|
||||
]
|
||||
)
|
||||
|
||||
# 原始消息数据(如果存在)
|
||||
if hasattr(message, 'raw_message') and message.raw_message:
|
||||
info_lines.extend([
|
||||
"",
|
||||
"🗂️ 原始消息数据:",
|
||||
f" 数据类型: {type(message.raw_message).__name__}",
|
||||
f" 数据大小: {len(str(message.raw_message))} 字符",
|
||||
])
|
||||
|
||||
if hasattr(message, "raw_message") and message.raw_message:
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🗂️ 原始消息数据:",
|
||||
f" 数据类型: {type(message.raw_message).__name__}",
|
||||
f" 数据大小: {len(str(message.raw_message))} 字符",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(info_lines)
|
||||
|
||||
|
||||
@register_command
|
||||
class SenderInfoCommand(BaseCommand):
|
||||
"""发送者信息命令,快速查看发送者信息"""
|
||||
|
||||
|
||||
command_name = "whoami"
|
||||
command_description = "查看发送命令的用户信息"
|
||||
command_pattern = r"^/whoami$"
|
||||
command_help = "使用方法: /whoami - 查看你的用户信息"
|
||||
command_examples = ["/whoami"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行发送者信息查看命令"""
|
||||
try:
|
||||
user = self.message.message_info.user_info
|
||||
group = self.message.message_info.group_info
|
||||
|
||||
|
||||
info_lines = [
|
||||
"👤 你的身份信息",
|
||||
f"🆔 用户ID: {user.user_id}",
|
||||
@@ -203,19 +227,21 @@ class SenderInfoCommand(BaseCommand):
|
||||
f"🏷️ 群名片: {user.user_cardname or '无'}",
|
||||
f"🌐 平台: {user.platform}",
|
||||
]
|
||||
|
||||
|
||||
if group:
|
||||
info_lines.extend([
|
||||
"",
|
||||
"👥 当前群聊:",
|
||||
f"🆔 群ID: {group.group_id}",
|
||||
f"📝 群名: {group.group_name or '未知'}",
|
||||
])
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 当前群聊:",
|
||||
f"🆔 群ID: {group.group_id}",
|
||||
f"📝 群名: {group.group_name or '未知'}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
info_lines.append("\n💬 当前在私聊中")
|
||||
|
||||
|
||||
return True, "\n".join(info_lines)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取发送者信息时出错: {e}")
|
||||
return False, f"获取发送者信息失败: {str(e)}"
|
||||
@@ -224,59 +250,65 @@ class SenderInfoCommand(BaseCommand):
|
||||
@register_command
|
||||
class ChatStreamInfoCommand(BaseCommand):
|
||||
"""聊天流信息命令"""
|
||||
|
||||
|
||||
command_name = "streaminfo"
|
||||
command_description = "查看当前聊天流的详细信息"
|
||||
command_pattern = r"^/streaminfo$"
|
||||
command_help = "使用方法: /streaminfo - 查看当前聊天流信息"
|
||||
command_examples = ["/streaminfo"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行聊天流信息查看命令"""
|
||||
try:
|
||||
if not hasattr(self.message, 'chat_stream') or not self.message.chat_stream:
|
||||
if not hasattr(self.message, "chat_stream") or not self.message.chat_stream:
|
||||
return False, "无法获取聊天流信息"
|
||||
|
||||
|
||||
chat_stream = self.message.chat_stream
|
||||
|
||||
|
||||
info_lines = [
|
||||
"🔄 聊天流信息",
|
||||
f"🆔 流ID: {chat_stream.stream_id}",
|
||||
f"🌐 平台: {chat_stream.platform}",
|
||||
f"⚡ 状态: {'激活' if chat_stream.is_active else '非激活'}",
|
||||
]
|
||||
|
||||
|
||||
# 用户信息
|
||||
if chat_stream.user_info:
|
||||
info_lines.extend([
|
||||
"",
|
||||
"👤 关联用户:",
|
||||
f" ID: {chat_stream.user_info.user_id}",
|
||||
f" 昵称: {chat_stream.user_info.user_nickname}",
|
||||
])
|
||||
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👤 关联用户:",
|
||||
f" ID: {chat_stream.user_info.user_id}",
|
||||
f" 昵称: {chat_stream.user_info.user_nickname}",
|
||||
]
|
||||
)
|
||||
|
||||
# 群信息
|
||||
if chat_stream.group_info:
|
||||
info_lines.extend([
|
||||
"",
|
||||
"👥 关联群聊:",
|
||||
f" 群ID: {chat_stream.group_info.group_id}",
|
||||
f" 群名: {chat_stream.group_info.group_name or '未知'}",
|
||||
])
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 关联群聊:",
|
||||
f" 群ID: {chat_stream.group_info.group_id}",
|
||||
f" 群名: {chat_stream.group_info.group_name or '未知'}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
info_lines.append("\n💬 类型: 私聊流")
|
||||
|
||||
|
||||
# 最近消息统计
|
||||
if hasattr(chat_stream, 'last_messages'):
|
||||
if hasattr(chat_stream, "last_messages"):
|
||||
msg_count = len(chat_stream.last_messages)
|
||||
info_lines.extend([
|
||||
"",
|
||||
f"📈 消息统计: 记录了 {msg_count} 条最近消息",
|
||||
])
|
||||
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
f"📈 消息统计: 记录了 {msg_count} 条最近消息",
|
||||
]
|
||||
)
|
||||
|
||||
return True, "\n".join(info_lines)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}")
|
||||
return False, f"获取聊天流信息失败: {str(e)}"
|
||||
return False, f"获取聊天流信息失败: {str(e)}"
|
||||
|
||||
@@ -5,43 +5,41 @@ from typing import Tuple, Optional
|
||||
|
||||
logger = get_logger("send_msg_command")
|
||||
|
||||
|
||||
@register_command
|
||||
class SendMessageCommand(BaseCommand, MessageAPI):
|
||||
"""发送消息命令,可以向指定群聊或私聊发送消息"""
|
||||
|
||||
|
||||
command_name = "send"
|
||||
command_description = "向指定群聊或私聊发送消息"
|
||||
command_pattern = r"^/send\s+(?P<target_type>group|user)\s+(?P<target_id>\d+)\s+(?P<content>.+)$"
|
||||
command_help = "使用方法: /send <group|user> <ID> <消息内容> - 发送消息到指定群聊或用户"
|
||||
command_examples = [
|
||||
"/send group 123456789 大家好!",
|
||||
"/send user 987654321 私聊消息"
|
||||
]
|
||||
command_examples = ["/send group 123456789 大家好!", "/send user 987654321 私聊消息"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
# 初始化MessageAPI需要的服务(虽然这里不会用到,但保持一致性)
|
||||
self._services = {}
|
||||
self.log_prefix = f"[Command:{self.command_name}]"
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行发送消息命令
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否执行成功, 回复消息)
|
||||
"""
|
||||
try:
|
||||
# 获取匹配到的参数
|
||||
target_type = self.matched_groups.get("target_type") # group 或 user
|
||||
target_id = self.matched_groups.get("target_id") # 群ID或用户ID
|
||||
content = self.matched_groups.get("content") # 消息内容
|
||||
|
||||
target_id = self.matched_groups.get("target_id") # 群ID或用户ID
|
||||
content = self.matched_groups.get("content") # 消息内容
|
||||
|
||||
if not all([target_type, target_id, content]):
|
||||
return False, "命令参数不完整,请检查格式"
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行发送消息命令: {target_type}:{target_id} -> {content[:50]}...")
|
||||
|
||||
|
||||
# 根据目标类型调用不同的发送方法
|
||||
if target_type == "group":
|
||||
success = await self._send_to_group(target_id, content)
|
||||
@@ -51,24 +49,24 @@ class SendMessageCommand(BaseCommand, MessageAPI):
|
||||
target_desc = f"用户 {target_id}"
|
||||
else:
|
||||
return False, f"不支持的目标类型: {target_type},只支持 group 或 user"
|
||||
|
||||
|
||||
# 返回执行结果
|
||||
if success:
|
||||
return True, f"✅ 消息已成功发送到 {target_desc}"
|
||||
else:
|
||||
return False, f"❌ 消息发送失败,可能是目标 {target_desc} 不存在或没有权限"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行发送消息命令时出错: {e}")
|
||||
return False, f"命令执行出错: {str(e)}"
|
||||
|
||||
|
||||
async def _send_to_group(self, group_id: str, content: str) -> bool:
|
||||
"""发送消息到群聊
|
||||
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
content: 消息内容
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
@@ -76,27 +74,27 @@ class SendMessageCommand(BaseCommand, MessageAPI):
|
||||
success = await self.send_text_to_group(
|
||||
text=content,
|
||||
group_id=group_id,
|
||||
platform="qq" # 默认使用QQ平台
|
||||
platform="qq", # 默认使用QQ平台
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送消息到群聊 {group_id}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 发送消息到群聊 {group_id} 失败")
|
||||
|
||||
|
||||
return success
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送群聊消息时出错: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _send_to_user(self, user_id: str, content: str) -> bool:
|
||||
"""发送消息到私聊
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
content: 消息内容
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
@@ -104,16 +102,16 @@ class SendMessageCommand(BaseCommand, MessageAPI):
|
||||
success = await self.send_text_to_user(
|
||||
text=content,
|
||||
user_id=user_id,
|
||||
platform="qq" # 默认使用QQ平台
|
||||
platform="qq", # 默认使用QQ平台
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送消息到用户 {user_id}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 发送消息到用户 {user_id} 失败")
|
||||
|
||||
|
||||
return success
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送私聊消息时出错: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -5,10 +5,11 @@ from typing import Tuple, Optional
|
||||
|
||||
logger = get_logger("send_msg_enhanced")
|
||||
|
||||
|
||||
@register_command
|
||||
class SendMessageEnhancedCommand(BaseCommand, MessageAPI):
|
||||
"""增强版发送消息命令,支持多种消息类型和平台"""
|
||||
|
||||
|
||||
command_name = "sendfull"
|
||||
command_description = "增强版消息发送命令,支持多种类型和平台"
|
||||
command_pattern = r"^/sendfull\s+(?P<msg_type>text|image|emoji)\s+(?P<target_type>group|user)\s+(?P<target_id>\d+)(?:\s+(?P<platform>\w+))?\s+(?P<content>.+)$"
|
||||
@@ -17,108 +18,93 @@ class SendMessageEnhancedCommand(BaseCommand, MessageAPI):
|
||||
"/sendfull text group 123456789 qq 大家好!这是文本消息",
|
||||
"/sendfull image user 987654321 https://example.com/image.jpg",
|
||||
"/sendfull emoji group 123456789 😄",
|
||||
"/sendfull text user 987654321 qq 私聊消息"
|
||||
"/sendfull text user 987654321 qq 私聊消息",
|
||||
]
|
||||
enable_command = True
|
||||
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self._services = {}
|
||||
self.log_prefix = f"[Command:{self.command_name}]"
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行增强版发送消息命令"""
|
||||
try:
|
||||
# 获取匹配参数
|
||||
msg_type = self.matched_groups.get("msg_type") # 消息类型: text/image/emoji
|
||||
target_type = self.matched_groups.get("target_type") # 目标类型: group/user
|
||||
target_id = self.matched_groups.get("target_id") # 目标ID
|
||||
platform = self.matched_groups.get("platform") or "qq" # 平台,默认qq
|
||||
content = self.matched_groups.get("content") # 内容
|
||||
|
||||
msg_type = self.matched_groups.get("msg_type") # 消息类型: text/image/emoji
|
||||
target_type = self.matched_groups.get("target_type") # 目标类型: group/user
|
||||
target_id = self.matched_groups.get("target_id") # 目标ID
|
||||
platform = self.matched_groups.get("platform") or "qq" # 平台,默认qq
|
||||
content = self.matched_groups.get("content") # 内容
|
||||
|
||||
if not all([msg_type, target_type, target_id, content]):
|
||||
return False, "命令参数不完整,请检查格式"
|
||||
|
||||
|
||||
# 验证消息类型
|
||||
valid_types = ["text", "image", "emoji"]
|
||||
if msg_type not in valid_types:
|
||||
return False, f"不支持的消息类型: {msg_type},支持的类型: {', '.join(valid_types)}"
|
||||
|
||||
|
||||
# 验证目标类型
|
||||
if target_type not in ["group", "user"]:
|
||||
return False, "目标类型只能是 group 或 user"
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 执行发送命令: {msg_type} -> {target_type}:{target_id} (平台:{platform})")
|
||||
|
||||
|
||||
# 根据消息类型和目标类型发送消息
|
||||
is_group = (target_type == "group")
|
||||
is_group = target_type == "group"
|
||||
success = await self.send_message_to_target(
|
||||
message_type=msg_type,
|
||||
content=content,
|
||||
platform=platform,
|
||||
target_id=target_id,
|
||||
is_group=is_group
|
||||
message_type=msg_type, content=content, platform=platform, target_id=target_id, is_group=is_group
|
||||
)
|
||||
|
||||
|
||||
# 构建结果消息
|
||||
target_desc = f"{'群聊' if is_group else '用户'} {target_id} (平台: {platform})"
|
||||
msg_type_desc = {
|
||||
"text": "文本",
|
||||
"image": "图片",
|
||||
"emoji": "表情"
|
||||
}.get(msg_type, msg_type)
|
||||
|
||||
msg_type_desc = {"text": "文本", "image": "图片", "emoji": "表情"}.get(msg_type, msg_type)
|
||||
|
||||
if success:
|
||||
return True, f"✅ {msg_type_desc}消息已成功发送到 {target_desc}"
|
||||
else:
|
||||
return False, f"❌ {msg_type_desc}消息发送失败,可能是目标 {target_desc} 不存在或没有权限"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行增强发送命令时出错: {e}")
|
||||
return False, f"命令执行出错: {str(e)}"
|
||||
|
||||
|
||||
@register_command
|
||||
@register_command
|
||||
class SendQuickCommand(BaseCommand, MessageAPI):
|
||||
"""快速发送文本消息命令"""
|
||||
|
||||
|
||||
command_name = "msg"
|
||||
command_description = "快速发送文本消息到群聊"
|
||||
command_pattern = r"^/msg\s+(?P<group_id>\d+)\s+(?P<content>.+)$"
|
||||
command_help = "使用方法: /msg <群ID> <消息内容> - 快速发送文本到指定群聊"
|
||||
command_examples = [
|
||||
"/msg 123456789 大家好!",
|
||||
"/msg 987654321 这是一条快速消息"
|
||||
]
|
||||
command_examples = ["/msg 123456789 大家好!", "/msg 987654321 这是一条快速消息"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self._services = {}
|
||||
self.log_prefix = f"[Command:{self.command_name}]"
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行快速发送消息命令"""
|
||||
try:
|
||||
group_id = self.matched_groups.get("group_id")
|
||||
content = self.matched_groups.get("content")
|
||||
|
||||
|
||||
if not all([group_id, content]):
|
||||
return False, "命令参数不完整"
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 快速发送到群 {group_id}: {content[:50]}...")
|
||||
|
||||
success = await self.send_text_to_group(
|
||||
text=content,
|
||||
group_id=group_id,
|
||||
platform="qq"
|
||||
)
|
||||
|
||||
|
||||
success = await self.send_text_to_group(text=content, group_id=group_id, platform="qq")
|
||||
|
||||
if success:
|
||||
return True, f"✅ 消息已发送到群 {group_id}"
|
||||
else:
|
||||
return False, f"❌ 发送到群 {group_id} 失败"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 快速发送命令出错: {e}")
|
||||
return False, f"发送失败: {str(e)}"
|
||||
@@ -127,44 +113,37 @@ class SendQuickCommand(BaseCommand, MessageAPI):
|
||||
@register_command
|
||||
class SendPrivateCommand(BaseCommand, MessageAPI):
|
||||
"""发送私聊消息命令"""
|
||||
|
||||
|
||||
command_name = "pm"
|
||||
command_description = "发送私聊消息到指定用户"
|
||||
command_pattern = r"^/pm\s+(?P<user_id>\d+)\s+(?P<content>.+)$"
|
||||
command_help = "使用方法: /pm <用户ID> <消息内容> - 发送私聊消息"
|
||||
command_examples = [
|
||||
"/pm 123456789 你好!",
|
||||
"/pm 987654321 这是私聊消息"
|
||||
]
|
||||
command_examples = ["/pm 123456789 你好!", "/pm 987654321 这是私聊消息"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self._services = {}
|
||||
self.log_prefix = f"[Command:{self.command_name}]"
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行私聊发送命令"""
|
||||
try:
|
||||
user_id = self.matched_groups.get("user_id")
|
||||
content = self.matched_groups.get("content")
|
||||
|
||||
|
||||
if not all([user_id, content]):
|
||||
return False, "命令参数不完整"
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 发送私聊到用户 {user_id}: {content[:50]}...")
|
||||
|
||||
success = await self.send_text_to_user(
|
||||
text=content,
|
||||
user_id=user_id,
|
||||
platform="qq"
|
||||
)
|
||||
|
||||
|
||||
success = await self.send_text_to_user(text=content, user_id=user_id, platform="qq")
|
||||
|
||||
if success:
|
||||
return True, f"✅ 私聊消息已发送到用户 {user_id}"
|
||||
else:
|
||||
return False, f"❌ 发送私聊到用户 {user_id} 失败"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 私聊发送命令出错: {e}")
|
||||
return False, f"私聊发送失败: {str(e)}"
|
||||
return False, f"私聊发送失败: {str(e)}"
|
||||
|
||||
@@ -6,173 +6,153 @@ import time
|
||||
|
||||
logger = get_logger("send_msg_with_context")
|
||||
|
||||
|
||||
@register_command
|
||||
class ContextAwareSendCommand(BaseCommand, MessageAPI):
|
||||
"""上下文感知的发送消息命令,展示如何利用原始消息信息"""
|
||||
|
||||
|
||||
command_name = "csend"
|
||||
command_description = "带上下文感知的发送消息命令"
|
||||
command_pattern = r"^/csend\s+(?P<target_type>group|user|here|reply)\s+(?P<target_id_or_content>.*?)(?:\s+(?P<content>.*))?$"
|
||||
command_pattern = (
|
||||
r"^/csend\s+(?P<target_type>group|user|here|reply)\s+(?P<target_id_or_content>.*?)(?:\s+(?P<content>.*))?$"
|
||||
)
|
||||
command_help = "使用方法: /csend <target_type> <参数> [内容]"
|
||||
command_examples = [
|
||||
"/csend group 123456789 大家好!",
|
||||
"/csend user 987654321 私聊消息",
|
||||
"/csend user 987654321 私聊消息",
|
||||
"/csend here 在当前聊天发送",
|
||||
"/csend reply 回复当前群/私聊"
|
||||
"/csend reply 回复当前群/私聊",
|
||||
]
|
||||
enable_command = True
|
||||
|
||||
|
||||
# 管理员用户ID列表(示例)
|
||||
ADMIN_USERS = ["123456789", "987654321"] # 可以从配置文件读取
|
||||
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self._services = {}
|
||||
self.log_prefix = f"[Command:{self.command_name}]"
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行上下文感知的发送命令"""
|
||||
try:
|
||||
# 获取命令发送者信息
|
||||
sender = self.message.message_info.user_info
|
||||
current_group = self.message.message_info.group_info
|
||||
|
||||
|
||||
# 权限检查
|
||||
if not self._check_permission(sender.user_id):
|
||||
return False, f"❌ 权限不足,只有管理员可以使用此命令\n你的ID: {sender.user_id}"
|
||||
|
||||
|
||||
# 解析命令参数
|
||||
target_type = self.matched_groups.get("target_type")
|
||||
target_id_or_content = self.matched_groups.get("target_id_or_content", "")
|
||||
content = self.matched_groups.get("content", "")
|
||||
|
||||
|
||||
# 根据目标类型处理不同情况
|
||||
if target_type == "here":
|
||||
# 发送到当前聊天
|
||||
return await self._send_to_current_chat(target_id_or_content, sender, current_group)
|
||||
|
||||
|
||||
elif target_type == "reply":
|
||||
# 回复到当前聊天,带发送者信息
|
||||
return await self._send_reply_with_context(target_id_or_content, sender, current_group)
|
||||
|
||||
|
||||
elif target_type in ["group", "user"]:
|
||||
# 发送到指定目标
|
||||
if not content:
|
||||
return False, "指定群聊或用户时需要提供消息内容"
|
||||
return await self._send_to_target(target_type, target_id_or_content, content, sender)
|
||||
|
||||
|
||||
else:
|
||||
return False, f"不支持的目标类型: {target_type}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行上下文感知发送命令时出错: {e}")
|
||||
return False, f"命令执行出错: {str(e)}"
|
||||
|
||||
|
||||
def _check_permission(self, user_id: str) -> bool:
|
||||
"""检查用户权限"""
|
||||
return user_id in self.ADMIN_USERS
|
||||
|
||||
|
||||
async def _send_to_current_chat(self, content: str, sender, current_group) -> Tuple[bool, str]:
|
||||
"""发送到当前聊天"""
|
||||
if not content:
|
||||
return False, "消息内容不能为空"
|
||||
|
||||
|
||||
# 构建带发送者信息的消息
|
||||
timestamp = time.strftime("%H:%M:%S", time.localtime())
|
||||
if current_group:
|
||||
# 群聊
|
||||
formatted_content = f"[管理员转发 {timestamp}] {sender.user_nickname}({sender.user_id}): {content}"
|
||||
success = await self.send_text_to_group(
|
||||
text=formatted_content,
|
||||
group_id=current_group.group_id,
|
||||
platform="qq"
|
||||
text=formatted_content, group_id=current_group.group_id, platform="qq"
|
||||
)
|
||||
target_desc = f"当前群聊 {current_group.group_name}({current_group.group_id})"
|
||||
else:
|
||||
# 私聊
|
||||
formatted_content = f"[管理员消息 {timestamp}]: {content}"
|
||||
success = await self.send_text_to_user(
|
||||
text=formatted_content,
|
||||
user_id=sender.user_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_user(text=formatted_content, user_id=sender.user_id, platform="qq")
|
||||
target_desc = "当前私聊"
|
||||
|
||||
|
||||
if success:
|
||||
return True, f"✅ 消息已发送到{target_desc}"
|
||||
else:
|
||||
return False, f"❌ 发送到{target_desc}失败"
|
||||
|
||||
|
||||
async def _send_reply_with_context(self, content: str, sender, current_group) -> Tuple[bool, str]:
|
||||
"""发送回复,带完整上下文信息"""
|
||||
if not content:
|
||||
return False, "回复内容不能为空"
|
||||
|
||||
|
||||
# 获取当前时间和环境信息
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
|
||||
# 构建上下文信息
|
||||
context_info = [
|
||||
f"📢 管理员回复 [{timestamp}]",
|
||||
f"👤 发送者: {sender.user_nickname}({sender.user_id})",
|
||||
]
|
||||
|
||||
|
||||
if current_group:
|
||||
context_info.append(f"👥 当前群聊: {current_group.group_name}({current_group.group_id})")
|
||||
target_desc = f"群聊 {current_group.group_name}"
|
||||
else:
|
||||
context_info.append("💬 当前环境: 私聊")
|
||||
target_desc = "私聊"
|
||||
|
||||
context_info.extend([
|
||||
f"📝 回复内容: {content}",
|
||||
"─" * 30
|
||||
])
|
||||
|
||||
|
||||
context_info.extend([f"📝 回复内容: {content}", "─" * 30])
|
||||
|
||||
formatted_content = "\n".join(context_info)
|
||||
|
||||
|
||||
# 发送消息
|
||||
if current_group:
|
||||
success = await self.send_text_to_group(
|
||||
text=formatted_content,
|
||||
group_id=current_group.group_id,
|
||||
platform="qq"
|
||||
text=formatted_content, group_id=current_group.group_id, platform="qq"
|
||||
)
|
||||
else:
|
||||
success = await self.send_text_to_user(
|
||||
text=formatted_content,
|
||||
user_id=sender.user_id,
|
||||
platform="qq"
|
||||
)
|
||||
|
||||
success = await self.send_text_to_user(text=formatted_content, user_id=sender.user_id, platform="qq")
|
||||
|
||||
if success:
|
||||
return True, f"✅ 带上下文的回复已发送到{target_desc}"
|
||||
else:
|
||||
return False, f"❌ 发送上下文回复到{target_desc}失败"
|
||||
|
||||
|
||||
async def _send_to_target(self, target_type: str, target_id: str, content: str, sender) -> Tuple[bool, str]:
|
||||
"""发送到指定目标,带发送者追踪信息"""
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
|
||||
# 构建带追踪信息的消息
|
||||
tracking_info = f"[管理转发 {timestamp}] 来自 {sender.user_nickname}({sender.user_id})"
|
||||
formatted_content = f"{tracking_info}\n{content}"
|
||||
|
||||
|
||||
if target_type == "group":
|
||||
success = await self.send_text_to_group(
|
||||
text=formatted_content,
|
||||
group_id=target_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_group(text=formatted_content, group_id=target_id, platform="qq")
|
||||
target_desc = f"群聊 {target_id}"
|
||||
else: # user
|
||||
success = await self.send_text_to_user(
|
||||
text=formatted_content,
|
||||
user_id=target_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_user(text=formatted_content, user_id=target_id, platform="qq")
|
||||
target_desc = f"用户 {target_id}"
|
||||
|
||||
|
||||
if success:
|
||||
return True, f"✅ 带追踪信息的消息已发送到{target_desc}"
|
||||
else:
|
||||
@@ -182,21 +162,21 @@ class ContextAwareSendCommand(BaseCommand, MessageAPI):
|
||||
@register_command
|
||||
class MessageContextCommand(BaseCommand):
|
||||
"""消息上下文命令,展示如何获取和利用上下文信息"""
|
||||
|
||||
|
||||
command_name = "context"
|
||||
command_description = "显示当前消息的完整上下文信息"
|
||||
command_pattern = r"^/context$"
|
||||
command_help = "使用方法: /context - 显示当前环境的上下文信息"
|
||||
command_examples = ["/context"]
|
||||
enable_command = True
|
||||
|
||||
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""显示上下文信息"""
|
||||
try:
|
||||
message = self.message
|
||||
user = message.message_info.user_info
|
||||
group = message.message_info.group_info
|
||||
|
||||
|
||||
# 构建上下文信息
|
||||
context_lines = [
|
||||
"🌐 当前上下文信息",
|
||||
@@ -212,42 +192,50 @@ class MessageContextCommand(BaseCommand):
|
||||
f" 群名片: {user.user_cardname or '无'}",
|
||||
f" 平台: {user.platform}",
|
||||
]
|
||||
|
||||
|
||||
if group:
|
||||
context_lines.extend([
|
||||
"",
|
||||
"👥 群聊环境:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
f" 平台: {group.platform}",
|
||||
])
|
||||
context_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 群聊环境:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
f" 平台: {group.platform}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
context_lines.extend([
|
||||
"",
|
||||
"💬 私聊环境",
|
||||
])
|
||||
|
||||
context_lines.extend(
|
||||
[
|
||||
"",
|
||||
"💬 私聊环境",
|
||||
]
|
||||
)
|
||||
|
||||
# 添加聊天流信息
|
||||
if hasattr(message, 'chat_stream') and message.chat_stream:
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
chat_stream = message.chat_stream
|
||||
context_lines.extend([
|
||||
"",
|
||||
"🔄 聊天流:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
f" 激活状态: {'激活' if chat_stream.is_active else '非激活'}",
|
||||
])
|
||||
|
||||
context_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔄 聊天流:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
f" 激活状态: {'激活' if chat_stream.is_active else '非激活'}",
|
||||
]
|
||||
)
|
||||
|
||||
# 添加消息内容信息
|
||||
context_lines.extend([
|
||||
"",
|
||||
"📝 消息内容:",
|
||||
f" 原始内容: {message.processed_plain_text}",
|
||||
f" 消息长度: {len(message.processed_plain_text)} 字符",
|
||||
f" 消息ID: {message.message_info.message_id}",
|
||||
])
|
||||
|
||||
context_lines.extend(
|
||||
[
|
||||
"",
|
||||
"📝 消息内容:",
|
||||
f" 原始内容: {message.processed_plain_text}",
|
||||
f" 消息长度: {len(message.processed_plain_text)} 字符",
|
||||
f" 消息ID: {message.message_info.message_id}",
|
||||
]
|
||||
)
|
||||
|
||||
return True, "\n".join(context_lines)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取上下文信息时出错: {e}")
|
||||
return False, f"获取上下文失败: {str(e)}"
|
||||
return False, f"获取上下文失败: {str(e)}"
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
"""测试插件动作模块"""
|
||||
|
||||
from . import mute_action # noqa
|
||||
|
||||
@@ -22,21 +22,20 @@ class MuteAction(PluginAction):
|
||||
"当有人刷屏时使用",
|
||||
"当有人发了擦边,或者色情内容时使用",
|
||||
"当有人要求禁言自己时使用",
|
||||
"如果某人已经被禁言了,就不要再次禁言了,除非你想追加时间!!"
|
||||
"如果某人已经被禁言了,就不要再次禁言了,除非你想追加时间!!",
|
||||
]
|
||||
enable_plugin = False # 启用插件
|
||||
associated_types = ["command", "text"]
|
||||
action_config_file_name = "mute_action_config.toml"
|
||||
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定,确保谨慎
|
||||
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应
|
||||
|
||||
|
||||
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应
|
||||
|
||||
# 关键词设置(用于Normal模式)
|
||||
activation_keywords = ["禁言", "mute", "ban", "silence"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
|
||||
# LLM判定提示词(用于Focus模式)
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用禁言动作的严格条件:
|
||||
@@ -59,13 +58,13 @@ class MuteAction(PluginAction):
|
||||
注意:禁言是严厉措施,只在明确违规或用户主动要求时使用。
|
||||
宁可保守也不要误判,保护用户的发言权利。
|
||||
"""
|
||||
|
||||
|
||||
# Random激活概率(备用)
|
||||
random_activation_probability = 0.05 # 设置很低的概率作为兜底
|
||||
|
||||
# 模式启用设置 - 禁言功能在所有模式下都可用
|
||||
mode_enable = ChatMode.ALL
|
||||
|
||||
|
||||
# 并行执行设置 - 禁言动作可以与回复并行执行,不覆盖回复内容
|
||||
parallel_action = False
|
||||
|
||||
@@ -73,15 +72,15 @@ class MuteAction(PluginAction):
|
||||
super().__init__(*args, **kwargs)
|
||||
# 生成配置文件(如果不存在)
|
||||
self._generate_config_if_needed()
|
||||
|
||||
|
||||
def _generate_config_if_needed(self):
|
||||
"""生成配置文件(如果不存在)"""
|
||||
import os
|
||||
|
||||
|
||||
# 获取动作文件所在目录
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(current_dir, "mute_action_config.toml")
|
||||
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
config_content = """\
|
||||
# 禁言动作配置文件
|
||||
@@ -130,11 +129,10 @@ log_mute_history = true
|
||||
|
||||
def _get_template_message(self, target: str, duration_str: str, reason: str) -> str:
|
||||
"""获取模板化的禁言消息"""
|
||||
templates = self.config.get("templates", [
|
||||
"好的,禁言 {target} {duration},理由:{reason}"
|
||||
])
|
||||
|
||||
templates = self.config.get("templates", ["好的,禁言 {target} {duration},理由:{reason}"])
|
||||
|
||||
import random
|
||||
|
||||
template = random.choice(templates)
|
||||
return template.format(target=target, duration=duration_str, reason=reason)
|
||||
|
||||
@@ -162,7 +160,7 @@ log_mute_history = true
|
||||
|
||||
# 获取时长限制配置
|
||||
min_duration, max_duration, default_duration = self._get_duration_limits()
|
||||
|
||||
|
||||
# 验证时长格式并转换
|
||||
try:
|
||||
duration_int = int(duration)
|
||||
@@ -170,9 +168,11 @@ log_mute_history = true
|
||||
error_msg = "禁言时长必须大于0"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
error_templates = self.config.get("error_messages", ["禁言时长必须是正数哦~"])
|
||||
await self.send_message_by_expressor(error_templates[2] if len(error_templates) > 2 else "禁言时长必须是正数哦~")
|
||||
await self.send_message_by_expressor(
|
||||
error_templates[2] if len(error_templates) > 2 else "禁言时长必须是正数哦~"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
# 限制禁言时长范围
|
||||
if duration_int < min_duration:
|
||||
duration_int = min_duration
|
||||
@@ -180,12 +180,14 @@ log_mute_history = true
|
||||
elif duration_int > max_duration:
|
||||
duration_int = max_duration
|
||||
logger.info(f"{self.log_prefix} 禁言时长过长,调整为{max_duration}秒")
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
|
||||
except (ValueError, TypeError):
|
||||
error_msg = f"禁言时长格式无效: {duration}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
error_templates = self.config.get("error_messages", ["禁言时长必须是数字哦~"])
|
||||
await self.send_message_by_expressor(error_templates[3] if len(error_templates) > 3 else "禁言时长必须是数字哦~")
|
||||
await self.send_message_by_expressor(
|
||||
error_templates[3] if len(error_templates) > 3 else "禁言时长必须是数字哦~"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
# 获取用户ID
|
||||
@@ -206,7 +208,7 @@ log_mute_history = true
|
||||
# 发送表达情绪的消息
|
||||
enable_formatting = self.config.get("enable_duration_formatting", True)
|
||||
time_str = self._format_duration(duration_int) if enable_formatting else f"{duration_int}秒"
|
||||
|
||||
|
||||
# 使用模板化消息
|
||||
message = self._get_template_message(target, time_str, reason)
|
||||
await self.send_message_by_expressor(message)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, Tuple
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("plugin_loader")
|
||||
@@ -9,52 +9,53 @@ logger = get_logger("plugin_loader")
|
||||
|
||||
class PluginLoader:
|
||||
"""统一的插件加载器,负责加载插件的所有组件(actions、commands等)"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.loaded_actions = 0
|
||||
self.loaded_commands = 0
|
||||
self.plugin_stats: Dict[str, Dict[str, int]] = {} # 统计每个插件加载的组件数量
|
||||
self.plugin_sources: Dict[str, str] = {} # 记录每个插件来自哪个路径
|
||||
|
||||
|
||||
def load_all_plugins(self) -> Tuple[int, int]:
|
||||
"""加载所有插件的所有组件
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (加载的动作数量, 加载的命令数量)
|
||||
"""
|
||||
# 定义插件搜索路径(优先级从高到低)
|
||||
plugin_paths = [
|
||||
("plugins", "plugins"), # 项目根目录的plugins文件夹
|
||||
("src.plugins", os.path.join("src", "plugins")) # src下的plugins文件夹
|
||||
("src.plugins", os.path.join("src", "plugins")), # src下的plugins文件夹
|
||||
]
|
||||
|
||||
|
||||
total_plugins_found = 0
|
||||
|
||||
|
||||
for plugin_import_path, plugin_dir_path in plugin_paths:
|
||||
try:
|
||||
plugins_loaded = self._load_plugins_from_path(plugin_import_path, plugin_dir_path)
|
||||
total_plugins_found += plugins_loaded
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从路径 {plugin_dir_path} 加载插件失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
if total_plugins_found == 0:
|
||||
logger.info("未找到任何插件目录或插件")
|
||||
|
||||
|
||||
# 输出加载统计
|
||||
self._log_loading_stats()
|
||||
|
||||
|
||||
return self.loaded_actions, self.loaded_commands
|
||||
|
||||
|
||||
def _load_plugins_from_path(self, plugin_import_path: str, plugin_dir_path: str) -> int:
|
||||
"""从指定路径加载插件
|
||||
|
||||
|
||||
Args:
|
||||
plugin_import_path: 插件的导入路径 (如 "plugins" 或 "src.plugins")
|
||||
plugin_dir_path: 插件目录的文件系统路径
|
||||
|
||||
|
||||
Returns:
|
||||
int: 找到的插件包数量
|
||||
"""
|
||||
@@ -62,9 +63,9 @@ class PluginLoader:
|
||||
if not os.path.exists(plugin_dir_path):
|
||||
logger.debug(f"插件目录 {plugin_dir_path} 不存在,跳过")
|
||||
return 0
|
||||
|
||||
|
||||
logger.info(f"正在从 {plugin_dir_path} 加载插件...")
|
||||
|
||||
|
||||
# 导入插件包
|
||||
try:
|
||||
plugins_package = importlib.import_module(plugin_import_path)
|
||||
@@ -72,122 +73,120 @@ class PluginLoader:
|
||||
except ImportError as e:
|
||||
logger.warning(f"导入插件包 {plugin_import_path} 失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 遍历插件包中的所有子包
|
||||
plugins_found = 0
|
||||
for _, plugin_name, is_pkg in pkgutil.iter_modules(
|
||||
plugins_package.__path__, plugins_package.__name__ + "."
|
||||
):
|
||||
for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + "."):
|
||||
if not is_pkg:
|
||||
continue
|
||||
|
||||
|
||||
logger.debug(f"检测到插件: {plugin_name}")
|
||||
# 记录插件来源
|
||||
self.plugin_sources[plugin_name] = plugin_dir_path
|
||||
self._load_single_plugin(plugin_name)
|
||||
plugins_found += 1
|
||||
|
||||
|
||||
if plugins_found > 0:
|
||||
logger.info(f"从 {plugin_dir_path} 找到 {plugins_found} 个插件包")
|
||||
else:
|
||||
logger.debug(f"从 {plugin_dir_path} 未找到任何插件包")
|
||||
|
||||
|
||||
return plugins_found
|
||||
|
||||
|
||||
def _load_single_plugin(self, plugin_name: str) -> None:
|
||||
"""加载单个插件的所有组件
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
"""
|
||||
plugin_stats = {"actions": 0, "commands": 0}
|
||||
|
||||
|
||||
# 加载动作组件
|
||||
actions_count = self._load_plugin_actions(plugin_name)
|
||||
plugin_stats["actions"] = actions_count
|
||||
self.loaded_actions += actions_count
|
||||
|
||||
# 加载命令组件
|
||||
|
||||
# 加载命令组件
|
||||
commands_count = self._load_plugin_commands(plugin_name)
|
||||
plugin_stats["commands"] = commands_count
|
||||
self.loaded_commands += commands_count
|
||||
|
||||
|
||||
# 记录插件统计信息
|
||||
if actions_count > 0 or commands_count > 0:
|
||||
self.plugin_stats[plugin_name] = plugin_stats
|
||||
logger.info(f"插件 {plugin_name} 加载完成: {actions_count} 个动作, {commands_count} 个命令")
|
||||
|
||||
|
||||
def _load_plugin_actions(self, plugin_name: str) -> int:
|
||||
"""加载插件的动作组件
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
int: 加载的动作数量
|
||||
"""
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
# 优先检查插件是否有actions子包
|
||||
plugin_actions_path = f"{plugin_name}.actions"
|
||||
plugin_actions_dir = plugin_name.replace(".", os.path.sep) + os.path.sep + "actions"
|
||||
|
||||
|
||||
actions_loaded_from_subdir = False
|
||||
|
||||
|
||||
# 首先尝试从actions子目录加载
|
||||
if os.path.exists(plugin_actions_dir):
|
||||
loaded_count += self._load_from_actions_subdir(plugin_name, plugin_actions_path, plugin_actions_dir)
|
||||
if loaded_count > 0:
|
||||
actions_loaded_from_subdir = True
|
||||
|
||||
|
||||
# 如果actions子目录不存在或加载失败,尝试从插件根目录加载
|
||||
if not actions_loaded_from_subdir:
|
||||
loaded_count += self._load_actions_from_root_dir(plugin_name)
|
||||
|
||||
|
||||
return loaded_count
|
||||
|
||||
|
||||
def _load_plugin_commands(self, plugin_name: str) -> int:
|
||||
"""加载插件的命令组件
|
||||
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
int: 加载的命令数量
|
||||
"""
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
# 优先检查插件是否有commands子包
|
||||
plugin_commands_path = f"{plugin_name}.commands"
|
||||
plugin_commands_dir = plugin_name.replace(".", os.path.sep) + os.path.sep + "commands"
|
||||
|
||||
|
||||
commands_loaded_from_subdir = False
|
||||
|
||||
|
||||
# 首先尝试从commands子目录加载
|
||||
if os.path.exists(plugin_commands_dir):
|
||||
loaded_count += self._load_from_commands_subdir(plugin_name, plugin_commands_path, plugin_commands_dir)
|
||||
if loaded_count > 0:
|
||||
commands_loaded_from_subdir = True
|
||||
|
||||
|
||||
# 如果commands子目录不存在或加载失败,尝试从插件根目录加载
|
||||
if not commands_loaded_from_subdir:
|
||||
loaded_count += self._load_commands_from_root_dir(plugin_name)
|
||||
|
||||
|
||||
return loaded_count
|
||||
|
||||
|
||||
def _load_from_actions_subdir(self, plugin_name: str, plugin_actions_path: str, plugin_actions_dir: str) -> int:
|
||||
"""从actions子目录加载动作"""
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
try:
|
||||
# 尝试导入插件的actions包
|
||||
actions_module = importlib.import_module(plugin_actions_path)
|
||||
logger.debug(f"成功加载插件动作模块: {plugin_actions_path}")
|
||||
|
||||
|
||||
# 遍历actions目录中的所有Python文件
|
||||
actions_dir = os.path.dirname(actions_module.__file__)
|
||||
for file in os.listdir(actions_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
action_module_name = f"{plugin_actions_path}.{file[:-3]}"
|
||||
try:
|
||||
importlib.import_module(action_module_name)
|
||||
@@ -195,25 +194,25 @@ class PluginLoader:
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"加载动作失败: {action_module_name}, 错误: {e}")
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
logger.debug(f"插件 {plugin_name} 的actions子包导入失败: {e}")
|
||||
|
||||
|
||||
return loaded_count
|
||||
|
||||
|
||||
def _load_from_commands_subdir(self, plugin_name: str, plugin_commands_path: str, plugin_commands_dir: str) -> int:
|
||||
"""从commands子目录加载命令"""
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
try:
|
||||
# 尝试导入插件的commands包
|
||||
commands_module = importlib.import_module(plugin_commands_path)
|
||||
logger.debug(f"成功加载插件命令模块: {plugin_commands_path}")
|
||||
|
||||
|
||||
# 遍历commands目录中的所有Python文件
|
||||
commands_dir = os.path.dirname(commands_module.__file__)
|
||||
for file in os.listdir(commands_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
command_module_name = f"{plugin_commands_path}.{file[:-3]}"
|
||||
try:
|
||||
importlib.import_module(command_module_name)
|
||||
@@ -221,29 +220,29 @@ class PluginLoader:
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"加载命令失败: {command_module_name}, 错误: {e}")
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
logger.debug(f"插件 {plugin_name} 的commands子包导入失败: {e}")
|
||||
|
||||
|
||||
return loaded_count
|
||||
|
||||
|
||||
def _load_actions_from_root_dir(self, plugin_name: str) -> int:
|
||||
"""从插件根目录加载动作文件"""
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
try:
|
||||
# 导入插件包本身
|
||||
plugin_module = importlib.import_module(plugin_name)
|
||||
logger.debug(f"尝试从插件根目录加载动作: {plugin_name}")
|
||||
|
||||
|
||||
# 遍历插件根目录中的所有Python文件
|
||||
plugin_dir = os.path.dirname(plugin_module.__file__)
|
||||
for file in os.listdir(plugin_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
# 跳过非动作文件(根据命名约定)
|
||||
if not (file.endswith('_action.py') or file.endswith('_actions.py') or 'action' in file):
|
||||
if not (file.endswith("_action.py") or file.endswith("_actions.py") or "action" in file):
|
||||
continue
|
||||
|
||||
|
||||
action_module_name = f"{plugin_name}.{file[:-3]}"
|
||||
try:
|
||||
importlib.import_module(action_module_name)
|
||||
@@ -251,29 +250,29 @@ class PluginLoader:
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"加载动作失败: {action_module_name}, 错误: {e}")
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
logger.debug(f"插件 {plugin_name} 导入失败: {e}")
|
||||
|
||||
|
||||
return loaded_count
|
||||
|
||||
|
||||
def _load_commands_from_root_dir(self, plugin_name: str) -> int:
|
||||
"""从插件根目录加载命令文件"""
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
try:
|
||||
# 导入插件包本身
|
||||
plugin_module = importlib.import_module(plugin_name)
|
||||
logger.debug(f"尝试从插件根目录加载命令: {plugin_name}")
|
||||
|
||||
|
||||
# 遍历插件根目录中的所有Python文件
|
||||
plugin_dir = os.path.dirname(plugin_module.__file__)
|
||||
for file in os.listdir(plugin_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
# 跳过非命令文件(根据命名约定)
|
||||
if not (file.endswith('_command.py') or file.endswith('_commands.py') or 'command' in file):
|
||||
if not (file.endswith("_command.py") or file.endswith("_commands.py") or "command" in file):
|
||||
continue
|
||||
|
||||
|
||||
command_module_name = f"{plugin_name}.{file[:-3]}"
|
||||
try:
|
||||
importlib.import_module(command_module_name)
|
||||
@@ -281,23 +280,25 @@ class PluginLoader:
|
||||
loaded_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"加载命令失败: {command_module_name}, 错误: {e}")
|
||||
|
||||
|
||||
except ImportError as e:
|
||||
logger.debug(f"插件 {plugin_name} 导入失败: {e}")
|
||||
|
||||
|
||||
return loaded_count
|
||||
|
||||
|
||||
def _log_loading_stats(self) -> None:
|
||||
"""输出加载统计信息"""
|
||||
logger.success(f"插件加载完成: 总计 {self.loaded_actions} 个动作, {self.loaded_commands} 个命令")
|
||||
|
||||
|
||||
if self.plugin_stats:
|
||||
logger.info("插件加载详情:")
|
||||
for plugin_name, stats in self.plugin_stats.items():
|
||||
plugin_display_name = plugin_name.split('.')[-1] # 只显示插件名称,不显示完整路径
|
||||
plugin_display_name = plugin_name.split(".")[-1] # 只显示插件名称,不显示完整路径
|
||||
source_path = self.plugin_sources.get(plugin_name, "未知路径")
|
||||
logger.info(f" {plugin_display_name} (来源: {source_path}): {stats['actions']} 动作, {stats['commands']} 命令")
|
||||
logger.info(
|
||||
f" {plugin_display_name} (来源: {source_path}): {stats['actions']} 动作, {stats['commands']} 命令"
|
||||
)
|
||||
|
||||
|
||||
# 创建全局插件加载器实例
|
||||
plugin_loader = PluginLoader()
|
||||
plugin_loader = PluginLoader()
|
||||
|
||||
@@ -23,14 +23,14 @@ class TTSAction(PluginAction):
|
||||
]
|
||||
enable_plugin = True # 启用插件
|
||||
associated_types = ["tts_text"]
|
||||
|
||||
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
normal_activation_type = ActionActivationType.KEYWORD
|
||||
|
||||
|
||||
# 关键词配置 - Normal模式下使用关键词触发
|
||||
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
|
||||
# 并行执行设置 - TTS可以与回复并行执行,不覆盖回复内容
|
||||
parallel_action = False
|
||||
|
||||
|
||||
@@ -22,11 +22,11 @@ class VTBAction(PluginAction):
|
||||
]
|
||||
enable_plugin = True # 启用插件
|
||||
associated_types = ["vtb_text"]
|
||||
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定,精确识别情感表达需求
|
||||
normal_activation_type = ActionActivationType.RANDOM # Normal模式使用随机激活,增加趣味性
|
||||
|
||||
normal_activation_type = ActionActivationType.RANDOM # Normal模式使用随机激活,增加趣味性
|
||||
|
||||
# LLM判定提示词(用于Focus模式)
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用VTB虚拟主播动作的条件:
|
||||
@@ -41,7 +41,7 @@ class VTBAction(PluginAction):
|
||||
3. 不涉及情感的日常对话
|
||||
4. 已经有足够的情感表达
|
||||
"""
|
||||
|
||||
|
||||
# Random激活概率(用于Normal模式)
|
||||
random_activation_probability = 0.08 # 较低概率,避免过度使用
|
||||
|
||||
|
||||
Reference in New Issue
Block a user