2
.gitignore
vendored
2
.gitignore
vendored
@@ -336,3 +336,5 @@ MaiBot.code-workspace
|
||||
/tests
|
||||
/tests
|
||||
.kilocode/rules/MoFox.md
|
||||
src/chat/planner_actions/planner (2).py
|
||||
rust_video/Cargo.lock
|
||||
|
||||
1
bot.py
1
bot.py
@@ -82,7 +82,6 @@ def easter_egg():
|
||||
async def graceful_shutdown():
|
||||
try:
|
||||
logger.info("正在优雅关闭麦麦...")
|
||||
|
||||
# 停止所有异步任务
|
||||
await async_task_manager.stop_and_wait_all_tasks()
|
||||
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from .config import global_config
|
||||
|
||||
__all__ = [
|
||||
"global_config",
|
||||
]
|
||||
@@ -1,151 +0,0 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
import tomlkit
|
||||
import shutil
|
||||
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from rich.traceback import install
|
||||
|
||||
from .config_base import ConfigBase
|
||||
from .official_configs import (
|
||||
DebugConfig,
|
||||
MaiBotServerConfig,
|
||||
NapcatServerConfig,
|
||||
NicknameConfig,
|
||||
SlicingConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
TEMPLATE_DIR = "plugins/napcat_adapter_plugin/template"
|
||||
CONFIG_DIR = "plugins/napcat_adapter_plugin/config"
|
||||
OLD_CONFIG_DIR = "plugins/napcat_adapter_plugin/config/old"
|
||||
|
||||
|
||||
def ensure_config_directories():
|
||||
"""确保配置目录存在"""
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
os.makedirs(OLD_CONFIG_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def update_config():
|
||||
"""更新配置文件,统一使用 config/old 目录进行备份"""
|
||||
# 确保目录存在
|
||||
ensure_config_directories()
|
||||
|
||||
# 定义文件路径
|
||||
template_path = f"{TEMPLATE_DIR}/template_config.toml"
|
||||
config_path = f"{CONFIG_DIR}/config.toml"
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
logger.info("主配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(template_path, config_path)
|
||||
logger.info(f"已创建新配置文件: {config_path}")
|
||||
logger.info("程序将退出,请检查配置文件后重启")
|
||||
|
||||
# 读取配置文件和模板文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建备份文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = os.path.join(OLD_CONFIG_DIR, f"config.toml.bak.{timestamp}")
|
||||
|
||||
# 备份旧配置文件
|
||||
shutil.copy2(config_path, backup_path)
|
||||
logger.info(f"已备份旧配置文件到: {backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path, config_path)
|
||||
logger.info(f"已创建新配置文件: {config_path}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
|
||||
"""将source字典的值更新到target字典中(如果target中存在相同的键)"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||
update_dict(target[key], value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(ConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
nickname: NicknameConfig
|
||||
napcat_server: NapcatServerConfig
|
||||
maibot_server: MaiBotServerConfig
|
||||
voice: VoiceConfig
|
||||
slicing: SlicingConfig
|
||||
debug: DebugConfig
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""
|
||||
加载配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象
|
||||
try:
|
||||
return Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
# 更新配置
|
||||
update_config()
|
||||
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path=f"{CONFIG_DIR}/config.toml")
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
@@ -1,136 +0,0 @@
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
TOML_DICT_TYPE = {
|
||||
int,
|
||||
float,
|
||||
str,
|
||||
bool,
|
||||
list,
|
||||
dict,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigBase:
|
||||
"""配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
|
||||
"""从字典加载配置字段"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: Dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
field_type = f.type
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type)
|
||||
except TypeError as e:
|
||||
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
|
||||
|
||||
return cls(**init_args)
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
"""
|
||||
转换字段值为指定类型
|
||||
|
||||
1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
|
||||
2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
|
||||
3. 对于基础类型(int, str, float, bool),直接转换
|
||||
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
|
||||
"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
|
||||
return field_type.from_dict(value)
|
||||
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_args_type = get_args(field_type)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
# 检查提供的value是否为list
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
return [cls._convert_field(item, field_args_type[0]) for item in value]
|
||||
if field_origin_type is set:
|
||||
return {cls._convert_field(item, field_args_type[0]) for item in value}
|
||||
if field_origin_type is tuple:
|
||||
# 检查提供的value长度是否与类型参数一致
|
||||
if len(value) != len(field_args_type):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
|
||||
|
||||
if field_origin_type is dict:
|
||||
# 检查提供的value是否为dict
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 检查字典的键值类型
|
||||
if len(field_args_type) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_args_type
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理Optional类型
|
||||
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
|
||||
if value is None:
|
||||
return None
|
||||
# 如果有数据,检查实际类型
|
||||
if type(value) not in field_args_type:
|
||||
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
|
||||
return cls._convert_field(value, field_args_type[0])
|
||||
|
||||
# 处理int, str, float, bool等基础类型
|
||||
if field_origin_type is None:
|
||||
if isinstance(value, field_type):
|
||||
return field_type(value)
|
||||
else:
|
||||
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal:
|
||||
# 获取Literal的允许值
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
# 处理其他类型
|
||||
if field_type is Any:
|
||||
return value
|
||||
|
||||
# 其他类型直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
|
||||
|
||||
def __str__(self):
|
||||
"""返回配置类的字符串表示"""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|
||||
@@ -1,145 +0,0 @@
|
||||
"""
|
||||
配置文件工具模块
|
||||
提供统一的配置文件生成和管理功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
def ensure_config_directories():
|
||||
"""确保配置目录存在"""
|
||||
os.makedirs("config", exist_ok=True)
|
||||
os.makedirs("config/old", exist_ok=True)
|
||||
|
||||
|
||||
def create_config_from_template(
|
||||
config_path: str, template_path: str, config_name: str = "配置文件", should_exit: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
从模板创建配置文件的统一函数
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
template_path: 模板文件路径
|
||||
config_name: 配置文件名称(用于日志显示)
|
||||
should_exit: 创建后是否退出程序
|
||||
|
||||
Returns:
|
||||
bool: 是否成功创建配置文件
|
||||
"""
|
||||
try:
|
||||
# 确保配置目录存在
|
||||
ensure_config_directories()
|
||||
|
||||
config_path_obj = Path(config_path)
|
||||
template_path_obj = Path(template_path)
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if config_path_obj.exists():
|
||||
return False # 配置文件已存在,无需创建
|
||||
|
||||
logger.info(f"{config_name}不存在,从模板创建新配置")
|
||||
|
||||
# 检查模板文件是否存在
|
||||
if not template_path_obj.exists():
|
||||
logger.error(f"模板文件不存在: {template_path}")
|
||||
if should_exit:
|
||||
logger.critical("无法创建配置文件,程序退出")
|
||||
quit(1)
|
||||
return False
|
||||
|
||||
# 确保配置文件目录存在
|
||||
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path_obj, config_path_obj)
|
||||
logger.info(f"已创建新{config_name}: {config_path}")
|
||||
|
||||
if should_exit:
|
||||
logger.info("程序将退出,请检查配置文件后重启")
|
||||
quit(0)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建{config_name}失败: {e}")
|
||||
if should_exit:
|
||||
logger.critical("无法创建配置文件,程序退出")
|
||||
quit(1)
|
||||
return False
|
||||
|
||||
|
||||
def create_default_config_dict(default_values: dict, config_path: str, config_name: str = "配置文件") -> bool:
|
||||
"""
|
||||
创建默认配置文件(使用字典数据)
|
||||
|
||||
Args:
|
||||
default_values: 默认配置值字典
|
||||
config_path: 配置文件路径
|
||||
config_name: 配置文件名称(用于日志显示)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功创建配置文件
|
||||
"""
|
||||
try:
|
||||
import tomlkit
|
||||
|
||||
config_path_obj = Path(config_path)
|
||||
|
||||
# 确保配置文件目录存在
|
||||
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入默认配置
|
||||
with open(config_path_obj, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(default_values, f)
|
||||
|
||||
logger.info(f"已创建默认{config_name}: {config_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建默认{config_name}失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def backup_config_file(config_path: str, backup_dir: str = "config/old") -> Optional[str]:
|
||||
"""
|
||||
备份配置文件
|
||||
|
||||
Args:
|
||||
config_path: 要备份的配置文件路径
|
||||
backup_dir: 备份目录
|
||||
|
||||
Returns:
|
||||
Optional[str]: 备份文件路径,失败时返回None
|
||||
"""
|
||||
try:
|
||||
config_path_obj = Path(config_path)
|
||||
if not config_path_obj.exists():
|
||||
return None
|
||||
|
||||
# 确保备份目录存在
|
||||
backup_dir_obj = Path(backup_dir)
|
||||
backup_dir_obj.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建备份文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_filename = f"{config_path_obj.stem}.toml.bak.{timestamp}"
|
||||
backup_path = backup_dir_obj / backup_filename
|
||||
|
||||
# 备份文件
|
||||
shutil.copy2(config_path_obj, backup_path)
|
||||
logger.info(f"已备份配置文件到: {backup_path}")
|
||||
|
||||
return str(backup_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份配置文件失败: {e}")
|
||||
return None
|
||||
@@ -1,359 +0,0 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from pathlib import Path
|
||||
import tomlkit
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .config_base import ConfigBase
|
||||
from .config_utils import create_config_from_template, create_default_config_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeaturesConfig(ConfigBase):
|
||||
"""功能配置类"""
|
||||
|
||||
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""群聊列表类型 白名单/黑名单"""
|
||||
|
||||
group_list: list[int] = field(default_factory=list)
|
||||
"""群聊列表"""
|
||||
|
||||
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
|
||||
"""私聊列表类型 白名单/黑名单"""
|
||||
|
||||
private_list: list[int] = field(default_factory=list)
|
||||
"""私聊列表"""
|
||||
|
||||
ban_user_id: list[int] = field(default_factory=list)
|
||||
"""被封禁的用户ID列表,封禁后将无法与其进行交互"""
|
||||
|
||||
ban_qq_bot: bool = False
|
||||
"""是否屏蔽QQ官方机器人,若为True,则所有QQ官方机器人将无法与MaiMCore进行交互"""
|
||||
|
||||
enable_poke: bool = True
|
||||
"""是否启用戳一戳功能"""
|
||||
|
||||
ignore_non_self_poke: bool = False
|
||||
"""是否无视不是针对自己的戳一戳"""
|
||||
|
||||
poke_debounce_seconds: int = 3
|
||||
"""戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"""
|
||||
|
||||
enable_reply_at: bool = True
|
||||
"""是否启用引用回复时艾特用户的功能"""
|
||||
|
||||
reply_at_rate: float = 0.5
|
||||
"""引用回复时艾特用户的几率 (0.0 ~ 1.0)"""
|
||||
|
||||
enable_video_analysis: bool = True
|
||||
"""是否启用视频识别功能"""
|
||||
|
||||
max_video_size_mb: int = 100
|
||||
"""视频文件最大大小限制(MB)"""
|
||||
|
||||
download_timeout: int = 60
|
||||
"""视频下载超时时间(秒)"""
|
||||
|
||||
supported_formats: list[str] = field(default_factory=lambda: ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
|
||||
"""支持的视频格式"""
|
||||
|
||||
# 消息缓冲配置
|
||||
enable_message_buffer: bool = True
|
||||
"""是否启用消息缓冲合并功能"""
|
||||
|
||||
message_buffer_enable_group: bool = True
|
||||
"""是否启用群消息缓冲合并"""
|
||||
|
||||
message_buffer_enable_private: bool = True
|
||||
"""是否启用私聊消息缓冲合并"""
|
||||
|
||||
message_buffer_interval: float = 3.0
|
||||
"""消息合并间隔时间(秒),在此时间内的连续消息将被合并"""
|
||||
|
||||
message_buffer_initial_delay: float = 0.5
|
||||
"""消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"""
|
||||
|
||||
message_buffer_max_components: int = 50
|
||||
"""单个会话最大缓冲消息组件数量,超过此数量将强制合并"""
|
||||
|
||||
message_buffer_block_prefixes: list[str] = field(default_factory=lambda: ["/", "!", "!", ".", "。", "#", "%"])
|
||||
"""消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"""
|
||||
|
||||
|
||||
class FeaturesManager:
|
||||
"""功能管理器,支持热重载"""
|
||||
|
||||
def __init__(self, config_path: str = "plugins/napcat_adapter_plugin/config/features.toml"):
|
||||
self.config_path = Path(config_path)
|
||||
self.config: Optional[FeaturesConfig] = None
|
||||
self._file_watcher_task: Optional[asyncio.Task] = None
|
||||
self._last_modified: Optional[float] = None
|
||||
self._callbacks: list = []
|
||||
|
||||
def add_reload_callback(self, callback):
|
||||
"""添加配置重载回调函数"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def remove_reload_callback(self, callback):
|
||||
"""移除配置重载回调函数"""
|
||||
if callback in self._callbacks:
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
async def _notify_callbacks(self):
|
||||
"""通知所有回调函数配置已重载"""
|
||||
for callback in self._callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(self.config)
|
||||
else:
|
||||
callback(self.config)
|
||||
except Exception as e:
|
||||
logger.error(f"配置重载回调执行失败: {e}")
|
||||
|
||||
def load_config(self) -> FeaturesConfig:
|
||||
"""加载功能配置文件"""
|
||||
try:
|
||||
# 检查配置文件是否存在,如果不存在则创建并退出程序
|
||||
if not self.config_path.exists():
|
||||
logger.info(f"功能配置文件不存在: {self.config_path}")
|
||||
self._create_default_config()
|
||||
# 配置文件创建后程序应该退出,让用户检查配置
|
||||
logger.info("程序将退出,请检查功能配置文件后重启")
|
||||
quit(0)
|
||||
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
self.config = FeaturesConfig.from_dict(config_data)
|
||||
self._last_modified = self.config_path.stat().st_mtime
|
||||
logger.info(f"功能配置加载成功: {self.config_path}")
|
||||
return self.config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置加载失败: {e}")
|
||||
logger.critical("无法加载功能配置文件,程序退出")
|
||||
quit(1)
|
||||
|
||||
def _create_default_config(self):
|
||||
"""创建默认功能配置文件"""
|
||||
template_path = "template/features_template.toml"
|
||||
|
||||
# 尝试从模板创建配置文件
|
||||
if create_config_from_template(
|
||||
str(self.config_path),
|
||||
template_path,
|
||||
"功能配置文件",
|
||||
should_exit=False, # 不在这里退出,由调用方决定
|
||||
):
|
||||
return
|
||||
|
||||
# 如果模板文件不存在,创建基本配置
|
||||
logger.info("模板文件不存在,创建基本功能配置")
|
||||
default_config = {
|
||||
"group_list_type": "whitelist",
|
||||
"group_list": [],
|
||||
"private_list_type": "whitelist",
|
||||
"private_list": [],
|
||||
"ban_user_id": [],
|
||||
"ban_qq_bot": False,
|
||||
"enable_poke": True,
|
||||
"ignore_non_self_poke": False,
|
||||
"poke_debounce_seconds": 3,
|
||||
"enable_reply_at": True,
|
||||
"reply_at_rate": 0.5,
|
||||
"enable_video_analysis": True,
|
||||
"max_video_size_mb": 100,
|
||||
"download_timeout": 60,
|
||||
"supported_formats": ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"],
|
||||
# 消息缓冲配置
|
||||
"enable_message_buffer": True,
|
||||
"message_buffer_enable_group": True,
|
||||
"message_buffer_enable_private": True,
|
||||
"message_buffer_interval": 3.0,
|
||||
"message_buffer_initial_delay": 0.5,
|
||||
"message_buffer_max_components": 50,
|
||||
"message_buffer_block_prefixes": ["/", "!", "!", ".", "。", "#", "%"],
|
||||
}
|
||||
|
||||
if not create_default_config_dict(default_config, str(self.config_path), "功能配置文件"):
|
||||
logger.critical("无法创建功能配置文件")
|
||||
quit(1)
|
||||
|
||||
async def reload_config(self) -> bool:
|
||||
"""重新加载配置文件"""
|
||||
try:
|
||||
if not self.config_path.exists():
|
||||
logger.warning(f"功能配置文件不存在,无法重载: {self.config_path}")
|
||||
return False
|
||||
|
||||
current_modified = self.config_path.stat().st_mtime
|
||||
if self._last_modified and current_modified <= self._last_modified:
|
||||
return False # 文件未修改
|
||||
|
||||
old_config = self.config
|
||||
new_config = self.load_config()
|
||||
|
||||
# 检查配置是否真的发生了变化
|
||||
if old_config and self._configs_equal(old_config, new_config):
|
||||
return False
|
||||
|
||||
logger.info("功能配置已重载")
|
||||
await self._notify_callbacks()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置重载失败: {e}")
|
||||
return False
|
||||
|
||||
def _configs_equal(self, config1: FeaturesConfig, config2: FeaturesConfig) -> bool:
|
||||
"""比较两个配置是否相等"""
|
||||
return (
|
||||
config1.group_list_type == config2.group_list_type
|
||||
and set(config1.group_list) == set(config2.group_list)
|
||||
and config1.private_list_type == config2.private_list_type
|
||||
and set(config1.private_list) == set(config2.private_list)
|
||||
and set(config1.ban_user_id) == set(config2.ban_user_id)
|
||||
and config1.ban_qq_bot == config2.ban_qq_bot
|
||||
and config1.enable_poke == config2.enable_poke
|
||||
and config1.ignore_non_self_poke == config2.ignore_non_self_poke
|
||||
and config1.poke_debounce_seconds == config2.poke_debounce_seconds
|
||||
and config1.enable_reply_at == config2.enable_reply_at
|
||||
and config1.reply_at_rate == config2.reply_at_rate
|
||||
and config1.enable_video_analysis == config2.enable_video_analysis
|
||||
and config1.max_video_size_mb == config2.max_video_size_mb
|
||||
and config1.download_timeout == config2.download_timeout
|
||||
and set(config1.supported_formats) == set(config2.supported_formats)
|
||||
and
|
||||
# 消息缓冲配置比较
|
||||
config1.enable_message_buffer == config2.enable_message_buffer
|
||||
and config1.message_buffer_enable_group == config2.message_buffer_enable_group
|
||||
and config1.message_buffer_enable_private == config2.message_buffer_enable_private
|
||||
and config1.message_buffer_interval == config2.message_buffer_interval
|
||||
and config1.message_buffer_initial_delay == config2.message_buffer_initial_delay
|
||||
and config1.message_buffer_max_components == config2.message_buffer_max_components
|
||||
and set(config1.message_buffer_block_prefixes) == set(config2.message_buffer_block_prefixes)
|
||||
)
|
||||
|
||||
async def start_file_watcher(self, check_interval: float = 1.0):
|
||||
"""启动文件监控,定期检查配置文件变化"""
|
||||
if self._file_watcher_task and not self._file_watcher_task.done():
|
||||
logger.warning("文件监控已在运行")
|
||||
return
|
||||
|
||||
self._file_watcher_task = asyncio.create_task(self._file_watcher_loop(check_interval))
|
||||
logger.info(f"功能配置文件监控已启动,检查间隔: {check_interval}秒")
|
||||
|
||||
async def stop_file_watcher(self):
|
||||
"""停止文件监控"""
|
||||
if self._file_watcher_task and not self._file_watcher_task.done():
|
||||
self._file_watcher_task.cancel()
|
||||
try:
|
||||
await self._file_watcher_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("功能配置文件监控已停止")
|
||||
|
||||
async def _file_watcher_loop(self, check_interval: float):
|
||||
"""文件监控循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(check_interval)
|
||||
await self.reload_config()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"文件监控循环出错: {e}")
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
def get_config(self) -> FeaturesConfig:
|
||||
"""获取当前功能配置"""
|
||||
if self.config is None:
|
||||
return self.load_config()
|
||||
return self.config
|
||||
|
||||
def is_group_allowed(self, group_id: int) -> bool:
|
||||
"""检查群聊是否被允许"""
|
||||
config = self.get_config()
|
||||
if config.group_list_type == "whitelist":
|
||||
return group_id in config.group_list
|
||||
else: # blacklist
|
||||
return group_id not in config.group_list
|
||||
|
||||
def is_private_allowed(self, user_id: int) -> bool:
|
||||
"""检查私聊是否被允许"""
|
||||
config = self.get_config()
|
||||
if config.private_list_type == "whitelist":
|
||||
return user_id in config.private_list
|
||||
else: # blacklist
|
||||
return user_id not in config.private_list
|
||||
|
||||
def is_user_banned(self, user_id: int) -> bool:
|
||||
"""检查用户是否被全局禁止"""
|
||||
config = self.get_config()
|
||||
return user_id in config.ban_user_id
|
||||
|
||||
def is_qq_bot_banned(self) -> bool:
|
||||
"""检查是否禁止QQ官方机器人"""
|
||||
config = self.get_config()
|
||||
return config.ban_qq_bot
|
||||
|
||||
def is_poke_enabled(self) -> bool:
|
||||
"""检查戳一戳功能是否启用"""
|
||||
config = self.get_config()
|
||||
return config.enable_poke
|
||||
|
||||
def is_non_self_poke_ignored(self) -> bool:
|
||||
"""检查是否忽略非自己戳一戳"""
|
||||
config = self.get_config()
|
||||
return config.ignore_non_self_poke
|
||||
|
||||
def is_message_buffer_enabled(self) -> bool:
|
||||
"""检查消息缓冲功能是否启用"""
|
||||
config = self.get_config()
|
||||
return config.enable_message_buffer
|
||||
|
||||
def is_message_buffer_group_enabled(self) -> bool:
|
||||
"""检查群消息缓冲是否启用"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_group
|
||||
|
||||
def is_message_buffer_private_enabled(self) -> bool:
|
||||
"""检查私聊消息缓冲是否启用"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_private
|
||||
|
||||
def get_message_buffer_interval(self) -> float:
|
||||
"""获取消息缓冲间隔时间"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_interval
|
||||
|
||||
def get_message_buffer_initial_delay(self) -> float:
|
||||
"""获取消息缓冲初始延迟"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_initial_delay
|
||||
|
||||
def get_message_buffer_max_components(self) -> int:
|
||||
"""获取消息缓冲最大组件数量"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_max_components
|
||||
|
||||
def is_message_buffer_group_enabled(self) -> bool:
|
||||
"""检查是否启用群聊消息缓冲"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_group
|
||||
|
||||
def is_message_buffer_private_enabled(self) -> bool:
|
||||
"""检查是否启用私聊消息缓冲"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_enable_private
|
||||
|
||||
def get_message_buffer_block_prefixes(self) -> list[str]:
|
||||
"""获取消息缓冲屏蔽前缀列表"""
|
||||
config = self.get_config()
|
||||
return config.message_buffer_block_prefixes
|
||||
|
||||
|
||||
# 全局功能管理器实例
|
||||
features_manager = FeaturesManager()
|
||||
@@ -1,215 +0,0 @@
|
||||
"""
|
||||
功能配置迁移脚本
|
||||
用于将旧的配置文件中的聊天、权限、视频处理等设置迁移到新的独立功能配置文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import tomlkit
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
def migrate_features_from_config(
|
||||
old_config_path: str = "plugins/napcat_adapter_plugin/config/config.toml",
|
||||
new_features_path: str = "plugins/napcat_adapter_plugin/config/features.toml",
|
||||
template_path: str = "plugins/napcat_adapter_plugin/template/features_template.toml",
|
||||
):
|
||||
"""
|
||||
从旧配置文件迁移功能设置到新的功能配置文件
|
||||
|
||||
Args:
|
||||
old_config_path: 旧配置文件路径
|
||||
new_features_path: 新功能配置文件路径
|
||||
template_path: 功能配置模板路径
|
||||
"""
|
||||
try:
|
||||
# 检查旧配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
logger.warning(f"旧配置文件不存在: {old_config_path}")
|
||||
return False
|
||||
|
||||
# 读取旧配置文件
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
|
||||
# 检查是否有chat配置段和video配置段
|
||||
chat_config = old_config.get("chat", {})
|
||||
video_config = old_config.get("video", {})
|
||||
|
||||
# 检查是否有权限相关配置
|
||||
permission_keys = [
|
||||
"group_list_type",
|
||||
"group_list",
|
||||
"private_list_type",
|
||||
"private_list",
|
||||
"ban_user_id",
|
||||
"ban_qq_bot",
|
||||
"enable_poke",
|
||||
"ignore_non_self_poke",
|
||||
"poke_debounce_seconds",
|
||||
]
|
||||
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
|
||||
|
||||
has_permission_config = any(key in chat_config for key in permission_keys)
|
||||
has_video_config = any(key in video_config for key in video_keys)
|
||||
|
||||
if not has_permission_config and not has_video_config:
|
||||
logger.info("旧配置文件中没有找到功能相关配置,无需迁移")
|
||||
return False
|
||||
|
||||
# 确保新功能配置目录存在
|
||||
new_features_dir = Path(new_features_path).parent
|
||||
new_features_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 如果新功能配置文件已存在,先备份
|
||||
if os.path.exists(new_features_path):
|
||||
backup_path = f"{new_features_path}.backup"
|
||||
shutil.copy2(new_features_path, backup_path)
|
||||
logger.info(f"已备份现有功能配置文件到: {backup_path}")
|
||||
|
||||
# 创建新的功能配置
|
||||
new_features_config = {
|
||||
"group_list_type": chat_config.get("group_list_type", "whitelist"),
|
||||
"group_list": chat_config.get("group_list", []),
|
||||
"private_list_type": chat_config.get("private_list_type", "whitelist"),
|
||||
"private_list": chat_config.get("private_list", []),
|
||||
"ban_user_id": chat_config.get("ban_user_id", []),
|
||||
"ban_qq_bot": chat_config.get("ban_qq_bot", False),
|
||||
"enable_poke": chat_config.get("enable_poke", True),
|
||||
"ignore_non_self_poke": chat_config.get("ignore_non_self_poke", False),
|
||||
"poke_debounce_seconds": chat_config.get("poke_debounce_seconds", 3),
|
||||
"enable_video_analysis": video_config.get("enable_video_analysis", True),
|
||||
"max_video_size_mb": video_config.get("max_video_size_mb", 100),
|
||||
"download_timeout": video_config.get("download_timeout", 60),
|
||||
"supported_formats": video_config.get(
|
||||
"supported_formats", ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"]
|
||||
),
|
||||
}
|
||||
|
||||
# 写入新的功能配置文件
|
||||
with open(new_features_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(new_features_config, f)
|
||||
|
||||
logger.info(f"功能配置已成功迁移到: {new_features_path}")
|
||||
|
||||
# 显示迁移的配置内容
|
||||
logger.info("迁移的配置内容:")
|
||||
for key, value in new_features_config.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"功能配置迁移失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def remove_features_from_old_config(config_path: str = "plugins/napcat_adapter_plugin/config/config.toml"):
|
||||
"""
|
||||
从旧配置文件中移除功能相关配置,并将旧配置移动到 config/old/ 目录
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(config_path):
|
||||
logger.warning(f"配置文件不存在: {config_path}")
|
||||
return False
|
||||
|
||||
# 确保 config/old 目录存在
|
||||
old_config_dir = "plugins/napcat_adapter_plugin/config/old"
|
||||
os.makedirs(old_config_dir, exist_ok=True)
|
||||
|
||||
# 备份原配置文件到 config/old 目录
|
||||
old_config_path = os.path.join(old_config_dir, "config_with_features.toml")
|
||||
shutil.copy2(config_path, old_config_path)
|
||||
logger.info(f"已备份包含功能配置的原文件到: {old_config_path}")
|
||||
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 移除chat段中的功能相关配置
|
||||
removed_keys = []
|
||||
if "chat" in config:
|
||||
chat_config = config["chat"]
|
||||
permission_keys = [
|
||||
"group_list_type",
|
||||
"group_list",
|
||||
"private_list_type",
|
||||
"private_list",
|
||||
"ban_user_id",
|
||||
"ban_qq_bot",
|
||||
"enable_poke",
|
||||
"ignore_non_self_poke",
|
||||
"poke_debounce_seconds",
|
||||
]
|
||||
|
||||
for key in permission_keys:
|
||||
if key in chat_config:
|
||||
del chat_config[key]
|
||||
removed_keys.append(key)
|
||||
|
||||
if removed_keys:
|
||||
logger.info(f"已从chat配置段中移除功能相关配置: {removed_keys}")
|
||||
|
||||
# 移除video段中的配置
|
||||
if "video" in config:
|
||||
video_config = config["video"]
|
||||
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
|
||||
|
||||
video_removed_keys = []
|
||||
for key in video_keys:
|
||||
if key in video_config:
|
||||
del video_config[key]
|
||||
video_removed_keys.append(key)
|
||||
|
||||
if video_removed_keys:
|
||||
logger.info(f"已从video配置段中移除配置: {video_removed_keys}")
|
||||
removed_keys.extend(video_removed_keys)
|
||||
|
||||
# 如果video段为空,则删除整个段
|
||||
if not video_config:
|
||||
del config["video"]
|
||||
logger.info("已删除空的video配置段")
|
||||
|
||||
if removed_keys:
|
||||
logger.info(f"总共移除的配置项: {removed_keys}")
|
||||
|
||||
# 写回配置文件
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(config))
|
||||
|
||||
logger.info(f"已更新配置文件: {config_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"移除功能配置失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def auto_migrate_features():
|
||||
"""
|
||||
自动执行功能配置迁移
|
||||
"""
|
||||
logger.info("开始自动功能配置迁移...")
|
||||
|
||||
# 执行迁移
|
||||
if migrate_features_from_config():
|
||||
logger.info("功能配置迁移成功")
|
||||
|
||||
# 询问是否要从旧配置文件中移除功能配置
|
||||
logger.info("功能配置已迁移到独立文件,建议从主配置文件中移除相关配置")
|
||||
# 在实际使用中,这里可以添加用户确认逻辑
|
||||
# 为了自动化,这里直接执行移除
|
||||
remove_features_from_old_config()
|
||||
|
||||
else:
|
||||
logger.info("功能配置迁移跳过或失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_migrate_features()
|
||||
@@ -1,72 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from .config_base import ConfigBase
|
||||
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
2. 所有新增的class都需要继承自ConfigBase
|
||||
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
|
||||
ADAPTER_PLATFORM = "qq"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NicknameConfig(ConfigBase):
|
||||
nickname: str
|
||||
"""机器人昵称"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NapcatServerConfig(ConfigBase):
|
||||
mode: Literal["reverse", "forward"] = "reverse"
|
||||
"""连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""主机地址"""
|
||||
|
||||
port: int = 8095
|
||||
"""端口号"""
|
||||
|
||||
url: str = ""
|
||||
"""正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws"""
|
||||
|
||||
access_token: str = ""
|
||||
"""WebSocket 连接的访问令牌,用于身份验证"""
|
||||
|
||||
heartbeat_interval: int = 30
|
||||
"""心跳间隔时间,单位为秒"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiBotServerConfig(ConfigBase):
|
||||
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
|
||||
"""平台名称,“qq”"""
|
||||
|
||||
host: str = "localhost"
|
||||
"""MaiMCore的主机地址"""
|
||||
|
||||
port: int = 8000
|
||||
"""MaiMCore的端口号"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
use_tts: bool = False
|
||||
"""是否启用TTS功能"""
|
||||
|
||||
@dataclass
|
||||
class SlicingConfig(ConfigBase):
|
||||
max_frame_size: int = 64
|
||||
"""WebSocket帧的最大大小,单位为字节,默认64KB"""
|
||||
|
||||
delay_ms: int = 10
|
||||
"""切片发送间隔时间,单位为毫秒"""
|
||||
|
||||
@dataclass
|
||||
class DebugConfig(ConfigBase):
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
"""日志级别,默认为INFO"""
|
||||
@@ -1,26 +0,0 @@
|
||||
from maim_message import Router, RouteConfig, TargetConfig
|
||||
from .config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from .send_handler import send_handler
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
global_config.maibot_server.platform_name: TargetConfig(
|
||||
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
|
||||
token=None,
|
||||
)
|
||||
}
|
||||
)
|
||||
router = Router(route_config)
|
||||
|
||||
|
||||
async def mmc_start_com():
|
||||
logger.info("正在连接MaiBot")
|
||||
router.register_class_handler(send_handler.handle_message)
|
||||
await router.run()
|
||||
|
||||
|
||||
async def mmc_stop_com():
|
||||
await router.stop()
|
||||
@@ -1,76 +0,0 @@
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..message_chunker import chunker
|
||||
from ..config import global_config
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from maim_message import MessageBase, Router
|
||||
|
||||
|
||||
class MessageSending:
|
||||
"""
|
||||
负责把消息发送到麦麦
|
||||
"""
|
||||
|
||||
maibot_router: Router = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def message_send(self, message_base: MessageBase) -> bool:
|
||||
"""
|
||||
发送消息(Ada -> MMC 方向,需要实现切片)
|
||||
Parameters:
|
||||
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
|
||||
"""
|
||||
try:
|
||||
# 检查是否需要切片发送
|
||||
message_dict = message_base.to_dict()
|
||||
|
||||
if chunker.should_chunk_message(message_dict):
|
||||
logger.info(f"消息过大,进行切片发送到 MaiBot")
|
||||
|
||||
# 切片消息
|
||||
chunks = chunker.chunk_message(message_dict)
|
||||
|
||||
# 逐个发送切片
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.debug(f"发送切片 {i+1}/{len(chunks)} 到 MaiBot")
|
||||
|
||||
# 获取对应的客户端并发送切片
|
||||
platform = message_base.message_info.platform
|
||||
if platform not in self.maibot_router.clients:
|
||||
logger.error(f"平台 {platform} 未连接")
|
||||
return False
|
||||
|
||||
client = self.maibot_router.clients[platform]
|
||||
send_status = await client.send_message(chunk)
|
||||
|
||||
if not send_status:
|
||||
logger.error(f"发送切片 {i+1}/{len(chunks)} 失败")
|
||||
return False
|
||||
|
||||
# 使用配置中的延迟时间
|
||||
if i < len(chunks) - 1:
|
||||
delay_seconds = global_config.slicing.delay_ms / 1000.0
|
||||
logger.debug(f"切片发送延迟: {global_config.slicing.delay_ms}毫秒")
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
logger.debug("所有切片发送完成")
|
||||
return True
|
||||
else:
|
||||
# 直接发送小消息
|
||||
send_status = await self.maibot_router.send_message(message_base)
|
||||
if not send_status:
|
||||
raise RuntimeError("可能是路由未正确配置或连接异常")
|
||||
return send_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {str(e)}")
|
||||
logger.error("请检查与MaiBot之间的连接")
|
||||
return False
|
||||
|
||||
|
||||
message_send_instance = MessageSending()
|
||||
59
scripts/update_prompt_imports.py
Normal file
59
scripts/update_prompt_imports.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
更新Prompt类导入脚本
|
||||
将旧的prompt_builder.Prompt导入更新为unified_prompt.Prompt
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# 需要更新的文件列表
|
||||
files_to_update = [
|
||||
"src/person_info/relationship_fetcher.py",
|
||||
"src/mood/mood_manager.py",
|
||||
"src/mais4u/mais4u_chat/body_emotion_action_manager.py",
|
||||
"src/chat/express/expression_learner.py",
|
||||
"src/chat/planner_actions/planner.py",
|
||||
"src/mais4u/mais4u_chat/s4u_prompt.py",
|
||||
"src/chat/message_receive/bot.py",
|
||||
"src/chat/replyer/default_generator.py",
|
||||
"src/chat/express/expression_selector.py",
|
||||
"src/mais4u/mai_think.py",
|
||||
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
|
||||
"src/plugin_system/core/tool_use.py",
|
||||
"src/chat/memory_system/memory_activator.py",
|
||||
"src/chat/utils/smart_prompt.py"
|
||||
]
|
||||
|
||||
def update_prompt_imports(file_path):
|
||||
"""更新文件中的Prompt导入"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"文件不存在: {file_path}")
|
||||
return False
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 替换导入语句
|
||||
old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager"
|
||||
new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager"
|
||||
|
||||
if old_import in content:
|
||||
new_content = content.replace(old_import, new_import)
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(new_content)
|
||||
print(f"已更新: {file_path}")
|
||||
return True
|
||||
else:
|
||||
print(f"无需更新: {file_path}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
updated_count = 0
|
||||
for file_path in files_to_update:
|
||||
if update_prompt_imports(file_path):
|
||||
updated_count += 1
|
||||
|
||||
print(f"\n更新完成!共更新了 {updated_count} 个文件")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,9 +3,8 @@ import time
|
||||
import traceback
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -20,10 +19,14 @@ from .hfc_context import HfcContext
|
||||
from .response_handler import ResponseHandler
|
||||
from .cycle_tracker import CycleTracker
|
||||
|
||||
# 日志记录器
|
||||
logger = get_logger("hfc.processor")
|
||||
|
||||
|
||||
class CycleProcessor:
|
||||
"""
|
||||
循环处理器类,负责处理单次思考循环的逻辑。
|
||||
"""
|
||||
def __init__(self, context: HfcContext, response_handler: ResponseHandler, cycle_tracker: CycleTracker):
|
||||
"""
|
||||
初始化循环处理器
|
||||
@@ -32,7 +35,7 @@ class CycleProcessor:
|
||||
context: HFC聊天上下文对象,包含聊天流、能量值等信息
|
||||
response_handler: 响应处理器,负责生成和发送回复
|
||||
cycle_tracker: 循环跟踪器,负责记录和管理每次思考循环的信息
|
||||
"""
|
||||
"""
|
||||
self.context = context
|
||||
self.response_handler = response_handler
|
||||
self.cycle_tracker = cycle_tracker
|
||||
@@ -52,17 +55,33 @@ class CycleProcessor:
|
||||
thinking_id,
|
||||
actions,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
"""
|
||||
发送并存储回复信息
|
||||
|
||||
Args:
|
||||
response_set: 回复内容集合
|
||||
loop_start_time: 循环开始时间
|
||||
action_message: 动作消息
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
actions: 动作列表
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, Any], str, Dict[str, float]]: 循环信息, 回复文本, 循环计时器
|
||||
"""
|
||||
# 发送回复
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self.response_handler.send_response(response_set, loop_start_time, action_message)
|
||||
|
||||
# 存储reply action信息
|
||||
# 存储reply action信息
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.get("chat_info_platform")
|
||||
if platform is None:
|
||||
platform = getattr(self.context.chat_stream, "platform", "unknown")
|
||||
|
||||
|
||||
# 获取用户信息并生成回复提示
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform,
|
||||
action_message.get("user_id", ""),
|
||||
@@ -70,6 +89,7 @@ class CycleProcessor:
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
# 存储动作信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.context.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -94,8 +114,8 @@ class CycleProcessor:
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def observe(self,interest_value:float = 0.0) -> bool:
|
||||
|
||||
async def observe(self, interest_value: float = 0.0) -> str:
|
||||
"""
|
||||
观察和处理单次思考循环的核心方法
|
||||
|
||||
@@ -103,7 +123,7 @@ class CycleProcessor:
|
||||
interest_value: 兴趣值
|
||||
|
||||
Returns:
|
||||
bool: 处理是否成功
|
||||
str: 动作类型
|
||||
|
||||
功能说明:
|
||||
- 开始新的思考循环并记录计时
|
||||
@@ -114,11 +134,20 @@ class CycleProcessor:
|
||||
"""
|
||||
action_type = "no_action"
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
|
||||
# 使用sigmoid函数将interest_value转换为概率
|
||||
# 当interest_value为0时,概率接近0(使用Focus模式)
|
||||
# 当interest_value很高时,概率接近1(使用Normal模式)
|
||||
def calculate_normal_mode_probability(interest_val: float) -> float:
|
||||
"""
|
||||
计算普通模式的概率
|
||||
|
||||
Args:
|
||||
interest_val: 兴趣值
|
||||
|
||||
Returns:
|
||||
float: 概率
|
||||
"""
|
||||
# 使用sigmoid函数,调整参数使概率分布更合理
|
||||
# 当interest_value = 0时,概率约为0.1
|
||||
# 当interest_value = 1时,概率约为0.5
|
||||
@@ -127,22 +156,32 @@ class CycleProcessor:
|
||||
k = 2.0 # 控制曲线陡峭程度
|
||||
x0 = 1.0 # 控制曲线中心点
|
||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||
|
||||
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 0.5 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||
|
||||
|
||||
# 计算普通模式概率
|
||||
normal_mode_probability = (
|
||||
calculate_normal_mode_probability(interest_value)
|
||||
* 0.5
|
||||
/ global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||
)
|
||||
|
||||
# 根据概率决定使用哪种模式
|
||||
if random.random() < normal_mode_probability:
|
||||
mode = ChatMode.NORMAL
|
||||
logger.info(f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Normal planner模式")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Normal planner模式"
|
||||
)
|
||||
else:
|
||||
mode = ChatMode.FOCUS
|
||||
logger.info(f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Focus planner模式")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Focus planner模式"
|
||||
)
|
||||
|
||||
# 开始新的思考循环
|
||||
cycle_timers, thinking_id = self.cycle_tracker.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考")
|
||||
|
||||
if ENABLE_S4U:
|
||||
await send_typing()
|
||||
await send_typing(self.context.chat_stream.user_info.user_id)
|
||||
|
||||
loop_start_time = time.time()
|
||||
|
||||
@@ -165,12 +204,16 @@ class CycleProcessor:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system import EventType
|
||||
|
||||
result = await event_manager.trigger_event(EventType.ON_PLAN,plugin_name="SYSTEM", stream_id=self.context.chat_stream)
|
||||
# 触发规划前事件
|
||||
result = await event_manager.trigger_event(
|
||||
EventType.ON_PLAN, plugin_name="SYSTEM", stream_id=self.context.chat_stream
|
||||
)
|
||||
if not result.all_continue_process():
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成")
|
||||
|
||||
|
||||
# 规划动作
|
||||
with Timer("规划器", cycle_timers):
|
||||
actions, _= await self.action_planner.plan(
|
||||
actions, _ = await self.action_planner.plan(
|
||||
mode=mode,
|
||||
loop_start_time=loop_start_time,
|
||||
available_actions=available_actions,
|
||||
@@ -179,11 +222,13 @@ class CycleProcessor:
|
||||
async def execute_action(action_info):
|
||||
"""执行单个动作的通用函数"""
|
||||
try:
|
||||
if action_info["action_type"] == "no_action":
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
if action_info["action_type"] == "no_reply":
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = action_info.get("reasoning", "选择不回复")
|
||||
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
|
||||
# 存储no_reply信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.context.chat_stream,
|
||||
@@ -194,14 +239,9 @@ class CycleProcessor:
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
)
|
||||
|
||||
return {
|
||||
"action_type": "no_reply",
|
||||
"success": True,
|
||||
"reply_text": "",
|
||||
"command": ""
|
||||
}
|
||||
elif action_info["action_type"] != "reply":
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
elif action_info["action_type"] != "reply" and action_info["action_type"] != "no_action":
|
||||
# 执行普通动作
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
@@ -210,41 +250,35 @@ class CycleProcessor:
|
||||
action_info["action_data"],
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
action_info["action_message"]
|
||||
action_info["action_message"],
|
||||
)
|
||||
return {
|
||||
"action_type": action_info["action_type"],
|
||||
"success": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command
|
||||
"command": command,
|
||||
}
|
||||
else:
|
||||
# 生成回复
|
||||
try:
|
||||
success, response_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.context.chat_stream,
|
||||
reply_message = action_info["action_message"],
|
||||
reply_message=action_info["action_message"],
|
||||
available_actions=available_actions,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="chat.replyer",
|
||||
from_plugin=False,
|
||||
)
|
||||
)
|
||||
if not success or not response_set:
|
||||
logger.info(f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None
|
||||
}
|
||||
logger.info(
|
||||
f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
|
||||
)
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None
|
||||
}
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
|
||||
# 发送并存储回复
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
response_set,
|
||||
loop_start_time,
|
||||
@@ -253,12 +287,7 @@ class CycleProcessor:
|
||||
thinking_id,
|
||||
actions,
|
||||
)
|
||||
return {
|
||||
"action_type": "reply",
|
||||
"success": True,
|
||||
"reply_text": reply_text,
|
||||
"loop_info": loop_info
|
||||
}
|
||||
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||
@@ -267,9 +296,9 @@ class CycleProcessor:
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"loop_info": None,
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
# 创建所有动作的后台任务
|
||||
action_tasks = [asyncio.create_task(execute_action(action)) for action in actions]
|
||||
|
||||
@@ -282,12 +311,12 @@ class CycleProcessor:
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
action_command = ""
|
||||
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||
continue
|
||||
|
||||
|
||||
action_info = actions[i]
|
||||
if result["action_type"] != "reply":
|
||||
action_success = result["success"]
|
||||
@@ -327,234 +356,17 @@ class CycleProcessor:
|
||||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
|
||||
|
||||
# 停止正在输入状态
|
||||
if ENABLE_S4U:
|
||||
await stop_typing()
|
||||
|
||||
# 结束循环
|
||||
self.context.chat_instance.cycle_tracker.end_cycle(loop_info, cycle_timers)
|
||||
self.context.chat_instance.cycle_tracker.print_cycle_info(cycle_timers)
|
||||
|
||||
action_type = actions[0]["action_type"] if actions else "no_action"
|
||||
# 管理no_reply计数器:当执行了非no_reply动作时,重置计数器
|
||||
if action_type != "no_reply":
|
||||
# no_reply逻辑已集成到heartFC_chat.py中,直接重置计数器
|
||||
self.context.chat_instance.recent_interest_records.clear()
|
||||
self.context.no_reply_consecutive = 0
|
||||
logger.debug(f"{self.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
||||
return True
|
||||
|
||||
if action_type == "no_reply":
|
||||
self.context.no_reply_consecutive += 1
|
||||
self.context.chat_instance._determine_form_type()
|
||||
|
||||
# 在一轮动作执行完毕后,增加睡眠压力
|
||||
if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system:
|
||||
if action_type not in ["no_reply", "no_action"]:
|
||||
self.context.energy_manager.increase_sleep_pressure()
|
||||
|
||||
return True
|
||||
|
||||
async def execute_plan(self, action_result: Dict[str, Any], target_message: Optional[Dict[str, Any]]):
|
||||
"""
|
||||
执行一个已经制定好的计划
|
||||
"""
|
||||
action_type = action_result.get("action_type", "error")
|
||||
|
||||
# 这里我们需要为执行计划创建一个新的循环追踪
|
||||
cycle_timers, thinking_id = self.cycle_tracker.start_cycle(is_proactive=True)
|
||||
loop_start_time = time.time()
|
||||
|
||||
if action_type == "reply":
|
||||
# 主动思考不应该直接触发简单回复,但为了逻辑完整性,我们假设它会调用response_handler
|
||||
# 注意:这里的 available_actions 和 plan_result 是缺失的,需要根据实际情况处理
|
||||
await self._handle_reply_action(
|
||||
target_message, {}, None, loop_start_time, cycle_timers, thinking_id, {"action_result": action_result}
|
||||
)
|
||||
else:
|
||||
await self._handle_other_actions(
|
||||
action_type,
|
||||
action_result.get("reasoning", ""),
|
||||
action_result.get("action_data", {}),
|
||||
action_result.get("is_parallel", False),
|
||||
None,
|
||||
target_message,
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
{"action_result": action_result},
|
||||
loop_start_time,
|
||||
)
|
||||
|
||||
async def _handle_reply_action(
|
||||
self, message_data, available_actions, gen_task, loop_start_time, cycle_timers, thinking_id, plan_result
|
||||
):
|
||||
"""
|
||||
处理回复类型的动作
|
||||
|
||||
Args:
|
||||
message_data: 消息数据
|
||||
available_actions: 可用动作列表
|
||||
gen_task: 预先创建的生成任务(可能为None)
|
||||
loop_start_time: 循环开始时间
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
plan_result: 规划结果
|
||||
|
||||
功能说明:
|
||||
- 根据聊天模式决定是否使用预生成的回复或实时生成
|
||||
- 在NORMAL模式下使用异步生成提高效率
|
||||
- 在FOCUS模式下同步生成确保及时响应
|
||||
- 发送生成的回复并结束循环
|
||||
"""
|
||||
# 初始化reply_to_str以避免UnboundLocalError
|
||||
reply_to_str = None
|
||||
|
||||
if self.context.loop_mode == ChatMode.NORMAL:
|
||||
if not gen_task:
|
||||
reply_to_str = await self._build_reply_to_str(message_data)
|
||||
gen_task = asyncio.create_task(
|
||||
self.response_handler.generate_response(
|
||||
message_data=message_data,
|
||||
available_actions=available_actions,
|
||||
reply_to=reply_to_str,
|
||||
request_type="chat.replyer.normal",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 如果gen_task已存在但reply_to_str还未构建,需要构建它
|
||||
if reply_to_str is None:
|
||||
reply_to_str = await self._build_reply_to_str(message_data)
|
||||
|
||||
try:
|
||||
response_set = await asyncio.wait_for(gen_task, timeout=global_config.chat.thinking_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
response_set = None
|
||||
else:
|
||||
reply_to_str = await self._build_reply_to_str(message_data)
|
||||
response_set = await self.response_handler.generate_response(
|
||||
message_data=message_data,
|
||||
available_actions=available_actions,
|
||||
reply_to=reply_to_str,
|
||||
request_type="chat.replyer.focus",
|
||||
)
|
||||
|
||||
if response_set:
|
||||
loop_info, _, _ = await self.response_handler.generate_and_send_reply(
|
||||
response_set, reply_to_str, loop_start_time, message_data, cycle_timers, thinking_id, plan_result
|
||||
)
|
||||
self.cycle_tracker.end_cycle(loop_info, cycle_timers)
|
||||
|
||||
async def _handle_other_actions(
|
||||
self,
|
||||
action_type,
|
||||
reasoning,
|
||||
action_data,
|
||||
is_parallel,
|
||||
gen_task,
|
||||
action_message,
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
plan_result,
|
||||
loop_start_time,
|
||||
):
|
||||
"""
|
||||
处理非回复类型的动作(如no_reply、自定义动作等)
|
||||
|
||||
Args:
|
||||
action_type: 动作类型
|
||||
reasoning: 动作理由
|
||||
action_data: 动作数据
|
||||
is_parallel: 是否并行执行
|
||||
gen_task: 生成任务
|
||||
action_message: 动作消息
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
plan_result: 规划结果
|
||||
loop_start_time: 循环开始时间
|
||||
|
||||
功能说明:
|
||||
- 在NORMAL模式下可能并行执行回复生成和动作处理
|
||||
- 等待所有异步任务完成
|
||||
- 整合回复和动作的执行结果
|
||||
- 构建最终循环信息并结束循环
|
||||
"""
|
||||
background_reply_task = None
|
||||
if self.context.loop_mode == ChatMode.NORMAL and is_parallel and gen_task:
|
||||
background_reply_task = asyncio.create_task(
|
||||
self._handle_parallel_reply(
|
||||
gen_task, loop_start_time, action_message, cycle_timers, thinking_id, plan_result
|
||||
)
|
||||
)
|
||||
|
||||
background_action_task = asyncio.create_task(
|
||||
self._handle_action(action_type, reasoning, action_data, cycle_timers, thinking_id, action_message)
|
||||
)
|
||||
|
||||
reply_loop_info, action_success, action_reply_text, action_command = None, False, "", ""
|
||||
|
||||
if background_reply_task:
|
||||
results = await asyncio.gather(background_reply_task, background_action_task, return_exceptions=True)
|
||||
reply_result, action_result_val = results
|
||||
if not isinstance(reply_result, BaseException) and reply_result is not None:
|
||||
reply_loop_info, _, _ = reply_result
|
||||
else:
|
||||
reply_loop_info = None
|
||||
|
||||
if not isinstance(action_result_val, BaseException) and action_result_val is not None:
|
||||
action_success, action_reply_text, action_command = action_result_val
|
||||
else:
|
||||
action_success, action_reply_text, action_command = False, "", ""
|
||||
else:
|
||||
results = await asyncio.gather(background_action_task, return_exceptions=True)
|
||||
if results and len(results) > 0:
|
||||
action_result_val = results[0] # Get the actual result from the tuple
|
||||
else:
|
||||
action_result_val = (False, "", "")
|
||||
|
||||
if not isinstance(action_result_val, BaseException) and action_result_val is not None:
|
||||
action_success, action_reply_text, action_command = action_result_val
|
||||
else:
|
||||
action_success, action_reply_text, action_command = False, "", ""
|
||||
|
||||
loop_info = self._build_final_loop_info(
|
||||
reply_loop_info, action_success, action_reply_text, action_command, plan_result
|
||||
)
|
||||
self.cycle_tracker.end_cycle(loop_info, cycle_timers)
|
||||
|
||||
async def _handle_parallel_reply(
|
||||
self, gen_task, loop_start_time, action_message, cycle_timers, thinking_id, plan_result
|
||||
):
|
||||
"""
|
||||
处理并行回复生成
|
||||
|
||||
Args:
|
||||
gen_task: 回复生成任务
|
||||
loop_start_time: 循环开始时间
|
||||
action_message: 动作消息
|
||||
cycle_timers: 循环计时器
|
||||
thinking_id: 思考ID
|
||||
plan_result: 规划结果
|
||||
|
||||
Returns:
|
||||
tuple: (循环信息, 回复文本, 计时器信息) 或 None
|
||||
|
||||
功能说明:
|
||||
- 等待并行回复生成任务完成(带超时)
|
||||
- 构建回复目标字符串
|
||||
- 发送生成的回复
|
||||
- 返回循环信息供上级方法使用
|
||||
"""
|
||||
try:
|
||||
response_set = await asyncio.wait_for(gen_task, timeout=global_config.chat.thinking_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return None, "", {}
|
||||
|
||||
if not response_set:
|
||||
return None, "", {}
|
||||
|
||||
reply_to_str = await self._build_reply_to_str(action_message)
|
||||
return await self.response_handler.generate_and_send_reply(
|
||||
response_set, reply_to_str, loop_start_time, action_message, cycle_timers, thinking_id, plan_result
|
||||
)
|
||||
return action_type
|
||||
|
||||
async def _handle_action(
|
||||
self, action, reasoning, action_data, cycle_timers, thinking_id, action_message
|
||||
@@ -581,6 +393,7 @@ class CycleProcessor:
|
||||
if not self.context.chat_stream:
|
||||
return False, "", ""
|
||||
try:
|
||||
# 创建动作处理器
|
||||
action_handler = self.context.action_manager.create_action(
|
||||
action_name=action,
|
||||
action_data=action_data,
|
||||
@@ -608,7 +421,7 @@ class CycleProcessor:
|
||||
if fallback_action and fallback_action != action:
|
||||
logger.info(f"{self.context.log_prefix} 使用回退动作: {fallback_action}")
|
||||
action_handler = self.context.action_manager.create_action(
|
||||
action_name=fallback_action,
|
||||
action_name=fallback_action if isinstance(fallback_action, list) else fallback_action,
|
||||
action_data=action_data,
|
||||
reasoning=f"原动作'{action}'不可用,自动回退。{reasoning}",
|
||||
cycle_timers=cycle_timers,
|
||||
@@ -622,74 +435,10 @@ class CycleProcessor:
|
||||
logger.error(f"{self.context.log_prefix} 回退方案也失败,无法创建任何动作处理器")
|
||||
return False, "", ""
|
||||
|
||||
# 执行动作
|
||||
success, reply_text = await action_handler.handle_action()
|
||||
return success, reply_text, ""
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 处理{action}时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
def _get_direct_reply_plan(self, loop_start_time):
|
||||
"""
|
||||
获取直接回复的规划结果
|
||||
|
||||
Args:
|
||||
loop_start_time: 循环开始时间
|
||||
|
||||
Returns:
|
||||
dict: 包含直接回复动作的规划结果
|
||||
|
||||
功能说明:
|
||||
- 在某些情况下跳过复杂规划,直接返回回复动作
|
||||
- 主要用于NORMAL模式下没有其他可用动作时的简化处理
|
||||
"""
|
||||
return {
|
||||
"action_result": {
|
||||
"action_type": "reply",
|
||||
"action_data": {"loop_start_time": loop_start_time},
|
||||
"reasoning": "",
|
||||
"timestamp": time.time(),
|
||||
"is_parallel": False,
|
||||
},
|
||||
"action_prompt": "",
|
||||
}
|
||||
|
||||
def _build_final_loop_info(self, reply_loop_info, action_success, action_reply_text, action_command, plan_result):
|
||||
"""
|
||||
构建最终的循环信息
|
||||
|
||||
Args:
|
||||
reply_loop_info: 回复循环信息(可能为None)
|
||||
action_success: 动作执行是否成功
|
||||
action_reply_text: 动作回复文本
|
||||
action_command: 动作命令
|
||||
plan_result: 规划结果
|
||||
|
||||
Returns:
|
||||
dict: 完整的循环信息,包含规划信息和动作信息
|
||||
|
||||
功能说明:
|
||||
- 如果有回复循环信息,则在其基础上添加动作信息
|
||||
- 如果没有回复信息,则创建新的循环信息结构
|
||||
- 整合所有执行结果供循环跟踪器记录
|
||||
"""
|
||||
if reply_loop_info:
|
||||
loop_info = reply_loop_info
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"command": action_command,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
loop_info = {
|
||||
"loop_plan_info": {"action_result": plan_result.get("action_result", {})},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"command": action_command,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
return loop_info
|
||||
|
||||
@@ -91,25 +91,24 @@ class CycleTracker:
|
||||
|
||||
# 获取动作类型,兼容新旧格式
|
||||
action_type = "未知动作"
|
||||
if hasattr(self, '_current_cycle_detail') and self._current_cycle_detail:
|
||||
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
|
||||
loop_plan_info = self._current_cycle_detail.loop_plan_info
|
||||
if isinstance(loop_plan_info, dict):
|
||||
action_result = loop_plan_info.get('action_result', {})
|
||||
action_result = loop_plan_info.get("action_result", {})
|
||||
if isinstance(action_result, dict):
|
||||
# 旧格式:action_result是字典
|
||||
action_type = action_result.get('action_type', '未知动作')
|
||||
action_type = action_result.get("action_type", "未知动作")
|
||||
elif isinstance(action_result, list) and action_result:
|
||||
# 新格式:action_result是actions列表
|
||||
action_type = action_result[0].get('action_type', '未知动作')
|
||||
action_type = action_result[0].get("action_type", "未知动作")
|
||||
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
||||
# 直接是actions列表的情况
|
||||
action_type = loop_plan_info[0].get('action_type', '未知动作')
|
||||
action_type = loop_plan_info[0].get("action_type", "未知动作")
|
||||
|
||||
if self.context.current_cycle_detail.end_time and self.context.current_cycle_detail.start_time:
|
||||
duration = self.context.current_cycle_detail.end_time - self.context.current_cycle_detail.start_time
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 第{self.context.current_cycle_detail.cycle_id}次思考,"
|
||||
f"耗时: {duration:.1f}秒, "
|
||||
f"选择动作: {action_type}"
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
@@ -3,10 +3,8 @@ import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from .hfc_context import HfcContext
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
from src.chat.chat_loop.sleep_manager import sleep_manager
|
||||
logger = get_logger("hfc")
|
||||
|
||||
|
||||
@@ -75,7 +73,7 @@ class EnergyManager:
|
||||
continue
|
||||
|
||||
# 判断当前是否为睡眠时间
|
||||
is_sleeping = schedule_manager.is_sleeping()
|
||||
is_sleeping = sleep_manager.SleepManager().is_sleeping()
|
||||
|
||||
if is_sleeping:
|
||||
# 睡眠中:减少睡眠压力
|
||||
|
||||
@@ -2,24 +2,24 @@ import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from typing import Optional, List, Dict, Any
|
||||
from collections import deque
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from src.schedule.schedule_manager import schedule_manager, SleepState
|
||||
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager, SleepState
|
||||
from src.plugin_system.apis import message_api
|
||||
|
||||
from .hfc_context import HfcContext
|
||||
from .energy_manager import EnergyManager
|
||||
from .proactive_thinker import ProactiveThinker
|
||||
from .proactive.proactive_thinker import ProactiveThinker
|
||||
from .cycle_processor import CycleProcessor
|
||||
from .response_handler import ResponseHandler
|
||||
from .cycle_tracker import CycleTracker
|
||||
from .wakeup_manager import WakeUpManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from .proactive.events import ProactiveTriggerEvent
|
||||
|
||||
logger = get_logger("hfc")
|
||||
|
||||
@@ -46,15 +46,18 @@ class HeartFChatting:
|
||||
self.energy_manager = EnergyManager(self.context)
|
||||
self.proactive_thinker = ProactiveThinker(self.context, self.cycle_processor)
|
||||
self.wakeup_manager = WakeUpManager(self.context)
|
||||
self.sleep_manager = SleepManager()
|
||||
|
||||
# 将唤醒度管理器设置到上下文中
|
||||
self.context.wakeup_manager = self.wakeup_manager
|
||||
self.context.energy_manager = self.energy_manager
|
||||
self.context.sleep_manager = self.sleep_manager
|
||||
# 将HeartFChatting实例设置到上下文中,以便其他组件可以调用其方法
|
||||
self.context.chat_instance = self
|
||||
|
||||
self._loop_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._proactive_monitor_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 记录最近3次的兴趣度
|
||||
self.recent_interest_records: deque = deque(maxlen=3)
|
||||
self._initialize_chat_mode()
|
||||
@@ -93,8 +96,12 @@ class HeartFChatting:
|
||||
self.context.relationship_builder = relationship_builder_manager.get_or_create_builder(self.context.stream_id)
|
||||
self.context.expression_learner = expression_learner_manager.get_expression_learner(self.context.stream_id)
|
||||
|
||||
#await self.energy_manager.start()
|
||||
await self.proactive_thinker.start()
|
||||
# 启动主动思考监视器
|
||||
if global_config.chat.enable_proactive_thinking:
|
||||
self._proactive_monitor_task = asyncio.create_task(self._proactive_monitor_loop())
|
||||
self._proactive_monitor_task.add_done_callback(self._handle_proactive_monitor_completion)
|
||||
logger.info(f"{self.context.log_prefix} 主动思考监视器已启动")
|
||||
|
||||
await self.wakeup_manager.start()
|
||||
|
||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
||||
@@ -116,8 +123,12 @@ class HeartFChatting:
|
||||
return
|
||||
self.context.running = False
|
||||
|
||||
#await self.energy_manager.stop()
|
||||
await self.proactive_thinker.stop()
|
||||
# 停止主动思考监视器
|
||||
if self._proactive_monitor_task and not self._proactive_monitor_task.done():
|
||||
self._proactive_monitor_task.cancel()
|
||||
await asyncio.sleep(0)
|
||||
logger.info(f"{self.context.log_prefix} 主动思考监视器已停止")
|
||||
|
||||
await self.wakeup_manager.stop()
|
||||
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
@@ -147,6 +158,151 @@ class HeartFChatting:
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.context.log_prefix} HeartFChatting: 结束了聊天")
|
||||
|
||||
def _handle_proactive_monitor_completion(self, task: asyncio.Task):
|
||||
"""
|
||||
处理主动思考监视器任务完成
|
||||
|
||||
Args:
|
||||
task: 完成的异步任务对象
|
||||
|
||||
功能说明:
|
||||
- 处理任务异常完成的情况
|
||||
- 记录任务正常结束或被取消的日志
|
||||
"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.context.log_prefix} 主动思考监视器异常: {exception}")
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考监视器正常结束")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考监视器被取消")
|
||||
|
||||
async def _proactive_monitor_loop(self):
|
||||
"""
|
||||
主动思考监视器循环
|
||||
|
||||
功能说明:
|
||||
- 定期检查是否需要进行主动思考
|
||||
- 计算聊天沉默时间,并与动态思考间隔比较
|
||||
- 当沉默时间超过阈值时,触发主动思考
|
||||
- 处理思考过程中的异常
|
||||
"""
|
||||
while self.context.running:
|
||||
await asyncio.sleep(15)
|
||||
|
||||
if not self._should_enable_proactive_thinking():
|
||||
continue
|
||||
|
||||
current_time = time.time()
|
||||
silence_duration = current_time - self.context.last_message_time
|
||||
target_interval = self._get_dynamic_thinking_interval()
|
||||
|
||||
if silence_duration >= target_interval:
|
||||
try:
|
||||
formatted_time = self._format_duration(silence_duration)
|
||||
event = ProactiveTriggerEvent(
|
||||
source="silence_monitor",
|
||||
reason=f"聊天已沉默 {formatted_time}",
|
||||
metadata={"silence_duration": silence_duration},
|
||||
)
|
||||
await self.proactive_thinker.think(event)
|
||||
self.context.last_message_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考触发执行出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def _should_enable_proactive_thinking(self) -> bool:
|
||||
"""
|
||||
判断是否应启用主动思考
|
||||
|
||||
Returns:
|
||||
bool: 如果应启用主动思考则返回True,否则返回False
|
||||
|
||||
功能说明:
|
||||
- 检查全局配置和特定聊天设置
|
||||
- 支持按群聊和私聊分别配置
|
||||
- 支持白名单模式,只在特定聊天中启用
|
||||
"""
|
||||
if not self.context.chat_stream:
|
||||
return False
|
||||
|
||||
is_group_chat = self.context.chat_stream.group_info is not None
|
||||
|
||||
if is_group_chat and not global_config.chat.proactive_thinking_in_group:
|
||||
return False
|
||||
if not is_group_chat and not global_config.chat.proactive_thinking_in_private:
|
||||
return False
|
||||
|
||||
stream_parts = self.context.stream_id.split(":")
|
||||
current_chat_identifier = f"{stream_parts}:{stream_parts}" if len(stream_parts) >= 2 else self.context.stream_id
|
||||
|
||||
enable_list = getattr(
|
||||
global_config.chat,
|
||||
"proactive_thinking_enable_in_groups" if is_group_chat else "proactive_thinking_enable_in_private",
|
||||
[],
|
||||
)
|
||||
return not enable_list or current_chat_identifier in enable_list
|
||||
|
||||
def _get_dynamic_thinking_interval(self) -> float:
|
||||
"""
|
||||
获取动态思考间隔时间
|
||||
|
||||
Returns:
|
||||
float: 思考间隔秒数
|
||||
|
||||
功能说明:
|
||||
- 尝试从timing_utils导入正态分布间隔函数
|
||||
- 根据配置计算动态间隔,增加随机性
|
||||
- 在无法导入或计算出错时,回退到固定的间隔
|
||||
"""
|
||||
try:
|
||||
from src.utils.timing_utils import get_normal_distributed_interval
|
||||
|
||||
base_interval = global_config.chat.proactive_thinking_interval
|
||||
delta_sigma = getattr(global_config.chat, "delta_sigma", 120)
|
||||
|
||||
if base_interval <= 0:
|
||||
base_interval = abs(base_interval)
|
||||
if delta_sigma < 0:
|
||||
delta_sigma = abs(delta_sigma)
|
||||
|
||||
if base_interval == 0 and delta_sigma == 0:
|
||||
return 300
|
||||
if delta_sigma == 0:
|
||||
return base_interval
|
||||
|
||||
sigma_percentage = delta_sigma / base_interval if base_interval > 0 else delta_sigma / 1000
|
||||
return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True)
|
||||
|
||||
except ImportError:
|
||||
logger.warning(f"{self.context.log_prefix} timing_utils不可用,使用固定间隔")
|
||||
return max(300, abs(global_config.chat.proactive_thinking_interval))
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 动态间隔计算出错: {e},使用固定间隔")
|
||||
return max(300, abs(global_config.chat.proactive_thinking_interval))
|
||||
|
||||
def _format_duration(self, seconds: float) -> str:
|
||||
"""
|
||||
格式化时长为可读字符串
|
||||
|
||||
Args:
|
||||
seconds: 时长秒数
|
||||
|
||||
Returns:
|
||||
str: 格式化后的字符串 (例如 "1小时2分3秒")
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
parts = []
|
||||
if hours > 0:
|
||||
parts.append(f"{hours}小时")
|
||||
if minutes > 0:
|
||||
parts.append(f"{minutes}分")
|
||||
if secs > 0 or not parts:
|
||||
parts.append(f"{secs}秒")
|
||||
return "".join(parts)
|
||||
|
||||
async def _main_chat_loop(self):
|
||||
"""
|
||||
主聊天循环
|
||||
@@ -197,8 +353,8 @@ class HeartFChatting:
|
||||
- NORMAL模式:检查进入FOCUS模式的条件,并通过normal_mode_handler处理消息
|
||||
"""
|
||||
# --- 核心状态更新 ---
|
||||
await schedule_manager.update_sleep_state(self.wakeup_manager)
|
||||
current_sleep_state = schedule_manager.get_current_sleep_state()
|
||||
await self.sleep_manager.update_sleep_state(self.wakeup_manager)
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
is_sleeping = current_sleep_state == SleepState.SLEEPING
|
||||
is_in_insomnia = current_sleep_state == SleepState.INSOMNIA
|
||||
|
||||
@@ -228,7 +384,7 @@ class HeartFChatting:
|
||||
self._handle_wakeup_messages(recent_messages)
|
||||
|
||||
# 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP
|
||||
current_sleep_state = schedule_manager.get_current_sleep_state()
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
|
||||
if current_sleep_state == SleepState.SLEEPING:
|
||||
# 只有在纯粹的 SLEEPING 状态下才跳过消息处理
|
||||
@@ -238,113 +394,56 @@ class HeartFChatting:
|
||||
logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。")
|
||||
|
||||
# 根据聊天模式处理新消息
|
||||
# 统一使用 _should_process_messages 判断是否应该处理
|
||||
should_process,interest_value = await self._should_process_messages(recent_messages if has_new_messages else None)
|
||||
if should_process:
|
||||
self.context.last_read_time = time.time()
|
||||
await self.cycle_processor.observe(interest_value = interest_value)
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
should_process, interest_value = await self._should_process_messages(recent_messages)
|
||||
if not should_process:
|
||||
# 消息数量不足或兴趣不够,等待
|
||||
await asyncio.sleep(0.5)
|
||||
return True
|
||||
|
||||
if not await self._should_process_messages(recent_messages if has_new_messages else None):
|
||||
return has_new_messages
|
||||
|
||||
# 处理新消息
|
||||
for message in recent_messages:
|
||||
await self.cycle_processor.observe(interest_value = interest_value)
|
||||
|
||||
return True # Skip rest of the logic for this iteration
|
||||
|
||||
# Messages should be processed
|
||||
action_type = await self.cycle_processor.observe(interest_value=interest_value)
|
||||
|
||||
# 管理no_reply计数器
|
||||
if action_type != "no_reply":
|
||||
self.recent_interest_records.clear()
|
||||
self.context.no_reply_consecutive = 0
|
||||
logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
||||
else: # action_type == "no_reply"
|
||||
self.context.no_reply_consecutive += 1
|
||||
self._determine_form_type()
|
||||
|
||||
# 在一轮动作执行完毕后,增加睡眠压力
|
||||
if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system:
|
||||
if action_type not in ["no_reply", "no_action"]:
|
||||
self.context.energy_manager.increase_sleep_pressure()
|
||||
|
||||
# 如果成功观察,增加能量值并重置累积兴趣值
|
||||
if has_new_messages:
|
||||
self.context.energy_value += 1 / global_config.chat.focus_value
|
||||
# 重置累积兴趣值,因为消息已经被成功处理
|
||||
self.context.breaking_accumulated_interest = 0.0
|
||||
logger.info(f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值")
|
||||
|
||||
self._check_focus_exit()
|
||||
|
||||
else:
|
||||
# 无新消息时,只进行模式检查,不进行思考循环
|
||||
self._check_focus_exit()
|
||||
|
||||
self.context.energy_value += 1 / global_config.chat.focus_value
|
||||
# 重置累积兴趣值,因为消息已经被成功处理
|
||||
self.context.breaking_accumulated_interest = 0.0
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值"
|
||||
)
|
||||
|
||||
# 更新上一帧的睡眠状态
|
||||
self.context.was_sleeping = is_sleeping
|
||||
|
||||
# --- 重新入睡逻辑 ---
|
||||
# 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡
|
||||
if schedule_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages:
|
||||
if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages:
|
||||
re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60
|
||||
# 使用 last_message_time 来判断空闲时间
|
||||
if time.time() - self.context.last_message_time > re_sleep_delay:
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。"
|
||||
)
|
||||
schedule_manager.reset_sleep_state_after_wakeup()
|
||||
self.sleep_manager.reset_sleep_state_after_wakeup()
|
||||
|
||||
# 保存HFC上下文状态
|
||||
self.context.save_context_state()
|
||||
|
||||
return has_new_messages
|
||||
|
||||
def _check_focus_exit(self):
|
||||
"""
|
||||
检查是否应该退出FOCUS模式
|
||||
|
||||
功能说明:
|
||||
- 区分私聊和群聊环境
|
||||
- 在强制私聊focus模式下,能量值低于1时重置为5但不退出
|
||||
- 在群聊focus模式下,如果配置为focus则不退出
|
||||
- 其他情况下,能量值低于1时退出到NORMAL模式
|
||||
"""
|
||||
is_private_chat = self.context.chat_stream.group_info is None if self.context.chat_stream else False
|
||||
is_group_chat = not is_private_chat
|
||||
|
||||
if global_config.chat.force_focus_private and is_private_chat:
|
||||
if self.context.energy_value <= 1:
|
||||
self.context.energy_value = 5
|
||||
return
|
||||
|
||||
if is_group_chat and global_config.chat.group_chat_mode == "focus":
|
||||
return
|
||||
|
||||
if self.context.energy_value <= 1: # 如果能量值小于等于1(非强制情况)
|
||||
self.context.energy_value = 1 # 将能量值设置为1
|
||||
|
||||
def _check_focus_entry(self, new_message_count: int):
|
||||
"""
|
||||
检查是否应该进入FOCUS模式
|
||||
|
||||
Args:
|
||||
new_message_count: 新消息数量
|
||||
|
||||
功能说明:
|
||||
- 区分私聊和群聊环境
|
||||
- 强制私聊focus模式:直接进入FOCUS模式并设置能量值为10
|
||||
- 群聊normal模式:不进入FOCUS模式
|
||||
- 根据focus_value配置和消息数量决定是否进入FOCUS模式
|
||||
- 当消息数量超过阈值或能量值达到30时进入FOCUS模式
|
||||
"""
|
||||
is_private_chat = self.context.chat_stream.group_info is None if self.context.chat_stream else False
|
||||
is_group_chat = not is_private_chat
|
||||
|
||||
if global_config.chat.force_focus_private and is_private_chat:
|
||||
self.context.energy_value = 10
|
||||
return
|
||||
|
||||
if is_group_chat and global_config.chat.group_chat_mode == "normal":
|
||||
return
|
||||
|
||||
if global_config.chat.focus_value != 0: # 如果专注值配置不为0(启用自动专注)
|
||||
if new_message_count > 3 / pow(
|
||||
global_config.chat.focus_value, 0.5
|
||||
): # 如果新消息数超过阈值(基于专注值计算)
|
||||
self.context.energy_value = (
|
||||
10 + (new_message_count / (3 / pow(global_config.chat.focus_value, 0.5))) * 10
|
||||
) # 根据消息数量计算能量值
|
||||
return # 返回,不再检查其他条件
|
||||
|
||||
def _handle_wakeup_messages(self, messages):
|
||||
"""
|
||||
处理休眠状态下的消息,累积唤醒度
|
||||
@@ -382,68 +481,84 @@ class HeartFChatting:
|
||||
|
||||
def _determine_form_type(self) -> str:
|
||||
"""判断使用哪种形式的no_reply"""
|
||||
# 检查是否启用breaking模式
|
||||
if not getattr(global_config.chat, "enable_breaking_mode", False):
|
||||
logger.info(f"{self.context.log_prefix} breaking模式已禁用,使用waiting形式")
|
||||
self.context.focus_energy = 1
|
||||
return "waiting"
|
||||
|
||||
# 如果连续no_reply次数少于3次,使用waiting形式
|
||||
if self.context.no_reply_consecutive <= 3:
|
||||
self.context.focus_energy = 1
|
||||
return "waiting"
|
||||
else:
|
||||
# 使用累积兴趣值而不是最近3次的记录
|
||||
total_interest = self.context.breaking_accumulated_interest
|
||||
|
||||
|
||||
# 计算调整后的阈值
|
||||
adjusted_threshold = 1 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||
|
||||
logger.info(f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
|
||||
)
|
||||
|
||||
# 如果累积兴趣值小于阈值,进入breaking形式
|
||||
if total_interest < adjusted_threshold:
|
||||
logger.info(f"{self.context.log_prefix} 累积兴趣度不足,进入breaking形式")
|
||||
self.context.focus_energy = random.randint(3, 6)
|
||||
return "breaking"
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 累积兴趣度充足,使用waiting形式")
|
||||
self.context.focus_energy = 1
|
||||
return "waiting"
|
||||
|
||||
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool,float]:
|
||||
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool, float]:
|
||||
"""
|
||||
统一判断是否应该处理消息的函数
|
||||
根据当前循环模式和消息内容决定是否继续处理
|
||||
"""
|
||||
if not new_message:
|
||||
return False, 0.0
|
||||
|
||||
new_message_count = len(new_message)
|
||||
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.context.chat_stream.stream_id)
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||
|
||||
modified_exit_count_threshold = self.context.focus_energy * 0.5 / talk_frequency
|
||||
modified_exit_interest_threshold = 1.5 / talk_frequency
|
||||
|
||||
|
||||
# 计算当前批次消息的兴趣值
|
||||
batch_interest = 0.0
|
||||
for msg_dict in new_message:
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if msg_dict.get("processed_plain_text", ""):
|
||||
batch_interest += interest_value
|
||||
|
||||
|
||||
# 在breaking形式下累积所有消息的兴趣值
|
||||
if new_message_count > 0:
|
||||
self.context.breaking_accumulated_interest += batch_interest
|
||||
total_interest = self.context.breaking_accumulated_interest
|
||||
else:
|
||||
total_interest = self.context.breaking_accumulated_interest
|
||||
|
||||
|
||||
if new_message_count >= modified_exit_count_threshold:
|
||||
# 记录兴趣度到列表
|
||||
self.recent_interest_records.append(total_interest)
|
||||
# 重置累积兴趣值,因为已经达到了消息数量阈值
|
||||
self.context.breaking_accumulated_interest = 0.0
|
||||
|
||||
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 累计消息数量达到{new_message_count}条(>{modified_exit_count_threshold:.1f}),结束等待,累积兴趣值: {total_interest:.2f}"
|
||||
)
|
||||
return True,total_interest/new_message_count
|
||||
return True, total_interest / new_message_count
|
||||
|
||||
# 检查累计兴趣值
|
||||
if new_message_count > 0:
|
||||
# 只在兴趣值变化时输出log
|
||||
if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest:
|
||||
logger.info(f"{self.context.log_prefix} breaking形式当前累积兴趣值: {total_interest:.2f}, 专注度: {global_config.chat.focus_value:.1f}")
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} breaking形式当前累积兴趣值: {total_interest:.2f}, 专注度: {global_config.chat.focus_value:.1f}"
|
||||
)
|
||||
self._last_accumulated_interest = total_interest
|
||||
if total_interest >= modified_exit_interest_threshold:
|
||||
# 记录兴趣度到列表
|
||||
@@ -453,67 +568,16 @@ class HeartFChatting:
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 累计兴趣值达到{total_interest:.2f}(>{modified_exit_interest_threshold:.1f}),结束等待"
|
||||
)
|
||||
return True,total_interest/new_message_count
|
||||
|
||||
return True, total_interest / new_message_count
|
||||
|
||||
# 每10秒输出一次等待状态
|
||||
if int(time.time() - self.context.last_read_time) > 0 and int(time.time() - self.context.last_read_time) % 10 == 0:
|
||||
if (
|
||||
int(time.time() - self.context.last_read_time) > 0
|
||||
and int(time.time() - self.context.last_read_time) % 10 == 0
|
||||
):
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,累积兴趣{total_interest:.1f},继续等待..."
|
||||
f"{self.context.log_prefix} 已等待{time.time() - self.context.last_read_time:.0f}秒,累计{new_message_count}条消息,累积兴趣{total_interest:.1f},继续等待..."
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return False,0.0
|
||||
|
||||
async def _execute_no_reply(self, new_message: List[Dict[str, Any]]) -> bool:
|
||||
"""执行breaking形式的no_reply(原有逻辑)"""
|
||||
new_message_count = len(new_message)
|
||||
# 检查消息数量是否达到阈值
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||
modified_exit_count_threshold = self.context.focus_energy / talk_frequency
|
||||
|
||||
if new_message_count >= modified_exit_count_threshold:
|
||||
# 记录兴趣度到列表
|
||||
total_interest = 0.0
|
||||
for msg_dict in new_message:
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if msg_dict.get("processed_plain_text", ""):
|
||||
total_interest += interest_value
|
||||
|
||||
self.recent_interest_records.append(total_interest)
|
||||
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 累计消息数量达到{new_message_count}条(>{modified_exit_count_threshold}),结束等待"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
# 检查累计兴趣值
|
||||
if new_message_count > 0:
|
||||
accumulated_interest = 0.0
|
||||
for msg_dict in new_message:
|
||||
text = msg_dict.get("processed_plain_text", "")
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if text:
|
||||
accumulated_interest += interest_value
|
||||
|
||||
# 只在兴趣值变化时输出log
|
||||
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
|
||||
logger.info(f"{self.context.log_prefix} breaking形式当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
|
||||
self._last_accumulated_interest = accumulated_interest
|
||||
|
||||
if accumulated_interest >= 3 / talk_frequency:
|
||||
# 记录兴趣度到列表
|
||||
self.recent_interest_records.append(accumulated_interest)
|
||||
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 累计兴趣值达到{accumulated_interest:.2f}(>{3 / talk_frequency}),结束等待"
|
||||
)
|
||||
return True
|
||||
|
||||
# 每10秒输出一次等待状态
|
||||
if int(time.time() - self.context.last_read_time) > 0 and int(time.time() - self.context.last_read_time) % 10 == 0:
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 已等待{time.time() - self.context.last_read_time:.0f}秒,累计{new_message_count}条消息,继续等待..."
|
||||
)
|
||||
|
||||
return False
|
||||
return False, 0.0
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
import time
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.relationship_builder_manager import RelationshipBuilder
|
||||
from src.chat.express.expression_learner import ExpressionLearner
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.chat_loop.hfc_utils import CycleDetail
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wakeup_manager import WakeUpManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from .energy_manager import EnergyManager
|
||||
from .heartFC_chat import HeartFChatting
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
|
||||
|
||||
class HfcContext:
|
||||
@@ -43,13 +43,13 @@ class HfcContext:
|
||||
|
||||
self.energy_value = self.chat_stream.energy_value
|
||||
self.sleep_pressure = self.chat_stream.sleep_pressure
|
||||
self.was_sleeping = False # 用于检测睡眠状态的切换
|
||||
|
||||
self.was_sleeping = False # 用于检测睡眠状态的切换
|
||||
|
||||
self.last_message_time = time.time()
|
||||
self.last_read_time = time.time() - 10
|
||||
|
||||
|
||||
# 从聊天流恢复breaking累积兴趣值
|
||||
self.breaking_accumulated_interest = getattr(self.chat_stream, 'breaking_accumulated_interest', 0.0)
|
||||
self.breaking_accumulated_interest = getattr(self.chat_stream, "breaking_accumulated_interest", 0.0)
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
|
||||
@@ -62,6 +62,7 @@ class HfcContext:
|
||||
# 唤醒度管理器 - 延迟初始化以避免循环导入
|
||||
self.wakeup_manager: Optional["WakeUpManager"] = None
|
||||
self.energy_manager: Optional["EnergyManager"] = None
|
||||
self.sleep_manager: Optional["SleepManager"] = None
|
||||
|
||||
self.focus_energy = 1
|
||||
self.no_reply_consecutive = 0
|
||||
@@ -69,7 +70,7 @@ class HfcContext:
|
||||
# breaking形式下的累积兴趣值
|
||||
self.breaking_accumulated_interest = 0.0
|
||||
# 引用HeartFChatting实例,以便其他组件可以调用其方法
|
||||
self.chat_instance = None
|
||||
self.chat_instance: "HeartFChatting"
|
||||
|
||||
def save_context_state(self):
|
||||
"""将当前状态保存到聊天流"""
|
||||
@@ -78,4 +79,4 @@ class HfcContext:
|
||||
self.chat_stream.sleep_pressure = self.sleep_pressure
|
||||
self.chat_stream.focus_energy = self.focus_energy
|
||||
self.chat_stream.no_reply_consecutive = self.no_reply_consecutive
|
||||
self.chat_stream.breaking_accumulated_interest = self.breaking_accumulated_interest
|
||||
self.chat_stream.breaking_accumulated_interest = self.breaking_accumulated_interest
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import time
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from maim_message.message_base import GroupInfo
|
||||
|
||||
from src.common.message_repository import count_messages
|
||||
|
||||
logger = get_logger("hfc")
|
||||
|
||||
@@ -123,43 +121,7 @@ class CycleDetail:
|
||||
self.loop_action_info = loop_info["loop_action_info"]
|
||||
|
||||
|
||||
def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
获取最近消息统计信息
|
||||
|
||||
Args:
|
||||
minutes: 检索的分钟数,默认30分钟
|
||||
chat_id: 指定的chat_id,仅统计该chat下的消息。为None时统计全部
|
||||
|
||||
Returns:
|
||||
dict: {"bot_reply_count": int, "total_message_count": int}
|
||||
|
||||
功能说明:
|
||||
- 统计指定时间范围内的消息数量
|
||||
- 区分机器人回复和总消息数
|
||||
- 可以针对特定聊天或全局统计
|
||||
- 用于分析聊天活跃度和机器人参与度
|
||||
"""
|
||||
|
||||
now = time.time()
|
||||
start_time = now - minutes * 60
|
||||
bot_id = global_config.bot.qq_account
|
||||
|
||||
filter_base: Dict[str, Any] = {"time": {"$gte": start_time}}
|
||||
if chat_id is not None:
|
||||
filter_base["chat_id"] = chat_id
|
||||
|
||||
# 总消息数
|
||||
total_message_count = count_messages(filter_base)
|
||||
# bot自身回复数
|
||||
bot_filter = filter_base.copy()
|
||||
bot_filter["user_id"] = bot_id
|
||||
bot_reply_count = count_messages(bot_filter)
|
||||
|
||||
return {"bot_reply_count": bot_reply_count, "total_message_count": total_message_count}
|
||||
|
||||
|
||||
async def send_typing():
|
||||
async def send_typing(user_id):
|
||||
"""
|
||||
发送打字状态指示
|
||||
|
||||
@@ -177,6 +139,11 @@ async def send_typing():
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
from plugin_system.core.event_manager import event_manager
|
||||
from src.plugins.built_in.napcat_adapter_plugin.event_types import NapcatEvent
|
||||
# 设置正在输入状态
|
||||
await event_manager.trigger_event(NapcatEvent.PERSONAL.SET_INPUT_STATUS,user_id=user_id,event_type=1)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
|
||||
13
src/chat/chat_loop/proactive/events.py
Normal file
13
src/chat/chat_loop/proactive/events.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProactiveTriggerEvent:
|
||||
"""
|
||||
主动思考触发事件的数据类
|
||||
"""
|
||||
|
||||
source: str # 触发源的标识,例如 "silence_monitor", "insomnia_manager"
|
||||
reason: str # 触发的具体原因,例如 "聊天已沉默10分钟", "深夜emo"
|
||||
metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # 可选的元数据,用于传递额外信息
|
||||
253
src/chat/chat_loop/proactive/proactive_thinker.py
Normal file
253
src/chat/chat_loop/proactive/proactive_thinker.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import time
|
||||
import traceback
|
||||
import orjson
|
||||
from typing import TYPE_CHECKING, Dict, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from ..hfc_context import HfcContext
|
||||
from .events import ProactiveTriggerEvent
|
||||
from src.plugin_system.apis import generator_api
|
||||
from src.plugin_system.apis.generator_api import process_human_text
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.plugin_system import tool_api
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages_with_id
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.common.database.sqlalchemy_database_api import store_action_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..cycle_processor import CycleProcessor
|
||||
|
||||
logger = get_logger("hfc")
|
||||
|
||||
|
||||
class ProactiveThinker:
|
||||
"""
|
||||
主动思考器,负责处理和执行主动思考事件。
|
||||
当接收到 ProactiveTriggerEvent 时,它会根据事件内容进行一系列决策和操作,
|
||||
例如调整情绪、调用规划器生成行动,并最终可能产生一个主动的回复。
|
||||
"""
|
||||
|
||||
def __init__(self, context: HfcContext, cycle_processor: "CycleProcessor"):
|
||||
"""
|
||||
初始化主动思考器。
|
||||
|
||||
Args:
|
||||
context (HfcContext): HFC聊天上下文对象,提供了当前聊天会话的所有背景信息。
|
||||
cycle_processor (CycleProcessor): 循环处理器,用于执行主动思考后产生的动作。
|
||||
|
||||
功能说明:
|
||||
- 接收并处理主动思考事件 (ProactiveTriggerEvent)。
|
||||
- 在思考前根据事件类型执行预处理操作,如修改当前情绪状态。
|
||||
- 调用行动规划器 (Action Planner) 来决定下一步应该做什么。
|
||||
- 如果规划结果是发送消息,则调用生成器API生成回复并发送。
|
||||
"""
|
||||
self.context = context
|
||||
self.cycle_processor = cycle_processor
|
||||
|
||||
async def think(self, trigger_event: ProactiveTriggerEvent):
|
||||
"""
|
||||
主动思考的统一入口API。
|
||||
这是外部触发主动思考时调用的主要方法。
|
||||
|
||||
Args:
|
||||
trigger_event (ProactiveTriggerEvent): 描述触发上下文的事件对象,包含了思考的来源和原因。
|
||||
"""
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 接收到主动思考事件: "
|
||||
f"来源='{trigger_event.source}', 原因='{trigger_event.reason}'"
|
||||
)
|
||||
|
||||
try:
|
||||
# 步骤 1: 根据事件类型执行思考前的准备工作,例如调整情绪。
|
||||
await self._prepare_for_thinking(trigger_event)
|
||||
|
||||
# 步骤 2: 执行核心的思考和决策逻辑。
|
||||
await self._execute_proactive_thinking(trigger_event)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并记录在思考过程中发生的任何异常。
|
||||
logger.error(f"{self.context.log_prefix} 主动思考 think 方法执行异常: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _prepare_for_thinking(self, trigger_event: ProactiveTriggerEvent):
|
||||
"""
|
||||
根据事件类型,在正式思考前执行准备工作。
|
||||
目前主要是处理来自失眠管理器的事件,并据此调整情绪。
|
||||
|
||||
Args:
|
||||
trigger_event (ProactiveTriggerEvent): 触发事件。
|
||||
"""
|
||||
# 目前只处理来自失眠管理器(insomnia_manager)的事件
|
||||
if trigger_event.source != "insomnia_manager":
|
||||
return
|
||||
|
||||
try:
|
||||
# 获取当前聊天的情绪对象
|
||||
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
|
||||
new_mood = None
|
||||
|
||||
# 根据失眠的不同原因设置对应的情绪
|
||||
if trigger_event.reason == "low_pressure":
|
||||
new_mood = "精力过剩,毫无睡意"
|
||||
elif trigger_event.reason == "random":
|
||||
new_mood = "深夜emo,胡思乱想"
|
||||
elif trigger_event.reason == "goodnight":
|
||||
new_mood = "有点困了,准备睡觉了"
|
||||
|
||||
# 如果成功匹配到了新的情绪,则更新情绪状态
|
||||
if new_mood:
|
||||
mood_obj.mood_state = new_mood
|
||||
mood_obj.last_change_time = time.time()
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 因 '{trigger_event.reason}',"
|
||||
f"情绪状态被强制更新为: {mood_obj.mood_state}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}")
|
||||
|
||||
async def _execute_proactive_thinking(self, trigger_event: ProactiveTriggerEvent):
|
||||
"""
|
||||
执行主动思考的核心逻辑。
|
||||
它会调用规划器来决定是否要采取行动,以及采取什么行动。
|
||||
|
||||
Args:
|
||||
trigger_event (ProactiveTriggerEvent): 触发事件。
|
||||
"""
|
||||
try:
|
||||
# 调用规划器的 PROACTIVE 模式,让其决定下一步的行动
|
||||
actions, _ = await self.cycle_processor.action_planner.plan(mode=ChatMode.PROACTIVE)
|
||||
|
||||
# 通常只关心规划出的第一个动作
|
||||
action_result = actions[0] if actions else {}
|
||||
|
||||
action_type = action_result.get("action_type")
|
||||
|
||||
if action_type == "proactive_reply":
|
||||
await self._generate_proactive_content_and_send(action_result)
|
||||
elif action_type != "do_nothing":
|
||||
logger.warning(f"{self.context.log_prefix} 主动思考返回了未知的动作类型: {action_type}")
|
||||
else:
|
||||
# 如果规划结果是“什么都不做”,则记录日志
|
||||
logger.info(f"{self.context.log_prefix} 主动思考决策: 保持沉默")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考执行异常: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _generate_proactive_content_and_send(self, action_result: Dict[str, Any]):
|
||||
"""
|
||||
获取实时信息,构建最终的生成提示词,并生成和发送主动回复。
|
||||
|
||||
Args:
|
||||
action_result (Dict[str, Any]): 规划器返回的动作结果。
|
||||
"""
|
||||
try:
|
||||
topic = action_result.get("action_data", {}).get("topic", "随便聊聊")
|
||||
logger.info(f"{self.context.log_prefix} 主动思考确定主题: '{topic}'")
|
||||
|
||||
# 1. 获取日程信息
|
||||
schedule_block = "你今天没有日程安排。"
|
||||
if global_config.planning_system.schedule_enable:
|
||||
if current_activity := schedule_manager.get_current_activity():
|
||||
schedule_block = f"你当前正在:{current_activity}。"
|
||||
|
||||
# 2. 网络搜索
|
||||
news_block = "暂时没有获取到最新资讯。"
|
||||
try:
|
||||
web_search_tool = tool_api.get_tool_instance("web_search")
|
||||
if web_search_tool:
|
||||
tool_args = {"query": topic, "max_results": 10}
|
||||
# 调用工具,并传递参数
|
||||
search_result_dict = await web_search_tool.execute(**tool_args)
|
||||
if search_result_dict and not search_result_dict.get("error"):
|
||||
news_block = search_result_dict.get("content", "未能提取有效资讯。")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}")
|
||||
|
||||
# 3. 获取最新的聊天上下文
|
||||
message_list = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.context.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.3),
|
||||
)
|
||||
chat_context_block, _ = build_readable_messages_with_id(messages=message_list)
|
||||
|
||||
# 4. 构建最终的生成提示词
|
||||
bot_name = global_config.bot.nickname
|
||||
personality = global_config.personality
|
||||
identity_block = (
|
||||
f"你的名字是{bot_name}。\n"
|
||||
f"关于你:{personality.personality_core},并且{personality.personality_side}。\n"
|
||||
f"你的身份是{personality.identity},平时说话风格是{personality.reply_style}。"
|
||||
)
|
||||
mood_block = f"你现在的心情是:{mood_manager.get_mood_by_chat_id(self.context.stream_id).mood_state}"
|
||||
|
||||
final_prompt = f"""
|
||||
## 你的角色
|
||||
{identity_block}
|
||||
|
||||
## 你的心情
|
||||
{mood_block}
|
||||
|
||||
## 你今天的日程安排
|
||||
{schedule_block}
|
||||
|
||||
## 关于你准备讨论的话题“{topic}”的最新信息
|
||||
{news_block}
|
||||
|
||||
## 最近的聊天内容
|
||||
{chat_context_block}
|
||||
|
||||
## 任务
|
||||
你之前决定要发起一个关于“{topic}”的对话。现在,请结合以上所有信息,自然地开启这个话题。
|
||||
|
||||
## 要求
|
||||
- 你的发言要听起来像是自发的,而不是在念报告。
|
||||
- 巧妙地将日程安排或最新信息融入到你的开场白中。
|
||||
- 风格要符合你的角色设定。
|
||||
- 直接输出你想要说的内容,不要包含其他额外信息。
|
||||
|
||||
你的回复应该:
|
||||
1. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。
|
||||
2. 目的是让对话更有趣、更深入。
|
||||
3. 不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。
|
||||
最终请输出一条简短、完整且口语化的回复。
|
||||
"""
|
||||
|
||||
# 5. 调用生成器API并发送
|
||||
response_text = await generator_api.generate_response_custom(
|
||||
chat_stream=self.context.chat_stream,
|
||||
prompt=final_prompt,
|
||||
request_type="chat.replyer.proactive",
|
||||
)
|
||||
|
||||
if response_text:
|
||||
response_set = process_human_text(
|
||||
content=response_text,
|
||||
enable_splitter=global_config.response_splitter.enable,
|
||||
enable_chinese_typo=global_config.chinese_typo.enable,
|
||||
)
|
||||
await self.cycle_processor.response_handler.send_response(
|
||||
response_set, time.time(), action_result.get("action_message")
|
||||
)
|
||||
await store_action_info(
|
||||
chat_stream=self.context.chat_stream,
|
||||
action_name="proactive_reply",
|
||||
action_data={"topic": topic, "response": response_text},
|
||||
action_prompt_display=f"主动发起对话: {topic}",
|
||||
action_done=True,
|
||||
)
|
||||
else:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考生成回复失败。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 生成主动回复内容时异常: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -1,353 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
from .hfc_context import HfcContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cycle_processor import CycleProcessor
|
||||
|
||||
logger = get_logger("hfc")
|
||||
|
||||
|
||||
class ProactiveThinker:
|
||||
def __init__(self, context: HfcContext, cycle_processor: "CycleProcessor"):
|
||||
"""
|
||||
初始化主动思考器
|
||||
|
||||
Args:
|
||||
context: HFC聊天上下文对象
|
||||
cycle_processor: 循环处理器,用于执行主动思考的结果
|
||||
|
||||
功能说明:
|
||||
- 管理机器人的主动发言功能
|
||||
- 根据沉默时间和配置触发主动思考
|
||||
- 提供私聊和群聊不同的思考提示模板
|
||||
- 使用3-sigma规则计算动态思考间隔
|
||||
"""
|
||||
self.context = context
|
||||
self.cycle_processor = cycle_processor
|
||||
self._proactive_thinking_task: Optional[asyncio.Task] = None
|
||||
|
||||
self.proactive_thinking_prompts = {
|
||||
"private": """现在你和你朋友的私聊里面已经隔了{time}没有发送消息了,请你结合上下文以及你和你朋友之前聊过的话题和你的人设来决定要不要主动发送消息,你可以选择:
|
||||
|
||||
1. 继续保持沉默(当{time}以前已经结束了一个话题并且你不想挑起新话题时)
|
||||
2. 选择回复(当{time}以前你发送了一条消息且没有人回复你时、你想主动挑起一个话题时)
|
||||
|
||||
请根据当前情况做出选择。如果选择回复,请直接发送你想说的内容;如果选择保持沉默,请只回复"沉默"(注意:这个词不会被发送到群聊中)。""",
|
||||
"group": """现在群里面已经隔了{time}没有人发送消息了,请你结合上下文以及群聊里面之前聊过的话题和你的人设来决定要不要主动发送消息,你可以选择:
|
||||
|
||||
1. 继续保持沉默(当{time}以前已经结束了一个话题并且你不想挑起新话题时)
|
||||
2. 选择回复(当{time}以前你发送了一条消息且没有人回复你时、你想主动挑起一个话题时)
|
||||
|
||||
请根据当前情况做出选择。如果选择回复,请直接发送你想说的内容;如果选择保持沉默,请只回复"沉默"(注意:这个词不会被发送到群聊中)。""",
|
||||
}
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
启动主动思考器
|
||||
|
||||
功能说明:
|
||||
- 检查运行状态和配置,避免重复启动
|
||||
- 只有在启用主动思考功能时才启动
|
||||
- 创建主动思考循环异步任务
|
||||
- 设置任务完成回调处理
|
||||
- 记录启动日志
|
||||
"""
|
||||
if self.context.running and not self._proactive_thinking_task and global_config.chat.enable_proactive_thinking:
|
||||
self._proactive_thinking_task = asyncio.create_task(self._proactive_thinking_loop())
|
||||
self._proactive_thinking_task.add_done_callback(self._handle_proactive_thinking_completion)
|
||||
logger.info(f"{self.context.log_prefix} 主动思考器已启动")
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
停止主动思考器
|
||||
|
||||
功能说明:
|
||||
- 取消正在运行的主动思考任务
|
||||
- 等待任务完全停止
|
||||
- 记录停止日志
|
||||
"""
|
||||
if self._proactive_thinking_task and not self._proactive_thinking_task.done():
|
||||
self._proactive_thinking_task.cancel()
|
||||
await asyncio.sleep(0)
|
||||
logger.info(f"{self.context.log_prefix} 主动思考器已停止")
|
||||
|
||||
def _handle_proactive_thinking_completion(self, task: asyncio.Task):
|
||||
"""
|
||||
处理主动思考任务完成
|
||||
|
||||
Args:
|
||||
task: 完成的异步任务对象
|
||||
|
||||
功能说明:
|
||||
- 处理任务正常完成或异常情况
|
||||
- 记录相应的日志信息
|
||||
- 区分取消和异常终止的情况
|
||||
"""
|
||||
try:
|
||||
if exception := task.exception():
|
||||
logger.error(f"{self.context.log_prefix} 主动思考循环异常: {exception}")
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考循环正常结束")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考循环被取消")
|
||||
|
||||
async def _proactive_thinking_loop(self):
|
||||
"""
|
||||
主动思考的主循环
|
||||
|
||||
功能说明:
|
||||
- 每15秒检查一次是否需要主动思考
|
||||
- 只在FOCUS模式下进行主动思考
|
||||
- 检查是否启用主动思考功能
|
||||
- 计算沉默时间并与动态间隔比较
|
||||
- 达到条件时执行主动思考并更新最后消息时间
|
||||
- 处理执行过程中的异常
|
||||
"""
|
||||
while self.context.running:
|
||||
await asyncio.sleep(15)
|
||||
|
||||
if self.context.loop_mode != ChatMode.FOCUS:
|
||||
continue
|
||||
|
||||
if not self._should_enable_proactive_thinking():
|
||||
continue
|
||||
|
||||
current_time = time.time()
|
||||
silence_duration = current_time - self.context.last_message_time
|
||||
|
||||
target_interval = self._get_dynamic_thinking_interval()
|
||||
|
||||
if silence_duration >= target_interval:
|
||||
try:
|
||||
await self._execute_proactive_thinking(silence_duration)
|
||||
self.context.last_message_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考执行出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def _should_enable_proactive_thinking(self) -> bool:
|
||||
"""
|
||||
检查是否应该启用主动思考
|
||||
|
||||
Returns:
|
||||
bool: 如果应该启用主动思考则返回True
|
||||
|
||||
功能说明:
|
||||
- 检查聊天流是否存在
|
||||
- 检查当前聊天是否在启用列表中(按平台和类型分别检查)
|
||||
- 根据聊天类型(群聊/私聊)和配置决定是否启用
|
||||
- 群聊需要proactive_thinking_in_group为True
|
||||
- 私聊需要proactive_thinking_in_private为True
|
||||
"""
|
||||
if not self.context.chat_stream:
|
||||
return False
|
||||
|
||||
is_group_chat = self.context.chat_stream.group_info is not None
|
||||
|
||||
# 检查基础开关
|
||||
if is_group_chat and not global_config.chat.proactive_thinking_in_group:
|
||||
return False
|
||||
if not is_group_chat and not global_config.chat.proactive_thinking_in_private:
|
||||
return False
|
||||
|
||||
# 获取当前聊天的完整标识 (platform:chat_id)
|
||||
stream_parts = self.context.stream_id.split(":")
|
||||
if len(stream_parts) >= 2:
|
||||
platform = stream_parts[0]
|
||||
chat_id = stream_parts[1]
|
||||
current_chat_identifier = f"{platform}:{chat_id}"
|
||||
else:
|
||||
# 如果无法解析,则使用原始stream_id
|
||||
current_chat_identifier = self.context.stream_id
|
||||
|
||||
# 检查是否在启用列表中
|
||||
if is_group_chat:
|
||||
# 群聊检查
|
||||
enable_list = getattr(global_config.chat, "proactive_thinking_enable_in_groups", [])
|
||||
if enable_list and current_chat_identifier not in enable_list:
|
||||
return False
|
||||
else:
|
||||
# 私聊检查
|
||||
enable_list = getattr(global_config.chat, "proactive_thinking_enable_in_private", [])
|
||||
if enable_list and current_chat_identifier not in enable_list:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_dynamic_thinking_interval(self) -> float:
|
||||
"""
|
||||
获取动态思考间隔
|
||||
|
||||
Returns:
|
||||
float: 计算得出的思考间隔时间(秒)
|
||||
|
||||
功能说明:
|
||||
- 使用3-sigma规则计算正态分布的思考间隔
|
||||
- 基于base_interval和delta_sigma配置计算
|
||||
- 处理特殊情况(为0或负数的配置)
|
||||
- 如果timing_utils不可用则使用固定间隔
|
||||
- 间隔范围被限制在1秒到86400秒(1天)之间
|
||||
"""
|
||||
try:
|
||||
from src.utils.timing_utils import get_normal_distributed_interval
|
||||
|
||||
base_interval = global_config.chat.proactive_thinking_interval
|
||||
delta_sigma = getattr(global_config.chat, "delta_sigma", 120)
|
||||
|
||||
if base_interval < 0:
|
||||
base_interval = abs(base_interval)
|
||||
if delta_sigma < 0:
|
||||
delta_sigma = abs(delta_sigma)
|
||||
|
||||
if base_interval == 0 and delta_sigma == 0:
|
||||
return 300
|
||||
elif base_interval == 0:
|
||||
sigma_percentage = delta_sigma / 1000
|
||||
return get_normal_distributed_interval(0, sigma_percentage, 1, 86400, use_3sigma_rule=True)
|
||||
elif delta_sigma == 0:
|
||||
return base_interval
|
||||
|
||||
sigma_percentage = delta_sigma / base_interval
|
||||
return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True)
|
||||
|
||||
except ImportError:
|
||||
logger.warning(f"{self.context.log_prefix} timing_utils不可用,使用固定间隔")
|
||||
return max(300, abs(global_config.chat.proactive_thinking_interval))
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 动态间隔计算出错: {e},使用固定间隔")
|
||||
return max(300, abs(global_config.chat.proactive_thinking_interval))
|
||||
|
||||
def _format_duration(self, seconds: float) -> str:
|
||||
"""
|
||||
格式化持续时间为中文描述
|
||||
|
||||
Args:
|
||||
seconds: 持续时间(秒)
|
||||
|
||||
Returns:
|
||||
str: 格式化后的时间字符串,如"1小时30分45秒"
|
||||
|
||||
功能说明:
|
||||
- 将秒数转换为小时、分钟、秒的组合
|
||||
- 只显示非零的时间单位
|
||||
- 如果所有单位都为0则显示"0秒"
|
||||
- 用于主动思考日志的时间显示
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
|
||||
parts = []
|
||||
if hours > 0:
|
||||
parts.append(f"{hours}小时")
|
||||
if minutes > 0:
|
||||
parts.append(f"{minutes}分")
|
||||
if secs > 0 or not parts:
|
||||
parts.append(f"{secs}秒")
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
async def _execute_proactive_thinking(self, silence_duration: float):
|
||||
"""
|
||||
执行主动思考
|
||||
|
||||
Args:
|
||||
silence_duration: 沉默持续时间(秒)
|
||||
"""
|
||||
formatted_time = self._format_duration(silence_duration)
|
||||
logger.info(f"{self.context.log_prefix} 触发主动思考,已沉默{formatted_time}")
|
||||
|
||||
try:
|
||||
# 直接调用 planner 的 PROACTIVE 模式
|
||||
action_result_tuple, target_message = await self.cycle_processor.action_planner.plan(
|
||||
mode=ChatMode.PROACTIVE
|
||||
)
|
||||
action_result = action_result_tuple.get("action_result")
|
||||
|
||||
# 如果决策不是 do_nothing,则执行
|
||||
if action_result and action_result.get("action_type") != "do_nothing":
|
||||
logger.info(f"{self.context.log_prefix} 主动思考决策: {action_result.get('action_type')}, 原因: {action_result.get('reasoning')}")
|
||||
# 在主动思考时,如果 target_message 为 None,则默认选取最新 message 作为 target_message
|
||||
if target_message is None and self.context.chat_stream and self.context.chat_stream.context:
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
latest_message = self.context.chat_stream.context.get_last_message()
|
||||
if isinstance(latest_message, MessageRecv):
|
||||
user_info = latest_message.message_info.user_info
|
||||
target_message = {
|
||||
"chat_info_platform": latest_message.message_info.platform,
|
||||
"user_platform": user_info.platform if user_info else None,
|
||||
"user_id": user_info.user_id if user_info else None,
|
||||
"processed_plain_text": latest_message.processed_plain_text,
|
||||
"is_mentioned": latest_message.is_mentioned,
|
||||
}
|
||||
|
||||
# 将决策结果交给 cycle_processor 的后续流程处理
|
||||
await self.cycle_processor.execute_plan(action_result, target_message)
|
||||
else:
|
||||
logger.info(f"{self.context.log_prefix} 主动思考决策: 保持沉默")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考执行异常: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def trigger_insomnia_thinking(self, reason: str):
|
||||
"""
|
||||
由外部事件(如失眠)触发的一次性主动思考
|
||||
|
||||
Args:
|
||||
reason: 触发的原因 (e.g., "low_pressure", "random")
|
||||
"""
|
||||
logger.info(f"{self.context.log_prefix} 因“{reason}”触发失眠,开始深夜思考...")
|
||||
|
||||
# 1. 根据原因修改情绪
|
||||
try:
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
|
||||
if reason == "low_pressure":
|
||||
mood_obj.mood_state = "精力过剩,毫无睡意"
|
||||
elif reason == "random":
|
||||
mood_obj.mood_state = "深夜emo,胡思乱想"
|
||||
mood_obj.last_change_time = time.time() # 更新时间戳以允许后续的情绪回归
|
||||
logger.info(f"{self.context.log_prefix} 因失眠,情绪状态被强制更新为: {mood_obj.mood_state}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}")
|
||||
|
||||
# 2. 直接执行主动思考逻辑
|
||||
try:
|
||||
# 传入一个象征性的silence_duration,因为它在这里不重要
|
||||
await self._execute_proactive_thinking(silence_duration=1)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 失眠思考执行出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def trigger_goodnight_thinking(self):
|
||||
"""
|
||||
在失眠状态结束后,触发一次准备睡觉的主动思考
|
||||
"""
|
||||
logger.info(f"{self.context.log_prefix} 失眠状态结束,准备睡觉,触发告别思考...")
|
||||
|
||||
# 1. 设置一个准备睡觉的特定情绪
|
||||
try:
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
|
||||
mood_obj.mood_state = "有点困了,准备睡觉了"
|
||||
mood_obj.last_change_time = time.time()
|
||||
logger.info(f"{self.context.log_prefix} 情绪状态更新为: {mood_obj.mood_state}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 设置睡前情绪时出错: {e}")
|
||||
|
||||
# 2. 直接执行主动思考逻辑
|
||||
try:
|
||||
await self._execute_proactive_thinking(silence_duration=1)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 睡前告别思考执行出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -1,24 +1,24 @@
|
||||
import time
|
||||
import orjson
|
||||
import random
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.plugin_system.apis import send_api, message_api, database_api
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from .hfc_context import HfcContext
|
||||
|
||||
# 导入反注入系统
|
||||
from src.chat.antipromptinjector import get_anti_injector
|
||||
from src.chat.antipromptinjector.types import ProcessResult
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
# 日志记录器
|
||||
logger = get_logger("hfc")
|
||||
anti_injector_logger = get_logger("anti_injector")
|
||||
|
||||
|
||||
class ResponseHandler:
|
||||
"""
|
||||
响应处理器类,负责生成和发送机器人的回复。
|
||||
"""
|
||||
def __init__(self, context: HfcContext):
|
||||
"""
|
||||
初始化响应处理器
|
||||
@@ -68,6 +68,7 @@ class ResponseHandler:
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 获取平台信息
|
||||
platform = "default"
|
||||
if self.context.chat_stream:
|
||||
platform = (
|
||||
@@ -76,11 +77,13 @@ class ResponseHandler:
|
||||
or self.context.chat_stream.platform
|
||||
)
|
||||
|
||||
# 获取用户信息并生成回复提示
|
||||
user_id = action_message.get("user_id", "")
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
# 存储动作信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.context.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -91,6 +94,7 @@ class ResponseHandler:
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
@@ -126,10 +130,12 @@ class ResponseHandler:
|
||||
- 正确处理元组格式的回复段
|
||||
"""
|
||||
current_time = time.time()
|
||||
# 计算新消息数量
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.context.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||
)
|
||||
|
||||
# 根据新消息数量决定是否需要引用回复
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
|
||||
reply_text = ""
|
||||
@@ -147,23 +153,28 @@ class ResponseHandler:
|
||||
# 向下兼容:如果已经是字符串,则直接使用
|
||||
data = str(reply_seg)
|
||||
|
||||
if isinstance(data, list):
|
||||
data = "".join(map(str, data))
|
||||
reply_text += data
|
||||
|
||||
# 如果是主动思考且内容为“沉默”,则不发送
|
||||
if is_proactive_thinking and data.strip() == "沉默":
|
||||
logger.info(f"{self.context.log_prefix} 主动思考决定保持沉默,不发送消息")
|
||||
continue
|
||||
|
||||
# 发送第一段回复
|
||||
if not first_replied:
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.context.stream_id,
|
||||
reply_to_message = message_data,
|
||||
reply_to_message=message_data,
|
||||
set_reply=need_reply,
|
||||
typing=False,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(
|
||||
# 发送后续回复
|
||||
sent_message = await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.context.stream_id,
|
||||
reply_to_message=None,
|
||||
@@ -172,101 +183,3 @@ class ResponseHandler:
|
||||
)
|
||||
|
||||
return reply_text
|
||||
|
||||
# TODO: 已废弃
|
||||
async def generate_response(
|
||||
self,
|
||||
message_data: dict,
|
||||
available_actions: Optional[Dict[str, Any]],
|
||||
reply_to: str,
|
||||
request_type: str = "chat.replyer.normal",
|
||||
) -> Optional[list]:
|
||||
"""
|
||||
生成回复内容
|
||||
|
||||
Args:
|
||||
message_data: 消息数据
|
||||
available_actions: 可用动作列表
|
||||
reply_to: 回复目标
|
||||
request_type: 请求类型,默认为普通回复
|
||||
|
||||
Returns:
|
||||
list: 生成的回复内容列表,失败时返回None
|
||||
|
||||
功能说明:
|
||||
- 在生成回复前进行反注入检测(提高效率)
|
||||
- 调用生成器API生成回复
|
||||
- 根据配置启用或禁用工具功能
|
||||
- 处理生成失败的情况
|
||||
- 记录生成过程中的错误和异常
|
||||
"""
|
||||
try:
|
||||
# === 反注入检测(仅在需要生成回复时) ===
|
||||
# 执行反注入检测(直接使用字典格式)
|
||||
anti_injector = get_anti_injector()
|
||||
result, modified_content, reason = await anti_injector.process_message(
|
||||
message_data, self.context.chat_stream
|
||||
)
|
||||
|
||||
# 根据反注入结果处理消息数据
|
||||
await anti_injector.handle_message_storage(result, modified_content, reason or "", message_data)
|
||||
|
||||
if result == ProcessResult.BLOCKED_BAN:
|
||||
# 用户被封禁 - 直接阻止回复生成
|
||||
anti_injector_logger.warning(f"用户被反注入系统封禁,阻止回复生成: {reason}")
|
||||
return None
|
||||
elif result == ProcessResult.BLOCKED_INJECTION:
|
||||
# 消息被阻止(危险内容等) - 直接阻止回复生成
|
||||
anti_injector_logger.warning(f"消息被反注入系统阻止,阻止回复生成: {reason}")
|
||||
return None
|
||||
elif result == ProcessResult.COUNTER_ATTACK:
|
||||
# 反击模式:生成反击消息作为回复
|
||||
anti_injector_logger.info(f"反击模式启动,生成反击回复: {reason}")
|
||||
if modified_content:
|
||||
# 返回反击消息作为回复内容
|
||||
return [("text", modified_content)]
|
||||
else:
|
||||
# 没有反击内容时阻止回复生成
|
||||
return None
|
||||
|
||||
# 检查是否需要加盾处理
|
||||
safety_prompt = None
|
||||
if result == ProcessResult.SHIELDED:
|
||||
# 获取安全系统提示词并注入
|
||||
shield = anti_injector.shield
|
||||
safety_prompt = shield.get_safety_system_prompt()
|
||||
await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt")
|
||||
anti_injector_logger.info(f"消息已被反注入系统加盾处理,已注入安全提示词: {reason}")
|
||||
|
||||
# 处理被修改的消息内容(用于生成回复)
|
||||
modified_reply_to = reply_to
|
||||
if modified_content:
|
||||
# 更新消息内容用于生成回复
|
||||
anti_injector_logger.info(f"消息内容已被反注入系统修改,使用修改后内容生成回复: {reason}")
|
||||
# 解析原始reply_to格式:"发送者:消息内容"
|
||||
if ":" in reply_to:
|
||||
sender_part, _ = reply_to.split(":", 1)
|
||||
modified_reply_to = f"{sender_part}:{modified_content}"
|
||||
else:
|
||||
# 如果格式不标准,直接使用修改后的内容
|
||||
modified_reply_to = modified_content
|
||||
|
||||
# === 正常的回复生成流程 ===
|
||||
success, reply_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.context.chat_stream,
|
||||
reply_to=modified_reply_to, # 使用可能被修改的内容
|
||||
available_actions=available_actions,
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type=request_type,
|
||||
from_plugin=False,
|
||||
)
|
||||
|
||||
if not success or not reply_set:
|
||||
logger.info(f"对 {message_data.get('processed_plain_text')} 的回复生成失败")
|
||||
return None
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix}回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
33
src/chat/chat_loop/sleep_manager/notification_sender.py
Normal file
33
src/chat/chat_loop/sleep_manager/notification_sender.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from ..hfc_context import HfcContext
|
||||
|
||||
logger = get_logger("notification_sender")
|
||||
|
||||
|
||||
class NotificationSender:
|
||||
@staticmethod
|
||||
async def send_goodnight_notification(context: HfcContext):
|
||||
"""发送晚安通知"""
|
||||
try:
|
||||
from ..proactive.events import ProactiveTriggerEvent
|
||||
from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
|
||||
proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
await proactive_thinker.think(event)
|
||||
except Exception as e:
|
||||
logger.error(f"发送晚安通知失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def send_insomnia_notification(context: HfcContext, reason: str):
|
||||
"""发送失眠通知"""
|
||||
try:
|
||||
from ..proactive.events import ProactiveTriggerEvent
|
||||
from ..proactive.proactive_thinker import ProactiveThinker
|
||||
|
||||
event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
|
||||
proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||
await proactive_thinker.think(event)
|
||||
except Exception as e:
|
||||
logger.error(f"发送失眠通知失败: {e}")
|
||||
304
src/chat/chat_loop/sleep_manager/sleep_manager.py
Normal file
304
src/chat/chat_loop/sleep_manager/sleep_manager.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime, timedelta, date
|
||||
from typing import Optional, TYPE_CHECKING, List, Dict, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .sleep_state import SleepState, SleepStateSerializer
|
||||
from .time_checker import TimeChecker
|
||||
from .notification_sender import NotificationSender
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wakeup_manager import WakeUpManager
|
||||
|
||||
logger = get_logger("sleep_manager")
|
||||
|
||||
|
||||
class SleepManager:
|
||||
"""
|
||||
睡眠管理器,核心组件之一,负责管理角色的睡眠周期和状态转换。
|
||||
它实现了一个状态机,根据预设的时间表、睡眠压力和随机因素,
|
||||
在不同的睡眠状态(如清醒、准备入睡、睡眠、失眠)之间进行切换。
|
||||
"""
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化睡眠管理器。
|
||||
"""
|
||||
self.time_checker = TimeChecker() # 时间检查器,用于判断当前是否处于理论睡眠时间
|
||||
self.last_sleep_log_time = 0 # 上次记录睡眠日志的时间戳
|
||||
self.sleep_log_interval = 35 # 睡眠日志记录间隔(秒)
|
||||
|
||||
# --- 统一睡眠状态管理 ---
|
||||
self._current_state: SleepState = SleepState.AWAKE # 当前睡眠状态
|
||||
self._sleep_buffer_end_time: Optional[datetime] = None # 睡眠缓冲结束时间,用于状态转换
|
||||
self._total_delayed_minutes_today: float = 0.0 # 今天总共延迟入睡的分钟数
|
||||
self._last_sleep_check_date: Optional[date] = None # 上次检查睡眠状态的日期
|
||||
self._last_fully_slept_log_time: float = 0 # 上次完全进入睡眠状态的时间戳
|
||||
self._re_sleep_attempt_time: Optional[datetime] = None # 被吵醒后,尝试重新入睡的时间点
|
||||
|
||||
# 从本地存储加载上一次的睡眠状态
|
||||
self._load_sleep_state()
|
||||
|
||||
def get_current_sleep_state(self) -> SleepState:
|
||||
"""获取当前的睡眠状态。"""
|
||||
return self._current_state
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
"""判断当前是否处于正在睡觉的状态。"""
|
||||
return self._current_state == SleepState.SLEEPING
|
||||
|
||||
async def update_sleep_state(self, wakeup_manager: Optional["WakeUpManager"] = None):
|
||||
"""
|
||||
更新睡眠状态的核心方法,实现状态机的主要逻辑。
|
||||
该方法会被周期性调用,以检查并更新当前的睡眠状态。
|
||||
|
||||
Args:
|
||||
wakeup_manager (Optional["WakeUpManager"]): 唤醒管理器,用于获取睡眠压力等上下文信息。
|
||||
"""
|
||||
# 如果全局禁用了睡眠系统,则强制设置为清醒状态并返回
|
||||
if not global_config.sleep_system.enable:
|
||||
if self._current_state != SleepState.AWAKE:
|
||||
logger.debug("睡眠系统禁用,强制设为 AWAKE")
|
||||
self._current_state = SleepState.AWAKE
|
||||
return
|
||||
|
||||
now = datetime.now()
|
||||
today = now.date()
|
||||
|
||||
# 跨天处理:如果日期变化,重置每日相关的睡眠状态
|
||||
if self._last_sleep_check_date != today:
|
||||
logger.info(f"新的一天 ({today}),重置睡眠状态。")
|
||||
self._total_delayed_minutes_today = 0
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._sleep_buffer_end_time = None
|
||||
self._last_sleep_check_date = today
|
||||
self._save_sleep_state()
|
||||
|
||||
# 检查当前是否处于理论上的睡眠时间段
|
||||
is_in_theoretical_sleep, activity = self.time_checker.is_in_theoretical_sleep_time(now.time())
|
||||
|
||||
# --- 状态机核心处理逻辑 ---
|
||||
if self._current_state == SleepState.AWAKE:
|
||||
if is_in_theoretical_sleep:
|
||||
self._handle_awake_to_sleep(now, activity, wakeup_manager)
|
||||
|
||||
elif self._current_state == SleepState.PREPARING_SLEEP:
|
||||
self._handle_preparing_sleep(now, is_in_theoretical_sleep, wakeup_manager)
|
||||
|
||||
elif self._current_state == SleepState.SLEEPING:
|
||||
self._handle_sleeping(now, is_in_theoretical_sleep, activity, wakeup_manager)
|
||||
|
||||
elif self._current_state == SleepState.INSOMNIA:
|
||||
self._handle_insomnia(now, is_in_theoretical_sleep)
|
||||
|
||||
elif self._current_state == SleepState.WOKEN_UP:
|
||||
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
|
||||
|
||||
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理从“清醒”到“准备入睡”的状态转换。"""
|
||||
if activity:
|
||||
logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...")
|
||||
else:
|
||||
logger.info("进入理论休眠时间,开始进行睡眠决策...")
|
||||
|
||||
if global_config.sleep_system.enable_flexible_sleep:
|
||||
# --- 新的弹性睡眠逻辑 ---
|
||||
if wakeup_manager:
|
||||
sleep_pressure = wakeup_manager.context.sleep_pressure
|
||||
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
|
||||
max_delay_minutes = global_config.sleep_system.max_sleep_delay_minutes
|
||||
|
||||
buffer_seconds = 0
|
||||
# 如果睡眠压力低于阈值,则计算延迟时间
|
||||
if sleep_pressure <= pressure_threshold:
|
||||
# 压力差,归一化到 (0, 1]
|
||||
pressure_diff = (pressure_threshold - sleep_pressure) / pressure_threshold
|
||||
# 延迟分钟数,压力越低,延迟越长
|
||||
delay_minutes = int(pressure_diff * max_delay_minutes)
|
||||
|
||||
# 确保总延迟不超过当日最大值
|
||||
remaining_delay = max_delay_minutes - self._total_delayed_minutes_today
|
||||
delay_minutes = min(delay_minutes, remaining_delay)
|
||||
|
||||
if delay_minutes > 0:
|
||||
# 增加一些随机性
|
||||
buffer_seconds = random.randint(int(delay_minutes * 0.8 * 60), int(delay_minutes * 1.2 * 60))
|
||||
self._total_delayed_minutes_today += buffer_seconds / 60.0
|
||||
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 较低,延迟 {buffer_seconds / 60:.1f} 分钟入睡。")
|
||||
else:
|
||||
# 延迟额度已用完,设置一个较短的准备时间
|
||||
buffer_seconds = random.randint(1 * 60, 2 * 60)
|
||||
logger.info("今日延迟入睡额度已用完,进入短暂准备后入睡。")
|
||||
else:
|
||||
# 睡眠压力较高,设置一个较短的准备时间
|
||||
buffer_seconds = random.randint(1 * 60, 2 * 60)
|
||||
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 较高,将在短暂准备后入睡。")
|
||||
|
||||
# 发送睡前通知
|
||||
if global_config.sleep_system.enable_pre_sleep_notification:
|
||||
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
||||
|
||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self._current_state = SleepState.PREPARING_SLEEP
|
||||
logger.info(f"进入准备入睡状态,将在 {buffer_seconds / 60:.1f} 分钟内入睡。")
|
||||
self._save_sleep_state()
|
||||
else:
|
||||
# 无法获取 wakeup_manager,退回旧逻辑
|
||||
buffer_seconds = random.randint(1 * 60, 3 * 60)
|
||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self._current_state = SleepState.PREPARING_SLEEP
|
||||
logger.warning("无法获取 WakeUpManager,弹性睡眠采用默认1-3分钟延迟。")
|
||||
self._save_sleep_state()
|
||||
else:
|
||||
# 非弹性睡眠模式
|
||||
if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification:
|
||||
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
||||
self._current_state = SleepState.SLEEPING
|
||||
|
||||
|
||||
def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理“准备入睡”状态下的逻辑。"""
|
||||
# 如果在准备期间离开了理论睡眠时间,则取消入睡
|
||||
if not is_in_theoretical_sleep:
|
||||
logger.info("准备入睡期间离开理论休眠时间,取消入睡,恢复清醒。")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._sleep_buffer_end_time = None
|
||||
self._save_sleep_state()
|
||||
# 如果缓冲时间结束,则正式进入睡眠状态
|
||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
||||
logger.info("睡眠缓冲期结束,正式进入休眠状态。")
|
||||
self._current_state = SleepState.SLEEPING
|
||||
self._last_fully_slept_log_time = now.timestamp()
|
||||
|
||||
# 设置一个随机的延迟,用于触发“睡后失眠”检查
|
||||
delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes
|
||||
delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1])
|
||||
self._sleep_buffer_end_time = now + timedelta(minutes=delay_minutes)
|
||||
logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。")
|
||||
|
||||
self._save_sleep_state()
|
||||
|
||||
def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理“正在睡觉”状态下的逻辑。"""
|
||||
# 如果理论睡眠时间结束,则自然醒来
|
||||
if not is_in_theoretical_sleep:
|
||||
logger.info("理论休眠时间结束,自然醒来。")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._save_sleep_state()
|
||||
# 检查是否到了触发“睡后失眠”的时间点
|
||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
||||
if wakeup_manager:
|
||||
sleep_pressure = wakeup_manager.context.sleep_pressure
|
||||
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
|
||||
# 检查是否触发失眠
|
||||
insomnia_reason = None
|
||||
if sleep_pressure < pressure_threshold:
|
||||
insomnia_reason = "low_pressure"
|
||||
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 低于阈值 ({pressure_threshold}),触发睡后失眠。")
|
||||
elif random.random() < getattr(global_config.sleep_system, "random_insomnia_chance", 0.1):
|
||||
insomnia_reason = "random"
|
||||
logger.info("随机触发失眠。")
|
||||
|
||||
if insomnia_reason:
|
||||
self._current_state = SleepState.INSOMNIA
|
||||
|
||||
# 设置失眠的持续时间
|
||||
duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes
|
||||
duration_minutes = random.randint(*duration_minutes_range)
|
||||
self._sleep_buffer_end_time = now + timedelta(minutes=duration_minutes)
|
||||
|
||||
# 发送失眠通知
|
||||
asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason))
|
||||
logger.info(f"进入失眠状态 (原因: {insomnia_reason}),将持续 {duration_minutes} 分钟。")
|
||||
else:
|
||||
# 睡眠压力正常,不触发失眠,清除检查时间点
|
||||
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 正常,未触发睡后失眠。")
|
||||
self._sleep_buffer_end_time = None
|
||||
self._save_sleep_state()
|
||||
else:
|
||||
# 定期记录睡眠日志
|
||||
current_timestamp = now.timestamp()
|
||||
if current_timestamp - self.last_sleep_log_time > self.sleep_log_interval and activity:
|
||||
logger.info(f"当前处于休眠活动 '{activity}' 中。")
|
||||
self.last_sleep_log_time = current_timestamp
|
||||
|
||||
def _handle_insomnia(self, now: datetime, is_in_theoretical_sleep: bool):
|
||||
"""处理“失眠”状态下的逻辑。"""
|
||||
# 如果离开理论睡眠时间,则失眠结束
|
||||
if not is_in_theoretical_sleep:
|
||||
logger.info("已离开理论休眠时间,失眠结束,恢复清醒。")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._sleep_buffer_end_time = None
|
||||
self._save_sleep_state()
|
||||
# 如果失眠持续时间已过,则恢复睡眠
|
||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
||||
logger.info("失眠状态持续时间已过,恢复睡眠。")
|
||||
self._current_state = SleepState.SLEEPING
|
||||
self._sleep_buffer_end_time = None
|
||||
self._save_sleep_state()
|
||||
|
||||
def _handle_woken_up(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理“被吵醒”状态下的逻辑。"""
|
||||
# 如果理论睡眠时间结束,则状态自动结束
|
||||
if not is_in_theoretical_sleep:
|
||||
logger.info("理论休眠时间结束,被吵醒的状态自动结束。")
|
||||
self._current_state = SleepState.AWAKE
|
||||
self._re_sleep_attempt_time = None
|
||||
self._save_sleep_state()
|
||||
# 到了尝试重新入睡的时间点
|
||||
elif self._re_sleep_attempt_time and now >= self._re_sleep_attempt_time:
|
||||
logger.info("被吵醒后经过一段时间,尝试重新入睡...")
|
||||
if wakeup_manager:
|
||||
sleep_pressure = wakeup_manager.context.sleep_pressure
|
||||
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
|
||||
|
||||
# 如果睡眠压力足够,则尝试重新入睡
|
||||
if sleep_pressure >= pressure_threshold:
|
||||
logger.info("睡眠压力足够,从被吵醒状态转换到准备入睡。")
|
||||
buffer_seconds = random.randint(3 * 60, 8 * 60)
|
||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||
self._current_state = SleepState.PREPARING_SLEEP
|
||||
self._re_sleep_attempt_time = None
|
||||
else:
|
||||
# 睡眠压力不足,延迟一段时间后再次尝试
|
||||
delay_minutes = 15
|
||||
self._re_sleep_attempt_time = now + timedelta(minutes=delay_minutes)
|
||||
logger.info(
|
||||
f"睡眠压力({sleep_pressure:.1f})仍然较低,暂时保持清醒,在 {delay_minutes} 分钟后再次尝试。"
|
||||
)
|
||||
self._save_sleep_state()
|
||||
|
||||
def reset_sleep_state_after_wakeup(self):
|
||||
"""
|
||||
当角色被用户消息等外部因素唤醒时调用此方法。
|
||||
将状态强制转换为 WOKEN_UP,并设置一个延迟,之后会尝试重新入睡。
|
||||
"""
|
||||
if self._current_state in [SleepState.PREPARING_SLEEP, SleepState.SLEEPING, SleepState.INSOMNIA]:
|
||||
logger.info("被唤醒,进入 WOKEN_UP 状态!")
|
||||
self._current_state = SleepState.WOKEN_UP
|
||||
self._sleep_buffer_end_time = None
|
||||
re_sleep_delay_minutes = getattr(global_config.sleep_system, "re_sleep_delay_minutes", 10)
|
||||
self._re_sleep_attempt_time = datetime.now() + timedelta(minutes=re_sleep_delay_minutes)
|
||||
logger.info(f"将在 {re_sleep_delay_minutes} 分钟后尝试重新入睡。")
|
||||
self._save_sleep_state()
|
||||
|
||||
def _save_sleep_state(self):
|
||||
"""将当前所有睡眠相关的状态打包并保存到本地存储。"""
|
||||
state_data = {
|
||||
"_current_state": self._current_state,
|
||||
"_sleep_buffer_end_time": self._sleep_buffer_end_time,
|
||||
"_total_delayed_minutes_today": self._total_delayed_minutes_today,
|
||||
"_last_sleep_check_date": self._last_sleep_check_date,
|
||||
"_re_sleep_attempt_time": self._re_sleep_attempt_time,
|
||||
}
|
||||
SleepStateSerializer.save(state_data)
|
||||
|
||||
def _load_sleep_state(self):
|
||||
"""从本地存储加载并恢复所有睡眠相关的状态。"""
|
||||
state_data = SleepStateSerializer.load()
|
||||
self._current_state = state_data["_current_state"]
|
||||
self._sleep_buffer_end_time = state_data["_sleep_buffer_end_time"]
|
||||
self._total_delayed_minutes_today = state_data["_total_delayed_minutes_today"]
|
||||
self._last_sleep_check_date = state_data["_last_sleep_check_date"]
|
||||
self._re_sleep_attempt_time = state_data["_re_sleep_attempt_time"]
|
||||
110
src/chat/chat_loop/sleep_manager/sleep_state.py
Normal file
110
src/chat/chat_loop/sleep_manager/sleep_state.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from enum import Enum, auto
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("sleep_state")
|
||||
|
||||
|
||||
class SleepState(Enum):
|
||||
"""
|
||||
定义了角色可能处于的几种睡眠状态。
|
||||
这是一个状态机,用于管理角色的睡眠周期。
|
||||
"""
|
||||
|
||||
AWAKE = auto() # 清醒状态
|
||||
INSOMNIA = auto() # 失眠状态
|
||||
PREPARING_SLEEP = auto() # 准备入睡状态,一个短暂的过渡期
|
||||
SLEEPING = auto() # 正在睡觉状态
|
||||
WOKEN_UP = auto() # 被吵醒状态
|
||||
|
||||
|
||||
class SleepStateSerializer:
|
||||
"""
|
||||
睡眠状态序列化器。
|
||||
负责将内存中的睡眠状态对象持久化到本地存储(如JSON文件),
|
||||
以及在程序启动时从本地存储中恢复状态。
|
||||
这样可以确保即使程序重启,角色的睡眠状态也能得以保留。
|
||||
"""
|
||||
@staticmethod
|
||||
def save(state_data: dict):
|
||||
"""
|
||||
将当前的睡眠状态数据保存到本地存储。
|
||||
|
||||
Args:
|
||||
state_data (dict): 包含睡眠状态信息的字典。
|
||||
datetime对象会被转换为时间戳,Enum成员会被转换为其名称字符串。
|
||||
"""
|
||||
try:
|
||||
# 准备要序列化的数据字典
|
||||
state = {
|
||||
# 保存当前状态的枚举名称
|
||||
"current_state": state_data["_current_state"].name,
|
||||
# 将datetime对象转换为Unix时间戳以便序列化
|
||||
"sleep_buffer_end_time_ts": state_data["_sleep_buffer_end_time"].timestamp()
|
||||
if state_data["_sleep_buffer_end_time"]
|
||||
else None,
|
||||
"total_delayed_minutes_today": state_data["_total_delayed_minutes_today"],
|
||||
# 将date对象转换为ISO格式的字符串
|
||||
"last_sleep_check_date_str": state_data["_last_sleep_check_date"].isoformat()
|
||||
if state_data["_last_sleep_check_date"]
|
||||
else None,
|
||||
"re_sleep_attempt_time_ts": state_data["_re_sleep_attempt_time"].timestamp()
|
||||
if state_data["_re_sleep_attempt_time"]
|
||||
else None,
|
||||
}
|
||||
# 写入本地存储
|
||||
local_storage["schedule_sleep_state"] = state
|
||||
logger.debug(f"已保存睡眠状态: {state}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存睡眠状态失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load() -> dict:
|
||||
"""
|
||||
从本地存储加载并解析睡眠状态。
|
||||
|
||||
Returns:
|
||||
dict: 包含恢复后睡眠状态信息的字典。
|
||||
如果加载失败或没有找到数据,则返回一个默认的清醒状态。
|
||||
"""
|
||||
# 定义一个默认的状态,以防加载失败
|
||||
state_data = {
|
||||
"_current_state": SleepState.AWAKE,
|
||||
"_sleep_buffer_end_time": None,
|
||||
"_total_delayed_minutes_today": 0,
|
||||
"_last_sleep_check_date": None,
|
||||
"_re_sleep_attempt_time": None,
|
||||
}
|
||||
try:
|
||||
# 从本地存储读取数据
|
||||
state = local_storage["schedule_sleep_state"]
|
||||
if state and isinstance(state, dict):
|
||||
# 恢复当前状态枚举
|
||||
state_name = state.get("current_state")
|
||||
if state_name and hasattr(SleepState, state_name):
|
||||
state_data["_current_state"] = SleepState[state_name]
|
||||
|
||||
# 从时间戳恢复datetime对象
|
||||
end_time_ts = state.get("sleep_buffer_end_time_ts")
|
||||
if end_time_ts:
|
||||
state_data["_sleep_buffer_end_time"] = datetime.fromtimestamp(end_time_ts)
|
||||
|
||||
# 恢复重新入睡尝试时间
|
||||
re_sleep_ts = state.get("re_sleep_attempt_time_ts")
|
||||
if re_sleep_ts:
|
||||
state_data["_re_sleep_attempt_time"] = datetime.fromtimestamp(re_sleep_ts)
|
||||
|
||||
# 恢复今日延迟睡眠总分钟数
|
||||
state_data["_total_delayed_minutes_today"] = state.get("total_delayed_minutes_today", 0)
|
||||
|
||||
# 从ISO格式字符串恢复date对象
|
||||
date_str = state.get("last_sleep_check_date_str")
|
||||
if date_str:
|
||||
state_data["_last_sleep_check_date"] = datetime.fromisoformat(date_str).date()
|
||||
|
||||
logger.info(f"成功从本地存储加载睡眠状态: {state}")
|
||||
except Exception as e:
|
||||
# 如果加载过程中出现任何问题,记录警告并返回默认状态
|
||||
logger.warning(f"加载睡眠状态失败,将使用默认值: {e}")
|
||||
return state_data
|
||||
108
src/chat/chat_loop/sleep_manager/time_checker.py
Normal file
108
src/chat/chat_loop/sleep_manager/time_checker.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
import random
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
logger = get_logger("time_checker")
|
||||
|
||||
|
||||
class TimeChecker:
|
||||
def __init__(self):
|
||||
# 缓存当天的偏移量,确保一天内使用相同的偏移量
|
||||
self._daily_sleep_offset: int = 0
|
||||
self._daily_wake_offset: int = 0
|
||||
self._offset_date = None
|
||||
|
||||
def _get_daily_offsets(self):
|
||||
"""获取当天的睡眠和起床时间偏移量,每天生成一次"""
|
||||
today = datetime.now().date()
|
||||
|
||||
# 如果是新的一天,重新生成偏移量
|
||||
if self._offset_date != today:
|
||||
sleep_offset_range = global_config.sleep_system.sleep_time_offset_minutes
|
||||
wake_offset_range = global_config.sleep_system.wake_up_time_offset_minutes
|
||||
|
||||
# 生成 ±offset_range 范围内的随机偏移量
|
||||
self._daily_sleep_offset = random.randint(-sleep_offset_range, sleep_offset_range)
|
||||
self._daily_wake_offset = random.randint(-wake_offset_range, wake_offset_range)
|
||||
self._offset_date = today
|
||||
|
||||
logger.debug(f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟")
|
||||
|
||||
return self._daily_sleep_offset, self._daily_wake_offset
|
||||
|
||||
def get_today_schedule(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""从全局 ScheduleManager 获取今天的日程安排。"""
|
||||
return schedule_manager.today_schedule
|
||||
|
||||
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
if global_config.sleep_system.sleep_by_schedule:
|
||||
if self.get_today_schedule():
|
||||
return self._is_in_schedule_sleep_time(now_time)
|
||||
else:
|
||||
return self._is_in_sleep_time(now_time)
|
||||
else:
|
||||
return self._is_in_sleep_time(now_time)
|
||||
|
||||
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
"""检查当前时间是否落在日程表的任何一个睡眠活动中"""
|
||||
sleep_keywords = ["休眠", "睡觉", "梦乡"]
|
||||
today_schedule = self.get_today_schedule()
|
||||
if today_schedule:
|
||||
for event in today_schedule:
|
||||
try:
|
||||
activity = event.get("activity", "").strip()
|
||||
time_range = event.get("time_range")
|
||||
|
||||
if not activity or not time_range:
|
||||
continue
|
||||
|
||||
if any(keyword in activity for keyword in sleep_keywords):
|
||||
start_str, end_str = time_range.split("-")
|
||||
start_time = datetime.strptime(start_str.strip(), "%H:%M").time()
|
||||
end_time = datetime.strptime(end_str.strip(), "%H:%M").time()
|
||||
|
||||
if start_time <= end_time: # 同一天
|
||||
if start_time <= now_time < end_time:
|
||||
return True, activity
|
||||
else: # 跨天
|
||||
if now_time >= start_time or now_time < end_time:
|
||||
return True, activity
|
||||
except (ValueError, KeyError, AttributeError) as e:
|
||||
logger.warning(f"解析日程事件时出错: {event}, 错误: {e}")
|
||||
continue
|
||||
return False, None
|
||||
|
||||
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
"""检查当前时间是否在固定的睡眠时间内(应用偏移量)"""
|
||||
try:
|
||||
start_time_str = global_config.sleep_system.fixed_sleep_time
|
||||
end_time_str = global_config.sleep_system.fixed_wake_up_time
|
||||
|
||||
# 获取当天的偏移量
|
||||
sleep_offset, wake_offset = self._get_daily_offsets()
|
||||
|
||||
# 解析基础时间
|
||||
base_start_time = datetime.strptime(start_time_str, "%H:%M")
|
||||
base_end_time = datetime.strptime(end_time_str, "%H:%M")
|
||||
|
||||
# 应用偏移量
|
||||
actual_start_time = (base_start_time + timedelta(minutes=sleep_offset)).time()
|
||||
actual_end_time = (base_end_time + timedelta(minutes=wake_offset)).time()
|
||||
|
||||
logger.debug(f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, "
|
||||
f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, "
|
||||
f"当前时间: {now_time.strftime('%H:%M')}")
|
||||
|
||||
if actual_start_time <= actual_end_time:
|
||||
if actual_start_time <= now_time < actual_end_time:
|
||||
return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})"
|
||||
else:
|
||||
if now_time >= actual_start_time or now_time < actual_end_time:
|
||||
return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})"
|
||||
except ValueError as e:
|
||||
logger.error(f"固定的睡眠时间格式不正确,请使用 HH:MM 格式: {e}")
|
||||
return False, None
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from .hfc_context import HfcContext
|
||||
from ..hfc_context import HfcContext
|
||||
|
||||
logger = get_logger("wakeup")
|
||||
|
||||
@@ -138,10 +138,13 @@ class WakeUpManager:
|
||||
return False
|
||||
|
||||
# 只有在休眠且非失眠状态下才累积唤醒度
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.schedule.sleep_manager import SleepState
|
||||
from .sleep_state import SleepState
|
||||
|
||||
current_sleep_state = schedule_manager.get_current_sleep_state()
|
||||
sleep_manager = self.context.sleep_manager
|
||||
if not sleep_manager:
|
||||
return False
|
||||
|
||||
current_sleep_state = sleep_manager.get_current_sleep_state()
|
||||
if current_sleep_state != SleepState.SLEEPING:
|
||||
return False
|
||||
|
||||
@@ -191,10 +194,9 @@ class WakeUpManager:
|
||||
|
||||
mood_manager.set_angry_from_wakeup(self.context.stream_id)
|
||||
|
||||
# 通知日程管理器重置睡眠状态
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
schedule_manager.reset_sleep_state_after_wakeup()
|
||||
# 通知SleepManager重置睡眠状态
|
||||
if self.context.sleep_manager:
|
||||
self.context.sleep_manager.reset_sleep_state_after_wakeup()
|
||||
|
||||
logger.info(f"{self.context.log_prefix} 唤醒度达到阈值({self.wakeup_threshold}),被吵醒进入愤怒状态!")
|
||||
|
||||
@@ -724,7 +724,7 @@ class EmojiManager:
|
||||
if not emoji.is_deleted and emoji.hash == emoji_hash:
|
||||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
|
||||
@@ -755,7 +755,7 @@ class EmojiManager:
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.common.database.sqlalchemy_models import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
144
src/chat/frequency_analyzer/analyzer.py
Normal file
144
src/chat/frequency_analyzer/analyzer.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Chat Frequency Analyzer
|
||||
=======================
|
||||
|
||||
本模块负责分析用户的聊天时间戳,以识别出他们最活跃的聊天时段(高峰时段)。
|
||||
|
||||
核心功能:
|
||||
- 使用滑动窗口算法来检测时间戳集中的区域。
|
||||
- 提供接口查询指定用户当前是否处于其聊天高峰时段内。
|
||||
- 结果会被缓存以提高性能。
|
||||
|
||||
可配置参数:
|
||||
- ANALYSIS_WINDOW_HOURS: 用于分析的时间窗口大小(小时)。
|
||||
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
|
||||
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
|
||||
"""
|
||||
import time as time_module
|
||||
from datetime import datetime, timedelta, time
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from .tracker import chat_frequency_tracker
|
||||
|
||||
# --- 可配置参数 ---
|
||||
# 用于分析的时间窗口大小(小时)
|
||||
ANALYSIS_WINDOW_HOURS = 2
|
||||
# 触发高峰时段所需的最小聊天次数
|
||||
MIN_CHATS_FOR_PEAK = 4
|
||||
# 两个独立高峰时段之间的最小间隔(小时)
|
||||
MIN_GAP_BETWEEN_PEAKS_HOURS = 1
|
||||
|
||||
|
||||
class ChatFrequencyAnalyzer:
|
||||
"""
|
||||
分析聊天时间戳,以识别用户的高频聊天时段。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 缓存分析结果,避免重复计算
|
||||
# 格式: { "chat_id": (timestamp_of_analysis, [peak_windows]) }
|
||||
self._analysis_cache: dict[str, tuple[float, list[tuple[time, time]]]] = {}
|
||||
self._cache_ttl_seconds = 60 * 30 # 缓存30分钟
|
||||
|
||||
def _find_peak_windows(self, timestamps: List[float]) -> List[Tuple[datetime, datetime]]:
|
||||
"""
|
||||
使用滑动窗口算法来识别时间戳列表中的高峰时段。
|
||||
|
||||
Args:
|
||||
timestamps (List[float]): 按时间排序的聊天时间戳。
|
||||
|
||||
Returns:
|
||||
List[Tuple[datetime, datetime]]: 识别出的高峰时段列表,每个元组代表一个时间窗口的开始和结束。
|
||||
"""
|
||||
if len(timestamps) < MIN_CHATS_FOR_PEAK:
|
||||
return []
|
||||
|
||||
# 将时间戳转换为 datetime 对象
|
||||
datetimes = [datetime.fromtimestamp(ts) for ts in timestamps]
|
||||
datetimes.sort()
|
||||
|
||||
peak_windows: List[Tuple[datetime, datetime]] = []
|
||||
window_start_idx = 0
|
||||
|
||||
for i in range(len(datetimes)):
|
||||
# 移动窗口的起始点
|
||||
while datetimes[i] - datetimes[window_start_idx] > timedelta(hours=ANALYSIS_WINDOW_HOURS):
|
||||
window_start_idx += 1
|
||||
|
||||
# 检查当前窗口是否满足高峰条件
|
||||
if i - window_start_idx + 1 >= MIN_CHATS_FOR_PEAK:
|
||||
current_window_start = datetimes[window_start_idx]
|
||||
current_window_end = datetimes[i]
|
||||
|
||||
# 合并重叠或相邻的高峰时段
|
||||
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS):
|
||||
# 扩展上一个窗口的结束时间
|
||||
peak_windows[-1] = (peak_windows[-1][0], current_window_end)
|
||||
else:
|
||||
peak_windows.append((current_window_start, current_window_end))
|
||||
|
||||
return peak_windows
|
||||
|
||||
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
|
||||
"""
|
||||
获取指定用户的高峰聊天时间段。
|
||||
|
||||
Args:
|
||||
chat_id (str): 聊天标识符。
|
||||
|
||||
Returns:
|
||||
List[Tuple[time, time]]: 高峰时段的列表,每个元组包含开始和结束时间 (time 对象)。
|
||||
"""
|
||||
# 检查缓存
|
||||
cached_timestamp, cached_windows = self._analysis_cache.get(chat_id, (0, []))
|
||||
if time_module.time() - cached_timestamp < self._cache_ttl_seconds:
|
||||
return cached_windows
|
||||
|
||||
timestamps = chat_frequency_tracker.get_timestamps_for_chat(chat_id)
|
||||
if not timestamps:
|
||||
return []
|
||||
|
||||
peak_datetime_windows = self._find_peak_windows(timestamps)
|
||||
|
||||
# 将 datetime 窗口转换为 time 窗口,并进行归一化处理
|
||||
peak_time_windows = []
|
||||
for start_dt, end_dt in peak_datetime_windows:
|
||||
# TODO:这里可以添加更复杂的逻辑来处理跨天的平均时间
|
||||
# 为简化,我们直接使用窗口的起止时间
|
||||
peak_time_windows.append((start_dt.time(), end_dt.time()))
|
||||
|
||||
# 更新缓存
|
||||
self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows)
|
||||
|
||||
return peak_time_windows
|
||||
|
||||
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
|
||||
"""
|
||||
检查当前时间是否处于用户的高峰聊天时段内。
|
||||
|
||||
Args:
|
||||
chat_id (str): 聊天标识符。
|
||||
now (Optional[datetime]): 要检查的时间,默认为当前时间。
|
||||
|
||||
Returns:
|
||||
bool: 如果处于高峰时段则返回 True,否则返回 False。
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now()
|
||||
|
||||
now_time = now.time()
|
||||
peak_times = self.get_peak_chat_times(chat_id)
|
||||
|
||||
for start_time, end_time in peak_times:
|
||||
if start_time <= end_time: # 同一天
|
||||
if start_time <= now_time <= end_time:
|
||||
return True
|
||||
else: # 跨天
|
||||
if now_time >= start_time or now_time <= end_time:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# 创建一个全局单例
|
||||
chat_frequency_analyzer = ChatFrequencyAnalyzer()
|
||||
77
src/chat/frequency_analyzer/tracker.py
Normal file
77
src/chat/frequency_analyzer/tracker.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import orjson
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 数据存储路径
|
||||
DATA_DIR = Path("data/frequency_analyzer")
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
TRACKER_FILE = DATA_DIR / "chat_timestamps.json"
|
||||
|
||||
logger = get_logger("ChatFrequencyTracker")
|
||||
|
||||
|
||||
class ChatFrequencyTracker:
|
||||
"""
|
||||
负责跟踪和存储用户聊天启动时间戳。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._timestamps: Dict[str, List[float]] = self._load_timestamps()
|
||||
|
||||
def _load_timestamps(self) -> Dict[str, List[float]]:
|
||||
"""从本地文件加载时间戳数据。"""
|
||||
if not TRACKER_FILE.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(TRACKER_FILE, "rb") as f:
|
||||
data = orjson.loads(f.read())
|
||||
logger.info(f"成功从 {TRACKER_FILE} 加载了聊天时间戳数据。")
|
||||
return data
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"无法解析 {TRACKER_FILE},将创建一个新的空数据文件。")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"加载聊天时间戳数据时发生未知错误: {e}")
|
||||
return {}
|
||||
|
||||
def _save_timestamps(self):
|
||||
"""将当前的时间戳数据保存到本地文件。"""
|
||||
try:
|
||||
with open(TRACKER_FILE, "wb") as f:
|
||||
f.write(orjson.dumps(self._timestamps))
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天时间戳数据到 {TRACKER_FILE} 时失败: {e}")
|
||||
|
||||
def record_chat_start(self, chat_id: str):
|
||||
"""
|
||||
记录一次聊天会话的开始。
|
||||
|
||||
Args:
|
||||
chat_id (str): 唯一的聊天标识符 (例如,用户ID)。
|
||||
"""
|
||||
now = time.time()
|
||||
if chat_id not in self._timestamps:
|
||||
self._timestamps[chat_id] = []
|
||||
|
||||
self._timestamps[chat_id].append(now)
|
||||
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
|
||||
self._save_timestamps()
|
||||
|
||||
def get_timestamps_for_chat(self, chat_id: str) -> Optional[List[float]]:
|
||||
"""
|
||||
获取指定聊天的所有时间戳记录。
|
||||
|
||||
Args:
|
||||
chat_id (str): 聊天标识符。
|
||||
|
||||
Returns:
|
||||
Optional[List[float]]: 时间戳列表,如果不存在则返回 None。
|
||||
"""
|
||||
return self._timestamps.get(chat_id)
|
||||
|
||||
|
||||
# 创建一个全局单例
|
||||
chat_frequency_tracker = ChatFrequencyTracker()
|
||||
119
src/chat/frequency_analyzer/trigger.py
Normal file
119
src/chat/frequency_analyzer/trigger.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Frequency-Based Proactive Trigger
|
||||
=================================
|
||||
|
||||
本模块实现了一个周期性任务,用于根据用户的聊天频率来智能地触发主动思考。
|
||||
|
||||
核心功能:
|
||||
- 定期运行,检查所有已知的私聊用户。
|
||||
- 调用 ChatFrequencyAnalyzer 判断当前是否处于用户的高峰聊天时段。
|
||||
- 如果满足条件(高峰时段、角色清醒、聊天循环空闲),则触发一次主动思考。
|
||||
- 包含冷却机制,以避免在同一个高峰时段内重复打扰用户。
|
||||
|
||||
可配置参数:
|
||||
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
|
||||
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.chat_loop.proactive.events import ProactiveTriggerEvent
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager
|
||||
from .analyzer import chat_frequency_analyzer
|
||||
|
||||
logger = get_logger("FrequencyBasedTrigger")
|
||||
|
||||
# --- 可配置参数 ---
|
||||
# 触发器检查周期(秒)
|
||||
TRIGGER_CHECK_INTERVAL_SECONDS = 60 * 5 # 5分钟
|
||||
# 冷却时间(小时),确保在一个高峰时段只触发一次
|
||||
COOLDOWN_HOURS = 3
|
||||
|
||||
|
||||
class FrequencyBasedTrigger:
|
||||
"""
|
||||
一个周期性任务,根据聊天频率分析结果来触发主动思考。
|
||||
"""
|
||||
|
||||
def __init__(self, sleep_manager: SleepManager):
|
||||
self._sleep_manager = sleep_manager
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
# 记录上次为用户触发的时间,用于冷却控制
|
||||
# 格式: { "chat_id": timestamp }
|
||||
self._last_triggered: Dict[str, float] = {}
|
||||
|
||||
async def _run_trigger_cycle(self):
|
||||
"""触发器的主要循环逻辑。"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(TRIGGER_CHECK_INTERVAL_SECONDS)
|
||||
logger.debug("开始执行频率触发器检查...")
|
||||
|
||||
# 1. 检查角色是否清醒
|
||||
if self._sleep_manager.is_sleeping():
|
||||
logger.debug("角色正在睡眠,跳过本次频率触发检查。")
|
||||
continue
|
||||
|
||||
# 2. 获取所有已知的聊天ID
|
||||
# 【注意】这里我们假设所有 subheartflow 的 ID 就是 chat_id
|
||||
all_chat_ids = list(heartflow.subheartflows.keys())
|
||||
if not all_chat_ids:
|
||||
continue
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
for chat_id in all_chat_ids:
|
||||
# 3. 检查是否处于冷却时间内
|
||||
last_triggered_time = self._last_triggered.get(chat_id, 0)
|
||||
if time.time() - last_triggered_time < COOLDOWN_HOURS * 3600:
|
||||
continue
|
||||
|
||||
# 4. 检查当前是否是该用户的高峰聊天时间
|
||||
if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
|
||||
|
||||
sub_heartflow = await heartflow.get_or_create_subheartflow(chat_id)
|
||||
if not sub_heartflow:
|
||||
logger.warning(f"无法为 {chat_id} 获取或创建 sub_heartflow。")
|
||||
continue
|
||||
|
||||
# 5. 检查用户当前是否已有活跃的思考或回复任务
|
||||
cycle_detail = sub_heartflow.heart_fc_instance.context.current_cycle_detail
|
||||
if cycle_detail and not cycle_detail.end_time:
|
||||
logger.debug(f"用户 {chat_id} 的聊天循环正忙(仍在周期 {cycle_detail.cycle_id} 中),本次不触发。")
|
||||
continue
|
||||
|
||||
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且聊天循环空闲,准备触发主动思考。")
|
||||
|
||||
# 6. 直接调用 proactive_thinker
|
||||
event = ProactiveTriggerEvent(
|
||||
source="frequency_analyzer",
|
||||
reason=f"User is in a high-frequency chat period."
|
||||
)
|
||||
await sub_heartflow.heart_fc_instance.proactive_thinker.think(event)
|
||||
|
||||
# 7. 更新触发时间,进入冷却
|
||||
self._last_triggered[chat_id] = time.time()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("频率触发器任务被取消。")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"频率触发器循环发生未知错误: {e}", exc_info=True)
|
||||
# 发生错误后,等待更长时间再重试,避免刷屏
|
||||
await asyncio.sleep(TRIGGER_CHECK_INTERVAL_SECONDS * 2)
|
||||
|
||||
def start(self):
|
||||
"""启动触发器任务。"""
|
||||
if self._task is None or self._task.done():
|
||||
self._task = asyncio.create_task(self._run_trigger_cycle())
|
||||
logger.info("基于聊天频率的主动思考触发器已启动。")
|
||||
|
||||
def stop(self):
|
||||
"""停止触发器任务。"""
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
logger.info("基于聊天频率的主动思考触发器已停止。")
|
||||
@@ -928,6 +928,7 @@ class EntorhinalCortex:
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"weight": 1.0, # 默认权重为1.0
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
}
|
||||
@@ -1084,6 +1085,7 @@ class EntorhinalCortex:
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"weight": 1.0, # 默认权重为1.0
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
|
||||
@@ -85,6 +85,7 @@ class ChatStream:
|
||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||
self.focus_energy = 1
|
||||
self.no_reply_consecutive = 0
|
||||
self.breaking_accumulated_interest = 0.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
@@ -97,6 +98,7 @@ class ChatStream:
|
||||
"last_active_time": self.last_active_time,
|
||||
"energy_value": self.energy_value,
|
||||
"sleep_pressure": self.sleep_pressure,
|
||||
"breaking_accumulated_interest": self.breaking_accumulated_interest,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -257,7 +259,7 @@ class ChatManager:
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
if model_instance and getattr(model_instance, "group_id", None):
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
@@ -403,7 +405,7 @@ class ChatManager:
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
if model_instance and getattr(model_instance, "group_id", None):
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
|
||||
@@ -120,7 +120,7 @@ class MessageRecv(Message):
|
||||
self.priority_mode = "interest"
|
||||
self.priority_info = None
|
||||
self.interest_value: float = 0.0
|
||||
|
||||
|
||||
self.key_words = []
|
||||
self.key_words_lite = []
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class MessageStorage:
|
||||
if isinstance(keywords, list):
|
||||
return orjson.dumps(keywords).decode("utf-8")
|
||||
return "[]"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_keywords(keywords_str: str) -> list:
|
||||
"""将JSON字符串反序列化为关键词列表"""
|
||||
|
||||
@@ -161,10 +161,8 @@ class ActionModifier:
|
||||
|
||||
available_actions = list(self.action_manager.get_using_actions().keys())
|
||||
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import orjson
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
import asyncio
|
||||
import math
|
||||
import random
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Tuple, List, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
@@ -9,7 +13,7 @@ from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
@@ -19,12 +23,20 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
|
||||
from src.plugin_system.base.component_types import (
|
||||
ActionInfo,
|
||||
ChatMode,
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
)
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -83,12 +95,34 @@ def init_prompt():
|
||||
## 长期记忆摘要
|
||||
{long_term_memory_block}
|
||||
|
||||
## 最近的聊天内容
|
||||
{chat_content_block}
|
||||
|
||||
## 最近的动作历史
|
||||
{actions_before_now_block}
|
||||
|
||||
## 任务
|
||||
基于以上所有信息,分析当前情况,决定是否需要主动做些什么。
|
||||
如果你认为不需要,就选择 'do_nothing'。
|
||||
基于以上所有信息(特别是最近的聊天内容),分析当前情况,决定是否适合主动开启一个**新的、但又与当前氛围相关**的话题。
|
||||
|
||||
## 可用动作
|
||||
{action_options_text}
|
||||
动作:proactive_reply
|
||||
动作描述:在当前对话的基础上,主动发起一个新的对话,分享一个有趣的想法、见闻或者对未来的计划。
|
||||
- 当你觉得可以说些什么来活跃气氛,并且内容与当前聊天氛围不冲突时
|
||||
- 当你有一些新的想法或计划想要分享,并且可以自然地衔接当前话题时
|
||||
{{
|
||||
"action": "proactive_reply",
|
||||
"reason": "决定主动发起对话的具体原因",
|
||||
"topic": "你想要发起对话的主题或内容(需要简洁)"
|
||||
}}
|
||||
|
||||
动作:do_nothing
|
||||
动作描述:保持沉默,不主动发起任何动作或对话。
|
||||
- 当你分析了所有信息后,觉得当前不是一个发起互动的好时机时
|
||||
- 当最近的聊天内容很连贯,你的插入会打断别人时
|
||||
{{
|
||||
"action": "do_nothing",
|
||||
"reason":"决定保持沉默的具体原因"
|
||||
}}
|
||||
|
||||
你必须从上面列出的可用action中选择一个。
|
||||
请以严格的 JSON 格式输出,且仅包含 JSON 内容:
|
||||
@@ -110,6 +144,37 @@ def init_prompt():
|
||||
"action_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{name_block}
|
||||
|
||||
{chat_context_description},{time_block},现在请你根据以下聊天内容,选择一个或多个合适的action。如果没有合适的action,请选择no_action。,
|
||||
{chat_content_block}
|
||||
|
||||
**要求**
|
||||
1.action必须符合使用条件,如果符合条件,就选择
|
||||
2.如果聊天内容不适合使用action,即使符合条件,也不要使用
|
||||
3.{moderation_prompt}
|
||||
4.请注意如果相同的内容已经被执行,请不要重复执行
|
||||
这是你最近执行过的动作:
|
||||
{actions_before_now_block}
|
||||
|
||||
**可用的action**
|
||||
|
||||
no_action:不选择任何动作
|
||||
{{
|
||||
"action": "no_action",
|
||||
"reason":"不动作的原因"
|
||||
}}
|
||||
|
||||
{action_options_text}
|
||||
|
||||
请选择,并说明触发action的消息id和选择该action的原因。消息id格式:m+数字
|
||||
请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容:
|
||||
""",
|
||||
"sub_planner_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ActionPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
@@ -117,14 +182,16 @@ class ActionPlanner:
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
# --- 大脑 ---
|
||||
self.planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
||||
) # 用于动作规划
|
||||
)
|
||||
# --- 小脑 (新增) ---
|
||||
self.planner_small_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner_small, request_type="planner_small"
|
||||
)
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
# 添加重试计数器
|
||||
self.plan_retry_count = 0
|
||||
self.max_plan_retries = 3
|
||||
|
||||
async def _get_long_term_memory_context(self) -> str:
|
||||
"""
|
||||
@@ -171,32 +238,18 @@ class ActionPlanner:
|
||||
构建动作选项
|
||||
"""
|
||||
action_options_block = ""
|
||||
|
||||
if mode == ChatMode.PROACTIVE:
|
||||
action_options_block += """动作:do_nothing
|
||||
动作描述:保持沉默,不主动发起任何动作或对话。
|
||||
- 当你分析了所有信息后,觉得当前不是一个发起互动的好时机时
|
||||
{{
|
||||
"action": "do_nothing",
|
||||
"reason":"决定保持沉默的具体原因"
|
||||
}}
|
||||
|
||||
"""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# TODO: 增加一个字段来判断action是否支持在PROACTIVE模式下使用
|
||||
|
||||
param_text = ""
|
||||
if action_info.action_parameters:
|
||||
param_text = "\n" + "\n".join(
|
||||
f' "{p_name}":"{p_desc}"'
|
||||
for p_name, p_desc in action_info.action_parameters.items()
|
||||
f' "{p_name}":"{p_desc}"' for p_name, p_desc in action_info.action_parameters.items()
|
||||
)
|
||||
|
||||
require_text = "\n".join(f"- {req}" for req in action_info.action_require)
|
||||
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async(
|
||||
"action_prompt"
|
||||
)
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
||||
action_options_block += using_action_prompt.format(
|
||||
action_name=action_name,
|
||||
action_description=action_info.description,
|
||||
@@ -205,9 +258,7 @@ class ActionPlanner:
|
||||
)
|
||||
return action_options_block
|
||||
|
||||
def find_message_by_id(
|
||||
self, message_id: str, message_id_list: list
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据message_id从message_id_list中查找对应的原始消息
|
||||
@@ -242,179 +293,370 @@ class ActionPlanner:
|
||||
# 假设消息列表是按时间顺序排列的,最后一个是最新的
|
||||
return message_id_list[-1].get("message")
|
||||
|
||||
def _parse_single_action(
|
||||
self,
|
||||
action_json: dict,
|
||||
message_id_list: list, # 使用 planner.py 的 list of dict
|
||||
current_available_actions: list, # 使用 planner.py 的 list of tuple
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
[注释] 解析单个小脑LLM返回的action JSON,并将其转换为标准化的字典。
|
||||
"""
|
||||
parsed_actions = []
|
||||
try:
|
||||
action = action_json.get("action", "no_action")
|
||||
reasoning = action_json.get("reason", "未提供原因")
|
||||
action_data = {k: v for k, v in action_json.items() if k not in ["action", "reason"]}
|
||||
|
||||
target_message = None
|
||||
if action != "no_action":
|
||||
if target_message_id := action_json.get("target_message_id"):
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
if target_message is None:
|
||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}'")
|
||||
target_message = self.get_latest_message(message_id_list)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
||||
|
||||
available_action_names = [name for name, _ in current_available_actions]
|
||||
if action not in ["no_action", "reply"] and action not in available_action_names:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_action'"
|
||||
)
|
||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
|
||||
action = "no_action"
|
||||
|
||||
# 将列表转换为字典格式以供将来使用
|
||||
available_actions_dict = dict(current_available_actions)
|
||||
parsed_actions.append(
|
||||
{
|
||||
"action_type": action,
|
||||
"reasoning": reasoning,
|
||||
"action_data": action_data,
|
||||
"action_message": target_message,
|
||||
"available_actions": available_actions_dict,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}解析单个action时出错: {e}")
|
||||
parsed_actions.append(
|
||||
{
|
||||
"action_type": "no_action",
|
||||
"reasoning": f"解析action时出错: {e}",
|
||||
"action_data": {},
|
||||
"action_message": None,
|
||||
"available_actions": dict(current_available_actions),
|
||||
}
|
||||
)
|
||||
return parsed_actions
|
||||
|
||||
def _filter_no_actions(self, action_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
[注释] 从一个action字典列表中过滤掉所有的 'no_action'。
|
||||
如果过滤后列表为空, 则返回一个空的列表, 或者根据需要返回一个默认的no_action字典。
|
||||
"""
|
||||
non_no_actions = [a for a in action_list if a.get("action_type") not in ["no_action", "no_reply"]]
|
||||
if non_no_actions:
|
||||
return non_no_actions
|
||||
# 如果都是 no_action,则返回一个包含第一个 no_action 的列表,以保留 reason
|
||||
return action_list[:1] if action_list else []
|
||||
|
||||
async def sub_plan(
|
||||
self,
|
||||
action_list: list, # 使用 planner.py 的 list of tuple
|
||||
chat_content_block: str,
|
||||
message_id_list: list, # 使用 planner.py 的 list of dict
|
||||
is_group_chat: bool = False,
|
||||
chat_target_info: Optional[dict] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
[注释] "小脑"规划器。接收一小组actions,使用轻量级LLM判断其中哪些应该被触发。
|
||||
这是一个独立的、并行的思考单元。返回一个包含action字典的列表。
|
||||
"""
|
||||
try:
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=time.time() - 1200,
|
||||
timestamp_end=time.time(),
|
||||
limit=20,
|
||||
)
|
||||
action_names_in_list = [name for name, _ in action_list]
|
||||
filtered_actions = [
|
||||
record for record in actions_before_now if record.get("action_name") in action_names_in_list
|
||||
]
|
||||
actions_before_now_block = build_readable_actions(actions=filtered_actions)
|
||||
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
if not is_group_chat and chat_target_info:
|
||||
chat_target_name = chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = ""
|
||||
for using_actions_name, using_actions_info in action_list:
|
||||
param_text = ""
|
||||
if using_actions_info.action_parameters:
|
||||
param_text = "\n" + "\n".join(
|
||||
f' "{p_name}":"{p_desc}"'
|
||||
for p_name, p_desc in using_actions_info.action_parameters.items()
|
||||
)
|
||||
require_text = "\n".join(f"- {req}" for req in using_actions_info.action_require)
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
||||
action_options_block += using_action_prompt.format(
|
||||
action_name=using_actions_name,
|
||||
action_description=using_actions_info.description,
|
||||
action_parameters=param_text,
|
||||
action_require=require_text,
|
||||
)
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("sub_planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
name_block=name_block,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"构建小脑提示词时出错: {e}\n{traceback.format_exc()}")
|
||||
return [{"action_type": "no_action", "reasoning": f"构建小脑Prompt时出错: {e}"}]
|
||||
|
||||
action_dicts: List[Dict[str, Any]] = []
|
||||
try:
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_small_llm.generate_response_async(prompt=prompt)
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix}小脑原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}小脑原始响应: {llm_content}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}小脑原始响应: {llm_content}")
|
||||
|
||||
if llm_content:
|
||||
parsed_json = orjson.loads(repair_json(llm_content))
|
||||
if isinstance(parsed_json, list):
|
||||
for item in parsed_json:
|
||||
if isinstance(item, dict):
|
||||
action_dicts.extend(self._parse_single_action(item, message_id_list, action_list))
|
||||
elif isinstance(parsed_json, dict):
|
||||
action_dicts.extend(self._parse_single_action(parsed_json, message_id_list, action_list))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.log_prefix}解析小脑响应JSON失败: {e}. LLM原始输出: '{llm_content}'")
|
||||
action_dicts.append({"action_type": "no_action", "reasoning": f"解析小脑响应失败: {e}"})
|
||||
|
||||
if not action_dicts:
|
||||
action_dicts.append({"action_type": "no_action", "reasoning": "小脑未返回有效action"})
|
||||
|
||||
return action_dicts
|
||||
|
||||
async def plan(
|
||||
self,
|
||||
mode: ChatMode = ChatMode.FOCUS,
|
||||
loop_start_time:float = 0.0,
|
||||
loop_start_time: float = 0.0,
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
[注释] "大脑"规划器。
|
||||
1. 启动多个并行的"小脑"(sub_plan)来决定是否执行具体的actions。
|
||||
2. 自己(大脑)则专注于决定是否进行聊天回复(reply)。
|
||||
3. 整合大脑和小脑的决策,返回最终要执行的动作列表。
|
||||
"""
|
||||
# --- 1. 准备上下文信息 ---
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.6),
|
||||
)
|
||||
# 大脑使用较长的上下文
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=message_list_before_now,
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
# 小脑使用较短、较新的上下文
|
||||
message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :]
|
||||
chat_content_block_short, message_id_list_short = build_readable_messages_with_id(
|
||||
messages=message_list_before_now_short,
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
action = "no_reply" # 默认动作
|
||||
reasoning = "规划器初始化默认"
|
||||
action_data = {}
|
||||
current_available_actions: Dict[str, ActionInfo] = {}
|
||||
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
|
||||
prompt: str = ""
|
||||
message_id_list: list = []
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
if available_actions is None:
|
||||
available_actions = current_available_actions
|
||||
|
||||
# --- 2. 启动小脑并行思考 ---
|
||||
all_sub_planner_results: List[Dict[str, Any]] = []
|
||||
try:
|
||||
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
|
||||
sub_planner_actions: Dict[str, ActionInfo] = {}
|
||||
for action_name, action_info in available_actions.items():
|
||||
|
||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat, # <-- Pass HFC state
|
||||
chat_target_info=chat_target_info, # <-- 传递获取到的聊天目标信息
|
||||
current_available_actions=current_available_actions, # <-- Pass determined actions
|
||||
if action_info.activation_type in [ActionActivationType.LLM_JUDGE, ActionActivationType.ALWAYS]:
|
||||
sub_planner_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.RANDOM:
|
||||
if random.random() < action_info.random_activation_probability:
|
||||
sub_planner_actions[action_name] = action_info
|
||||
elif action_info.activation_type == ActionActivationType.KEYWORD:
|
||||
if any(keyword in chat_content_block_short for keyword in action_info.activation_keywords):
|
||||
sub_planner_actions[action_name] = action_info
|
||||
|
||||
if sub_planner_actions:
|
||||
sub_planner_actions_num = len(sub_planner_actions)
|
||||
planner_size_config = global_config.chat.planner_size
|
||||
sub_planner_size = int(planner_size_config) + (
|
||||
1 if random.random() < planner_size_config - int(planner_size_config) else 0
|
||||
)
|
||||
sub_planner_num = math.ceil(sub_planner_actions_num / sub_planner_size)
|
||||
logger.info(f"{self.log_prefix}使用{sub_planner_num}个小脑进行思考 (尺寸: {sub_planner_size})")
|
||||
|
||||
action_items = list(sub_planner_actions.items())
|
||||
random.shuffle(action_items)
|
||||
sub_planner_lists = [action_items[i::sub_planner_num] for i in range(sub_planner_num)]
|
||||
|
||||
sub_plan_tasks = [
|
||||
self.sub_plan(
|
||||
action_list=action_group,
|
||||
chat_content_block=chat_content_block_short,
|
||||
message_id_list=message_id_list_short,
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
)
|
||||
for action_group in sub_planner_lists
|
||||
]
|
||||
sub_plan_results = await asyncio.gather(*sub_plan_tasks)
|
||||
for sub_result in sub_plan_results:
|
||||
all_sub_planner_results.extend(sub_result)
|
||||
|
||||
sub_actions_str = ", ".join(
|
||||
a["action_type"] for a in all_sub_planner_results if a["action_type"] != "no_action"
|
||||
) or "no_action"
|
||||
logger.info(f"{self.log_prefix}小脑决策: [{sub_actions_str}]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}小脑调度过程中出错: {e}\n{traceback.format_exc()}")
|
||||
|
||||
# --- 3. 大脑独立思考是否回复 ---
|
||||
action, reasoning, action_data, target_message = "no_reply", "大脑初始化默认", {}, None
|
||||
try:
|
||||
prompt, _ = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat,
|
||||
chat_target_info=chat_target_info,
|
||||
current_available_actions={},
|
||||
mode=mode,
|
||||
chat_content_block_override=chat_content_block,
|
||||
message_id_list_override=message_id_list,
|
||||
)
|
||||
|
||||
# --- 调用 LLM (普通文本生成) ---
|
||||
llm_content = None
|
||||
try:
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
if reasoning_content:
|
||||
logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||
action = "no_reply"
|
||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if llm_content:
|
||||
try:
|
||||
parsed_json = orjson.loads(repair_json(llm_content))
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
if parsed_json:
|
||||
parsed_json = parsed_json[-1]
|
||||
logger.warning(f"{self.log_prefix}LLM返回了多个JSON对象,使用最后一个: {parsed_json}")
|
||||
else:
|
||||
parsed_json = {}
|
||||
|
||||
if not isinstance(parsed_json, dict):
|
||||
logger.error(f"{self.log_prefix}解析后的JSON不是字典类型: {type(parsed_json)}")
|
||||
parsed_json = {}
|
||||
|
||||
parsed_json = orjson.loads(repair_json(llm_content))
|
||||
parsed_json = parsed_json[-1] if isinstance(parsed_json, list) and parsed_json else parsed_json
|
||||
if isinstance(parsed_json, dict):
|
||||
action = parsed_json.get("action", "no_reply")
|
||||
reasoning = parsed_json.get("reason", "未提供原因")
|
||||
|
||||
# 将所有其他属性添加到action_data
|
||||
for key, value in parsed_json.items():
|
||||
if key not in ["action", "reason"]:
|
||||
action_data[key] = value
|
||||
|
||||
# 非no_reply动作需要target_message_id
|
||||
action_data = {k: v for k, v in parsed_json.items() if k not in ["action", "reason"]}
|
||||
if action != "no_reply":
|
||||
if target_message_id := parsed_json.get("target_message_id"):
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
# 如果获取的target_message为None,输出warning并重新plan
|
||||
if target_message is None:
|
||||
self.plan_retry_count += 1
|
||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
|
||||
|
||||
# 如果连续三次plan均为None,输出error并选取最新消息
|
||||
if self.plan_retry_count >= self.max_plan_retries:
|
||||
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message")
|
||||
target_message = self.get_latest_message(message_id_list)
|
||||
self.plan_retry_count = 0 # 重置计数器
|
||||
else:
|
||||
# 递归重新plan
|
||||
return await self.plan(mode, loop_start_time, available_actions)
|
||||
else:
|
||||
# 成功获取到target_message,重置计数器
|
||||
self.plan_retry_count = 0
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
||||
|
||||
|
||||
if action != "no_reply" and action != "reply" and action not in current_available_actions:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
|
||||
)
|
||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
|
||||
action = "no_reply"
|
||||
if target_id := parsed_json.get("target_message_id"):
|
||||
target_message = self.find_message_by_id(target_id, message_id_list)
|
||||
if not target_message:
|
||||
target_message = self.get_latest_message(message_id_list)
|
||||
logger.info(f"{self.log_prefix}大脑决策: [{action}]")
|
||||
|
||||
except Exception as json_e:
|
||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||
traceback.print_exc()
|
||||
reasoning = f"解析LLM响应JSON失败: {json_e}. 将使用默认动作 'no_reply'."
|
||||
action = "no_reply"
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}大脑处理过程中发生意外错误: {e}\n{traceback.format_exc()}")
|
||||
action, reasoning = "no_reply", f"大脑处理错误: {e}"
|
||||
|
||||
except Exception as outer_e:
|
||||
logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_reply: {outer_e}")
|
||||
traceback.print_exc()
|
||||
action = "no_reply"
|
||||
reasoning = f"Planner 内部处理错误: {outer_e}"
|
||||
|
||||
is_parallel = False
|
||||
if mode == ChatMode.NORMAL and action in current_available_actions:
|
||||
is_parallel = current_available_actions[action].parallel_action
|
||||
|
||||
# --- 4. 整合大脑和小脑的决策 ---
|
||||
# 如果是私聊且开启了强制回复,则将no_reply强制改为reply
|
||||
if not is_group_chat and global_config.chat.force_reply_private and action == "no_reply":
|
||||
action = "reply"
|
||||
reasoning = "私聊强制回复"
|
||||
logger.info(f"{self.log_prefix}私聊强制回复已触发,将动作从 'no_reply' 修改为 'reply'")
|
||||
|
||||
is_parallel = True
|
||||
for info in all_sub_planner_results:
|
||||
action_type = info.get("action_type")
|
||||
if action_type and action_type not in ["no_action", "no_reply"]:
|
||||
action_info = available_actions.get(action_type)
|
||||
if action_info and not action_info.parallel_action:
|
||||
is_parallel = False
|
||||
break
|
||||
|
||||
action_data["loop_start_time"] = loop_start_time
|
||||
final_actions: List[Dict[str, Any]] = []
|
||||
|
||||
if is_parallel:
|
||||
logger.info(f"{self.log_prefix}决策模式: 大脑与小脑并行")
|
||||
if action not in ["no_action", "no_reply"]:
|
||||
final_actions.append(
|
||||
{
|
||||
"action_type": action,
|
||||
"reasoning": reasoning,
|
||||
"action_data": action_data,
|
||||
"action_message": target_message,
|
||||
"available_actions": available_actions,
|
||||
}
|
||||
)
|
||||
final_actions.extend(all_sub_planner_results)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix}决策模式: 小脑优先 (检测到非并行action)")
|
||||
final_actions.extend(all_sub_planner_results)
|
||||
|
||||
final_actions = self._filter_no_actions(final_actions)
|
||||
|
||||
if not final_actions:
|
||||
final_actions = [
|
||||
{
|
||||
"action_type": "no_action",
|
||||
"reasoning": "所有规划器都选择不执行动作",
|
||||
"action_data": {}, "action_message": None, "available_actions": available_actions
|
||||
}
|
||||
]
|
||||
|
||||
final_target_message = target_message
|
||||
if not final_target_message and final_actions:
|
||||
final_target_message = next((act.get("action_message") for act in final_actions if act.get("action_message")), None)
|
||||
|
||||
actions_str = ", ".join([a.get('action_type', 'N/A') for a in final_actions])
|
||||
logger.info(f"{self.log_prefix}最终执行动作 ({len(final_actions)}): [{actions_str}]")
|
||||
|
||||
actions = []
|
||||
|
||||
# 1. 添加Planner取得的动作
|
||||
actions.append({
|
||||
"action_type": action,
|
||||
"reasoning": reasoning,
|
||||
"action_data": action_data,
|
||||
"action_message": target_message,
|
||||
"available_actions": available_actions # 添加这个字段
|
||||
})
|
||||
|
||||
if action != "reply" and is_parallel:
|
||||
actions.append({
|
||||
"action_type": "reply",
|
||||
"action_message": target_message,
|
||||
"available_actions": available_actions
|
||||
})
|
||||
|
||||
return actions,target_message
|
||||
return final_actions, final_target_message
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool, # Now passed as argument
|
||||
chat_target_info: Optional[dict], # Now passed as argument
|
||||
is_group_chat: bool,
|
||||
chat_target_info: Optional[dict],
|
||||
current_available_actions: Dict[str, ActionInfo],
|
||||
refresh_time :bool = False,
|
||||
mode: ChatMode = ChatMode.FOCUS,
|
||||
) -> tuple[str, list]: # sourcery skip: use-join
|
||||
chat_content_block_override: Optional[str] = None,
|
||||
message_id_list_override: Optional[List] = None,
|
||||
refresh_time: bool = False, # 添加缺失的参数
|
||||
) -> tuple[str, list]:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
# --- 通用信息获取 ---
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
if global_config.bot.alias_names
|
||||
else ""
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
bot_core_personality = global_config.personality.personality_core
|
||||
identity_block = (
|
||||
f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
)
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
|
||||
schedule_block = ""
|
||||
if global_config.schedule.enable:
|
||||
if global_config.planning_system.schedule_enable:
|
||||
if current_activity := schedule_manager.get_current_activity():
|
||||
schedule_block = (
|
||||
f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
||||
)
|
||||
schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
||||
|
||||
mood_block = ""
|
||||
if global_config.mood.enable_mood:
|
||||
@@ -424,20 +666,38 @@ class ActionPlanner:
|
||||
# --- 根据模式构建不同的Prompt ---
|
||||
if mode == ChatMode.PROACTIVE:
|
||||
long_term_memory_block = await self._get_long_term_memory_context()
|
||||
action_options_text = await self._build_action_options(
|
||||
current_available_actions, mode
|
||||
|
||||
# 获取最近的聊天记录用于主动思考决策
|
||||
message_list_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.2), # 主动思考时只看少量最近消息
|
||||
)
|
||||
chat_content_block, _ = build_readable_messages_with_id(
|
||||
messages=message_list_short,
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
prompt_template = await global_prompt_manager.get_prompt_async(
|
||||
"proactive_planner_prompt"
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
prompt = prompt_template.format(
|
||||
time_block=time_block,
|
||||
identity_block=identity_block,
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
long_term_memory_block=long_term_memory_block,
|
||||
action_options_text=action_options_text,
|
||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
)
|
||||
return prompt, []
|
||||
|
||||
@@ -463,12 +723,8 @@ class ActionPlanner:
|
||||
limit=5,
|
||||
)
|
||||
|
||||
actions_before_now_block = build_readable_actions(
|
||||
actions=actions_before_now
|
||||
)
|
||||
actions_before_now_block = (
|
||||
f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
if refresh_time:
|
||||
self.last_obs_time_mark = time.time()
|
||||
@@ -504,30 +760,22 @@ class ActionPlanner:
|
||||
}}"""
|
||||
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
chat_target_name = None
|
||||
chat_target_name = None
|
||||
if not is_group_chat and chat_target_info:
|
||||
chat_target_name = (
|
||||
chat_target_info.get("person_name")
|
||||
or chat_target_info.get("user_nickname")
|
||||
or "对方"
|
||||
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
|
||||
)
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = await self._build_action_options(
|
||||
current_available_actions, mode
|
||||
)
|
||||
action_options_block = await self._build_action_options(current_available_actions, mode)
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
custom_prompt_block = ""
|
||||
if global_config.custom_prompt.planner_custom_prompt_content:
|
||||
custom_prompt_block = (
|
||||
global_config.custom_prompt.planner_custom_prompt_content
|
||||
)
|
||||
custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async(
|
||||
"planner_prompt"
|
||||
)
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
@@ -555,9 +803,7 @@ class ActionPlanner:
|
||||
"""
|
||||
is_group_chat = True
|
||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
logger.debug(
|
||||
f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
@@ -568,13 +814,9 @@ class ActionPlanner:
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[
|
||||
action_name
|
||||
]
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到"
|
||||
)
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
# 将no_reply作为系统级特殊动作添加到可用动作中
|
||||
# no_reply虽然是系统级决策,但需要让规划器认为它是可用的
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
默认回复生成器 - 集成SmartPrompt系统
|
||||
使用重构后的SmartPrompt系统替换原有的复杂提示词构建逻辑
|
||||
默认回复生成器 - 集成统一Prompt系统
|
||||
使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑
|
||||
"""
|
||||
|
||||
import traceback
|
||||
@@ -11,11 +11,9 @@ import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.chat.utils.prompt_utils import PromptUtils
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
@@ -23,7 +21,7 @@ from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
@@ -37,10 +35,9 @@ from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
# 导入新的智能Prompt系统
|
||||
from src.chat.utils.smart_prompt import SmartPrompt, SmartPromptParameters
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import PromptParameters
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
@@ -286,6 +283,7 @@ class DefaultReplyer:
|
||||
return False, None, None
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
# 触发 POST_LLM 事件(请求 LLM 之前)
|
||||
if not from_plugin:
|
||||
result = await event_manager.trigger_event(
|
||||
EventType.POST_LLM, plugin_name="SYSTEM", prompt=prompt, stream_id=stream_id
|
||||
@@ -307,6 +305,7 @@ class DefaultReplyer:
|
||||
"model": model_name,
|
||||
"tool_calls": tool_call,
|
||||
}
|
||||
|
||||
# 触发 AFTER_LLM 事件
|
||||
if not from_plugin:
|
||||
result = await event_manager.trigger_event(
|
||||
@@ -600,7 +599,8 @@ class DefaultReplyer:
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具"""
|
||||
return PromptUtils.parse_reply_target(target_message)
|
||||
from src.chat.utils.prompt import Prompt
|
||||
return Prompt.parse_reply_target(target_message)
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
"""构建关键词反应提示
|
||||
@@ -706,16 +706,16 @@ class DefaultReplyer:
|
||||
# 检查最新五条消息中是否包含bot自己说的消息
|
||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
||||
|
||||
|
||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||
|
||||
|
||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
||||
if not has_bot_message:
|
||||
core_dialogue_prompt = ""
|
||||
else:
|
||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||
|
||||
|
||||
core_dialogue_prompt_str = build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
@@ -819,7 +819,7 @@ class DefaultReplyer:
|
||||
mood_prompt = ""
|
||||
|
||||
if reply_to:
|
||||
#兼容旧的reply_to
|
||||
# 兼容旧的reply_to
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
else:
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
@@ -830,7 +830,7 @@ class DefaultReplyer:
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
sender = person_name
|
||||
target = reply_message.get('processed_plain_text')
|
||||
target = reply_message.get("processed_plain_text")
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
@@ -875,7 +875,8 @@ class DefaultReplyer:
|
||||
target_user_info = None
|
||||
if sender:
|
||||
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
||||
|
||||
|
||||
from src.chat.utils.prompt import Prompt
|
||||
# 并行执行六个构建任务
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
@@ -888,7 +889,7 @@ class DefaultReplyer:
|
||||
),
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
|
||||
self._time_and_run_task(
|
||||
PromptUtils.build_cross_context(chat_id, target_user_info, global_config.personality.prompt_mode),
|
||||
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
|
||||
"cross_context",
|
||||
),
|
||||
)
|
||||
@@ -939,7 +940,8 @@ class DefaultReplyer:
|
||||
identity_block = await get_individuality().get_personality_block()
|
||||
|
||||
schedule_block = ""
|
||||
if global_config.schedule.enable:
|
||||
if global_config.planning_system.schedule_enable:
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
current_activity = schedule_manager.get_current_activity()
|
||||
if current_activity:
|
||||
schedule_block = f"你当前正在:{current_activity}。"
|
||||
@@ -971,8 +973,8 @@ class DefaultReplyer:
|
||||
# 根据配置选择模板
|
||||
current_prompt_mode = global_config.personality.prompt_mode
|
||||
|
||||
# 使用重构后的SmartPrompt系统
|
||||
prompt_params = SmartPromptParameters(
|
||||
# 使用新的统一Prompt系统 - 创建PromptParameters
|
||||
prompt_parameters = PromptParameters(
|
||||
chat_id=chat_id,
|
||||
is_group_chat=is_group_chat,
|
||||
sender=sender,
|
||||
@@ -1005,12 +1007,19 @@ class DefaultReplyer:
|
||||
action_descriptions=action_descriptions,
|
||||
)
|
||||
|
||||
# 使用重构后的SmartPrompt系统
|
||||
smart_prompt = SmartPrompt(
|
||||
template_name=None, # 由current_prompt_mode自动选择
|
||||
parameters=prompt_params,
|
||||
)
|
||||
prompt_text = await smart_prompt.build_prompt()
|
||||
# 使用新的统一Prompt系统 - 使用正确的模板名称
|
||||
template_name = None
|
||||
if current_prompt_mode == "s4u":
|
||||
template_name = "s4u_style_prompt"
|
||||
elif current_prompt_mode == "normal":
|
||||
template_name = "normal_style_prompt"
|
||||
elif current_prompt_mode == "minimal":
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
# 获取模板内容
|
||||
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
|
||||
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||
prompt_text = await prompt.build()
|
||||
|
||||
return prompt_text
|
||||
|
||||
@@ -1024,7 +1033,7 @@ class DefaultReplyer:
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
|
||||
if reply_message:
|
||||
sender = reply_message.get("sender")
|
||||
target = reply_message.get("target")
|
||||
@@ -1111,8 +1120,8 @@ class DefaultReplyer:
|
||||
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
# 使用重构后的SmartPrompt系统 - Expressor模式
|
||||
prompt_params = SmartPromptParameters(
|
||||
# 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters
|
||||
prompt_parameters = PromptParameters(
|
||||
chat_id=chat_id,
|
||||
is_group_chat=is_group_chat,
|
||||
sender=sender,
|
||||
@@ -1132,8 +1141,10 @@ class DefaultReplyer:
|
||||
relation_info_block=relation_info,
|
||||
)
|
||||
|
||||
smart_prompt = SmartPrompt(parameters=prompt_params)
|
||||
prompt_text = await smart_prompt.build_prompt()
|
||||
# 使用新的统一Prompt系统 - Expressor模式
|
||||
template_prompt = await global_prompt_manager.get_prompt_async("default_expressor_prompt")
|
||||
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||
prompt_text = await prompt.build()
|
||||
|
||||
return prompt_text
|
||||
|
||||
@@ -1181,7 +1192,9 @@ class DefaultReplyer:
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(prompt)
|
||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||
prompt
|
||||
)
|
||||
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
|
||||
|
||||
@@ -1250,7 +1250,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
# 检查必要信息是否存在 且 不是机器人自己
|
||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||
continue
|
||||
|
||||
|
||||
# 添加空值检查,防止 platform 为 None 时出错
|
||||
if platform is None:
|
||||
platform = "unknown"
|
||||
|
||||
823
src/chat/utils/prompt.py
Normal file
823
src/chat/utils/prompt.py
Normal file
@@ -0,0 +1,823 @@
|
||||
"""
|
||||
统一提示词系统 - 合并模板管理和智能构建功能
|
||||
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
|
||||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
import contextvars
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Literal, Tuple
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("unified_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
|
||||
|
||||
class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
if context_id is not None:
|
||||
try:
|
||||
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
|
||||
try:
|
||||
if context_id not in self._context_prompts:
|
||||
self._context_prompts[context_id] = {}
|
||||
finally:
|
||||
self._context_lock.release()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"获取上下文锁超时,context_id: {context_id}")
|
||||
context_id = None
|
||||
|
||||
previous_context = self._current_context
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
else:
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
if context_id is not None and token is not None:
|
||||
try:
|
||||
self._current_context_var.reset(token)
|
||||
except Exception as e:
|
||||
logger.warning(f"恢复上下文时出错: {e}")
|
||||
try:
|
||||
self._current_context = previous_context
|
||||
except Exception:
|
||||
...
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
"""异步获取当前作用域中的提示模板"""
|
||||
async with self._context_lock:
|
||||
current_context = self._current_context
|
||||
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
|
||||
if (
|
||||
current_context
|
||||
and current_context in self._context_prompts
|
||||
and name in self._context_prompts[current_context]
|
||||
):
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
if prompt.name:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""统一提示词管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._prompts = {}
|
||||
self._counter = 0
|
||||
self._context = PromptContext()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
"""异步获取提示模板"""
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
|
||||
async with self._lock:
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
self._counter += 1
|
||||
return f"prompt_{self._counter}"
|
||||
|
||||
def register(self, prompt: "Prompt") -> None:
|
||||
"""注册一个prompt"""
|
||||
if not prompt.name:
|
||||
prompt.name = self.generate_name(prompt.template)
|
||||
self._prompts[prompt.name] = prompt
|
||||
|
||||
def add_prompt(self, name: str, fstr: str) -> "Prompt":
|
||||
"""添加新提示模板"""
|
||||
prompt = Prompt(fstr, name=name)
|
||||
if prompt.name:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
"""格式化提示模板"""
|
||||
prompt = await self.get_prompt_async(name)
|
||||
result = prompt.format(**kwargs)
|
||||
return result
|
||||
|
||||
|
||||
# 全局单例
|
||||
global_prompt_manager = PromptManager()
|
||||
|
||||
|
||||
class Prompt:
|
||||
"""
|
||||
统一提示词类 - 合并模板管理和智能构建功能
|
||||
真正的Prompt类,支持模板管理和智能上下文构建
|
||||
"""
|
||||
|
||||
# 临时标记,作为类常量
|
||||
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
|
||||
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
should_register: bool = True
|
||||
):
|
||||
"""
|
||||
初始化统一提示词
|
||||
|
||||
Args:
|
||||
template: 提示词模板字符串
|
||||
name: 提示词名称
|
||||
parameters: 构建参数
|
||||
should_register: 是否自动注册到全局管理器
|
||||
"""
|
||||
self.template = template
|
||||
self.name = name
|
||||
self.parameters = parameters or PromptParameters()
|
||||
self.args = self._parse_template_args(template)
|
||||
self._formatted_result = ""
|
||||
|
||||
# 预处理模板中的转义花括号
|
||||
self._processed_template = self._process_escaped_braces(template)
|
||||
|
||||
# 自动注册
|
||||
if should_register and not global_prompt_manager._context._current_context:
|
||||
global_prompt_manager.register(self)
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""处理模板中的转义花括号"""
|
||||
if isinstance(template, list):
|
||||
template = "\n".join(str(item) for item in template)
|
||||
elif not isinstance(template, str):
|
||||
template = str(template)
|
||||
|
||||
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
|
||||
|
||||
@staticmethod
|
||||
def _restore_escaped_braces(template: str) -> str:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def _parse_template_args(self, template: str) -> List[str]:
|
||||
"""解析模板参数"""
|
||||
template_args = []
|
||||
processed_template = self._process_escaped_braces(template)
|
||||
result = re.findall(r"\{(.*?)}", processed_template)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
return template_args
|
||||
|
||||
async def build(self) -> str:
|
||||
"""
|
||||
构建完整的提示词,包含智能上下文
|
||||
|
||||
Returns:
|
||||
str: 构建完成的提示词文本
|
||||
"""
|
||||
# 参数验证
|
||||
errors = self.parameters.validate()
|
||||
if errors:
|
||||
logger.error(f"参数验证失败: {', '.join(errors)}")
|
||||
raise ValueError(f"参数验证失败: {', '.join(errors)}")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建上下文数据
|
||||
context_data = await self._build_context_data()
|
||||
|
||||
# 格式化模板
|
||||
result = await self._format_with_context(context_data)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
|
||||
self._formatted_result = result
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"构建Prompt超时: {e}")
|
||||
raise TimeoutError(f"构建Prompt超时: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"构建Prompt失败: {e}")
|
||||
raise RuntimeError(f"构建Prompt失败: {e}")
|
||||
|
||||
async def _build_context_data(self) -> Dict[str, Any]:
|
||||
"""构建智能上下文数据"""
|
||||
# 并行执行所有构建任务
|
||||
start_time = time.time()
|
||||
timing_logs = {}
|
||||
|
||||
try:
|
||||
# 准备构建任务
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
# 初始化预构建参数
|
||||
pre_built_params = {}
|
||||
if self.parameters.expression_habits_block:
|
||||
pre_built_params["expression_habits_block"] = self.parameters.expression_habits_block
|
||||
if self.parameters.relation_info_block:
|
||||
pre_built_params["relation_info_block"] = self.parameters.relation_info_block
|
||||
if self.parameters.memory_block:
|
||||
pre_built_params["memory_block"] = self.parameters.memory_block
|
||||
if self.parameters.tool_info_block:
|
||||
pre_built_params["tool_info_block"] = self.parameters.tool_info_block
|
||||
if self.parameters.knowledge_prompt:
|
||||
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
|
||||
if self.parameters.cross_context_block:
|
||||
pre_built_params["cross_context_block"] = self.parameters.cross_context_block
|
||||
|
||||
# 根据参数确定要构建的项
|
||||
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
|
||||
tasks.append(self._build_expression_habits())
|
||||
task_names.append("expression_habits")
|
||||
|
||||
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
|
||||
tasks.append(self._build_memory_block())
|
||||
task_names.append("memory_block")
|
||||
|
||||
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
|
||||
tasks.append(self._build_relation_info())
|
||||
task_names.append("relation_info")
|
||||
|
||||
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
|
||||
tasks.append(self._build_tool_info())
|
||||
task_names.append("tool_info")
|
||||
|
||||
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
|
||||
tasks.append(self._build_knowledge_info())
|
||||
task_names.append("knowledge_info")
|
||||
|
||||
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
|
||||
tasks.append(self._build_cross_context())
|
||||
task_names.append("cross_context")
|
||||
|
||||
# 性能优化
|
||||
base_timeout = 10.0
|
||||
task_timeout = 2.0
|
||||
timeout_seconds = min(
|
||||
max(base_timeout, len(tasks) * task_timeout),
|
||||
30.0,
|
||||
)
|
||||
|
||||
max_concurrent_tasks = 5
|
||||
if len(tasks) > max_concurrent_tasks:
|
||||
results = []
|
||||
for i in range(0, len(tasks), max_concurrent_tasks):
|
||||
batch_tasks = tasks[i : i + max_concurrent_tasks]
|
||||
batch_names = task_names[i : i + max_concurrent_tasks]
|
||||
|
||||
batch_results = await asyncio.wait_for(
|
||||
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
results.extend(batch_results)
|
||||
else:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
|
||||
# 处理结果
|
||||
context_data = {}
|
||||
for i, result in enumerate(results):
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"构建任务{task_name}失败: {str(result)}")
|
||||
elif isinstance(result, dict):
|
||||
context_data.update(result)
|
||||
|
||||
# 添加预构建的参数
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"构建超时 ({timeout_seconds}s)")
|
||||
context_data = {}
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
|
||||
# 构建聊天历史
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
await self._build_s4u_chat_context(context_data)
|
||||
else:
|
||||
await self._build_normal_chat_context(context_data)
|
||||
|
||||
# 补充基础信息
|
||||
context_data.update({
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
|
||||
"extra_info_block": self.parameters.extra_info_block,
|
||||
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"identity": self.parameters.identity_block,
|
||||
"schedule_block": self.parameters.schedule_block,
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block,
|
||||
"reply_target_block": self.parameters.reply_target_block,
|
||||
"mood_state": self.parameters.mood_prompt,
|
||||
"action_descriptions": self.parameters.action_descriptions,
|
||||
})
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
|
||||
|
||||
return context_data
|
||||
|
||||
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
"""构建S4U模式的聊天上下文"""
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
|
||||
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
|
||||
self.parameters.message_list_before_now_long,
|
||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||
self.parameters.sender
|
||||
)
|
||||
|
||||
context_data["core_dialogue_prompt"] = core_dialogue
|
||||
context_data["background_dialogue_prompt"] = background_dialogue
|
||||
|
||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
"""构建normal模式的聊天上下文"""
|
||||
if not self.parameters.chat_talking_prompt_short:
|
||||
return
|
||||
|
||||
context_data["chat_info"] = f"""群里的聊天内容:
|
||||
{self.parameters.chat_talking_prompt_short}"""
|
||||
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""构建S4U风格的分离对话prompt"""
|
||||
# 实现逻辑与原有SmartPromptBuilder相同
|
||||
core_dialogue_list = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True,
|
||||
)
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
||||
|
||||
if not has_bot_message:
|
||||
core_dialogue_prompt = ""
|
||||
else:
|
||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :]
|
||||
|
||||
core_dialogue_prompt_str = build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""--------------------------------
|
||||
这是你和{sender}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt_str}
|
||||
--------------------------------
|
||||
"""
|
||||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
|
||||
async def _build_expression_habits(self) -> Dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
# 简化的实现,完整实现需要导入相关模块
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
async def _build_memory_block(self) -> Dict[str, Any]:
|
||||
"""构建记忆块"""
|
||||
# 简化的实现
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
return {"relation_info_block": relation_info}
|
||||
except Exception as e:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return {"relation_info_block": ""}
|
||||
|
||||
async def _build_tool_info(self) -> Dict[str, Any]:
|
||||
"""构建工具信息"""
|
||||
# 简化的实现
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
async def _build_knowledge_info(self) -> Dict[str, Any]:
|
||||
"""构建知识信息"""
|
||||
# 简化的实现
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
cross_context = await Prompt.build_cross_context(
|
||||
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
|
||||
)
|
||||
return {"cross_context_block": cross_context}
|
||||
except Exception as e:
|
||||
logger.error(f"构建跨群上下文失败: {e}")
|
||||
return {"cross_context_block": ""}
|
||||
|
||||
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
|
||||
"""使用上下文数据格式化模板"""
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
params = self._prepare_s4u_params(context_data)
|
||||
elif self.parameters.prompt_mode == "normal":
|
||||
params = self._prepare_normal_params(context_data)
|
||||
else:
|
||||
params = self._prepare_default_params(context_data)
|
||||
|
||||
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
|
||||
|
||||
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备S4U模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
"tool_info_block": context_data.get("tool_info_block", ""),
|
||||
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
|
||||
"memory_block": context_data.get("memory_block", ""),
|
||||
"relation_info_block": context_data.get("relation_info_block", ""),
|
||||
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
|
||||
"cross_context_block": context_data.get("cross_context_block", ""),
|
||||
"identity": self.parameters.identity_block or context_data.get("identity", ""),
|
||||
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
|
||||
"sender_name": self.parameters.sender,
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"reply_style": global_config.personality.reply_style,
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
}
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备Normal模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
"tool_info_block": context_data.get("tool_info_block", ""),
|
||||
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
|
||||
"memory_block": context_data.get("memory_block", ""),
|
||||
"relation_info_block": context_data.get("relation_info_block", ""),
|
||||
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
|
||||
"cross_context_block": context_data.get("cross_context_block", ""),
|
||||
"identity": self.parameters.identity_block or context_data.get("identity", ""),
|
||||
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
|
||||
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
"chat_info": context_data.get("chat_info", ""),
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"config_expression_style": global_config.personality.reply_style,
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
}
|
||||
|
||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备默认模式的参数"""
|
||||
return {
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
"relation_info_block": context_data.get("relation_info_block", ""),
|
||||
"chat_target": "",
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
"chat_info": context_data.get("chat_info", ""),
|
||||
"identity": self.parameters.identity_block or context_data.get("identity", ""),
|
||||
"chat_target_2": "",
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"raw_reply": self.parameters.target,
|
||||
"reason": "",
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"reply_style": global_config.personality.reply_style,
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
}
|
||||
|
||||
def format(self, *args, **kwargs) -> str:
|
||||
"""格式化模板,支持位置参数和关键字参数"""
|
||||
try:
|
||||
# 先用位置参数格式化
|
||||
if args:
|
||||
formatted_args = {}
|
||||
for i in range(len(args)):
|
||||
if i < len(self.args):
|
||||
formatted_args[self.args[i]] = args[i]
|
||||
processed_template = self._processed_template.format(**formatted_args)
|
||||
else:
|
||||
processed_template = self._processed_template
|
||||
|
||||
# 再用关键字参数格式化
|
||||
if kwargs:
|
||||
processed_template = processed_template.format(**kwargs)
|
||||
|
||||
# 将临时标记还原为实际的花括号
|
||||
result = self._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回格式化后的结果或原始模板"""
|
||||
return self._formatted_result if self._formatted_result else self.template
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回提示词的表示形式"""
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
|
||||
# =============================================================================
|
||||
# PromptUtils功能迁移 - 静态工具方法
|
||||
# 这些方法原来在PromptUtils类中,现在作为Prompt类的静态方法
|
||||
# 解决循环导入问题
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
Args:
|
||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (发送者名称, 消息内容)
|
||||
"""
|
||||
sender = ""
|
||||
target = ""
|
||||
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||
"""
|
||||
构建关系信息 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 关系信息字符串
|
||||
"""
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
if not reply_to:
|
||||
return ""
|
||||
sender, text = Prompt.parse_reply_target(reply_to)
|
||||
if not sender or not text:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
prompt_mode: 当前提示词模式
|
||||
target_user_info: 目标用户信息
|
||||
|
||||
Returns:
|
||||
str: 跨群聊上下文字符串
|
||||
"""
|
||||
if not global_config.cross_context.enable:
|
||||
return ""
|
||||
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
|
||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||
if not other_chat_raw_ids:
|
||||
return ""
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
return ""
|
||||
|
||||
if prompt_mode == "normal":
|
||||
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
||||
elif prompt_mode == "s4u":
|
||||
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target_id(reply_to: str) -> str:
|
||||
"""
|
||||
解析回复目标中的用户ID
|
||||
|
||||
Args:
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 用户ID
|
||||
"""
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
# 复用parse_reply_target方法的逻辑
|
||||
sender, _ = Prompt.parse_reply_target(reply_to)
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
**kwargs
|
||||
) -> Prompt:
|
||||
"""快速创建Prompt实例的工厂函数"""
|
||||
if parameters is None:
|
||||
parameters = PromptParameters(**kwargs)
|
||||
return Prompt(template, name, parameters)
|
||||
|
||||
|
||||
async def create_prompt_async(
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
**kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
import re
|
||||
import asyncio
|
||||
import contextvars
|
||||
|
||||
from rich.traceback import install
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("prompt_build")
|
||||
|
||||
|
||||
class PromptContext:
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
# 使用contextvars创建协程上下文变量
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
# 保存当前上下文并设置新上下文
|
||||
if context_id is not None:
|
||||
try:
|
||||
# 添加超时保护,避免长时间等待锁
|
||||
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
|
||||
try:
|
||||
if context_id not in self._context_prompts:
|
||||
self._context_prompts[context_id] = {}
|
||||
finally:
|
||||
self._context_lock.release()
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"获取上下文锁超时,context_id: {context_id}")
|
||||
# 超时时直接进入,不设置上下文
|
||||
context_id = None
|
||||
|
||||
# 保存当前协程的上下文值,不影响其他协程
|
||||
previous_context = self._current_context
|
||||
# 设置当前协程的新上下文
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
else:
|
||||
# 如果没有提供新上下文,保持当前上下文不变
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
# 恢复之前的上下文,添加异常保护
|
||||
if context_id is not None and token is not None:
|
||||
try:
|
||||
self._current_context_var.reset(token)
|
||||
except Exception as e:
|
||||
logger.warning(f"恢复上下文时出错: {e}")
|
||||
# 如果reset失败,尝试直接设置
|
||||
try:
|
||||
self._current_context = previous_context
|
||||
except Exception:
|
||||
...
|
||||
# 静默忽略恢复失败
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
"""异步获取当前作用域中的提示模板"""
|
||||
async with self._context_lock:
|
||||
current_context = self._current_context
|
||||
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
|
||||
if (
|
||||
current_context
|
||||
and current_context in self._context_prompts
|
||||
and name in self._context_prompts[current_context]
|
||||
):
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
if prompt.name:
|
||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self):
|
||||
self._prompts = {}
|
||||
self._counter = 0
|
||||
self._context = PromptContext()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
# 首先尝试从当前上下文获取
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
# 如果上下文中不存在,则使用全局提示模板
|
||||
async with self._lock:
|
||||
# logger.debug(f"从全局获取提示词: {name}")
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
self._counter += 1
|
||||
return f"prompt_{self._counter}"
|
||||
|
||||
def register(self, prompt: "Prompt") -> None:
|
||||
"""注册一个prompt"""
|
||||
if not prompt.name:
|
||||
prompt.name = self.generate_name(prompt.template)
|
||||
self._prompts[prompt.name] = prompt
|
||||
|
||||
def add_prompt(self, name: str, fstr: str) -> "Prompt":
|
||||
prompt = Prompt(fstr, name=name)
|
||||
if prompt.name:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
# 获取当前提示词
|
||||
prompt = await self.get_prompt_async(name)
|
||||
# 获取基本格式化结果
|
||||
result = prompt.format(**kwargs)
|
||||
return result
|
||||
|
||||
|
||||
# 全局单例
|
||||
global_prompt_manager = PromptManager()
|
||||
|
||||
|
||||
class Prompt(str):
|
||||
template: str
|
||||
name: Optional[str]
|
||||
args: List[str]
|
||||
_args: List[Any]
|
||||
_kwargs: Dict[str, Any]
|
||||
# 临时标记,作为类常量
|
||||
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
|
||||
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""处理模板中的转义花括号,将 \\{ 和 \\} 替换为临时标记""" # type: ignore
|
||||
# 如果传入的是列表,将其转换为字符串
|
||||
if isinstance(template, list):
|
||||
template = "\n".join(str(item) for item in template)
|
||||
elif not isinstance(template, str):
|
||||
template = str(template)
|
||||
|
||||
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
|
||||
|
||||
@staticmethod
|
||||
def _restore_escaped_braces(template: str) -> str:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def __new__(
|
||||
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
|
||||
):
|
||||
# 如果传入的是元组,转换为列表
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
should_register = kwargs.pop("_should_register", True)
|
||||
|
||||
# 预处理模板中的转义花括号
|
||||
processed_fstr = cls._process_escaped_braces(fstr)
|
||||
|
||||
# 解析模板
|
||||
template_args = []
|
||||
result = re.findall(r"\{(.*?)}", processed_fstr)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
|
||||
# 如果提供了初始参数,立即格式化
|
||||
if kwargs or args:
|
||||
formatted = cls._format_template(fstr, args=args, kwargs=kwargs)
|
||||
obj = super().__new__(cls, formatted)
|
||||
else:
|
||||
obj = super().__new__(cls, "")
|
||||
|
||||
obj.template = fstr
|
||||
obj.name = name
|
||||
obj.args = template_args
|
||||
obj._args = args or []
|
||||
obj._kwargs = kwargs
|
||||
|
||||
# 修改自动注册逻辑
|
||||
if should_register and not global_prompt_manager._context._current_context:
|
||||
global_prompt_manager.register(obj)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
|
||||
):
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = cls(fstr, name, args, **kwargs)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def _format_template(
|
||||
cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# 预处理模板中的转义花括号
|
||||
processed_template = cls._process_escaped_braces(template)
|
||||
|
||||
template_args = []
|
||||
result = re.findall(r"\{(.*?)}", processed_template)
|
||||
for expr in result:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
formatted_args = {}
|
||||
formatted_kwargs = {}
|
||||
|
||||
# 处理位置参数
|
||||
if args:
|
||||
# print(len(template_args), len(args), template_args, args)
|
||||
for i in range(len(args)):
|
||||
if i < len(template_args):
|
||||
arg = args[i]
|
||||
if isinstance(arg, Prompt):
|
||||
formatted_args[template_args[i]] = arg.format(**kwargs)
|
||||
else:
|
||||
formatted_args[template_args[i]] = arg
|
||||
else:
|
||||
logger.error(
|
||||
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
|
||||
)
|
||||
raise ValueError("格式化模板失败")
|
||||
|
||||
# 处理关键字参数
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, Prompt):
|
||||
remaining_kwargs = {k: v for k, v in kwargs.items() if k != key}
|
||||
formatted_kwargs[key] = value.format(**remaining_kwargs)
|
||||
else:
|
||||
formatted_kwargs[key] = value
|
||||
|
||||
try:
|
||||
# 先用位置参数格式化
|
||||
if args:
|
||||
processed_template = processed_template.format(**formatted_args)
|
||||
# 再用关键字参数格式化
|
||||
if kwargs:
|
||||
processed_template = processed_template.format(**formatted_kwargs)
|
||||
|
||||
# 将临时标记还原为实际的花括号
|
||||
result = cls._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(
|
||||
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
|
||||
) from e
|
||||
|
||||
def format(self, *args, **kwargs) -> "str":
|
||||
"""支持位置参数和关键字参数的格式化,使用"""
|
||||
ret = type(self)(
|
||||
self.template,
|
||||
self.name,
|
||||
args=list(args) if args else self._args,
|
||||
_should_register=False,
|
||||
**kwargs or self._kwargs,
|
||||
)
|
||||
# print(f"prompt build result: {ret} name: {ret.name} ")
|
||||
return str(ret)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__str__() if self._kwargs or self._args else self.template
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
@@ -1,156 +0,0 @@
|
||||
"""
|
||||
智能提示词参数模块 - 优化参数结构
|
||||
简化SmartPromptParameters,减少冗余和重复
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmartPromptParameters:
|
||||
"""简化的智能提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""统一的参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
|
||||
def get_needed_build_tasks(self) -> List[str]:
|
||||
"""获取需要执行的任务列表"""
|
||||
tasks = []
|
||||
|
||||
if self.enable_expression and not self.expression_habits_block:
|
||||
tasks.append("expression_habits")
|
||||
|
||||
if self.enable_memory and not self.memory_block:
|
||||
tasks.append("memory_block")
|
||||
|
||||
if self.enable_relation and not self.relation_info_block:
|
||||
tasks.append("relation_info")
|
||||
|
||||
if self.enable_tool and not self.tool_info_block:
|
||||
tasks.append("tool_info")
|
||||
|
||||
if self.enable_knowledge and not self.knowledge_prompt:
|
||||
tasks.append("knowledge_info")
|
||||
|
||||
if self.enable_cross_context and not self.cross_context_block:
|
||||
tasks.append("cross_context")
|
||||
|
||||
return tasks
|
||||
|
||||
@classmethod
|
||||
def from_legacy_params(cls, **kwargs) -> "SmartPromptParameters":
|
||||
"""
|
||||
从旧版参数创建新参数对象
|
||||
|
||||
Args:
|
||||
**kwargs: 旧版参数
|
||||
|
||||
Returns:
|
||||
SmartPromptParameters: 新参数对象
|
||||
"""
|
||||
return cls(
|
||||
# 基础参数
|
||||
chat_id=kwargs.get("chat_id", ""),
|
||||
is_group_chat=kwargs.get("is_group_chat", False),
|
||||
sender=kwargs.get("sender", ""),
|
||||
target=kwargs.get("target", ""),
|
||||
reply_to=kwargs.get("reply_to", ""),
|
||||
extra_info=kwargs.get("extra_info", ""),
|
||||
prompt_mode=kwargs.get("current_prompt_mode", "s4u"),
|
||||
# 功能开关
|
||||
enable_tool=kwargs.get("enable_tool", True),
|
||||
enable_memory=kwargs.get("enable_memory", True),
|
||||
enable_expression=kwargs.get("enable_expression", True),
|
||||
enable_relation=kwargs.get("enable_relation", True),
|
||||
enable_cross_context=kwargs.get("enable_cross_context", True),
|
||||
enable_knowledge=kwargs.get("enable_knowledge", True),
|
||||
# 性能控制
|
||||
max_context_messages=kwargs.get("max_context_messages", 50),
|
||||
debug_mode=kwargs.get("debug_mode", False),
|
||||
# 聊天历史和上下文
|
||||
chat_target_info=kwargs.get("chat_target_info"),
|
||||
message_list_before_now_long=kwargs.get("message_list_before_now_long", []),
|
||||
message_list_before_short=kwargs.get("message_list_before_short", []),
|
||||
chat_talking_prompt_short=kwargs.get("chat_talking_prompt_short", ""),
|
||||
target_user_info=kwargs.get("target_user_info"),
|
||||
# 已构建的内容块
|
||||
expression_habits_block=kwargs.get("expression_habits_block", ""),
|
||||
relation_info_block=kwargs.get("relation_info", ""),
|
||||
memory_block=kwargs.get("memory_block", ""),
|
||||
tool_info_block=kwargs.get("tool_info", ""),
|
||||
knowledge_prompt=kwargs.get("knowledge_prompt", ""),
|
||||
cross_context_block=kwargs.get("cross_context_block", ""),
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt=kwargs.get("keywords_reaction_prompt", ""),
|
||||
extra_info_block=kwargs.get("extra_info_block", ""),
|
||||
time_block=kwargs.get("time_block", ""),
|
||||
identity_block=kwargs.get("identity_block", ""),
|
||||
schedule_block=kwargs.get("schedule_block", ""),
|
||||
moderation_prompt_block=kwargs.get("moderation_prompt_block", ""),
|
||||
reply_target_block=kwargs.get("reply_target_block", ""),
|
||||
mood_prompt=kwargs.get("mood_prompt", ""),
|
||||
action_descriptions=kwargs.get("action_descriptions", ""),
|
||||
# 可用动作信息
|
||||
available_actions=kwargs.get("available_actions", None),
|
||||
)
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
共享提示词工具模块 - 消除重复代码
|
||||
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
logger = get_logger("prompt_utils")
|
||||
|
||||
|
||||
class PromptUtils:
|
||||
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
Args:
|
||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (发送者名称, 消息内容)
|
||||
"""
|
||||
sender = ""
|
||||
target = ""
|
||||
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_id: str, reply_to: str) -> str:
|
||||
"""
|
||||
构建关系信息 - 统一实现
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 关系信息字符串
|
||||
"""
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
if not reply_to:
|
||||
return ""
|
||||
sender, text = PromptUtils.parse_reply_target(reply_to)
|
||||
if not sender or not text:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
||||
"""
|
||||
if not global_config.cross_context.enable:
|
||||
return ""
|
||||
|
||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||
if not other_chat_raw_ids:
|
||||
return ""
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
return ""
|
||||
|
||||
if current_prompt_mode == "normal":
|
||||
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
||||
elif current_prompt_mode == "s4u":
|
||||
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target_id(reply_to: str) -> str:
|
||||
"""
|
||||
解析回复目标中的用户ID
|
||||
|
||||
Args:
|
||||
reply_to: 回复目标字符串
|
||||
|
||||
Returns:
|
||||
str: 用户ID
|
||||
"""
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
# 复用parse_reply_target方法的逻辑
|
||||
sender, _ = PromptUtils.parse_reply_target(reply_to)
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
return ""
|
||||
@@ -1,938 +0,0 @@
|
||||
"""
|
||||
智能Prompt系统 - 完全重构版本
|
||||
基于原有DefaultReplyer的完整功能集成,使用新的参数结构
|
||||
解决实现质量不高、功能集成不完整和错误处理不足的问题
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.chat.utils.prompt_utils import PromptUtils
|
||||
from src.chat.utils.prompt_parameters import SmartPromptParameters
|
||||
|
||||
logger = get_logger("smart_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatContext:
|
||||
"""聊天上下文信息"""
|
||||
|
||||
chat_id: str = ""
|
||||
platform: str = ""
|
||||
is_group: bool = False
|
||||
user_id: str = ""
|
||||
user_nickname: str = ""
|
||||
group_id: Optional[str] = None
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SmartPromptBuilder:
|
||||
"""重构的智能提示词构建器 - 统一错误处理和功能集成,移除缓存机制和依赖检查"""
|
||||
|
||||
def __init__(self):
|
||||
# 移除缓存相关初始化
|
||||
pass
|
||||
|
||||
async def build_context_data(self, params: SmartPromptParameters) -> Dict[str, Any]:
|
||||
"""并行构建完整的上下文数据 - 移除缓存机制和依赖检查"""
|
||||
|
||||
# 并行执行所有构建任务
|
||||
start_time = time.time()
|
||||
timing_logs = {}
|
||||
|
||||
try:
|
||||
# 准备构建任务
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
# 初始化预构建参数,使用新的结构
|
||||
pre_built_params = {}
|
||||
if params.expression_habits_block:
|
||||
pre_built_params["expression_habits_block"] = params.expression_habits_block
|
||||
if params.relation_info_block:
|
||||
pre_built_params["relation_info_block"] = params.relation_info_block
|
||||
if params.memory_block:
|
||||
pre_built_params["memory_block"] = params.memory_block
|
||||
if params.tool_info_block:
|
||||
pre_built_params["tool_info_block"] = params.tool_info_block
|
||||
if params.knowledge_prompt:
|
||||
pre_built_params["knowledge_prompt"] = params.knowledge_prompt
|
||||
if params.cross_context_block:
|
||||
pre_built_params["cross_context_block"] = params.cross_context_block
|
||||
|
||||
# 根据新的参数结构确定要构建的项
|
||||
if params.enable_expression and not pre_built_params.get("expression_habits_block"):
|
||||
tasks.append(self._build_expression_habits(params))
|
||||
task_names.append("expression_habits")
|
||||
|
||||
if params.enable_memory and not pre_built_params.get("memory_block"):
|
||||
tasks.append(self._build_memory_block(params))
|
||||
task_names.append("memory_block")
|
||||
|
||||
if params.enable_relation and not pre_built_params.get("relation_info_block"):
|
||||
tasks.append(self._build_relation_info(params))
|
||||
task_names.append("relation_info")
|
||||
|
||||
# 添加mai_think上下文构建任务
|
||||
if not pre_built_params.get("mai_think"):
|
||||
tasks.append(self._build_mai_think_context(params))
|
||||
task_names.append("mai_think_context")
|
||||
|
||||
if params.enable_tool and not pre_built_params.get("tool_info_block"):
|
||||
tasks.append(self._build_tool_info(params))
|
||||
task_names.append("tool_info")
|
||||
|
||||
if params.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
|
||||
tasks.append(self._build_knowledge_info(params))
|
||||
task_names.append("knowledge_info")
|
||||
|
||||
if params.enable_cross_context and not pre_built_params.get("cross_context_block"):
|
||||
tasks.append(self._build_cross_context(params))
|
||||
task_names.append("cross_context")
|
||||
|
||||
# 性能优化:根据任务数量动态调整超时时间
|
||||
base_timeout = 10.0 # 基础超时时间
|
||||
task_timeout = 2.0 # 每个任务的超时时间
|
||||
timeout_seconds = min(
|
||||
max(base_timeout, len(tasks) * task_timeout), # 根据任务数量计算超时
|
||||
30.0, # 最大超时时间
|
||||
)
|
||||
|
||||
# 性能优化:限制并发任务数量,避免资源耗尽
|
||||
max_concurrent_tasks = 5 # 最大并发任务数
|
||||
if len(tasks) > max_concurrent_tasks:
|
||||
# 分批执行任务
|
||||
results = []
|
||||
for i in range(0, len(tasks), max_concurrent_tasks):
|
||||
batch_tasks = tasks[i : i + max_concurrent_tasks]
|
||||
batch_names = task_names[i : i + max_concurrent_tasks]
|
||||
|
||||
batch_results = await asyncio.wait_for(
|
||||
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
results.extend(batch_results)
|
||||
else:
|
||||
# 一次性执行所有任务
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
|
||||
# 处理结果并收集性能数据
|
||||
context_data = {}
|
||||
for i, result in enumerate(results):
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"构建任务{task_name}失败: {str(result)}")
|
||||
elif isinstance(result, dict):
|
||||
# 结果格式: {component_name: value}
|
||||
context_data.update(result)
|
||||
|
||||
# 记录耗时过长的任务
|
||||
if task_name in timing_logs and timing_logs[task_name] > 8.0:
|
||||
logger.warning(f"构建任务{task_name}耗时过长: {timing_logs[task_name]:.2f}s")
|
||||
|
||||
# 添加预构建的参数
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"构建超时 ({timeout_seconds}s)")
|
||||
context_data = {}
|
||||
|
||||
# 添加预构建的参数,即使在超时情况下
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
|
||||
# 构建聊天历史 - 根据模式不同
|
||||
if params.prompt_mode == "s4u":
|
||||
await self._build_s4u_chat_context(context_data, params)
|
||||
else:
|
||||
await self._build_normal_chat_context(context_data, params)
|
||||
|
||||
# 补充基础信息
|
||||
context_data.update(
|
||||
{
|
||||
"keywords_reaction_prompt": params.keywords_reaction_prompt,
|
||||
"extra_info_block": params.extra_info_block,
|
||||
"time_block": params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"identity": params.identity_block,
|
||||
"schedule_block": params.schedule_block,
|
||||
"moderation_prompt": params.moderation_prompt_block,
|
||||
"reply_target_block": params.reply_target_block,
|
||||
"mood_state": params.mood_prompt,
|
||||
"action_descriptions": params.action_descriptions,
|
||||
}
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
if timing_logs:
|
||||
timing_str = "; ".join([f"{name}: {time:.2f}s" for name, time in timing_logs.items()])
|
||||
logger.info(f"构建任务耗时: {timing_str}")
|
||||
logger.debug(f"构建完成,总耗时: {total_time:.2f}s")
|
||||
|
||||
return context_data
|
||||
|
||||
async def _build_s4u_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None:
|
||||
"""构建S4U模式的聊天上下文 - 使用新参数结构"""
|
||||
if not params.message_list_before_now_long:
|
||||
return
|
||||
|
||||
# 使用共享工具构建分离历史
|
||||
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
|
||||
params.message_list_before_now_long,
|
||||
params.target_user_info.get("user_id") if params.target_user_info else "",
|
||||
params.sender
|
||||
)
|
||||
|
||||
context_data["core_dialogue_prompt"] = core_dialogue
|
||||
context_data["background_dialogue_prompt"] = background_dialogue
|
||||
|
||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None:
|
||||
"""构建normal模式的聊天上下文 - 使用新参数结构"""
|
||||
if not params.chat_talking_prompt_short:
|
||||
return
|
||||
|
||||
context_data["chat_info"] = f"""群里的聊天内容:
|
||||
{params.chat_talking_prompt_short}"""
|
||||
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
"""构建S4U风格的分离对话prompt - 完整实现"""
|
||||
core_dialogue_list = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
|
||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
# bot 和目标用户的对话
|
||||
core_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
all_dialogue_prompt = ""
|
||||
if message_list_before_now:
|
||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
||||
all_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True,
|
||||
)
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
|
||||
# 构建核心对话 prompt
|
||||
core_dialogue_prompt = ""
|
||||
if core_dialogue_list:
|
||||
# 检查最新五条消息中是否包含bot自己说的消息
|
||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
||||
|
||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||
|
||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
||||
if not has_bot_message:
|
||||
core_dialogue_prompt = ""
|
||||
else:
|
||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||
|
||||
core_dialogue_prompt_str = build_readable_messages(
|
||||
core_dialogue_list,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""--------------------------------
|
||||
这是你和{sender}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt_str}
|
||||
--------------------------------
|
||||
"""
|
||||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
|
||||
async def _build_mai_think_context(self, params: SmartPromptParameters) -> Any:
|
||||
"""构建mai_think上下文 - 完全继承DefaultReplyer功能"""
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
|
||||
# 获取mai_think实例
|
||||
mai_think = mai_thinking_manager.get_mai_think(params.chat_id)
|
||||
|
||||
# 设置mai_think的上下文信息
|
||||
mai_think.memory_block = params.memory_block or ""
|
||||
mai_think.relation_info_block = params.relation_info_block or ""
|
||||
mai_think.time_block = params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
# 设置聊天目标信息
|
||||
if params.is_group_chat:
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
|
||||
else:
|
||||
chat_target_name = "对方"
|
||||
if params.chat_target_info:
|
||||
chat_target_name = (
|
||||
params.chat_target_info.get("person_name") or params.chat_target_info.get("user_nickname") or "对方"
|
||||
)
|
||||
chat_target_1 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private1", sender_name=chat_target_name
|
||||
)
|
||||
chat_target_2 = await global_prompt_manager.format_prompt(
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
|
||||
mai_think.chat_target = chat_target_1
|
||||
mai_think.chat_target_2 = chat_target_2
|
||||
mai_think.chat_info = params.chat_talking_prompt_short or ""
|
||||
mai_think.mood_state = params.mood_prompt or ""
|
||||
mai_think.identity = params.identity_block or ""
|
||||
mai_think.sender = params.sender
|
||||
mai_think.target = params.target
|
||||
|
||||
# 返回mai_think实例,以便后续使用
|
||||
return mai_think
|
||||
|
||||
def _parse_reply_target_id(self, reply_to: str) -> str:
|
||||
"""解析回复目标中的用户ID"""
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
# 复用_parse_reply_target方法的逻辑
|
||||
sender, _ = self._parse_reply_target(reply_to)
|
||||
if not sender:
|
||||
return ""
|
||||
|
||||
# 获取用户ID
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
return str(user_id) if user_id else ""
|
||||
|
||||
async def _build_expression_habits(self, params: SmartPromptParameters) -> Dict[str, Any]:
|
||||
"""构建表达习惯 - 使用共享工具类,完全继承DefaultReplyer功能"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(params.chat_id)
|
||||
if not use_expression:
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
|
||||
style_habits = []
|
||||
grammar_habits = []
|
||||
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
try:
|
||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
||||
params.chat_id, params.chat_talking_prompt_short, max_num=8, min_num=2, target_message=params.target
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式失败: {e}")
|
||||
selected_expressions = []
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "grammar":
|
||||
grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
|
||||
style_habits_str = "\n".join(style_habits)
|
||||
grammar_habits_str = "\n".join(grammar_habits)
|
||||
|
||||
# 动态构建expression habits块
|
||||
expression_habits_block = ""
|
||||
expression_habits_title = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_title = (
|
||||
"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
|
||||
)
|
||||
expression_habits_block += f"{style_habits_str}\n"
|
||||
if grammar_habits_str.strip():
|
||||
expression_habits_title = (
|
||||
"你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
|
||||
)
|
||||
expression_habits_block += f"{grammar_habits_str}\n"
|
||||
|
||||
if style_habits_str.strip() and grammar_habits_str.strip():
|
||||
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中。"
|
||||
|
||||
return {"expression_habits_block": f"{expression_habits_title}\n{expression_habits_block}"}
|
||||
|
||||
async def _build_memory_block(self, params: SmartPromptParameters) -> Dict[str, Any]:
|
||||
"""构建记忆块 - 使用共享工具类,完全继承DefaultReplyer功能"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
|
||||
|
||||
instant_memory = None
|
||||
|
||||
# 初始化记忆激活器
|
||||
try:
|
||||
memory_activator = MemoryActivator()
|
||||
|
||||
# 获取长期记忆
|
||||
running_memories = await memory_activator.activate_memory_with_chat_history(
|
||||
target_message=params.target, chat_history_prompt=params.chat_talking_prompt_short
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"激活记忆失败: {e}")
|
||||
running_memories = []
|
||||
|
||||
# 处理瞬时记忆
|
||||
if global_config.memory.enable_instant_memory:
|
||||
# 使用异步记忆包装器(最优化的非阻塞模式)
|
||||
try:
|
||||
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
|
||||
|
||||
# 获取异步记忆包装器
|
||||
async_memory = get_async_instant_memory(params.chat_id)
|
||||
|
||||
# 后台存储聊天历史(完全非阻塞)
|
||||
async_memory.store_memory_background(params.chat_talking_prompt_short)
|
||||
|
||||
# 快速检索记忆,最大超时2秒
|
||||
instant_memory = await async_memory.get_memory_with_fallback(params.target, max_timeout=2.0)
|
||||
|
||||
logger.info(f"异步瞬时记忆:{instant_memory}")
|
||||
|
||||
except ImportError:
|
||||
# 如果异步包装器不可用,尝试使用异步记忆管理器
|
||||
try:
|
||||
from src.chat.memory_system.async_memory_optimizer import (
|
||||
retrieve_memory_nonblocking,
|
||||
store_memory_nonblocking,
|
||||
)
|
||||
|
||||
# 异步存储聊天历史(非阻塞)
|
||||
asyncio.create_task(
|
||||
store_memory_nonblocking(chat_id=params.chat_id, content=params.chat_talking_prompt_short)
|
||||
)
|
||||
|
||||
# 尝试从缓存获取瞬时记忆
|
||||
instant_memory = await retrieve_memory_nonblocking(chat_id=params.chat_id, query=params.target)
|
||||
|
||||
# 如果没有缓存结果,快速检索一次
|
||||
if instant_memory is None:
|
||||
try:
|
||||
# 使用VectorInstantMemoryV2实例
|
||||
instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1)
|
||||
instant_memory = await asyncio.wait_for(
|
||||
instant_memory_system.get_memory_for_context(params.target), timeout=1.5
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("瞬时记忆检索超时,使用空结果")
|
||||
instant_memory = ""
|
||||
|
||||
logger.info(f"向量瞬时记忆:{instant_memory}")
|
||||
|
||||
except ImportError:
|
||||
# 最后的fallback:使用原有逻辑但加上超时控制
|
||||
logger.warning("异步记忆系统不可用,使用带超时的同步方式")
|
||||
|
||||
# 使用VectorInstantMemoryV2实例
|
||||
instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1)
|
||||
|
||||
# 异步存储聊天历史
|
||||
asyncio.create_task(instant_memory_system.store_message(params.chat_talking_prompt_short))
|
||||
|
||||
# 带超时的记忆检索
|
||||
try:
|
||||
instant_memory = await asyncio.wait_for(
|
||||
instant_memory_system.get_memory_for_context(params.target),
|
||||
timeout=1.0, # 最保守的1秒超时
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("瞬时记忆检索超时,跳过记忆获取")
|
||||
instant_memory = ""
|
||||
except Exception as e:
|
||||
logger.error(f"瞬时记忆检索失败: {e}")
|
||||
instant_memory = ""
|
||||
|
||||
logger.info(f"同步瞬时记忆:{instant_memory}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"瞬时记忆系统异常: {e}")
|
||||
instant_memory = ""
|
||||
|
||||
# 构建记忆字符串,即使某种记忆为空也要继续
|
||||
memory_str = ""
|
||||
has_any_memory = False
|
||||
|
||||
# 添加长期记忆
|
||||
if running_memories:
|
||||
if not memory_str:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memories:
|
||||
memory_str += f"- {running_memory['content']}\n"
|
||||
has_any_memory = True
|
||||
|
||||
# 添加瞬时记忆
|
||||
if instant_memory:
|
||||
if not memory_str:
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
memory_str += f"- {instant_memory}\n"
|
||||
has_any_memory = True
|
||||
|
||||
# 注入视频分析结果引导语
|
||||
memory_str = self._inject_video_prompt_if_needed(params.target, memory_str)
|
||||
|
||||
# 只有当完全没有任何记忆时才返回空字符串
|
||||
return {"memory_block": memory_str if has_any_memory else ""}
|
||||
|
||||
def _inject_video_prompt_if_needed(self, target: str, memory_str: str) -> str:
|
||||
"""统一视频分析结果注入逻辑"""
|
||||
if target and ("[视频内容]" in target or "好的,我将根据您提供的" in target):
|
||||
video_prompt_injection = (
|
||||
"\n请注意,以上内容是你刚刚观看的视频,请以第一人称分享你的观后感,而不是在分析一份报告。"
|
||||
)
|
||||
return memory_str + video_prompt_injection
|
||||
return memory_str
|
||||
|
||||
async def _build_relation_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
|
||||
"""构建关系信息 - 使用共享工具类"""
|
||||
try:
|
||||
relation_info = await PromptUtils.build_relation_info(params.chat_id, params.reply_to)
|
||||
return {"relation_info_block": relation_info}
|
||||
except Exception as e:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return {"relation_info_block": ""}
|
||||
|
||||
async def _build_tool_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
|
||||
"""构建工具信息 - 使用共享工具类,完全继承DefaultReplyer功能"""
|
||||
if not params.enable_tool:
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
if not params.reply_to:
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
sender, text = PromptUtils.parse_reply_target(params.reply_to)
|
||||
|
||||
if not text:
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
|
||||
# 使用工具执行器获取信息
|
||||
try:
|
||||
tool_executor = ToolExecutor(chat_id=params.chat_id)
|
||||
tool_results, _, _ = await tool_executor.execute_from_chat_message(
|
||||
sender=sender, target_message=text, chat_history=params.chat_talking_prompt_short, return_details=False
|
||||
)
|
||||
|
||||
if tool_results:
|
||||
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
|
||||
for tool_result in tool_results:
|
||||
tool_name = tool_result.get("tool_name", "unknown")
|
||||
content = tool_result.get("content", "")
|
||||
result_type = tool_result.get("type", "tool_result")
|
||||
|
||||
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
||||
|
||||
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
||||
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
||||
|
||||
return {"tool_info_block": tool_info_str}
|
||||
else:
|
||||
logger.debug("未获取到任何工具结果")
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具信息获取失败: {e}")
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
async def _build_knowledge_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
|
||||
"""构建知识信息 - 使用共享工具类,完全继承DefaultReplyer功能"""
|
||||
if not params.reply_to:
|
||||
logger.debug("没有回复对象,跳过获取知识库内容")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
sender, content = PromptUtils.parse_reply_target(params.reply_to)
|
||||
if not content:
|
||||
logger.debug("回复对象内容为空,跳过获取知识库内容")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
logger.debug(
|
||||
f"获取知识库内容,元消息:{params.chat_talking_prompt_short[:30]}...,消息长度: {len(params.chat_talking_prompt_short)}"
|
||||
)
|
||||
|
||||
# 从LPMM知识库获取知识
|
||||
try:
|
||||
# 检查LPMM知识库是否启用
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.config.config import model_config
|
||||
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"lpmm_get_knowledge_prompt",
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
chat_history=params.chat_talking_prompt_short,
|
||||
sender=sender,
|
||||
target_message=content,
|
||||
)
|
||||
|
||||
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||
prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
|
||||
tool_executor = ToolExecutor(chat_id=params.chat_id)
|
||||
result = await tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
|
||||
|
||||
if not result or not result.get("content"):
|
||||
logger.debug("从LPMM知识库获取知识失败,返回空知识...")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
found_knowledge_from_lpmm = result.get("content", "")
|
||||
logger.debug(
|
||||
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"knowledge_prompt": f"你有以下这些**知识**:\n{found_knowledge_from_lpmm}\n请你**记住上面的知识**,之后可能会用到。\n"
|
||||
}
|
||||
else:
|
||||
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
async def _build_cross_context(self, params: SmartPromptParameters) -> Dict[str, Any]:
|
||||
"""构建跨群上下文 - 使用共享工具类"""
|
||||
try:
|
||||
cross_context = await PromptUtils.build_cross_context(
|
||||
params.chat_id, params.prompt_mode, params.target_user_info
|
||||
)
|
||||
return {"cross_context_block": cross_context}
|
||||
except Exception as e:
|
||||
logger.error(f"构建跨群上下文失败: {e}")
|
||||
return {"cross_context_block": ""}
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具类"""
|
||||
return PromptUtils.parse_reply_target(target_message)
|
||||
|
||||
|
||||
class SmartPrompt:
|
||||
"""重构的智能提示词核心类 - 移除缓存机制和依赖检查,简化架构"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template_name: Optional[str] = None,
|
||||
parameters: Optional[SmartPromptParameters] = None,
|
||||
):
|
||||
self.parameters = parameters or SmartPromptParameters()
|
||||
self.template_name = template_name or self._get_default_template()
|
||||
self.builder = SmartPromptBuilder()
|
||||
|
||||
def _get_default_template(self) -> str:
|
||||
"""根据模式选择默认模板"""
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
return "s4u_style_prompt"
|
||||
elif self.parameters.prompt_mode == "normal":
|
||||
return "normal_style_prompt"
|
||||
else:
|
||||
return "default_expressor_prompt"
|
||||
|
||||
async def build_prompt(self) -> str:
|
||||
"""构建最终的Prompt文本 - 移除缓存机制和依赖检查"""
|
||||
# 参数验证
|
||||
errors = self.parameters.validate()
|
||||
if errors:
|
||||
logger.error(f"参数验证失败: {', '.join(errors)}")
|
||||
raise ValueError(f"参数验证失败: {', '.join(errors)}")
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建基础上下文的完整映射
|
||||
context_data = await self.builder.build_context_data(self.parameters)
|
||||
|
||||
# 检查关键上下文数据
|
||||
if not context_data or not isinstance(context_data, dict):
|
||||
logger.error("构建的上下文数据无效")
|
||||
raise ValueError("构建的上下文数据无效")
|
||||
|
||||
# 获取模板
|
||||
template = await self._get_template()
|
||||
if template is None:
|
||||
logger.error("无法获取模板")
|
||||
raise ValueError("无法获取模板")
|
||||
|
||||
# 根据模式传递不同的参数
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
result = await self._build_s4u_prompt(template, context_data)
|
||||
elif self.parameters.prompt_mode == "normal":
|
||||
result = await self._build_normal_prompt(template, context_data)
|
||||
else:
|
||||
result = await self._build_default_prompt(template, context_data)
|
||||
|
||||
# 记录性能数据
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"SmartPrompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"构建Prompt超时: {e}")
|
||||
raise TimeoutError(f"构建Prompt超时: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"构建Prompt失败: {e}")
|
||||
raise RuntimeError(f"构建Prompt失败: {e}")
|
||||
|
||||
async def _get_template(self) -> Optional[Prompt]:
|
||||
"""获取模板"""
|
||||
try:
|
||||
return await global_prompt_manager.get_prompt_async(self.template_name)
|
||||
except Exception as e:
|
||||
logger.error(f"获取模板 {self.template_name} 失败: {e}")
|
||||
raise RuntimeError(f"获取模板 {self.template_name} 失败: {e}")
|
||||
|
||||
async def _build_s4u_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
|
||||
"""构建S4U模式的完整Prompt - 使用新参数结构"""
|
||||
params = {
|
||||
**context_data,
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
"tool_info_block": context_data.get("tool_info_block", ""),
|
||||
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
|
||||
"memory_block": context_data.get("memory_block", ""),
|
||||
"relation_info_block": context_data.get("relation_info_block", ""),
|
||||
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
|
||||
"cross_context_block": context_data.get("cross_context_block", ""),
|
||||
"identity": self.parameters.identity_block or context_data.get("identity", ""),
|
||||
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
|
||||
"sender_name": self.parameters.sender,
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"reply_style": global_config.personality.reply_style,
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
}
|
||||
return await global_prompt_manager.format_prompt(self.template_name, **params)
|
||||
|
||||
async def _build_normal_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
|
||||
"""构建Normal模式的完整Prompt - 使用新参数结构"""
|
||||
params = {
|
||||
**context_data,
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
"tool_info_block": context_data.get("tool_info_block", ""),
|
||||
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
|
||||
"memory_block": context_data.get("memory_block", ""),
|
||||
"relation_info_block": context_data.get("relation_info_block", ""),
|
||||
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
|
||||
"cross_context_block": context_data.get("cross_context_block", ""),
|
||||
"identity": self.parameters.identity_block or context_data.get("identity", ""),
|
||||
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
|
||||
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
"chat_info": context_data.get("chat_info", ""),
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"config_expression_style": global_config.personality.reply_style,
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
}
|
||||
return await global_prompt_manager.format_prompt(self.template_name, **params)
|
||||
|
||||
async def _build_default_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
|
||||
"""构建默认模式的Prompt - 使用新参数结构"""
|
||||
params = {
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
"relation_info_block": context_data.get("relation_info_block", ""),
|
||||
"chat_target": "",
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
"chat_info": context_data.get("chat_info", ""),
|
||||
"identity": self.parameters.identity_block or context_data.get("identity", ""),
|
||||
"chat_target_2": "",
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"raw_reply": self.parameters.target,
|
||||
"reason": "",
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"reply_style": global_config.personality.reply_style,
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
}
|
||||
return await global_prompt_manager.format_prompt(self.template_name, **params)
|
||||
|
||||
|
||||
# 工厂函数 - 简化创建 - 更新参数结构
|
||||
def create_smart_prompt(
|
||||
chat_id: str = "", sender_name: str = "", target_message: str = "", reply_to: str = "", **kwargs
|
||||
) -> SmartPrompt:
|
||||
"""快速创建智能Prompt实例的工厂函数 - 使用新参数结构"""
|
||||
|
||||
# 使用新的参数结构
|
||||
parameters = SmartPromptParameters(
|
||||
chat_id=chat_id, sender=sender_name, target=target_message, reply_to=reply_to, **kwargs
|
||||
)
|
||||
|
||||
return SmartPrompt(parameters=parameters)
|
||||
|
||||
|
||||
class SmartPromptHealthChecker:
|
||||
"""SmartPrompt健康检查器 - 移除依赖检查"""
|
||||
|
||||
@staticmethod
|
||||
async def check_system_health() -> Dict[str, Any]:
|
||||
"""检查系统健康状态 - 移除依赖检查"""
|
||||
health_status = {"status": "healthy", "components": {}, "issues": []}
|
||||
|
||||
try:
|
||||
# 检查配置
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
health_status["components"]["config"] = "ok"
|
||||
|
||||
# 检查关键配置项
|
||||
if not hasattr(global_config, "personality") or not hasattr(global_config.personality, "prompt_mode"):
|
||||
health_status["issues"].append("缺少personality.prompt_mode配置")
|
||||
health_status["status"] = "degraded"
|
||||
|
||||
if not hasattr(global_config, "memory") or not hasattr(global_config.memory, "enable_memory"):
|
||||
health_status["issues"].append("缺少memory.enable_memory配置")
|
||||
|
||||
except Exception as e:
|
||||
health_status["components"]["config"] = f"failed: {str(e)}"
|
||||
health_status["issues"].append("配置加载失败")
|
||||
health_status["status"] = "unhealthy"
|
||||
|
||||
# 检查Prompt模板
|
||||
try:
|
||||
required_templates = ["s4u_style_prompt", "normal_style_prompt", "default_expressor_prompt"]
|
||||
for template_name in required_templates:
|
||||
try:
|
||||
await global_prompt_manager.get_prompt_async(template_name)
|
||||
health_status["components"][f"template_{template_name}"] = "ok"
|
||||
except Exception as e:
|
||||
health_status["components"][f"template_{template_name}"] = f"failed: {str(e)}"
|
||||
health_status["issues"].append(f"模板{template_name}加载失败")
|
||||
health_status["status"] = "degraded"
|
||||
|
||||
except Exception as e:
|
||||
health_status["components"]["prompt_templates"] = f"failed: {str(e)}"
|
||||
health_status["issues"].append("Prompt模板检查失败")
|
||||
health_status["status"] = "unhealthy"
|
||||
|
||||
return health_status
|
||||
|
||||
except Exception as e:
|
||||
return {"status": "unhealthy", "components": {}, "issues": [f"健康检查异常: {str(e)}"]}
|
||||
|
||||
@staticmethod
|
||||
async def run_performance_test() -> Dict[str, Any]:
|
||||
"""运行性能测试"""
|
||||
test_results = {"status": "completed", "tests": {}, "summary": {}}
|
||||
|
||||
try:
|
||||
# 创建测试参数
|
||||
test_params = SmartPromptParameters(
|
||||
chat_id="test_chat",
|
||||
sender="test_user",
|
||||
target="test_message",
|
||||
reply_to="test_user:test_message",
|
||||
prompt_mode="s4u",
|
||||
)
|
||||
|
||||
# 测试不同模式下的构建性能
|
||||
modes = ["s4u", "normal", "minimal"]
|
||||
for mode in modes:
|
||||
test_params.prompt_mode = mode
|
||||
smart_prompt = SmartPrompt(parameters=test_params)
|
||||
|
||||
# 运行多次测试取平均值
|
||||
times = []
|
||||
for _ in range(3):
|
||||
start_time = time.time()
|
||||
try:
|
||||
await smart_prompt.build_prompt()
|
||||
end_time = time.time()
|
||||
times.append(end_time - start_time)
|
||||
except Exception as e:
|
||||
times.append(float("inf"))
|
||||
logger.error(f"性能测试失败 (模式: {mode}): {e}")
|
||||
|
||||
# 计算统计信息
|
||||
valid_times = [t for t in times if t != float("inf")]
|
||||
if valid_times:
|
||||
avg_time = sum(valid_times) / len(valid_times)
|
||||
min_time = min(valid_times)
|
||||
max_time = max(valid_times)
|
||||
|
||||
test_results["tests"][mode] = {
|
||||
"avg_time": avg_time,
|
||||
"min_time": min_time,
|
||||
"max_time": max_time,
|
||||
"success_rate": len(valid_times) / len(times),
|
||||
}
|
||||
else:
|
||||
test_results["tests"][mode] = {
|
||||
"avg_time": float("inf"),
|
||||
"min_time": float("inf"),
|
||||
"max_time": float("inf"),
|
||||
"success_rate": 0,
|
||||
}
|
||||
|
||||
# 计算总体统计
|
||||
all_avg_times = [
|
||||
test["avg_time"] for test in test_results["tests"].values() if test["avg_time"] != float("inf")
|
||||
]
|
||||
if all_avg_times:
|
||||
test_results["summary"] = {
|
||||
"overall_avg_time": sum(all_avg_times) / len(all_avg_times),
|
||||
"fastest_mode": min(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0],
|
||||
"slowest_mode": max(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0],
|
||||
}
|
||||
|
||||
return test_results
|
||||
|
||||
except Exception as e:
|
||||
return {"status": "failed", "tests": {}, "summary": {}, "error": str(e)}
|
||||
@@ -4,7 +4,6 @@ import time
|
||||
import hashlib
|
||||
import uuid
|
||||
import io
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
@@ -27,16 +26,15 @@ logger = get_logger("chat_image")
|
||||
def is_image_message(message: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断消息是否为图片消息
|
||||
|
||||
|
||||
Args:
|
||||
message: 消息字典
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否为图片消息
|
||||
"""
|
||||
return message.get("type") == "image" or (
|
||||
isinstance(message.get("content"), dict) and
|
||||
message["content"].get("type") == "image"
|
||||
isinstance(message.get("content"), dict) and message["content"].get("type") == "image"
|
||||
)
|
||||
|
||||
|
||||
@@ -596,7 +594,6 @@ class ImageManager:
|
||||
return "", "[图片]"
|
||||
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
image_manager = None
|
||||
|
||||
|
||||
@@ -361,6 +361,7 @@ class GraphNodes(Base):
|
||||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||
memory_items = Column(Text, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
weight = Column(Float, nullable=False, default=1.0)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
|
||||
@@ -443,6 +443,12 @@ MODULE_COLORS = {
|
||||
"manifest_utils": "\033[38;5;39m", # 蓝色
|
||||
"schedule_manager": "\033[38;5;27m", # 深蓝色
|
||||
"monthly_plan_manager": "\033[38;5;171m",
|
||||
"plan_manager": "\033[38;5;171m",
|
||||
"llm_generator": "\033[38;5;171m",
|
||||
"schedule_bridge": "\033[38;5;171m",
|
||||
"sleep_manager": "\033[38;5;171m",
|
||||
"official_configs": "\033[38;5;171m",
|
||||
"mmc_com_layer": "\033[38;5;67m",
|
||||
# 聊天和多媒体扩展
|
||||
"chat_voice": "\033[38;5;87m", # 浅青色
|
||||
"typo_gen": "\033[38;5;123m", # 天蓝色
|
||||
@@ -564,8 +570,14 @@ MODULE_ALIASES = {
|
||||
"dependency_config": "依赖配置",
|
||||
"dependency_manager": "依赖管理",
|
||||
"manifest_utils": "清单工具",
|
||||
"schedule_manager": "计划管理",
|
||||
"monthly_plan_manager": "月度计划",
|
||||
"schedule_manager": "规划系统-日程表管理",
|
||||
"monthly_plan_manager": "规划系统-月度计划",
|
||||
"plan_manager": "规划系统-计划管理",
|
||||
"llm_generator": "规划系统-LLM生成",
|
||||
"schedule_bridge": "计划桥接",
|
||||
"sleep_manager": "睡眠管理",
|
||||
"official_configs": "官方配置",
|
||||
"mmc_com_layer": "MMC通信层",
|
||||
# 聊天和多媒体扩展
|
||||
"chat_voice": "语音处理",
|
||||
"typo_gen": "错字生成",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import List, Dict, Any, Literal
|
||||
from typing import List, Dict, Any, Literal, Union
|
||||
from pydantic import Field, field_validator
|
||||
from threading import Lock
|
||||
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
|
||||
@@ -9,7 +10,7 @@ class APIProvider(ValidatedConfigBase):
|
||||
|
||||
name: str = Field(..., min_length=1, description="API提供商名称")
|
||||
base_url: str = Field(..., description="API基础URL")
|
||||
api_key: str = Field(..., min_length=1, description="API密钥")
|
||||
api_key: Union[str, List[str]] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询")
|
||||
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
|
||||
default="openai", description="客户端类型(如openai/google等,默认为openai)"
|
||||
)
|
||||
@@ -33,12 +34,33 @@ class APIProvider(ValidatedConfigBase):
|
||||
@classmethod
|
||||
def validate_api_key(cls, v):
|
||||
"""验证API密钥不能为空"""
|
||||
if not v or not v.strip():
|
||||
raise ValueError("API密钥不能为空")
|
||||
if isinstance(v, str):
|
||||
if not v.strip():
|
||||
raise ValueError("API密钥不能为空")
|
||||
elif isinstance(v, list):
|
||||
if not v:
|
||||
raise ValueError("API密钥列表不能为空")
|
||||
for key in v:
|
||||
if not isinstance(key, str) or not key.strip():
|
||||
raise ValueError("API密钥列表中的密钥不能为空")
|
||||
else:
|
||||
raise ValueError("API密钥必须是字符串或字符串列表")
|
||||
return v
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._api_key_lock = Lock()
|
||||
self._api_key_index = 0
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.api_key
|
||||
with self._api_key_lock:
|
||||
if isinstance(self.api_key, str):
|
||||
return self.api_key
|
||||
if not self.api_key:
|
||||
raise ValueError("API密钥列表为空")
|
||||
key = self.api_key[self._api_key_index]
|
||||
self._api_key_index = (self._api_key_index + 1) % len(self.api_key)
|
||||
return key
|
||||
|
||||
|
||||
class ModelInfo(ValidatedConfigBase):
|
||||
@@ -113,6 +135,7 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
voice: TaskConfig = Field(..., description="语音识别模型配置")
|
||||
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
|
||||
planner: TaskConfig = Field(..., description="规划模型配置")
|
||||
planner_small: TaskConfig = Field(..., description="小脑(sub-planner)规划模型配置")
|
||||
embedding: TaskConfig = Field(..., description="嵌入模型配置")
|
||||
lpmm_entity_extract: TaskConfig = Field(..., description="LPMM实体提取模型配置")
|
||||
lpmm_rdf_build: TaskConfig = Field(..., description="LPMM RDF构建模型配置")
|
||||
@@ -147,9 +170,9 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
models: List[ModelInfo] = Field(..., min_items=1, description="模型列表")
|
||||
models: List[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表")
|
||||
api_providers: List[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -35,17 +35,16 @@ from src.config.official_configs import (
|
||||
VoiceConfig,
|
||||
DebugConfig,
|
||||
CustomPromptConfig,
|
||||
ScheduleConfig,
|
||||
VideoAnalysisConfig,
|
||||
DependencyManagementConfig,
|
||||
WebSearchConfig,
|
||||
AntiPromptInjectionConfig,
|
||||
SleepSystemConfig,
|
||||
MonthlyPlanSystemConfig,
|
||||
CrossContextConfig,
|
||||
PermissionConfig,
|
||||
CommandConfig,
|
||||
MaizoneIntercomConfig,
|
||||
PlanningSystemConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
@@ -81,8 +80,8 @@ def get_key_comment(toml_table, key):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, "keys"):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key: # type: ignore
|
||||
return k.trivia.comment # type: ignore
|
||||
if isinstance(k, KeyType) and k.key == key: # type: ignore
|
||||
return k.trivia.comment # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
@@ -379,7 +378,6 @@ class Config(ValidatedConfigBase):
|
||||
debug: DebugConfig = Field(..., description="调试配置")
|
||||
custom_prompt: CustomPromptConfig = Field(..., description="自定义提示配置")
|
||||
voice: VoiceConfig = Field(..., description="语音配置")
|
||||
schedule: ScheduleConfig = Field(..., description="调度配置")
|
||||
permission: PermissionConfig = Field(..., description="权限配置")
|
||||
command: CommandConfig = Field(..., description="命令系统配置")
|
||||
|
||||
@@ -395,8 +393,8 @@ class Config(ValidatedConfigBase):
|
||||
)
|
||||
web_search: WebSearchConfig = Field(default_factory=lambda: WebSearchConfig(), description="网络搜索配置")
|
||||
sleep_system: SleepSystemConfig = Field(default_factory=lambda: SleepSystemConfig(), description="睡眠系统配置")
|
||||
monthly_plan_system: MonthlyPlanSystemConfig = Field(
|
||||
default_factory=lambda: MonthlyPlanSystemConfig(), description="月层计划系统配置"
|
||||
planning_system: PlanningSystemConfig = Field(
|
||||
default_factory=lambda: PlanningSystemConfig(), description="规划系统配置"
|
||||
)
|
||||
cross_context: CrossContextConfig = Field(
|
||||
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"
|
||||
|
||||
@@ -75,7 +75,7 @@ class ChatConfig(ValidatedConfigBase):
|
||||
at_bot_inevitable_reply: bool = Field(default=False, description="@机器人的必然回复")
|
||||
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
|
||||
focus_value: float = Field(default=1.0, description="专注值")
|
||||
force_focus_private: bool = Field(default=False, description="强制专注私聊")
|
||||
force_reply_private: bool = Field(default=False, description="强制回复私聊")
|
||||
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
|
||||
timestamp_display_mode: Literal["normal", "normal_no_YMD", "relative"] = Field(
|
||||
default="normal_no_YMD", description="时间戳显示模式"
|
||||
@@ -92,6 +92,7 @@ class ChatConfig(ValidatedConfigBase):
|
||||
default_factory=list, description="启用主动思考的群聊范围,格式:platform:group_id,为空则不限制"
|
||||
)
|
||||
delta_sigma: int = Field(default=120, description="采用正态分布随机时间间隔")
|
||||
planner_size: float = Field(default=5.0, ge=1.0, description="小脑(sub-planner)的尺寸,决定每个小脑处理多少个action")
|
||||
|
||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
@@ -259,7 +260,6 @@ class NormalChatConfig(ValidatedConfigBase):
|
||||
"""普通聊天配置类"""
|
||||
|
||||
|
||||
|
||||
class ExpressionRule(ValidatedConfigBase):
|
||||
"""表达学习规则"""
|
||||
|
||||
@@ -519,11 +519,19 @@ class LPMMKnowledgeConfig(ValidatedConfigBase):
|
||||
embedding_dimension: int = Field(default=1024, description="嵌入维度")
|
||||
|
||||
|
||||
class ScheduleConfig(ValidatedConfigBase):
|
||||
"""日程配置类"""
|
||||
class PlanningSystemConfig(ValidatedConfigBase):
|
||||
"""规划系统配置 (日程与月度计划)"""
|
||||
|
||||
enable: bool = Field(default=True, description="启用")
|
||||
guidelines: Optional[str] = Field(default=None, description="指导方针")
|
||||
# --- 日程生成 (原 ScheduleConfig) ---
|
||||
schedule_enable: bool = Field(True, description="是否启用每日日程生成功能")
|
||||
schedule_guidelines: str = Field("", description="日程生成指导原则")
|
||||
|
||||
# --- 月度计划 (原 MonthlyPlanSystemConfig) ---
|
||||
monthly_plan_enable: bool = Field(True, description="是否启用月度计划系统")
|
||||
monthly_plan_guidelines: str = Field("", description="月度计划生成指导原则")
|
||||
max_plans_per_month: int = Field(10, description="每月最多生成的计划数量")
|
||||
avoid_repetition_days: int = Field(7, description="避免在多少天内重复使用同一个月度计划")
|
||||
completion_threshold: int = Field(3, description="一个月度计划被使用多少次后算作完成")
|
||||
|
||||
|
||||
class DependencyManagementConfig(ValidatedConfigBase):
|
||||
@@ -602,6 +610,11 @@ class SleepSystemConfig(ValidatedConfigBase):
|
||||
"""睡眠系统配置类"""
|
||||
|
||||
enable: bool = Field(default=True, description="是否启用睡眠系统")
|
||||
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
|
||||
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
|
||||
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
|
||||
sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机")
|
||||
wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机")
|
||||
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
|
||||
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
|
||||
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
|
||||
@@ -615,7 +628,12 @@ class SleepSystemConfig(ValidatedConfigBase):
|
||||
|
||||
# --- 失眠机制相关参数 ---
|
||||
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
|
||||
insomnia_duration_minutes: int = Field(default=30, ge=1, description="单次失眠状态的持续时间(分钟)")
|
||||
insomnia_trigger_delay_minutes: List[int] = Field(
|
||||
default_factory=lambda:[30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
|
||||
)
|
||||
insomnia_duration_minutes: List[int] = Field(
|
||||
default_factory=lambda:[15, 45], description="单次失眠状态的持续时间范围(分钟)"
|
||||
)
|
||||
sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值")
|
||||
deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值")
|
||||
insomnia_chance_low_pressure: float = Field(default=0.6, ge=0.0, le=1.0, description="压力不足时的失眠基础概率")
|
||||
@@ -638,22 +656,13 @@ class SleepSystemConfig(ValidatedConfigBase):
|
||||
)
|
||||
|
||||
|
||||
class MonthlyPlanSystemConfig(ValidatedConfigBase):
|
||||
"""月度计划系统配置类"""
|
||||
|
||||
enable: bool = Field(default=True, description="是否启用本功能")
|
||||
max_plans_per_month: int = Field(default=20, ge=1, description="每个月允许存在的最大计划数量")
|
||||
completion_threshold: int = Field(default=3, ge=1, description="计划使用多少次后自动标记为已完成")
|
||||
avoid_repetition_days: int = Field(default=7, ge=1, description="多少天内不重复抽取同一个计划")
|
||||
guidelines: Optional[str] = Field(default=None, description="月度计划生成的指导原则")
|
||||
|
||||
|
||||
class ContextGroup(ValidatedConfigBase):
|
||||
"""上下文共享组配置"""
|
||||
|
||||
name: str = Field(..., description="共享组的名称")
|
||||
chat_ids: List[List[str]] = Field(
|
||||
..., description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]'
|
||||
...,
|
||||
description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]',
|
||||
)
|
||||
|
||||
|
||||
|
||||
52
src/main.py
52
src/main.py
@@ -8,6 +8,7 @@ from maim_message import MessageServer
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.config.config import global_config
|
||||
@@ -29,36 +30,57 @@ from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
if not global_config.memory.enable_memory:
|
||||
import src.chat.memory_system.Hippocampus as hippocampus_module
|
||||
|
||||
|
||||
class MockHippocampusManager:
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def get_hippocampus(self):
|
||||
return None
|
||||
|
||||
async def build_memory(self):
|
||||
pass
|
||||
|
||||
async def forget_memory(self, percentage: float = 0.005):
|
||||
pass
|
||||
|
||||
async def consolidate_memory(self):
|
||||
pass
|
||||
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3, fast_retrieval: bool = False) -> list:
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
max_memory_num: int = 3,
|
||||
max_memory_length: int = 2,
|
||||
max_depth: int = 3,
|
||||
fast_retrieval: bool = False,
|
||||
) -> list:
|
||||
return []
|
||||
async def get_memory_from_topic(self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3) -> list:
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> list:
|
||||
return []
|
||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
|
||||
|
||||
async def get_activate_from_text(
|
||||
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||
) -> tuple[float, list[str]]:
|
||||
return 0.0, []
|
||||
|
||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||
return []
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
return []
|
||||
|
||||
|
||||
hippocampus_module.hippocampus_manager = MockHippocampusManager()
|
||||
|
||||
# 插件系统现在使用统一的插件加载器
|
||||
|
||||
# 插件系统现在使用统一的插件加载器
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -68,7 +90,7 @@ logger = get_logger("main")
|
||||
class MainSystem:
|
||||
def __init__(self):
|
||||
self.hippocampus_manager = hippocampus_manager
|
||||
|
||||
|
||||
self.individuality: Individuality = get_individuality()
|
||||
|
||||
# 使用消息API替代直接的FastAPI实例
|
||||
@@ -93,6 +115,9 @@ class MainSystem:
|
||||
"""清理资源"""
|
||||
try:
|
||||
# 停止消息重组器
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system import EventType
|
||||
asyncio.run(event_manager.trigger_event(EventType.ON_STOP,plugin_name="SYSTEM"))
|
||||
from src.utils.message_chunker import reassembler
|
||||
import asyncio
|
||||
|
||||
@@ -211,7 +236,6 @@ MoFox_Bot(第三方修改版)
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
|
||||
# 启动情绪管理器
|
||||
await mood_manager.start()
|
||||
logger.info("情绪管理器初始化成功")
|
||||
@@ -226,11 +250,11 @@ MoFox_Bot(第三方修改版)
|
||||
# 初始化记忆系统
|
||||
self.hippocampus_manager.initialize()
|
||||
logger.info("记忆系统初始化成功")
|
||||
|
||||
|
||||
# 初始化异步记忆管理器
|
||||
try:
|
||||
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
|
||||
|
||||
|
||||
await async_memory_manager.initialize()
|
||||
logger.info("记忆管理器初始化成功")
|
||||
except Exception as e:
|
||||
@@ -251,7 +275,7 @@ MoFox_Bot(第三方修改版)
|
||||
await self.individuality.initialize()
|
||||
|
||||
# 初始化月度计划管理器
|
||||
if global_config.monthly_plan_system.enable:
|
||||
if global_config.planning_system.monthly_plan_enable:
|
||||
logger.info("正在初始化月度计划管理器...")
|
||||
try:
|
||||
await monthly_plan_manager.start_monthly_plan_generation()
|
||||
@@ -260,7 +284,7 @@ MoFox_Bot(第三方修改版)
|
||||
logger.error(f"月度计划管理器初始化失败: {e}")
|
||||
|
||||
# 初始化日程管理器
|
||||
if global_config.schedule.enable:
|
||||
if global_config.planning_system.schedule_enable:
|
||||
logger.info("日程表功能已启用,正在初始化管理器...")
|
||||
await schedule_manager.load_or_generate_today_schedule()
|
||||
await schedule_manager.start_daily_schedule_generation()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import time
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
|
||||
@@ -6,7 +6,7 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
|
||||
@@ -150,6 +150,18 @@ class PersonInfoManager:
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
# 你们的英文注释是何意味?
|
||||
|
||||
# 检查并修复关键字段为None的情况喵
|
||||
if final_data.get("user_id") is None:
|
||||
logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["user_id"] = "unknown"
|
||||
|
||||
if final_data.get("platform") is None:
|
||||
logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["platform"] = "unknown"
|
||||
|
||||
# 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
@@ -199,6 +211,15 @@ class PersonInfoManager:
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
# 检查并修复关键字段为None的情况
|
||||
if final_data.get("user_id") is None:
|
||||
logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["user_id"] = "unknown"
|
||||
|
||||
if final_data.get("platform") is None:
|
||||
logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["platform"] = "unknown"
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
@@ -295,6 +316,15 @@ class PersonInfoManager:
|
||||
creation_data["platform"] = data["platform"]
|
||||
if data and "user_id" in data:
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
# 额外检查关键字段,如果为None则使用默认值
|
||||
if creation_data.get("user_id") is None:
|
||||
logger.warning(f"创建用户时user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
creation_data["user_id"] = "unknown"
|
||||
|
||||
if creation_data.get("platform") is None:
|
||||
logger.warning(f"创建用户时platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
creation_data["platform"] = "unknown"
|
||||
|
||||
# 使用安全的创建方法,处理竞态条件
|
||||
await self._safe_create_person_info(person_id, creation_data)
|
||||
|
||||
@@ -9,7 +9,7 @@ from json_repair import repair_json
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
|
||||
@@ -25,36 +25,30 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
|
||||
return None
|
||||
|
||||
is_group = current_stream.group_info is not None
|
||||
current_chat_raw_id = (
|
||||
current_stream.group_info.group_id if is_group else current_stream.user_info.user_id
|
||||
)
|
||||
if is_group:
|
||||
assert current_stream.group_info is not None
|
||||
current_chat_raw_id = current_stream.group_info.group_id
|
||||
else:
|
||||
current_chat_raw_id = current_stream.user_info.user_id
|
||||
current_type = "group" if is_group else "private"
|
||||
|
||||
for group in global_config.cross_context.groups:
|
||||
# 检查当前聊天的ID和类型是否在组的chat_ids中
|
||||
if [current_type, str(current_chat_raw_id)] in group.chat_ids:
|
||||
# 返回组内其他聊天的 [type, id] 列表
|
||||
return [
|
||||
chat_info
|
||||
for chat_info in group.chat_ids
|
||||
if chat_info != [current_type, str(current_chat_raw_id)]
|
||||
]
|
||||
return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def build_cross_context_normal(
|
||||
chat_stream: ChatStream, other_chat_infos: List[List[str]]
|
||||
) -> str:
|
||||
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str:
|
||||
"""
|
||||
构建跨群聊/私聊上下文 (Normal模式)
|
||||
"""
|
||||
cross_context_messages = []
|
||||
for chat_type, chat_raw_id in other_chat_infos:
|
||||
is_group = chat_type == "group"
|
||||
stream_id = get_chat_manager().get_stream_id(
|
||||
chat_stream.platform, chat_raw_id, is_group=is_group
|
||||
)
|
||||
stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
||||
if not stream_id:
|
||||
continue
|
||||
|
||||
@@ -66,9 +60,7 @@ async def build_cross_context_normal(
|
||||
)
|
||||
if messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
||||
formatted_messages, _ = build_readable_messages_with_id(
|
||||
messages, timestamp_mode="relative"
|
||||
)
|
||||
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
||||
@@ -95,9 +87,7 @@ async def build_cross_context_s4u(
|
||||
if user_id:
|
||||
for chat_type, chat_raw_id in other_chat_infos:
|
||||
is_group = chat_type == "group"
|
||||
stream_id = get_chat_manager().get_stream_id(
|
||||
chat_stream.platform, chat_raw_id, is_group=is_group
|
||||
)
|
||||
stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
||||
if not stream_id:
|
||||
continue
|
||||
|
||||
@@ -112,9 +102,7 @@ async def build_cross_context_s4u(
|
||||
if user_messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
||||
user_name = (
|
||||
target_user_info.get("person_name")
|
||||
or target_user_info.get("user_nickname")
|
||||
or user_id
|
||||
target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||
)
|
||||
formatted_messages, _ = build_readable_messages_with_id(
|
||||
user_messages, timestamp_mode="relative"
|
||||
@@ -129,4 +117,64 @@ async def build_cross_context_s4u(
|
||||
if not cross_context_messages:
|
||||
return ""
|
||||
|
||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
|
||||
async def get_chat_history_by_group_name(group_name: str) -> str:
|
||||
"""
|
||||
根据互通组名字获取聊天记录
|
||||
"""
|
||||
target_group = None
|
||||
for group in global_config.cross_context.groups:
|
||||
if group.name == group_name:
|
||||
target_group = group
|
||||
break
|
||||
|
||||
if not target_group:
|
||||
return f"找不到名为 {group_name} 的互通组。"
|
||||
|
||||
if not target_group.chat_ids:
|
||||
return f"互通组 {group_name} 中没有配置任何聊天。"
|
||||
|
||||
chat_infos = target_group.chat_ids
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
cross_context_messages = []
|
||||
for chat_type, chat_raw_id in chat_infos:
|
||||
is_group = chat_type == "group"
|
||||
|
||||
found_stream = None
|
||||
for stream in chat_manager.streams.values():
|
||||
if is_group:
|
||||
if stream.group_info and stream.group_info.group_id == chat_raw_id:
|
||||
found_stream = stream
|
||||
break
|
||||
else: # private
|
||||
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
|
||||
found_stream = stream
|
||||
break
|
||||
|
||||
if not found_stream:
|
||||
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
|
||||
continue
|
||||
|
||||
stream_id = found_stream.stream_id
|
||||
|
||||
try:
|
||||
messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=5, # 可配置
|
||||
)
|
||||
if messages:
|
||||
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
||||
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
||||
continue
|
||||
|
||||
if not cross_context_messages:
|
||||
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
|
||||
|
||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
@@ -12,7 +12,6 @@ import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
@@ -21,6 +20,7 @@ from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
# 日志记录器
|
||||
logger = get_logger("generator_api")
|
||||
|
||||
|
||||
@@ -107,15 +107,14 @@ async def generate_reply(
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(
|
||||
chat_stream, chat_id, request_type=request_type
|
||||
)
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
|
||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||
|
||||
# 向下兼容,从action_data中获取reply_to和extra_info
|
||||
if not reply_to and action_data:
|
||||
reply_to = action_data.get("reply_to", "")
|
||||
if not extra_info and action_data:
|
||||
@@ -136,6 +135,7 @@ async def generate_reply(
|
||||
return False, [], None
|
||||
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
|
||||
if content := llm_response_dict.get("content", ""):
|
||||
# 处理为拟人化文本
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
else:
|
||||
reply_set = []
|
||||
@@ -211,6 +211,7 @@ async def rewrite_reply(
|
||||
)
|
||||
reply_set = []
|
||||
if content:
|
||||
# 处理为拟人化文本
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
if success:
|
||||
@@ -236,9 +237,12 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
"""
|
||||
if isinstance(content, list):
|
||||
content = "".join(map(str, content))
|
||||
if not isinstance(content, str):
|
||||
raise ValueError("content 必须是字符串类型")
|
||||
try:
|
||||
# 处理LLM响应
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
reply_set = []
|
||||
@@ -259,6 +263,18 @@ async def generate_response_custom(
|
||||
request_type: str = "generator_api",
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
使用自定义提示生成回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
chat_id: 聊天ID
|
||||
request_type: 请求类型
|
||||
prompt: 自定义提示
|
||||
|
||||
Returns:
|
||||
Optional[str]: 生成的回复内容
|
||||
"""
|
||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
|
||||
@@ -30,7 +30,6 @@
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import difflib
|
||||
import asyncio
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
@@ -41,16 +40,16 @@ from maim_message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, replace_user_references_async
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from maim_message import Seg
|
||||
from src.config.config import global_config
|
||||
|
||||
# 日志记录器
|
||||
logger = get_logger("send_api")
|
||||
|
||||
# 适配器命令响应等待池
|
||||
_adapter_response_pool: Dict[str, asyncio.Future] = {}
|
||||
|
||||
|
||||
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
|
||||
"""查找要回复的消息
|
||||
|
||||
@@ -97,10 +96,11 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
||||
}
|
||||
|
||||
message_recv = MessageRecv(message_dict)
|
||||
|
||||
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
||||
return message_recv
|
||||
|
||||
|
||||
def put_adapter_response(request_id: str, response_data: dict) -> None:
|
||||
"""将适配器响应放入响应池"""
|
||||
if request_id in _adapter_response_pool:
|
||||
@@ -187,14 +187,16 @@ async def _send_to_target(
|
||||
# 创建消息段
|
||||
message_segment = Seg(type=message_type, data=content) # type: ignore
|
||||
|
||||
# 处理回复消息
|
||||
if reply_to_message:
|
||||
anchor_message = message_dict_to_message_recv(message_dict=reply_to_message)
|
||||
anchor_message.update_chat_stream(target_stream)
|
||||
reply_to_platform_id = (
|
||||
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
anchor_message = None
|
||||
reply_to_platform_id = None
|
||||
|
||||
# 构建发送消息对象
|
||||
bot_message = MessageSending(
|
||||
@@ -233,7 +235,6 @@ async def _send_to_target(
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 公共API函数 - 预定义类型的发送函数
|
||||
# =============================================================================
|
||||
@@ -273,7 +274,9 @@ async def text_to_stream(
|
||||
)
|
||||
|
||||
|
||||
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False) -> bool:
|
||||
async def emoji_to_stream(
|
||||
emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
|
||||
) -> bool:
|
||||
"""向指定流发送表情包
|
||||
|
||||
Args:
|
||||
@@ -284,10 +287,14 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply)
|
||||
return await _send_to_target(
|
||||
"emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply
|
||||
)
|
||||
|
||||
|
||||
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False) -> bool:
|
||||
async def image_to_stream(
|
||||
image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
|
||||
) -> bool:
|
||||
"""向指定流发送图片
|
||||
|
||||
Args:
|
||||
@@ -298,11 +305,17 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply)
|
||||
return await _send_to_target(
|
||||
"image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply
|
||||
)
|
||||
|
||||
|
||||
async def command_to_stream(
|
||||
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False
|
||||
command: Union[str, dict],
|
||||
stream_id: str,
|
||||
storage_message: bool = True,
|
||||
display_message: str = "",
|
||||
set_reply: bool = False,
|
||||
) -> bool:
|
||||
"""向指定流发送命令
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ class BaseAction(ABC):
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||
# =============================================================================
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, Dict, List, Union
|
||||
from typing import Tuple, Optional, List, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .component_types import EventType, EventHandlerInfo, ComponentType
|
||||
@@ -23,17 +23,26 @@ class BaseEventHandler(ABC):
|
||||
"""是否拦截消息,默认为否"""
|
||||
init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN]
|
||||
"""初始化时订阅的事件名称"""
|
||||
plugin_name = None
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
"""对应插件名"""
|
||||
self.plugin_config: Optional[Dict] = None
|
||||
"""插件配置字典"""
|
||||
|
||||
self.subscribed_events = []
|
||||
"""订阅的事件列表"""
|
||||
if EventType.UNKNOWN in self.init_subscribe:
|
||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||
|
||||
# 优先使用实例级别的 plugin_config,如果没有则使用类级别的配置
|
||||
# 事件管理器会在注册时通过 set_plugin_config 设置实例级别的配置
|
||||
instance_config = getattr(self, "plugin_config", None)
|
||||
if instance_config is not None:
|
||||
self.plugin_config = instance_config
|
||||
else:
|
||||
# 如果实例级别没有配置,则使用类级别的配置(向后兼容)
|
||||
self.plugin_config = getattr(self.__class__, "plugin_config", {})
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]:
|
||||
"""执行事件处理的抽象方法,子类必须实现
|
||||
@@ -89,15 +98,7 @@ class BaseEventHandler(ABC):
|
||||
weight=cls.weight,
|
||||
intercept_message=cls.intercept_message,
|
||||
)
|
||||
|
||||
def set_plugin_config(self, plugin_config: Dict) -> None:
|
||||
"""设置插件配置
|
||||
|
||||
Args:
|
||||
plugin_config (dict): 插件配置字典
|
||||
"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
|
||||
def set_plugin_name(self, plugin_name: str) -> None:
|
||||
"""设置插件名称
|
||||
|
||||
@@ -106,6 +107,9 @@ class BaseEventHandler(ABC):
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
|
||||
def set_plugin_config(self,plugin_config) -> None:
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ class EventType(Enum):
|
||||
"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_STOP ="on_stop"
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
@@ -215,27 +216,7 @@ class EventInfo(ComponentInfo):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT
|
||||
|
||||
|
||||
# 事件类型枚举
|
||||
class EventType(Enum):
|
||||
"""
|
||||
事件类型枚举类
|
||||
"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
AFTER_LLM = "after_llm"
|
||||
POST_SEND = "post_send"
|
||||
AFTER_SEND = "after_send"
|
||||
UNKNOWN = "unknown" # 未知事件类型
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
@@ -170,6 +171,8 @@ class ComponentRegistry:
|
||||
return False
|
||||
|
||||
action_class.plugin_name = action_info.plugin_name
|
||||
# 设置插件配置
|
||||
action_class.plugin_config = self.get_plugin_config(action_info.plugin_name) or {}
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
# 如果启用,添加到默认动作集
|
||||
@@ -188,6 +191,8 @@ class ComponentRegistry:
|
||||
return False
|
||||
|
||||
command_class.plugin_name = command_info.plugin_name
|
||||
# 设置插件配置
|
||||
command_class.plugin_config = self.get_plugin_config(command_info.plugin_name) or {}
|
||||
self._command_registry[command_name] = command_class
|
||||
|
||||
# 如果启用了且有匹配模式
|
||||
@@ -220,6 +225,8 @@ class ComponentRegistry:
|
||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||
|
||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||
# 设置插件配置
|
||||
plus_command_class.plugin_config = self.get_plugin_config(plus_command_info.plugin_name) or {}
|
||||
self._plus_command_registry[plus_command_name] = plus_command_class
|
||||
|
||||
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
||||
@@ -230,6 +237,8 @@ class ComponentRegistry:
|
||||
tool_name = tool_info.name
|
||||
|
||||
tool_class.plugin_name = tool_info.plugin_name
|
||||
# 设置插件配置
|
||||
tool_class.plugin_config = self.get_plugin_config(tool_info.plugin_name) or {}
|
||||
self._tool_registry[tool_name] = tool_class
|
||||
|
||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||
@@ -248,6 +257,9 @@ class ComponentRegistry:
|
||||
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
|
||||
return False
|
||||
|
||||
handler_class.plugin_name = handler_info.plugin_name
|
||||
# 设置插件配置
|
||||
handler_class.plugin_config = self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
self._event_handler_registry[handler_name] = handler_class
|
||||
|
||||
if not handler_info.enabled:
|
||||
@@ -258,7 +270,7 @@ class ComponentRegistry:
|
||||
# 使用EventManager进行事件处理器注册
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
return event_manager.register_event_handler(handler_class)
|
||||
return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {})
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
@@ -655,20 +667,35 @@ class ComponentRegistry:
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
return plugin_info.components if plugin_info else []
|
||||
|
||||
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
||||
def get_plugin_config(self, plugin_name: str) -> dict:
|
||||
"""获取插件配置
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 插件配置字典或None
|
||||
dict: 插件配置字典,如果插件实例不存在或配置为空,返回空字典
|
||||
"""
|
||||
# 从插件管理器获取插件实例的配置
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
|
||||
return plugin_instance.config if plugin_instance else None
|
||||
if plugin_instance and plugin_instance.config:
|
||||
return plugin_instance.config
|
||||
|
||||
# 如果插件实例不存在,尝试从配置文件读取
|
||||
try:
|
||||
import toml
|
||||
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_data = toml.load(f)
|
||||
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
|
||||
return config_data
|
||||
except Exception as e:
|
||||
logger.debug(f"读取插件 {plugin_name} 配置文件失败: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
"""获取注册中心统计信息"""
|
||||
|
||||
@@ -68,7 +68,7 @@ class EventManager:
|
||||
event = BaseEvent(event_name, allowed_subscribers, allowed_triggers)
|
||||
self._events[event_name] = event
|
||||
logger.debug(f"事件 {event_name} 注册成功")
|
||||
|
||||
|
||||
# 检查是否有缓存的订阅需要处理
|
||||
self._process_pending_subscriptions(event_name)
|
||||
|
||||
@@ -145,11 +145,12 @@ class EventManager:
|
||||
logger.info(f"事件 {event_name} 已禁用")
|
||||
return True
|
||||
|
||||
def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
Args:
|
||||
handler_class (Type[BaseEventHandler]): 事件处理器类
|
||||
plugin_config (Optional[dict]): 插件配置字典,默认为None
|
||||
|
||||
Returns:
|
||||
bool: 注册成功返回True,已存在返回False
|
||||
@@ -163,7 +164,13 @@ class EventManager:
|
||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._event_handlers[handler_name] = handler_class()
|
||||
# 创建事件处理器实例,传递插件配置
|
||||
handler_instance = handler_class()
|
||||
handler_instance.plugin_config = plugin_config
|
||||
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'):
|
||||
handler_instance.set_plugin_config(plugin_config)
|
||||
|
||||
self._event_handlers[handler_name] = handler_instance
|
||||
|
||||
# 处理init_subscribe,缓存失败的订阅
|
||||
if self._event_handlers[handler_name].init_subscribe:
|
||||
|
||||
@@ -200,7 +200,7 @@ class PluginManager:
|
||||
|
||||
# 检查并调用 on_plugin_loaded 钩子(如果存在)
|
||||
if hasattr(plugin_instance, "on_plugin_loaded") and callable(
|
||||
getattr(plugin_instance, "on_plugin_loaded")
|
||||
plugin_instance.on_plugin_loaded
|
||||
):
|
||||
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
|
||||
try:
|
||||
|
||||
@@ -6,7 +6,7 @@ from src.plugin_system.core.global_announcement_manager import global_announceme
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
import inspect
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -59,19 +59,79 @@ class AtAction(BaseAction):
|
||||
if not user_info or not user_info.get("user_id"):
|
||||
logger.info(f"找不到名为 '{user_name}' 的用户。")
|
||||
return False, "用户不存在"
|
||||
await self.send_command(
|
||||
"SEND_AT_MESSAGE",
|
||||
args={"qq_id": user_info.get("user_id"), "text": at_message},
|
||||
display_message=f"艾特用户 {user_name} 并发送消息: {at_message}",
|
||||
)
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送消息: {at_message}",
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
logger.info("艾特用户的动作已触发,但具体实现待完成。")
|
||||
return True, "艾特用户的动作已触发,但具体实现待完成。"
|
||||
try:
|
||||
# 使用回复器生成艾特回复,而不是直接发送命令
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
# 获取当前聊天流
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(self.chat_id)
|
||||
|
||||
if not chat_stream:
|
||||
logger.error(f"找不到聊天流: {self.stream_id}")
|
||||
return False, "聊天流不存在"
|
||||
|
||||
# 创建回复器实例
|
||||
replyer = DefaultReplyer(chat_stream)
|
||||
|
||||
# 构建回复对象,将艾特消息作为回复目标
|
||||
reply_to = f"{user_name}:{at_message}"
|
||||
extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}"
|
||||
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system import EventType
|
||||
# 触发post_llm
|
||||
result = await event_manager.trigger_event(EventType.POST_LLM,plugin_name="SYSTEM")
|
||||
if not result.all_continue_process():
|
||||
return False, f"被组件{result.get_summary().get("stopped_handlers","")}打断"
|
||||
|
||||
# 使用回复器生成回复
|
||||
success, llm_response, prompt = await replyer.generate_reply_with_context(
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
enable_tool=False, # 艾特回复通常不需要工具调用
|
||||
from_plugin=True # 标识来自插件
|
||||
)
|
||||
|
||||
if success and llm_response:
|
||||
# 获取生成的回复内容
|
||||
reply_content = llm_response.get("content", "")
|
||||
if reply_content:
|
||||
# 获取用户QQ号,发送真正的艾特消息
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
# 发送真正的艾特命令,使用回复器生成的智能内容
|
||||
await self.send_command(
|
||||
"SEND_AT_MESSAGE",
|
||||
args={"qq_id": user_id, "text": reply_content},
|
||||
display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}",
|
||||
)
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}",
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}")
|
||||
return True, "智能艾特消息发送成功"
|
||||
else:
|
||||
logger.warning("回复器生成了空内容")
|
||||
return False, "回复内容为空"
|
||||
else:
|
||||
logger.error("回复器生成回复失败")
|
||||
return False, "回复生成失败"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True)
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行艾特用户动作失败:{str(e)}",
|
||||
action_done=False,
|
||||
)
|
||||
return False, f"执行失败: {str(e)}"
|
||||
|
||||
|
||||
class AtCommand(BaseCommand):
|
||||
|
||||
@@ -39,8 +39,9 @@ class EmojiAction(BaseAction):
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用表情动作的条件:
|
||||
1. 用户明确要求使用表情包
|
||||
2. 这是一个适合表达强烈情绪的场合
|
||||
3. 不要发送太多表情包,如果你已经发送过多个表情包则回答"否"
|
||||
2. 这是一个适合表达情绪的场合
|
||||
3. 发表情包能使当前对话更有趣
|
||||
4. 不要发送太多表情包,如果你已经发送过多个表情包则回答"否"
|
||||
|
||||
请回答"是"或"否"。
|
||||
"""
|
||||
|
||||
@@ -53,7 +53,9 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
"enable_reply": ConfigField(type=bool, default=True, description="完成后是否回复"),
|
||||
"ai_image_number": ConfigField(type=int, default=1, description="AI生成图片数量"),
|
||||
"image_number": ConfigField(type=int, default=1, description="本地配图数量(1-9张)"),
|
||||
"image_directory": ConfigField(type=str, default=(Path(__file__).parent / "images").as_posix(), description="图片存储目录")
|
||||
"image_directory": ConfigField(
|
||||
type=str, default=(Path(__file__).parent / "images").as_posix(), description="图片存储目录"
|
||||
),
|
||||
},
|
||||
"read": {
|
||||
"permission": ConfigField(type=list, default=[], description="阅读权限QQ号列表"),
|
||||
@@ -75,7 +77,9 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
"forbidden_hours_end": ConfigField(type=int, default=6, description="禁止发送的结束小时(24小时制)"),
|
||||
},
|
||||
"cookie": {
|
||||
"http_fallback_host": ConfigField(type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"),
|
||||
"http_fallback_host": ConfigField(
|
||||
type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"
|
||||
),
|
||||
"http_fallback_port": ConfigField(type=int, default=9999, description="备用Cookie获取服务的端口"),
|
||||
"napcat_token": ConfigField(type=str, default="", description="Napcat服务的认证Token(可选)"),
|
||||
},
|
||||
@@ -95,14 +99,14 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
image_service = ImageService(self.get_config)
|
||||
cookie_service = CookieService(self.get_config)
|
||||
reply_tracker_service = ReplyTrackerService()
|
||||
|
||||
|
||||
# 使用已创建的 reply_tracker_service 实例
|
||||
qzone_service = QZoneService(
|
||||
self.get_config,
|
||||
content_service,
|
||||
image_service,
|
||||
self.get_config,
|
||||
content_service,
|
||||
image_service,
|
||||
cookie_service,
|
||||
reply_tracker_service # 传入已创建的实例
|
||||
reply_tracker_service, # 传入已创建的实例
|
||||
)
|
||||
scheduler_service = SchedulerService(self.get_config, qzone_service)
|
||||
monitor_service = MonitorService(self.get_config, qzone_service)
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.common.logger import get_logger
|
||||
import imghdr
|
||||
import asyncio
|
||||
from src.plugin_system.apis import llm_api, config_api, generator_api
|
||||
from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from maim_message import UserInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -87,6 +88,11 @@ class ContentService:
|
||||
if context:
|
||||
prompt += f"\n作为参考,这里有一些最近的聊天记录:\n---\n{context}\n---"
|
||||
|
||||
# 添加跨群聊上下文
|
||||
cross_context = await get_chat_history_by_group_name("maizone_context_group")
|
||||
if cross_context and "找不到名为" not in cross_context:
|
||||
prompt += f"\n\n---跨群聊参考---\n{cross_context}\n---"
|
||||
|
||||
# 添加历史记录以避免重复
|
||||
prompt += "\n\n---历史说说记录---\n"
|
||||
history_block = await get_send_history(qq_account)
|
||||
@@ -232,7 +238,7 @@ class ContentService:
|
||||
for i in range(3): # 重试3次
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=30) as resp:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"下载图片失败: {image_url}, status: {resp.status}")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
@@ -272,8 +272,10 @@ class QZoneService:
|
||||
# 检查是否已经在持久化记录中标记为已回复
|
||||
if not self.reply_tracker.has_replied(fid, comment_tid):
|
||||
# 记录日志以便追踪
|
||||
logger.debug(f"发现新评论需要回复 - 说说ID: {fid}, 评论ID: {comment_tid}, "
|
||||
f"评论人: {comment.get('nickname', '')}, 内容: {comment.get('content', '')}")
|
||||
logger.debug(
|
||||
f"发现新评论需要回复 - 说说ID: {fid}, 评论ID: {comment_tid}, "
|
||||
f"评论人: {comment.get('nickname', '')}, 内容: {comment.get('content', '')}"
|
||||
)
|
||||
comments_to_reply.append(comment)
|
||||
|
||||
if not comments_to_reply:
|
||||
@@ -791,10 +793,11 @@ class QZoneService:
|
||||
try:
|
||||
# 修复回复逻辑:确保能正确提醒被回复的人
|
||||
data = {
|
||||
"topicId": f"{host_qq}_{fid}__1", # 使用标准评论格式,而不是针对特定评论
|
||||
"topicId": f"{host_qq}_{fid}__1",
|
||||
"parent_tid": comment_tid,
|
||||
"uin": uin,
|
||||
"hostUin": host_qq,
|
||||
"content": f"回复@{target_name}:{content}", # 内容中明确标示回复对象
|
||||
"content": content,
|
||||
"format": "fs",
|
||||
"plat": "qzone",
|
||||
"source": "ic",
|
||||
@@ -802,12 +805,14 @@ class QZoneService:
|
||||
"ref": "feeds",
|
||||
"richtype": "",
|
||||
"richval": "",
|
||||
"paramstr": f"@{target_name}", # 确保触发@提醒机制
|
||||
"paramstr": "",
|
||||
}
|
||||
|
||||
|
||||
# 记录详细的请求参数用于调试
|
||||
logger.info(f"子回复请求参数: topicId={data['topicId']}, parent_tid={data['parent_tid']}, content='{content[:50]}...'")
|
||||
|
||||
logger.info(
|
||||
f"子回复请求参数: topicId={data['topicId']}, parent_tid={data['parent_tid']}, content='{content[:50]}...'"
|
||||
)
|
||||
|
||||
await _request("POST", self.REPLY_URL, params={"g_tk": gtk}, data=data)
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -41,7 +41,7 @@ class ReplyTrackerService:
|
||||
if not isinstance(data, dict):
|
||||
logger.error("加载的数据不是字典格式")
|
||||
return False
|
||||
|
||||
|
||||
for feed_id, comments in data.items():
|
||||
if not isinstance(feed_id, str):
|
||||
logger.error(f"无效的说说ID格式: {feed_id}")
|
||||
@@ -70,12 +70,14 @@ class ReplyTrackerService:
|
||||
logger.warning("回复记录文件为空,将创建新的记录")
|
||||
self.replied_comments = {}
|
||||
return
|
||||
|
||||
|
||||
data = json.loads(file_content)
|
||||
if self._validate_data(data):
|
||||
self.replied_comments = data
|
||||
logger.info(f"已加载 {len(self.replied_comments)} 条说说的回复记录,"
|
||||
f"总计 {sum(len(comments) for comments in self.replied_comments.values())} 条评论")
|
||||
logger.info(
|
||||
f"已加载 {len(self.replied_comments)} 条说说的回复记录,"
|
||||
f"总计 {sum(len(comments) for comments in self.replied_comments.values())} 条评论"
|
||||
)
|
||||
else:
|
||||
logger.error("加载的数据格式无效,将创建新的记录")
|
||||
self.replied_comments = {}
|
||||
@@ -112,12 +114,12 @@ class ReplyTrackerService:
|
||||
self._cleanup_old_records()
|
||||
|
||||
# 创建临时文件
|
||||
temp_file = self.reply_record_file.with_suffix('.tmp')
|
||||
|
||||
temp_file = self.reply_record_file.with_suffix(".tmp")
|
||||
|
||||
# 先写入临时文件
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
# 如果写入成功,重命名为正式文件
|
||||
if temp_file.stat().st_size > 0: # 确保写入成功
|
||||
# 在Windows上,如果目标文件已存在,需要先删除它
|
||||
@@ -128,7 +130,7 @@ class ReplyTrackerService:
|
||||
else:
|
||||
logger.error("临时文件写入失败,文件大小为0")
|
||||
temp_file.unlink() # 删除空的临时文件
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存回复记录失败: {e}", exc_info=True)
|
||||
# 尝试删除可能存在的临时文件
|
||||
@@ -204,7 +206,7 @@ class ReplyTrackerService:
|
||||
|
||||
# 确保将comment_id转换为字符串格式
|
||||
comment_id_str = str(comment_id)
|
||||
|
||||
|
||||
if feed_id not in self.replied_comments:
|
||||
self.replied_comments[feed_id] = {}
|
||||
|
||||
|
||||
@@ -1746,3 +1746,32 @@ class SetGroupSignHandler(BaseEventHandler):
|
||||
else:
|
||||
logger.error("事件 napcat_set_group_sign 请求失败!")
|
||||
return HandlerResult(False, False, {"status": "error"})
|
||||
|
||||
# ===PERSONAL===
|
||||
class SetInputStatusHandler(BaseEventHandler):
|
||||
handler_name: str = "napcat_set_input_status_handler"
|
||||
handler_description: str = "设置输入状态"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.PERSONAL.SET_INPUT_STATUS]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
user_id = params.get("user_id", "")
|
||||
event_type = params.get("event_type", 0)
|
||||
|
||||
if params.get("raw", ""):
|
||||
user_id = raw.get("user_id", "")
|
||||
event_type = raw.get("event_type", 0)
|
||||
|
||||
if not user_id or event_type is None:
|
||||
logger.error("事件 napcat_set_input_status 缺少必要参数: user_id 或 event_type")
|
||||
return HandlerResult(False, False, {"status": "error"})
|
||||
|
||||
payload = {"user_id": str(user_id), "event_type": int(event_type)}
|
||||
response = await send_handler.send_message_to_napcat(action="set_input_status", params=payload)
|
||||
if response.get("status", "") == "ok":
|
||||
return HandlerResult(True, True, response)
|
||||
else:
|
||||
logger.error("事件 napcat_set_input_status 请求失败!")
|
||||
return HandlerResult(False, False, {"status": "error"})
|
||||
@@ -1816,3 +1816,27 @@ class NapcatEvent:
|
||||
"""
|
||||
|
||||
class FILE(Enum): ...
|
||||
|
||||
class PERSONAL(Enum):
|
||||
SET_INPUT_STATUS = "napcat_set_input_status"
|
||||
"""
|
||||
设置输入状态
|
||||
|
||||
Args:
|
||||
user_id (Optional[str|int]): 用户id(必需)
|
||||
event_type (Optional[int]): 输入状态id(必需)
|
||||
raw (Optional[dict]): 原始请求体
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"status": "ok",
|
||||
"retcode": 0,
|
||||
"data": {
|
||||
"result": 0,
|
||||
"errMsg": "string"
|
||||
},
|
||||
"message": "string",
|
||||
"wording": "string",
|
||||
"echo": "string"
|
||||
}
|
||||
"""
|
||||
@@ -7,8 +7,8 @@ from . import event_types, CONSTS, event_handlers
|
||||
from typing import List
|
||||
|
||||
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -18,9 +18,6 @@ from .src.recv_handler.meta_event_handler import meta_event_handler
|
||||
from .src.recv_handler.notice_handler import notice_handler
|
||||
from .src.recv_handler.message_sending import message_send_instance
|
||||
from .src.send_handler import send_handler
|
||||
from .src.config import global_config
|
||||
from .src.config.features_config import features_manager
|
||||
from .src.config.migrate_features import auto_migrate_features
|
||||
from .src.mmc_com_layer import mmc_start_com, router, mmc_stop_com
|
||||
from .src.response_pool import put_response, check_timeout_response
|
||||
from .src.websocket_manager import websocket_manager
|
||||
@@ -37,17 +34,20 @@ def get_classes_in_module(module):
|
||||
classes.append(member)
|
||||
return classes
|
||||
|
||||
|
||||
async def message_recv(server_connection: Server.ServerConnection):
|
||||
await message_handler.set_server_connection(server_connection)
|
||||
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
||||
await send_handler.set_server_connection(server_connection)
|
||||
async for raw_message in server_connection:
|
||||
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||
# 只在debug模式下记录原始消息
|
||||
if logger.level <= 10: # DEBUG level
|
||||
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
try:
|
||||
# 首先尝试解析原始消息
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
|
||||
|
||||
# 检查是否是切片消息 (来自 MMC)
|
||||
if chunker.is_chunk_message(decoded_raw_message):
|
||||
logger.debug("接收到切片消息,尝试重组")
|
||||
@@ -61,14 +61,14 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
# 切片尚未完整,继续等待更多切片
|
||||
logger.debug("等待更多切片...")
|
||||
continue
|
||||
|
||||
|
||||
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
if post_type in ["meta_event", "message", "notice"]:
|
||||
await message_queue.put(decoded_raw_message)
|
||||
elif post_type is None:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"消息解析失败: {e}")
|
||||
logger.debug(f"原始消息: {raw_message[:500]}...")
|
||||
@@ -76,6 +76,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
logger.debug(f"原始消息: {raw_message[:500]}...")
|
||||
|
||||
|
||||
async def message_process():
|
||||
"""消息处理主循环"""
|
||||
logger.info("消息处理器已启动")
|
||||
@@ -84,7 +85,7 @@ async def message_process():
|
||||
try:
|
||||
# 使用超时等待,以便能够响应取消请求
|
||||
message = await asyncio.wait_for(message_queue.get(), timeout=1.0)
|
||||
|
||||
|
||||
post_type = message.get("post_type")
|
||||
if post_type == "message":
|
||||
await message_handler.handle_raw_message(message)
|
||||
@@ -94,10 +95,10 @@ async def message_process():
|
||||
await notice_handler.handle_notice(message)
|
||||
else:
|
||||
logger.warning(f"未知的post_type: {post_type}")
|
||||
|
||||
|
||||
message_queue.task_done()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时是正常的,继续循环
|
||||
continue
|
||||
@@ -112,7 +113,7 @@ async def message_process():
|
||||
except ValueError:
|
||||
pass
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("消息处理器已停止")
|
||||
raise
|
||||
@@ -132,74 +133,70 @@ async def message_process():
|
||||
except Exception as e:
|
||||
logger.debug(f"清理消息队列时出错: {e}")
|
||||
|
||||
async def napcat_server():
|
||||
|
||||
async def napcat_server(plugin_config: dict):
|
||||
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
|
||||
mode = global_config.napcat_server.mode
|
||||
# 使用插件系统配置API获取配置
|
||||
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
|
||||
logger.info(f"正在启动 adapter,连接模式: {mode}")
|
||||
|
||||
try:
|
||||
await websocket_manager.start_connection(message_recv)
|
||||
await websocket_manager.start_connection(message_recv, plugin_config)
|
||||
except Exception as e:
|
||||
logger.error(f"启动 WebSocket 连接失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def graceful_shutdown():
|
||||
"""优雅关闭所有组件"""
|
||||
try:
|
||||
logger.info("正在关闭adapter...")
|
||||
|
||||
|
||||
# 停止消息重组器的清理任务
|
||||
try:
|
||||
await reassembler.stop_cleanup_task()
|
||||
except Exception as e:
|
||||
logger.warning(f"停止消息重组器清理任务时出错: {e}")
|
||||
|
||||
# 停止功能管理器文件监控
|
||||
try:
|
||||
await features_manager.stop_file_watcher()
|
||||
except Exception as e:
|
||||
logger.warning(f"停止功能管理器文件监控时出错: {e}")
|
||||
|
||||
|
||||
# 停止功能管理器文件监控(已迁移到插件系统配置,无需操作)
|
||||
|
||||
# 关闭消息处理器(包括消息缓冲器)
|
||||
try:
|
||||
await message_handler.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭消息处理器时出错: {e}")
|
||||
|
||||
|
||||
# 关闭 WebSocket 连接
|
||||
try:
|
||||
await websocket_manager.stop_connection()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭WebSocket连接时出错: {e}")
|
||||
|
||||
|
||||
# 关闭 MaiBot 连接
|
||||
try:
|
||||
await mmc_stop_com()
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭MaiBot连接时出错: {e}")
|
||||
|
||||
|
||||
# 取消所有剩余任务
|
||||
current_task = asyncio.current_task()
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not current_task and not t.done()]
|
||||
|
||||
|
||||
if tasks:
|
||||
logger.info(f"正在取消 {len(tasks)} 个剩余任务...")
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
|
||||
# 等待任务取消完成,忽略 CancelledError
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=10
|
||||
)
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("部分任务取消超时")
|
||||
except Exception as e:
|
||||
logger.debug(f"任务取消过程中的异常(可忽略): {e}")
|
||||
|
||||
|
||||
logger.info("Adapter已成功关闭")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Adapter关闭中出现错误: {e}")
|
||||
finally:
|
||||
@@ -214,6 +211,7 @@ async def graceful_shutdown():
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
"""自动启动Adapter"""
|
||||
|
||||
@@ -224,27 +222,44 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
init_subscribe = [EventType.ON_START]
|
||||
|
||||
async def execute(self, kwargs):
|
||||
# 执行功能配置迁移(如果需要)
|
||||
logger.info("检查功能配置迁移...")
|
||||
auto_migrate_features()
|
||||
|
||||
# 启动消息重组器的清理任务
|
||||
logger.info("启动消息重组器...")
|
||||
await reassembler.start_cleanup_task()
|
||||
|
||||
# 初始化功能管理器
|
||||
logger.info("正在初始化功能管理器...")
|
||||
features_manager.load_config()
|
||||
await features_manager.start_file_watcher(check_interval=2.0)
|
||||
logger.info("功能管理器初始化完成")
|
||||
logger.info("开始启动Napcat Adapter")
|
||||
message_send_instance.maibot_router = router
|
||||
|
||||
# 创建单独的异步任务,防止阻塞主线程
|
||||
asyncio.create_task(napcat_server())
|
||||
asyncio.create_task(mmc_start_com())
|
||||
asyncio.create_task(self._start_maibot_connection())
|
||||
asyncio.create_task(napcat_server(self.plugin_config))
|
||||
asyncio.create_task(message_process())
|
||||
asyncio.create_task(check_timeout_response())
|
||||
|
||||
async def _start_maibot_connection(self):
|
||||
"""非阻塞方式启动MaiBot连接,等待主服务启动后再连接"""
|
||||
# 等待一段时间让MaiBot主服务完全启动
|
||||
await asyncio.sleep(5)
|
||||
|
||||
max_attempts = 10
|
||||
attempt = 0
|
||||
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
logger.info(f"尝试连接MaiBot (第{attempt + 1}次)")
|
||||
await mmc_start_com(self.plugin_config)
|
||||
message_send_instance.maibot_router = router
|
||||
logger.info("MaiBot router连接已建立")
|
||||
return
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
if attempt >= max_attempts:
|
||||
logger.error(f"MaiBot连接失败,已达到最大重试次数: {e}")
|
||||
return
|
||||
else:
|
||||
delay = min(2 + attempt, 10) # 逐渐增加延迟,最大10秒
|
||||
logger.warning(f"MaiBot连接失败: {e},{delay}秒后重试")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
|
||||
class StopNapcatAdapterHandler(BaseEventHandler):
|
||||
"""关闭Adapter"""
|
||||
|
||||
@@ -257,16 +272,24 @@ class StopNapcatAdapterHandler(BaseEventHandler):
|
||||
async def execute(self, kwargs):
|
||||
await graceful_shutdown()
|
||||
return
|
||||
|
||||
|
||||
|
||||
@register_plugin
|
||||
class NapcatAdapterPlugin(BasePlugin):
|
||||
plugin_name = CONSTS.PLUGIN_NAME
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
@property
|
||||
def enable_plugin(self) -> bool:
|
||||
"""通过配置文件动态控制插件启用状态"""
|
||||
# 如果已经通过配置加载了状态,使用配置中的值
|
||||
if hasattr(self, '_is_enabled'):
|
||||
return self._is_enabled
|
||||
# 否则使用默认值(禁用状态)
|
||||
return False
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息"}
|
||||
|
||||
@@ -275,10 +298,84 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
"config_version": ConfigField(type=str, default="1.3.0", description="配置文件版本"),
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
},
|
||||
"inner": {
|
||||
"version": ConfigField(type=str, default="0.2.1", description="配置版本号,请勿修改"),
|
||||
},
|
||||
"nickname": {
|
||||
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
|
||||
},
|
||||
"napcat_server": {
|
||||
"mode": ConfigField(type=str, default="reverse", description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
|
||||
"host": ConfigField(type=str, default="localhost", description="主机地址"),
|
||||
"port": ConfigField(type=int, default=8095, description="端口号"),
|
||||
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)"),
|
||||
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
|
||||
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
|
||||
},
|
||||
"maibot_server": {
|
||||
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"),
|
||||
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口,即PORT字段"),
|
||||
"platform_name": ConfigField(type=str, default="napcat", description="平台名称,用于消息路由"),
|
||||
},
|
||||
"voice": {
|
||||
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"),
|
||||
},
|
||||
"slicing": {
|
||||
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"),
|
||||
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
|
||||
},
|
||||
"debug": {
|
||||
"level": ConfigField(type=str, default="INFO", description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
|
||||
},
|
||||
"features": {
|
||||
# 权限设置
|
||||
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]),
|
||||
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
|
||||
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]),
|
||||
"private_list": ConfigField(type=list, default=[], description="用户ID列表"),
|
||||
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"),
|
||||
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
|
||||
|
||||
# 聊天功能设置
|
||||
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
|
||||
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
|
||||
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"),
|
||||
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
|
||||
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
|
||||
|
||||
# 视频处理设置
|
||||
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
|
||||
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制(MB)"),
|
||||
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
|
||||
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"),
|
||||
|
||||
# 消息缓冲设置
|
||||
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
|
||||
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
|
||||
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"),
|
||||
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"),
|
||||
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"),
|
||||
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"),
|
||||
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "!", ".", "。", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"),
|
||||
}
|
||||
}
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息",
|
||||
"inner": "内部配置信息(请勿修改)",
|
||||
"nickname": "昵称配置(目前未使用)",
|
||||
"napcat_server": "Napcat连接的ws服务设置",
|
||||
"maibot_server": "连接麦麦的ws服务设置",
|
||||
"voice": "发送语音设置",
|
||||
"slicing": "WebSocket消息切片设置",
|
||||
"debug": "调试设置",
|
||||
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)"
|
||||
}
|
||||
|
||||
def register_events(self):
|
||||
# 注册事件
|
||||
for e in event_types.NapcatEvent.ON_RECEIVED:
|
||||
@@ -295,7 +392,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
|
||||
def get_plugin_components(self):
|
||||
self.register_events()
|
||||
|
||||
|
||||
components = []
|
||||
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
|
||||
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
|
||||
@@ -303,3 +400,21 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
if issubclass(handler, BaseEventHandler):
|
||||
components.append((handler.get_handler_info(), handler))
|
||||
return components
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
# 设置插件配置
|
||||
message_send_instance.set_plugin_config(self.config)
|
||||
# 设置chunker的插件配置
|
||||
chunker.set_plugin_config(self.config)
|
||||
# 设置response_pool的插件配置
|
||||
from .src.response_pool import set_plugin_config as set_response_pool_config
|
||||
set_response_pool_config(self.config)
|
||||
# 设置send_handler的插件配置
|
||||
send_handler.set_plugin_config(self.config)
|
||||
# 设置message_handler的插件配置
|
||||
message_handler.set_plugin_config(self.config)
|
||||
# 设置notice_handler的插件配置
|
||||
notice_handler.set_plugin_config(self.config)
|
||||
# 设置meta_event_handler的插件配置
|
||||
meta_event_handler.set_plugin_config(self.config)
|
||||
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
||||
@@ -7,7 +7,7 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from .config.features_config import features_manager
|
||||
from src.plugin_system.apis import config_api
|
||||
from .recv_handler import RealMessageType
|
||||
|
||||
|
||||
@@ -43,6 +43,11 @@ class SimpleMessageBuffer:
|
||||
self.lock = asyncio.Lock()
|
||||
self.merge_callback = merge_callback
|
||||
self._shutdown = False
|
||||
self.plugin_config = None
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
def get_session_id(self, event_data: Dict[str, Any]) -> str:
|
||||
"""根据事件数据生成会话ID"""
|
||||
@@ -97,8 +102,7 @@ class SimpleMessageBuffer:
|
||||
return True
|
||||
|
||||
# 检查屏蔽前缀
|
||||
config = features_manager.get_config()
|
||||
block_prefixes = tuple(config.message_buffer_block_prefixes)
|
||||
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
|
||||
|
||||
text = text.strip()
|
||||
if text.startswith(block_prefixes):
|
||||
@@ -124,15 +128,15 @@ class SimpleMessageBuffer:
|
||||
if self._shutdown:
|
||||
return False
|
||||
|
||||
config = features_manager.get_config()
|
||||
if not config.enable_message_buffer:
|
||||
# 检查是否启用消息缓冲
|
||||
if not config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", False):
|
||||
return False
|
||||
|
||||
# 检查是否启用对应类型的缓冲
|
||||
message_type = event_data.get("message_type", "")
|
||||
if message_type == "group" and not config.message_buffer_enable_group:
|
||||
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
|
||||
return False
|
||||
elif message_type == "private" and not config.message_buffer_enable_private:
|
||||
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
|
||||
return False
|
||||
|
||||
# 提取文本
|
||||
@@ -154,8 +158,8 @@ class SimpleMessageBuffer:
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 检查是否超过最大组件数量
|
||||
if len(session.messages) >= config.message_buffer_max_components:
|
||||
logger.info(f"会话 {session_id} 消息数量达到上限,强制合并")
|
||||
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
|
||||
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
|
||||
asyncio.create_task(self._force_merge_session(session_id))
|
||||
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
|
||||
session = self.buffer_pool[session_id]
|
||||
@@ -187,8 +191,8 @@ class SimpleMessageBuffer:
|
||||
|
||||
async def _wait_and_start_merge(self, session_id: str):
|
||||
"""等待初始延迟后开始合并定时器"""
|
||||
config = features_manager.get_config()
|
||||
await asyncio.sleep(config.message_buffer_initial_delay)
|
||||
initial_delay = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_initial_delay", 0.5)
|
||||
await asyncio.sleep(initial_delay)
|
||||
|
||||
async with self.lock:
|
||||
session = self.buffer_pool.get(session_id)
|
||||
@@ -206,8 +210,8 @@ class SimpleMessageBuffer:
|
||||
|
||||
async def _wait_and_merge(self, session_id: str):
|
||||
"""等待合并间隔后执行合并"""
|
||||
config = features_manager.get_config()
|
||||
await asyncio.sleep(config.message_buffer_interval)
|
||||
interval = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_interval", 2.0)
|
||||
await asyncio.sleep(interval)
|
||||
await self._merge_session(session_id)
|
||||
|
||||
async def _force_merge_session(self, session_id: str):
|
||||
@@ -236,7 +240,7 @@ class SimpleMessageBuffer:
|
||||
merged_text = ",".join(text_parts) # 使用中文逗号连接
|
||||
message_count = len(session.messages)
|
||||
|
||||
logger.info(f"合并会话 {session_id} 的 {message_count} 条文本消息: {merged_text[:100]}...")
|
||||
logger.debug(f"合并会话 {session_id} 的 {message_count} 条文本消息: {merged_text[:100]}...")
|
||||
|
||||
# 调用回调函数
|
||||
if self.merge_callback:
|
||||
@@ -290,13 +294,13 @@ class SimpleMessageBuffer:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
logger.info(f"清理过期会话: {session_id}")
|
||||
logger.debug(f"清理过期会话: {session_id}")
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭消息缓冲器"""
|
||||
self._shutdown = True
|
||||
logger.info("正在关闭简化消息缓冲器...")
|
||||
logger.debug("正在关闭简化消息缓冲器...")
|
||||
|
||||
# 刷新所有缓冲区
|
||||
await self.flush_all()
|
||||
@@ -307,4 +311,4 @@ class SimpleMessageBuffer:
|
||||
await self._cancel_session_timers(session)
|
||||
self.buffer_pool.clear()
|
||||
|
||||
logger.info("简化消息缓冲器已关闭")
|
||||
logger.debug("简化消息缓冲器已关闭")
|
||||
@@ -3,24 +3,32 @@
|
||||
用于在 Ada 发送给 MMC 时进行消息切片,利用 WebSocket 协议的自动重组特性
|
||||
仅在 Ada -> MMC 方向进行切片,其他方向(MMC -> Ada,Ada <-> Napcat)不切片
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from .config import global_config
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
|
||||
class MessageChunker:
|
||||
"""消息切片器,用于处理大消息的分片发送"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.max_chunk_size = global_config.slicing.max_frame_size * 1024
|
||||
self.max_chunk_size = 64 * 1024 # 默认值,将在设置配置时更新
|
||||
self.plugin_config = None
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
if plugin_config:
|
||||
max_frame_size = config_api.get_plugin_config(plugin_config, "slicing.max_frame_size", 64)
|
||||
self.max_chunk_size = max_frame_size * 1024
|
||||
|
||||
def should_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
|
||||
"""判断消息是否需要切片"""
|
||||
@@ -29,19 +37,21 @@ class MessageChunker:
|
||||
message_str = json.dumps(message, ensure_ascii=False)
|
||||
else:
|
||||
message_str = message
|
||||
return len(message_str.encode('utf-8')) > self.max_chunk_size
|
||||
return len(message_str.encode("utf-8")) > self.max_chunk_size
|
||||
except Exception as e:
|
||||
logger.error(f"检查消息大小时出错: {e}")
|
||||
return False
|
||||
|
||||
def chunk_message(self, message: Union[str, Dict[str, Any]], chunk_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
|
||||
def chunk_message(
|
||||
self, message: Union[str, Dict[str, Any]], chunk_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
将消息切片
|
||||
|
||||
|
||||
Args:
|
||||
message: 要切片的消息(字符串或字典)
|
||||
chunk_id: 切片组ID,如果不提供则自动生成
|
||||
|
||||
|
||||
Returns:
|
||||
切片后的消息字典列表
|
||||
"""
|
||||
@@ -51,30 +61,30 @@ class MessageChunker:
|
||||
message_str = json.dumps(message, ensure_ascii=False)
|
||||
else:
|
||||
message_str = message
|
||||
|
||||
|
||||
if not self.should_chunk_message(message_str):
|
||||
# 不需要切片的情况,如果输入是字典则返回字典,如果是字符串则包装成非切片标记的字典
|
||||
if isinstance(message, dict):
|
||||
return [message]
|
||||
else:
|
||||
return [{"_original_message": message_str}]
|
||||
|
||||
|
||||
if chunk_id is None:
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
message_bytes = message_str.encode('utf-8')
|
||||
|
||||
message_bytes = message_str.encode("utf-8")
|
||||
total_size = len(message_bytes)
|
||||
|
||||
|
||||
# 计算需要多少个切片
|
||||
num_chunks = (total_size + self.max_chunk_size - 1) // self.max_chunk_size
|
||||
|
||||
|
||||
chunks = []
|
||||
for i in range(num_chunks):
|
||||
start_pos = i * self.max_chunk_size
|
||||
end_pos = min(start_pos + self.max_chunk_size, total_size)
|
||||
|
||||
|
||||
chunk_data = message_bytes[start_pos:end_pos]
|
||||
|
||||
|
||||
# 构建切片消息
|
||||
chunk_message = {
|
||||
"__mmc_chunk_info__": {
|
||||
@@ -83,17 +93,17 @@ class MessageChunker:
|
||||
"total_chunks": num_chunks,
|
||||
"chunk_size": len(chunk_data),
|
||||
"total_size": total_size,
|
||||
"timestamp": time.time()
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
"__mmc_chunk_data__": chunk_data.decode('utf-8', errors='ignore'),
|
||||
"__mmc_is_chunked__": True
|
||||
"__mmc_chunk_data__": chunk_data.decode("utf-8", errors="ignore"),
|
||||
"__mmc_is_chunked__": True,
|
||||
}
|
||||
|
||||
|
||||
chunks.append(chunk_message)
|
||||
|
||||
|
||||
logger.debug(f"消息切片完成: {total_size} bytes -> {num_chunks} chunks (ID: {chunk_id})")
|
||||
return chunks
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息切片时出错: {e}")
|
||||
# 出错时返回原消息
|
||||
@@ -101,7 +111,7 @@ class MessageChunker:
|
||||
return [message]
|
||||
else:
|
||||
return [{"_original_message": message}]
|
||||
|
||||
|
||||
def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
|
||||
"""判断是否是切片消息"""
|
||||
try:
|
||||
@@ -109,12 +119,12 @@ class MessageChunker:
|
||||
data = json.loads(message)
|
||||
else:
|
||||
data = message
|
||||
|
||||
|
||||
return (
|
||||
isinstance(data, dict) and
|
||||
"__mmc_chunk_info__" in data and
|
||||
"__mmc_chunk_data__" in data and
|
||||
"__mmc_is_chunked__" in data
|
||||
isinstance(data, dict)
|
||||
and "__mmc_chunk_info__" in data
|
||||
and "__mmc_chunk_data__" in data
|
||||
and "__mmc_is_chunked__" in data
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return False
|
||||
@@ -122,17 +132,17 @@ class MessageChunker:
|
||||
|
||||
class MessageReassembler:
|
||||
"""消息重组器,用于重组接收到的切片消息"""
|
||||
|
||||
|
||||
def __init__(self, timeout: int = 30):
|
||||
self.timeout = timeout
|
||||
self.chunk_buffers: Dict[str, Dict[str, Any]] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
|
||||
async def start_cleanup_task(self):
|
||||
"""启动清理任务"""
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_expired_chunks())
|
||||
|
||||
|
||||
async def stop_cleanup_task(self):
|
||||
"""停止清理任务"""
|
||||
if self._cleanup_task:
|
||||
@@ -142,35 +152,35 @@ class MessageReassembler:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
|
||||
async def _cleanup_expired_chunks(self):
|
||||
"""清理过期的切片缓冲区"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(10) # 每10秒检查一次
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
expired_chunks = []
|
||||
for chunk_id, buffer_info in self.chunk_buffers.items():
|
||||
if current_time - buffer_info['timestamp'] > self.timeout:
|
||||
if current_time - buffer_info["timestamp"] > self.timeout:
|
||||
expired_chunks.append(chunk_id)
|
||||
|
||||
|
||||
for chunk_id in expired_chunks:
|
||||
logger.warning(f"清理过期的切片缓冲区: {chunk_id}")
|
||||
del self.chunk_buffers[chunk_id]
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期切片时出错: {e}")
|
||||
|
||||
|
||||
async def add_chunk(self, message: Union[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
添加切片,如果切片完整则返回重组后的消息
|
||||
|
||||
|
||||
Args:
|
||||
message: 切片消息(字符串或字典)
|
||||
|
||||
|
||||
Returns:
|
||||
如果切片完整则返回重组后的原始消息字典,否则返回None
|
||||
"""
|
||||
@@ -180,7 +190,7 @@ class MessageReassembler:
|
||||
chunk_data = json.loads(message)
|
||||
else:
|
||||
chunk_data = message
|
||||
|
||||
|
||||
# 检查是否是切片消息
|
||||
if not chunker.is_chunk_message(chunk_data):
|
||||
# 不是切片消息,直接返回
|
||||
@@ -192,38 +202,38 @@ class MessageReassembler:
|
||||
return {"text_message": chunk_data["_original_message"]}
|
||||
else:
|
||||
return chunk_data
|
||||
|
||||
|
||||
chunk_info = chunk_data["__mmc_chunk_info__"]
|
||||
chunk_content = chunk_data["__mmc_chunk_data__"]
|
||||
|
||||
|
||||
chunk_id = chunk_info["chunk_id"]
|
||||
chunk_index = chunk_info["chunk_index"]
|
||||
total_chunks = chunk_info["total_chunks"]
|
||||
chunk_timestamp = chunk_info.get("timestamp", time.time())
|
||||
|
||||
|
||||
# 初始化缓冲区
|
||||
if chunk_id not in self.chunk_buffers:
|
||||
self.chunk_buffers[chunk_id] = {
|
||||
"chunks": {},
|
||||
"total_chunks": total_chunks,
|
||||
"received_chunks": 0,
|
||||
"timestamp": chunk_timestamp
|
||||
"timestamp": chunk_timestamp,
|
||||
}
|
||||
|
||||
|
||||
buffer = self.chunk_buffers[chunk_id]
|
||||
|
||||
|
||||
# 检查切片是否已经接收过
|
||||
if chunk_index in buffer["chunks"]:
|
||||
logger.warning(f"重复接收切片: {chunk_id}#{chunk_index}")
|
||||
return None
|
||||
|
||||
|
||||
# 添加切片
|
||||
buffer["chunks"][chunk_index] = chunk_content
|
||||
buffer["received_chunks"] += 1
|
||||
buffer["timestamp"] = time.time() # 更新时间戳
|
||||
|
||||
|
||||
logger.debug(f"接收切片: {chunk_id}#{chunk_index} ({buffer['received_chunks']}/{total_chunks})")
|
||||
|
||||
|
||||
# 检查是否接收完整
|
||||
if buffer["received_chunks"] == total_chunks:
|
||||
# 重组消息
|
||||
@@ -233,25 +243,25 @@ class MessageReassembler:
|
||||
logger.error(f"切片 {chunk_id}#{i} 缺失,无法重组")
|
||||
return None
|
||||
reassembled_message += buffer["chunks"][i]
|
||||
|
||||
|
||||
# 清理缓冲区
|
||||
del self.chunk_buffers[chunk_id]
|
||||
|
||||
|
||||
logger.debug(f"消息重组完成: {chunk_id} ({len(reassembled_message)} chars)")
|
||||
|
||||
|
||||
# 尝试反序列化重组后的消息
|
||||
try:
|
||||
return json.loads(reassembled_message)
|
||||
except json.JSONDecodeError:
|
||||
# 如果不能反序列化为JSON,则作为文本消息返回
|
||||
return {"text_message": reassembled_message}
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(f"处理切片消息时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_pending_chunks_info(self) -> Dict[str, Any]:
|
||||
"""获取待处理切片信息"""
|
||||
info = {}
|
||||
@@ -260,11 +270,11 @@ class MessageReassembler:
|
||||
"received": buffer["received_chunks"],
|
||||
"total": buffer["total_chunks"],
|
||||
"progress": f"{buffer['received_chunks']}/{buffer['total_chunks']}",
|
||||
"age_seconds": time.time() - buffer["timestamp"]
|
||||
"age_seconds": time.time() - buffer["timestamp"],
|
||||
}
|
||||
return info
|
||||
|
||||
|
||||
# 全局实例
|
||||
chunker = MessageChunker()
|
||||
reassembler = MessageReassembler()
|
||||
reassembler = MessageReassembler()
|
||||
@@ -0,0 +1,44 @@
|
||||
from maim_message import Router, RouteConfig, TargetConfig
|
||||
from src.common.logger import get_logger
|
||||
from .send_handler import send_handler
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
router = None
|
||||
|
||||
|
||||
def create_router(plugin_config: dict):
|
||||
"""创建路由器实例"""
|
||||
global router
|
||||
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "napcat")
|
||||
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
|
||||
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
|
||||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
platform_name: TargetConfig(
|
||||
url=f"ws://{host}:{port}/ws",
|
||||
token=None,
|
||||
)
|
||||
}
|
||||
)
|
||||
router = Router(route_config)
|
||||
return router
|
||||
|
||||
|
||||
async def mmc_start_com(plugin_config: dict = None):
|
||||
"""启动MaiBot连接"""
|
||||
logger.info("正在连接MaiBot")
|
||||
if plugin_config:
|
||||
create_router(plugin_config)
|
||||
|
||||
if router:
|
||||
router.register_class_handler(send_handler.handle_message)
|
||||
await router.run()
|
||||
|
||||
|
||||
async def mmc_stop_com():
|
||||
"""停止MaiBot连接"""
|
||||
if router:
|
||||
await router.stop()
|
||||
@@ -5,8 +5,7 @@ from ...CONSTS import PLUGIN_NAME
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from ..config import global_config
|
||||
from ..config.features_config import features_manager
|
||||
from src.plugin_system.apis import config_api
|
||||
from ..message_buffer import SimpleMessageBuffer
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
@@ -48,9 +47,17 @@ class MessageHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
self.bot_id_list: Dict[int, bool] = {}
|
||||
self.plugin_config = None
|
||||
# 初始化简化消息缓冲器,传入回调函数
|
||||
self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message)
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
# 将配置传递给消息缓冲器
|
||||
if self.message_buffer:
|
||||
self.message_buffer.set_plugin_config(plugin_config)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭消息处理器,清理资源"""
|
||||
if self.message_buffer:
|
||||
@@ -90,21 +97,41 @@ class MessageHandler:
|
||||
|
||||
# 使用新的权限管理器检查权限
|
||||
if group_id:
|
||||
if not features_manager.is_group_allowed(group_id):
|
||||
logger.warning("群聊不在聊天权限范围内,消息被丢弃")
|
||||
return False
|
||||
# 检查群聊黑白名单
|
||||
group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist")
|
||||
group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", [])
|
||||
|
||||
if group_list_type == "whitelist":
|
||||
if group_id not in group_list:
|
||||
logger.warning("群聊不在白名单中,消息被丢弃")
|
||||
return False
|
||||
else: # blacklist
|
||||
if group_id in group_list:
|
||||
logger.warning("群聊在黑名单中,消息被丢弃")
|
||||
return False
|
||||
else:
|
||||
if not features_manager.is_private_allowed(user_id):
|
||||
logger.warning("私聊不在聊天权限范围内,消息被丢弃")
|
||||
return False
|
||||
# 检查私聊黑白名单
|
||||
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist")
|
||||
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
|
||||
|
||||
if private_list_type == "whitelist":
|
||||
if user_id not in private_list:
|
||||
logger.warning("私聊不在白名单中,消息被丢弃")
|
||||
return False
|
||||
else: # blacklist
|
||||
if user_id in private_list:
|
||||
logger.warning("私聊在黑名单中,消息被丢弃")
|
||||
return False
|
||||
|
||||
# 检查全局禁止名单
|
||||
if not ignore_global_list and features_manager.is_user_banned(user_id):
|
||||
ban_user_id = config_api.get_plugin_config(self.plugin_config, "features.ban_user_id", [])
|
||||
if not ignore_global_list and user_id in ban_user_id:
|
||||
logger.warning("用户在全局黑名单中,消息被丢弃")
|
||||
return False
|
||||
|
||||
# 检查QQ官方机器人
|
||||
if features_manager.is_qq_bot_banned() and group_id and not ignore_bot:
|
||||
ban_qq_bot = config_api.get_plugin_config(self.plugin_config, "features.ban_qq_bot", False)
|
||||
if ban_qq_bot and group_id and not ignore_bot:
|
||||
logger.debug("开始判断是否为机器人")
|
||||
member_info = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if member_info:
|
||||
@@ -129,6 +156,21 @@ class MessageHandler:
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
"""
|
||||
|
||||
# 添加原始消息调试日志,特别关注message字段
|
||||
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}")
|
||||
logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
|
||||
|
||||
# 检查是否包含@或video消息段
|
||||
message_segments = raw_message.get('message', [])
|
||||
if message_segments:
|
||||
for i, seg in enumerate(message_segments):
|
||||
seg_type = seg.get('type')
|
||||
if seg_type in ['at', 'video']:
|
||||
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
|
||||
elif seg_type not in ['text', 'face', 'image']:
|
||||
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
|
||||
|
||||
message_type: str = raw_message.get("message_type")
|
||||
message_id: int = raw_message.get("message_id")
|
||||
# message_time: int = raw_message.get("time")
|
||||
@@ -149,7 +191,7 @@ class MessageHandler:
|
||||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
@@ -175,7 +217,7 @@ class MessageHandler:
|
||||
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=nickname,
|
||||
user_cardname=None,
|
||||
@@ -192,7 +234,7 @@ class MessageHandler:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
@@ -210,7 +252,7 @@ class MessageHandler:
|
||||
|
||||
# 发送者用户信息
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
|
||||
user_id=sender_info.get("user_id"),
|
||||
user_nickname=sender_info.get("nickname"),
|
||||
user_cardname=sender_info.get("card"),
|
||||
@@ -223,7 +265,7 @@ class MessageHandler:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
|
||||
group_info: GroupInfo = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
|
||||
group_id=raw_message.get("group_id"),
|
||||
group_name=group_name,
|
||||
)
|
||||
@@ -233,12 +275,12 @@ class MessageHandler:
|
||||
return None
|
||||
|
||||
additional_config: dict = {}
|
||||
if global_config.voice.use_tts:
|
||||
if config_api.get_plugin_config(self.plugin_config, "voice.use_tts"):
|
||||
additional_config["allow_tts"] = True
|
||||
|
||||
# 消息信息
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
|
||||
message_id=message_id,
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
@@ -260,19 +302,19 @@ class MessageHandler:
|
||||
return None
|
||||
|
||||
# 检查是否需要使用消息缓冲
|
||||
if features_manager.is_message_buffer_enabled():
|
||||
enable_message_buffer = config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", True)
|
||||
if enable_message_buffer:
|
||||
# 检查消息类型是否启用缓冲
|
||||
message_type = raw_message.get("message_type")
|
||||
should_use_buffer = False
|
||||
|
||||
if message_type == "group" and features_manager.is_message_buffer_group_enabled():
|
||||
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True):
|
||||
should_use_buffer = True
|
||||
elif message_type == "private" and features_manager.is_message_buffer_private_enabled():
|
||||
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True):
|
||||
should_use_buffer = True
|
||||
|
||||
if should_use_buffer:
|
||||
logger.debug(f"尝试缓冲消息,消息类型: {message_type}, 用户: {user_info.user_id}")
|
||||
logger.debug(f"原始消息段: {raw_message.get('message', [])}")
|
||||
|
||||
# 尝试添加到缓冲器
|
||||
buffered = await self.message_buffer.add_text_message(
|
||||
@@ -286,10 +328,10 @@ class MessageHandler:
|
||||
)
|
||||
|
||||
if buffered:
|
||||
logger.info(f"✅ 文本消息已成功缓冲: {user_info.user_id}")
|
||||
logger.debug(f"✅ 文本消息已成功缓冲: {user_info.user_id}")
|
||||
return None # 缓冲成功,不立即发送
|
||||
# 如果缓冲失败(消息包含非文本元素),走正常处理流程
|
||||
logger.info(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}")
|
||||
logger.debug(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}")
|
||||
# 缓冲失败时继续执行后面的正常处理流程,不要直接返回
|
||||
|
||||
logger.debug(f"准备发送消息到MaiBot,消息段数量: {len(seg_message)}")
|
||||
@@ -307,7 +349,7 @@ class MessageHandler:
|
||||
raw_message=raw_message.get("raw_message"),
|
||||
)
|
||||
|
||||
logger.info("发送到Maibot处理信息")
|
||||
logger.debug("发送到Maibot处理信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None:
|
||||
@@ -326,6 +368,18 @@ class MessageHandler:
|
||||
for sub_message in real_message:
|
||||
sub_message: dict
|
||||
sub_message_type = sub_message.get("type")
|
||||
|
||||
# 添加详细的消息类型调试信息
|
||||
logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}")
|
||||
|
||||
# 特别关注 at 和 video 消息的识别
|
||||
if sub_message_type == "at":
|
||||
logger.debug(f"检测到@消息: {sub_message}")
|
||||
elif sub_message_type == "video":
|
||||
logger.debug(f"检测到VIDEO消息: {sub_message}")
|
||||
elif sub_message_type not in ["text", "face", "image", "record"]:
|
||||
logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}")
|
||||
|
||||
match sub_message_type:
|
||||
case RealMessageType.text:
|
||||
ret_seg = await self.handle_text_message(sub_message)
|
||||
@@ -379,6 +433,7 @@ class MessageHandler:
|
||||
else:
|
||||
logger.warning("record处理失败或不支持")
|
||||
case RealMessageType.video:
|
||||
logger.debug(f"开始处理VIDEO消息段: {sub_message}")
|
||||
ret_seg = await self.handle_video_message(sub_message)
|
||||
if ret_seg:
|
||||
await event_manager.trigger_event(
|
||||
@@ -386,8 +441,9 @@ class MessageHandler:
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("video处理失败")
|
||||
logger.warning(f"video处理失败,原始消息: {sub_message}")
|
||||
case RealMessageType.at:
|
||||
logger.debug(f"开始处理AT消息段: {sub_message}")
|
||||
ret_seg = await self.handle_at_message(
|
||||
sub_message,
|
||||
raw_message.get("self_id"),
|
||||
@@ -399,7 +455,7 @@ class MessageHandler:
|
||||
)
|
||||
seg_message.append(ret_seg)
|
||||
else:
|
||||
logger.warning("at处理失败")
|
||||
logger.warning(f"at处理失败,原始消息: {sub_message}")
|
||||
case RealMessageType.rps:
|
||||
ret_seg = await self.handle_rps_message(sub_message)
|
||||
if ret_seg:
|
||||
@@ -502,9 +558,7 @@ class MessageHandler:
|
||||
message_data: dict = raw_message.get("data")
|
||||
image_sub_type = message_data.get("sub_type")
|
||||
try:
|
||||
logger.debug(f"开始下载图片: {message_data.get('url')}")
|
||||
image_base64 = await get_image_base64(message_data.get("url"))
|
||||
logger.debug(f"图片下载成功,大小: {len(image_base64)} 字符")
|
||||
except Exception as e:
|
||||
logger.error(f"图片消息处理失败: {str(e)}")
|
||||
return None
|
||||
@@ -595,8 +649,8 @@ class MessageHandler:
|
||||
video_url = message_data.get("url")
|
||||
file_path = message_data.get("filePath") or message_data.get("file_path")
|
||||
|
||||
logger.info(f"视频URL: {video_url}")
|
||||
logger.info(f"视频文件路径: {file_path}")
|
||||
logger.debug(f"视频URL: {video_url}")
|
||||
logger.debug(f"视频文件路径: {file_path}")
|
||||
|
||||
# 优先使用本地文件路径,其次使用URL
|
||||
video_source = file_path if file_path else video_url
|
||||
@@ -609,14 +663,14 @@ class MessageHandler:
|
||||
try:
|
||||
# 检查是否为本地文件路径
|
||||
if file_path and Path(file_path).exists():
|
||||
logger.info(f"使用本地视频文件: {file_path}")
|
||||
logger.debug(f"使用本地视频文件: {file_path}")
|
||||
# 直接读取本地文件
|
||||
with open(file_path, "rb") as f:
|
||||
video_data = f.read()
|
||||
|
||||
# 将视频数据编码为base64用于传输
|
||||
video_base64 = base64.b64encode(video_data).decode("utf-8")
|
||||
logger.info(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
|
||||
|
||||
# 返回包含详细信息的字典格式
|
||||
return Seg(
|
||||
@@ -629,7 +683,7 @@ class MessageHandler:
|
||||
)
|
||||
|
||||
elif video_url:
|
||||
logger.info(f"使用视频URL下载: {video_url}")
|
||||
logger.debug(f"使用视频URL下载: {video_url}")
|
||||
# 使用video_handler下载视频
|
||||
video_downloader = get_video_downloader()
|
||||
download_result = await video_downloader.download_video(video_url)
|
||||
@@ -641,7 +695,7 @@ class MessageHandler:
|
||||
|
||||
# 将视频数据编码为base64用于传输
|
||||
video_base64 = base64.b64encode(download_result["data"]).decode("utf-8")
|
||||
logger.info(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
|
||||
logger.debug(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
|
||||
|
||||
# 返回包含详细信息的字典格式
|
||||
return Seg(
|
||||
@@ -710,15 +764,15 @@ class MessageHandler:
|
||||
processed_message: Seg
|
||||
if image_count < 5 and image_count > 0:
|
||||
# 处理图片数量小于5的情况,此时解析图片为base64
|
||||
logger.info("图片数量小于5,开始解析图片为base64")
|
||||
logger.debug("图片数量小于5,开始解析图片为base64")
|
||||
processed_message = await self._recursive_parse_image_seg(handled_message, True)
|
||||
elif image_count > 0:
|
||||
logger.info("图片数量大于等于5,开始解析图片为占位符")
|
||||
logger.debug("图片数量大于等于5,开始解析图片为占位符")
|
||||
# 处理图片数量大于等于5的情况,此时解析图片为占位符
|
||||
processed_message = await self._recursive_parse_image_seg(handled_message, False)
|
||||
else:
|
||||
# 处理没有图片的情况,此时直接返回
|
||||
logger.info("没有图片,直接返回")
|
||||
logger.debug("没有图片,直接返回")
|
||||
processed_message = handled_message
|
||||
|
||||
# 添加转发消息提示
|
||||
@@ -743,31 +797,31 @@ class MessageHandler:
|
||||
"""
|
||||
message_data: dict = raw_message.get("data", {})
|
||||
json_data = message_data.get("data", "")
|
||||
|
||||
|
||||
# 检查JSON消息格式
|
||||
if not message_data or "data" not in message_data:
|
||||
logger.warning("JSON消息格式不正确")
|
||||
return Seg(type="json", data=json.dumps(message_data))
|
||||
|
||||
|
||||
try:
|
||||
nested_data = json.loads(json_data)
|
||||
|
||||
|
||||
# 检查是否是QQ小程序分享消息
|
||||
if "app" in nested_data and "com.tencent.miniapp" in str(nested_data.get("app", "")):
|
||||
logger.debug("检测到QQ小程序分享消息,开始提取信息")
|
||||
|
||||
|
||||
# 提取目标字段
|
||||
extracted_info = {}
|
||||
|
||||
|
||||
# 提取 meta.detail_1 中的信息
|
||||
meta = nested_data.get("meta", {})
|
||||
detail_1 = meta.get("detail_1", {})
|
||||
|
||||
|
||||
if detail_1:
|
||||
extracted_info["title"] = detail_1.get("title", "")
|
||||
extracted_info["desc"] = detail_1.get("desc", "")
|
||||
qqdocurl = detail_1.get("qqdocurl", "")
|
||||
|
||||
|
||||
# 从qqdocurl中提取b23.tv短链接
|
||||
if qqdocurl and "b23.tv" in qqdocurl:
|
||||
# 查找b23.tv链接的起始位置
|
||||
@@ -785,26 +839,29 @@ class MessageHandler:
|
||||
extracted_info["short_url"] = qqdocurl
|
||||
else:
|
||||
extracted_info["short_url"] = qqdocurl
|
||||
|
||||
|
||||
# 如果成功提取到关键信息,返回格式化的文本
|
||||
if extracted_info.get("title") or extracted_info.get("desc") or extracted_info.get("short_url"):
|
||||
content_parts = []
|
||||
|
||||
|
||||
if extracted_info.get("title"):
|
||||
content_parts.append(f"来源: {extracted_info['title']}")
|
||||
|
||||
|
||||
if extracted_info.get("desc"):
|
||||
content_parts.append(f"标题: {extracted_info['desc']}")
|
||||
|
||||
|
||||
if extracted_info.get("short_url"):
|
||||
content_parts.append(f"链接: {extracted_info['short_url']}")
|
||||
|
||||
|
||||
formatted_content = "\n".join(content_parts)
|
||||
return Seg(type="text", data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}")
|
||||
|
||||
return Seg(
|
||||
type="text",
|
||||
data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}",
|
||||
)
|
||||
|
||||
# 如果没有提取到关键信息,返回None
|
||||
return None
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析JSON消息失败: {e}")
|
||||
return None
|
||||
@@ -849,7 +906,7 @@ class MessageHandler:
|
||||
return Seg(type="text", data="[表情包]")
|
||||
return Seg(type="emoji", data=encoded_image)
|
||||
else:
|
||||
logger.info(f"不处理类型: {seg_data.type}")
|
||||
logger.debug(f"不处理类型: {seg_data.type}")
|
||||
return seg_data
|
||||
else:
|
||||
if seg_data.type == "seglist":
|
||||
@@ -863,7 +920,7 @@ class MessageHandler:
|
||||
elif seg_data.type == "emoji":
|
||||
return Seg(type="text", data="[动画表情]")
|
||||
else:
|
||||
logger.info(f"不处理类型: {seg_data.type}")
|
||||
logger.debug(f"不处理类型: {seg_data.type}")
|
||||
return seg_data
|
||||
|
||||
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]:
|
||||
@@ -1038,7 +1095,7 @@ class MessageHandler:
|
||||
raw_message=raw_message.get("raw_message", ""),
|
||||
)
|
||||
|
||||
logger.info(f"发送缓冲合并消息到Maibot处理: {session_id}")
|
||||
logger.debug(f"发送缓冲合并消息到Maibot处理: {session_id}")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
except Exception as e:
|
||||
@@ -0,0 +1,117 @@
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..message_chunker import chunker
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from maim_message import MessageBase, Router
|
||||
|
||||
|
||||
class MessageSending:
|
||||
"""
|
||||
负责把消息发送到麦麦
|
||||
"""
|
||||
|
||||
maibot_router: Router = None
|
||||
plugin_config = None
|
||||
_connection_retries = 0
|
||||
_max_retries = 3
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
async def _attempt_reconnect(self):
|
||||
"""尝试重新连接MaiBot router"""
|
||||
if self._connection_retries < self._max_retries:
|
||||
self._connection_retries += 1
|
||||
logger.warning(f"尝试重新连接MaiBot router (第{self._connection_retries}次)")
|
||||
try:
|
||||
# 重新导入router
|
||||
from ..mmc_com_layer import router
|
||||
self.maibot_router = router
|
||||
if self.maibot_router is not None:
|
||||
logger.info("MaiBot router重连成功")
|
||||
self._connection_retries = 0 # 重置重试计数
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重连失败: {e}")
|
||||
else:
|
||||
logger.error(f"已达到最大重连次数({self._max_retries}),停止重试")
|
||||
return False
|
||||
|
||||
async def message_send(self, message_base: MessageBase) -> bool:
|
||||
"""
|
||||
发送消息(Ada -> MMC 方向,需要实现切片)
|
||||
Parameters:
|
||||
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
|
||||
"""
|
||||
try:
|
||||
# 检查maibot_router是否已初始化
|
||||
if self.maibot_router is None:
|
||||
logger.warning("MaiBot router未初始化,尝试重新连接")
|
||||
if not await self._attempt_reconnect():
|
||||
logger.error("MaiBot router重连失败,无法发送消息")
|
||||
logger.error("请检查与MaiBot之间的连接")
|
||||
return False
|
||||
# 检查是否需要切片发送
|
||||
message_dict = message_base.to_dict()
|
||||
|
||||
if chunker.should_chunk_message(message_dict):
|
||||
logger.info("消息过大,进行切片发送到 MaiBot")
|
||||
|
||||
# 切片消息
|
||||
chunks = chunker.chunk_message(message_dict)
|
||||
|
||||
# 逐个发送切片
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.debug(f"发送切片 {i + 1}/{len(chunks)} 到 MaiBot")
|
||||
|
||||
# 获取对应的客户端并发送切片
|
||||
platform = message_base.message_info.platform
|
||||
|
||||
# 再次检查router状态(防止运行时被重置)
|
||||
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'):
|
||||
logger.warning("MaiBot router连接已断开,尝试重新连接")
|
||||
if not await self._attempt_reconnect():
|
||||
logger.error("MaiBot router重连失败,切片发送中止")
|
||||
return False
|
||||
|
||||
if platform not in self.maibot_router.clients:
|
||||
logger.error(f"平台 {platform} 未连接")
|
||||
return False
|
||||
|
||||
client = self.maibot_router.clients[platform]
|
||||
send_status = await client.send_message(chunk)
|
||||
|
||||
if not send_status:
|
||||
logger.error(f"发送切片 {i + 1}/{len(chunks)} 失败")
|
||||
return False
|
||||
|
||||
# 使用配置中的延迟时间
|
||||
if i < len(chunks) - 1 and self.plugin_config:
|
||||
delay_ms = config_api.get_plugin_config(self.plugin_config, "slicing.delay_ms", 10)
|
||||
delay_seconds = delay_ms / 1000.0
|
||||
logger.debug(f"切片发送延迟: {delay_ms}毫秒")
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
logger.debug("所有切片发送完成")
|
||||
return True
|
||||
else:
|
||||
# 直接发送小消息
|
||||
send_status = await self.maibot_router.send_message(message_base)
|
||||
if not send_status:
|
||||
raise RuntimeError("可能是路由未正确配置或连接异常")
|
||||
return send_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {str(e)}")
|
||||
logger.error("请检查与MaiBot之间的连接")
|
||||
return False
|
||||
|
||||
|
||||
message_send_instance = MessageSending()
|
||||
@@ -1,7 +1,7 @@
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from ..config import global_config
|
||||
from src.plugin_system.apis import config_api
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
@@ -14,8 +14,15 @@ class MetaEventHandler:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.interval = global_config.napcat_server.heartbeat_interval
|
||||
self.interval = 5.0 # 默认值,稍后通过set_plugin_config设置
|
||||
self._interval_checking = False
|
||||
self.plugin_config = None
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
# 更新interval值
|
||||
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
|
||||
|
||||
async def handle_meta_event(self, message: dict) -> None:
|
||||
event_type = message.get("meta_event_type")
|
||||
@@ -8,8 +8,7 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from ..config import global_config
|
||||
from ..config.features_config import features_manager
|
||||
from src.plugin_system.apis import config_api
|
||||
from ..database import BanUser, db_manager, is_identical
|
||||
from . import NoticeType, ACCEPT_FORMAT
|
||||
from .message_sending import message_send_instance
|
||||
@@ -38,6 +37,11 @@ class NoticeHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection | None = None
|
||||
self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间
|
||||
self.plugin_config = None
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
@@ -112,10 +116,10 @@ class NoticeHandler:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.Notify.poke:
|
||||
if features_manager.is_poke_enabled() and await message_handler.check_allow_to_chat(
|
||||
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat(
|
||||
user_id, group_id, False, False
|
||||
):
|
||||
logger.info("处理戳一戳消息")
|
||||
logger.debug("处理戳一戳消息")
|
||||
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
||||
else:
|
||||
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
|
||||
@@ -159,13 +163,13 @@ class NoticeHandler:
|
||||
else:
|
||||
logger.warning("无法获取notice消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
message_id="notice",
|
||||
time=message_time,
|
||||
user_info=user_info,
|
||||
@@ -187,7 +191,7 @@ class NoticeHandler:
|
||||
if system_notice:
|
||||
await self.put_notice(message_base)
|
||||
else:
|
||||
logger.info("发送到Maibot处理通知信息")
|
||||
logger.debug("发送到Maibot处理通知信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
async def handle_poke_notify(
|
||||
@@ -206,12 +210,12 @@ class NoticeHandler:
|
||||
# 防抖检查:如果是针对机器人的戳一戳,检查防抖时间
|
||||
if self_id == target_id:
|
||||
current_time = time.time()
|
||||
debounce_seconds = features_manager.get_config().poke_debounce_seconds
|
||||
debounce_seconds = config_api.get_plugin_config(self.plugin_config, "features.poke_debounce_seconds", 2.0)
|
||||
|
||||
if self.last_poke_time > 0:
|
||||
time_diff = current_time - self.last_poke_time
|
||||
if time_diff < debounce_seconds:
|
||||
logger.info(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)")
|
||||
logger.debug(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)")
|
||||
return None, None
|
||||
|
||||
# 记录这次戳一戳的时间
|
||||
@@ -230,7 +234,7 @@ class NoticeHandler:
|
||||
else:
|
||||
user_name = "QQ用户"
|
||||
user_cardname = "QQ用户"
|
||||
logger.info("无法获取戳一戳对方的用户昵称")
|
||||
logger.debug("无法获取戳一戳对方的用户昵称")
|
||||
|
||||
# 计算Seg
|
||||
if self_id == target_id:
|
||||
@@ -243,8 +247,8 @@ class NoticeHandler:
|
||||
|
||||
else:
|
||||
# 如果配置为忽略不是针对自己的戳一戳,则直接返回None
|
||||
if features_manager.is_non_self_poke_ignored():
|
||||
logger.info("忽略不是针对自己的戳一戳消息")
|
||||
if config_api.get_plugin_config(self.plugin_config, "features.ignore_non_self_poke", False):
|
||||
logger.debug("忽略不是针对自己的戳一戳消息")
|
||||
return None, None
|
||||
|
||||
# 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境
|
||||
@@ -254,7 +258,7 @@ class NoticeHandler:
|
||||
target_name = fetched_member_info.get("nickname")
|
||||
else:
|
||||
target_name = "QQ用户"
|
||||
logger.info("无法获取被戳一戳方的用户昵称")
|
||||
logger.debug("无法获取被戳一戳方的用户昵称")
|
||||
display_name = user_name
|
||||
else:
|
||||
return None, None
|
||||
@@ -268,7 +272,7 @@ class NoticeHandler:
|
||||
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
|
||||
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=user_id,
|
||||
user_nickname=user_name,
|
||||
user_cardname=user_cardname,
|
||||
@@ -299,7 +303,7 @@ class NoticeHandler:
|
||||
operator_nickname = "QQ用户"
|
||||
|
||||
operator_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=operator_id,
|
||||
user_nickname=operator_nickname,
|
||||
user_cardname=operator_cardname,
|
||||
@@ -328,7 +332,7 @@ class NoticeHandler:
|
||||
user_nickname = fetched_member_info.get("nickname")
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
banned_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
@@ -367,7 +371,7 @@ class NoticeHandler:
|
||||
operator_nickname = "QQ用户"
|
||||
|
||||
operator_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=operator_id,
|
||||
user_nickname=operator_nickname,
|
||||
user_cardname=operator_cardname,
|
||||
@@ -393,7 +397,7 @@ class NoticeHandler:
|
||||
else:
|
||||
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
|
||||
lifted_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
@@ -436,13 +440,13 @@ class NoticeHandler:
|
||||
else:
|
||||
logger.warning("无法获取notice消息所在群的名称")
|
||||
group_info = GroupInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
group_id=group_id,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
message_info: BaseMessageInfo = BaseMessageInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
message_id="notice",
|
||||
time=time.time(),
|
||||
user_info=None, # 自然解除禁言没有操作者
|
||||
@@ -493,7 +497,7 @@ class NoticeHandler:
|
||||
user_cardname = fetched_member_info.get("card")
|
||||
|
||||
lifted_user_info: UserInfo = UserInfo(
|
||||
platform=global_config.maibot_server.platform_name,
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
@@ -517,7 +521,7 @@ class NoticeHandler:
|
||||
continue
|
||||
if ban_record.lift_time <= int(time.time()):
|
||||
# 触发自然解除禁言
|
||||
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
|
||||
logger.debug(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
|
||||
self.lifted_list.append(ban_record)
|
||||
self.banned_list.remove(ban_record)
|
||||
await asyncio.sleep(5)
|
||||
@@ -1,19 +1,26 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict
|
||||
from .config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
response_dict: Dict = {}
|
||||
response_time_dict: Dict = {}
|
||||
plugin_config = None
|
||||
|
||||
|
||||
def set_plugin_config(config: dict):
|
||||
"""设置插件配置"""
|
||||
global plugin_config
|
||||
plugin_config = config
|
||||
|
||||
|
||||
async def get_response(request_id: str, timeout: int = 10) -> dict:
|
||||
response = await asyncio.wait_for(_get_response(request_id), timeout)
|
||||
_ = response_time_dict.pop(request_id)
|
||||
logger.info(f"响应信息id: {request_id} 已从响应字典中取出")
|
||||
logger.debug(f"响应信息id: {request_id} 已从响应字典中取出")
|
||||
return response
|
||||
|
||||
|
||||
@@ -31,18 +38,25 @@ async def put_response(response: dict):
|
||||
now_time = time.time()
|
||||
response_dict[echo_id] = response
|
||||
response_time_dict[echo_id] = now_time
|
||||
logger.info(f"响应信息id: {echo_id} 已存入响应字典")
|
||||
logger.debug(f"响应信息id: {echo_id} 已存入响应字典")
|
||||
|
||||
|
||||
async def check_timeout_response() -> None:
|
||||
while True:
|
||||
cleaned_message_count: int = 0
|
||||
now_time = time.time()
|
||||
|
||||
# 获取心跳间隔配置
|
||||
heartbeat_interval = 30 # 默认值
|
||||
if plugin_config:
|
||||
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
|
||||
|
||||
for echo_id, response_time in list(response_time_dict.items()):
|
||||
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
|
||||
if now_time - response_time > heartbeat_interval:
|
||||
cleaned_message_count += 1
|
||||
response_dict.pop(echo_id)
|
||||
response_time_dict.pop(echo_id)
|
||||
logger.warning(f"响应消息 {echo_id} 超时,已删除")
|
||||
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
|
||||
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)
|
||||
if cleaned_message_count > 0:
|
||||
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
|
||||
await asyncio.sleep(heartbeat_interval)
|
||||
@@ -12,9 +12,9 @@ from maim_message import (
|
||||
MessageBase,
|
||||
)
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from . import CommandType
|
||||
from .config import global_config
|
||||
from .response_pool import get_response
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -22,12 +22,16 @@ logger = get_logger("napcat_adapter")
|
||||
from .utils import get_image_format, convert_image_to_gif
|
||||
from .recv_handler.message_sending import message_send_instance
|
||||
from .websocket_manager import websocket_manager
|
||||
from .config.features_config import features_manager
|
||||
|
||||
|
||||
class SendHandler:
|
||||
def __init__(self):
|
||||
self.server_connection: Optional[Server.ServerConnection] = None
|
||||
self.plugin_config = None
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
@@ -287,11 +291,8 @@ class SendHandler:
|
||||
"""处理回复消息"""
|
||||
reply_seg = {"type": "reply", "data": {"id": id}}
|
||||
|
||||
# 获取功能配置
|
||||
ft_config = features_manager.get_config()
|
||||
|
||||
# 检查是否启用引用艾特功能
|
||||
if not ft_config.enable_reply_at:
|
||||
if not config_api.get_plugin_config(self.plugin_config, "features.enable_reply_at", False):
|
||||
return reply_seg
|
||||
|
||||
try:
|
||||
@@ -310,7 +311,7 @@ class SendHandler:
|
||||
return reply_seg
|
||||
|
||||
# 根据概率决定是否艾特用户
|
||||
if random.random() < ft_config.reply_at_rate:
|
||||
if random.random() < config_api.get_plugin_config(self.plugin_config, "features.reply_at_rate", 0.5):
|
||||
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
|
||||
# 在艾特后面添加一个空格
|
||||
text_seg = {"type": "text", "data": {"text": " "}}
|
||||
@@ -354,7 +355,11 @@ class SendHandler:
|
||||
|
||||
def handle_voice_message(self, encoded_voice: str) -> dict:
|
||||
"""处理语音消息"""
|
||||
if not global_config.voice.use_tts:
|
||||
use_tts = False
|
||||
if self.plugin_config:
|
||||
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
|
||||
|
||||
if not use_tts:
|
||||
logger.warning("未启用语音消息处理")
|
||||
return {}
|
||||
if not encoded_voice:
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user