Merge pull request #2 from MoFox-Studio/dev

merge Dev
This commit is contained in:
Windpicker-owo
2025-09-07 16:15:55 +08:00
committed by GitHub
131 changed files with 5889 additions and 5484 deletions

2
.gitignore vendored
View File

@@ -336,3 +336,5 @@ MaiBot.code-workspace
/tests
/tests
.kilocode/rules/MoFox.md
src/chat/planner_actions/planner (2).py
rust_video/Cargo.lock

1
bot.py
View File

@@ -82,7 +82,6 @@ def easter_egg():
async def graceful_shutdown():
try:
logger.info("正在优雅关闭麦麦...")
# 停止所有异步任务
await async_task_manager.stop_and_wait_all_tasks()

View File

@@ -1,5 +0,0 @@
from .config import global_config
__all__ = [
"global_config",
]

View File

@@ -1,151 +0,0 @@
import os
from dataclasses import dataclass
from datetime import datetime
import tomlkit
import shutil
from tomlkit import TOMLDocument
from tomlkit.items import Table
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from rich.traceback import install
from .config_base import ConfigBase
from .official_configs import (
DebugConfig,
MaiBotServerConfig,
NapcatServerConfig,
NicknameConfig,
SlicingConfig,
VoiceConfig,
)
install(extra_lines=3)
TEMPLATE_DIR = "plugins/napcat_adapter_plugin/template"
CONFIG_DIR = "plugins/napcat_adapter_plugin/config"
OLD_CONFIG_DIR = "plugins/napcat_adapter_plugin/config/old"
def ensure_config_directories():
"""确保配置目录存在"""
os.makedirs(CONFIG_DIR, exist_ok=True)
os.makedirs(OLD_CONFIG_DIR, exist_ok=True)
def update_config():
"""更新配置文件,统一使用 config/old 目录进行备份"""
# 确保目录存在
ensure_config_directories()
# 定义文件路径
template_path = f"{TEMPLATE_DIR}/template_config.toml"
config_path = f"{CONFIG_DIR}/config.toml"
# 检查配置文件是否存在
if not os.path.exists(config_path):
logger.info("主配置文件不存在,从模板创建新配置")
shutil.copy2(template_path, config_path)
logger.info(f"已创建新配置文件: {config_path}")
logger.info("程序将退出,请检查配置文件后重启")
# 读取配置文件和模板文件
with open(config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version")
new_version = new_config["inner"].get("version")
if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建备份文件
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = os.path.join(OLD_CONFIG_DIR, f"config.toml.bak.{timestamp}")
# 备份旧配置文件
shutil.copy2(config_path, backup_path)
logger.info(f"已备份旧配置文件到: {backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, config_path)
logger.info(f"已创建新配置文件: {config_path}")
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
"""将source字典的值更新到target字典中如果target中存在相同的键"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
@dataclass
class Config(ConfigBase):
"""总配置类"""
nickname: NicknameConfig
napcat_server: NapcatServerConfig
maibot_server: MaiBotServerConfig
voice: VoiceConfig
slicing: SlicingConfig
debug: DebugConfig
def load_config(config_path: str) -> Config:
"""
加载配置文件
:param config_path: 配置文件路径
:return: Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象
try:
return Config.from_dict(config_data)
except Exception as e:
logger.critical("配置文件解析失败")
raise e
# 更新配置
update_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=f"{CONFIG_DIR}/config.toml")
logger.info("非常的新鲜,非常的美味!")

View File

@@ -1,136 +0,0 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal, Dict, Union
T = TypeVar("T", bound="ConfigBase")
TOML_DICT_TYPE = {
int,
float,
str,
bool,
list,
dict,
}
@dataclass
class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: Dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
field_type = f.type
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
try:
init_args[field_name] = cls._convert_field(value, field_type)
except TypeError as e:
raise TypeError(f"字段 '{field_name}' 出现类型错误: {e}") from e
except Exception as e:
raise RuntimeError(f"无法将字段 '{field_name}' 转换为目标类型,出现错误: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""
转换字段值为指定类型
1. 对于嵌套的 dataclass递归调用相应的 from_dict 方法
2. 对于泛型集合类型list, set, tuple递归转换每个元素
3. 对于基础类型int, str, float, bool直接转换
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
return field_type.from_dict(value)
field_origin_type = get_origin(field_type)
field_args_type = get_args(field_type)
# 处理泛型集合类型list, set, tuple
if field_origin_type in {list, set, tuple}:
# 检查提供的value是否为list
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
return [cls._convert_field(item, field_args_type[0]) for item in value]
if field_origin_type is set:
return {cls._convert_field(item, field_args_type[0]) for item in value}
if field_origin_type is tuple:
# 检查提供的value长度是否与类型参数一致
if len(value) != len(field_args_type):
raise TypeError(
f"Expected {len(field_args_type)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg_type) for item, arg_type in zip(value, field_args_type))
if field_origin_type is dict:
# 检查提供的value是否为dict
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
# 检查字典的键值类型
if len(field_args_type) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_args_type
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理Optional类型
if field_origin_type is Union: # assert get_origin(Optional[Any]) is Union
if value is None:
return None
# 如果有数据,检查实际类型
if type(value) not in field_args_type:
raise TypeError(f"Expected {field_args_type} for {field_type.__name__}, got {type(value).__name__}")
return cls._convert_field(value, field_args_type[0])
# 处理int, str, float, bool等基础类型
if field_origin_type is None:
if isinstance(value, field_type):
return field_type(value)
else:
raise TypeError(f"Expected {field_type.__name__}, got {type(value).__name__}")
# 处理Literal类型
if field_origin_type is Literal:
# 获取Literal的允许值
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
# 处理其他类型
if field_type is Any:
return value
# 其他类型直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"无法将 {type(value).__name__} 转换为 {field_type.__name__}") from e
def __str__(self):
"""返回配置类的字符串表示"""
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"

View File

@@ -1,145 +0,0 @@
"""
配置文件工具模块
提供统一的配置文件生成和管理功能
"""
import os
import shutil
from pathlib import Path
from datetime import datetime
from typing import Optional
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
def ensure_config_directories():
"""确保配置目录存在"""
os.makedirs("config", exist_ok=True)
os.makedirs("config/old", exist_ok=True)
def create_config_from_template(
config_path: str, template_path: str, config_name: str = "配置文件", should_exit: bool = True
) -> bool:
"""
从模板创建配置文件的统一函数
Args:
config_path: 配置文件路径
template_path: 模板文件路径
config_name: 配置文件名称(用于日志显示)
should_exit: 创建后是否退出程序
Returns:
bool: 是否成功创建配置文件
"""
try:
# 确保配置目录存在
ensure_config_directories()
config_path_obj = Path(config_path)
template_path_obj = Path(template_path)
# 检查配置文件是否存在
if config_path_obj.exists():
return False # 配置文件已存在,无需创建
logger.info(f"{config_name}不存在,从模板创建新配置")
# 检查模板文件是否存在
if not template_path_obj.exists():
logger.error(f"模板文件不存在: {template_path}")
if should_exit:
logger.critical("无法创建配置文件,程序退出")
quit(1)
return False
# 确保配置文件目录存在
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
# 复制模板文件到配置目录
shutil.copy2(template_path_obj, config_path_obj)
logger.info(f"已创建新{config_name}: {config_path}")
if should_exit:
logger.info("程序将退出,请检查配置文件后重启")
quit(0)
return True
except Exception as e:
logger.error(f"创建{config_name}失败: {e}")
if should_exit:
logger.critical("无法创建配置文件,程序退出")
quit(1)
return False
def create_default_config_dict(default_values: dict, config_path: str, config_name: str = "配置文件") -> bool:
"""
创建默认配置文件(使用字典数据)
Args:
default_values: 默认配置值字典
config_path: 配置文件路径
config_name: 配置文件名称(用于日志显示)
Returns:
bool: 是否成功创建配置文件
"""
try:
import tomlkit
config_path_obj = Path(config_path)
# 确保配置文件目录存在
config_path_obj.parent.mkdir(parents=True, exist_ok=True)
# 写入默认配置
with open(config_path_obj, "w", encoding="utf-8") as f:
tomlkit.dump(default_values, f)
logger.info(f"已创建默认{config_name}: {config_path}")
return True
except Exception as e:
logger.error(f"创建默认{config_name}失败: {e}")
return False
def backup_config_file(config_path: str, backup_dir: str = "config/old") -> Optional[str]:
"""
备份配置文件
Args:
config_path: 要备份的配置文件路径
backup_dir: 备份目录
Returns:
Optional[str]: 备份文件路径失败时返回None
"""
try:
config_path_obj = Path(config_path)
if not config_path_obj.exists():
return None
# 确保备份目录存在
backup_dir_obj = Path(backup_dir)
backup_dir_obj.mkdir(parents=True, exist_ok=True)
# 创建备份文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_filename = f"{config_path_obj.stem}.toml.bak.{timestamp}"
backup_path = backup_dir_obj / backup_filename
# 备份文件
shutil.copy2(config_path_obj, backup_path)
logger.info(f"已备份配置文件到: {backup_path}")
return str(backup_path)
except Exception as e:
logger.error(f"备份配置文件失败: {e}")
return None

View File

@@ -1,359 +0,0 @@
import asyncio
from dataclasses import dataclass, field
from typing import Literal, Optional
from pathlib import Path
import tomlkit
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .config_base import ConfigBase
from .config_utils import create_config_from_template, create_default_config_dict
@dataclass
class FeaturesConfig(ConfigBase):
"""功能配置类"""
group_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""群聊列表类型 白名单/黑名单"""
group_list: list[int] = field(default_factory=list)
"""群聊列表"""
private_list_type: Literal["whitelist", "blacklist"] = "whitelist"
"""私聊列表类型 白名单/黑名单"""
private_list: list[int] = field(default_factory=list)
"""私聊列表"""
ban_user_id: list[int] = field(default_factory=list)
"""被封禁的用户ID列表封禁后将无法与其进行交互"""
ban_qq_bot: bool = False
"""是否屏蔽QQ官方机器人若为True则所有QQ官方机器人将无法与MaiMCore进行交互"""
enable_poke: bool = True
"""是否启用戳一戳功能"""
ignore_non_self_poke: bool = False
"""是否无视不是针对自己的戳一戳"""
poke_debounce_seconds: int = 3
"""戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"""
enable_reply_at: bool = True
"""是否启用引用回复时艾特用户的功能"""
reply_at_rate: float = 0.5
"""引用回复时艾特用户的几率 (0.0 ~ 1.0)"""
enable_video_analysis: bool = True
"""是否启用视频识别功能"""
max_video_size_mb: int = 100
"""视频文件最大大小限制MB"""
download_timeout: int = 60
"""视频下载超时时间(秒)"""
supported_formats: list[str] = field(default_factory=lambda: ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"])
"""支持的视频格式"""
# 消息缓冲配置
enable_message_buffer: bool = True
"""是否启用消息缓冲合并功能"""
message_buffer_enable_group: bool = True
"""是否启用群消息缓冲合并"""
message_buffer_enable_private: bool = True
"""是否启用私聊消息缓冲合并"""
message_buffer_interval: float = 3.0
"""消息合并间隔时间(秒),在此时间内的连续消息将被合并"""
message_buffer_initial_delay: float = 0.5
"""消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"""
message_buffer_max_components: int = 50
"""单个会话最大缓冲消息组件数量,超过此数量将强制合并"""
message_buffer_block_prefixes: list[str] = field(default_factory=lambda: ["/", "!", "", ".", "", "#", "%"])
"""消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"""
class FeaturesManager:
"""功能管理器,支持热重载"""
def __init__(self, config_path: str = "plugins/napcat_adapter_plugin/config/features.toml"):
self.config_path = Path(config_path)
self.config: Optional[FeaturesConfig] = None
self._file_watcher_task: Optional[asyncio.Task] = None
self._last_modified: Optional[float] = None
self._callbacks: list = []
def add_reload_callback(self, callback):
"""添加配置重载回调函数"""
self._callbacks.append(callback)
def remove_reload_callback(self, callback):
"""移除配置重载回调函数"""
if callback in self._callbacks:
self._callbacks.remove(callback)
async def _notify_callbacks(self):
"""通知所有回调函数配置已重载"""
for callback in self._callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(self.config)
else:
callback(self.config)
except Exception as e:
logger.error(f"配置重载回调执行失败: {e}")
def load_config(self) -> FeaturesConfig:
"""加载功能配置文件"""
try:
# 检查配置文件是否存在,如果不存在则创建并退出程序
if not self.config_path.exists():
logger.info(f"功能配置文件不存在: {self.config_path}")
self._create_default_config()
# 配置文件创建后程序应该退出,让用户检查配置
logger.info("程序将退出,请检查功能配置文件后重启")
quit(0)
with open(self.config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
self.config = FeaturesConfig.from_dict(config_data)
self._last_modified = self.config_path.stat().st_mtime
logger.info(f"功能配置加载成功: {self.config_path}")
return self.config
except Exception as e:
logger.error(f"功能配置加载失败: {e}")
logger.critical("无法加载功能配置文件,程序退出")
quit(1)
def _create_default_config(self):
"""创建默认功能配置文件"""
template_path = "template/features_template.toml"
# 尝试从模板创建配置文件
if create_config_from_template(
str(self.config_path),
template_path,
"功能配置文件",
should_exit=False, # 不在这里退出,由调用方决定
):
return
# 如果模板文件不存在,创建基本配置
logger.info("模板文件不存在,创建基本功能配置")
default_config = {
"group_list_type": "whitelist",
"group_list": [],
"private_list_type": "whitelist",
"private_list": [],
"ban_user_id": [],
"ban_qq_bot": False,
"enable_poke": True,
"ignore_non_self_poke": False,
"poke_debounce_seconds": 3,
"enable_reply_at": True,
"reply_at_rate": 0.5,
"enable_video_analysis": True,
"max_video_size_mb": 100,
"download_timeout": 60,
"supported_formats": ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"],
# 消息缓冲配置
"enable_message_buffer": True,
"message_buffer_enable_group": True,
"message_buffer_enable_private": True,
"message_buffer_interval": 3.0,
"message_buffer_initial_delay": 0.5,
"message_buffer_max_components": 50,
"message_buffer_block_prefixes": ["/", "!", "", ".", "", "#", "%"],
}
if not create_default_config_dict(default_config, str(self.config_path), "功能配置文件"):
logger.critical("无法创建功能配置文件")
quit(1)
async def reload_config(self) -> bool:
"""重新加载配置文件"""
try:
if not self.config_path.exists():
logger.warning(f"功能配置文件不存在,无法重载: {self.config_path}")
return False
current_modified = self.config_path.stat().st_mtime
if self._last_modified and current_modified <= self._last_modified:
return False # 文件未修改
old_config = self.config
new_config = self.load_config()
# 检查配置是否真的发生了变化
if old_config and self._configs_equal(old_config, new_config):
return False
logger.info("功能配置已重载")
await self._notify_callbacks()
return True
except Exception as e:
logger.error(f"功能配置重载失败: {e}")
return False
def _configs_equal(self, config1: FeaturesConfig, config2: FeaturesConfig) -> bool:
"""比较两个配置是否相等"""
return (
config1.group_list_type == config2.group_list_type
and set(config1.group_list) == set(config2.group_list)
and config1.private_list_type == config2.private_list_type
and set(config1.private_list) == set(config2.private_list)
and set(config1.ban_user_id) == set(config2.ban_user_id)
and config1.ban_qq_bot == config2.ban_qq_bot
and config1.enable_poke == config2.enable_poke
and config1.ignore_non_self_poke == config2.ignore_non_self_poke
and config1.poke_debounce_seconds == config2.poke_debounce_seconds
and config1.enable_reply_at == config2.enable_reply_at
and config1.reply_at_rate == config2.reply_at_rate
and config1.enable_video_analysis == config2.enable_video_analysis
and config1.max_video_size_mb == config2.max_video_size_mb
and config1.download_timeout == config2.download_timeout
and set(config1.supported_formats) == set(config2.supported_formats)
and
# 消息缓冲配置比较
config1.enable_message_buffer == config2.enable_message_buffer
and config1.message_buffer_enable_group == config2.message_buffer_enable_group
and config1.message_buffer_enable_private == config2.message_buffer_enable_private
and config1.message_buffer_interval == config2.message_buffer_interval
and config1.message_buffer_initial_delay == config2.message_buffer_initial_delay
and config1.message_buffer_max_components == config2.message_buffer_max_components
and set(config1.message_buffer_block_prefixes) == set(config2.message_buffer_block_prefixes)
)
async def start_file_watcher(self, check_interval: float = 1.0):
"""启动文件监控,定期检查配置文件变化"""
if self._file_watcher_task and not self._file_watcher_task.done():
logger.warning("文件监控已在运行")
return
self._file_watcher_task = asyncio.create_task(self._file_watcher_loop(check_interval))
logger.info(f"功能配置文件监控已启动,检查间隔: {check_interval}")
async def stop_file_watcher(self):
"""停止文件监控"""
if self._file_watcher_task and not self._file_watcher_task.done():
self._file_watcher_task.cancel()
try:
await self._file_watcher_task
except asyncio.CancelledError:
pass
logger.info("功能配置文件监控已停止")
async def _file_watcher_loop(self, check_interval: float):
"""文件监控循环"""
while True:
try:
await asyncio.sleep(check_interval)
await self.reload_config()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"文件监控循环出错: {e}")
await asyncio.sleep(check_interval)
def get_config(self) -> FeaturesConfig:
"""获取当前功能配置"""
if self.config is None:
return self.load_config()
return self.config
def is_group_allowed(self, group_id: int) -> bool:
"""检查群聊是否被允许"""
config = self.get_config()
if config.group_list_type == "whitelist":
return group_id in config.group_list
else: # blacklist
return group_id not in config.group_list
def is_private_allowed(self, user_id: int) -> bool:
"""检查私聊是否被允许"""
config = self.get_config()
if config.private_list_type == "whitelist":
return user_id in config.private_list
else: # blacklist
return user_id not in config.private_list
def is_user_banned(self, user_id: int) -> bool:
"""检查用户是否被全局禁止"""
config = self.get_config()
return user_id in config.ban_user_id
def is_qq_bot_banned(self) -> bool:
"""检查是否禁止QQ官方机器人"""
config = self.get_config()
return config.ban_qq_bot
def is_poke_enabled(self) -> bool:
"""检查戳一戳功能是否启用"""
config = self.get_config()
return config.enable_poke
def is_non_self_poke_ignored(self) -> bool:
"""检查是否忽略非自己戳一戳"""
config = self.get_config()
return config.ignore_non_self_poke
def is_message_buffer_enabled(self) -> bool:
"""检查消息缓冲功能是否启用"""
config = self.get_config()
return config.enable_message_buffer
def is_message_buffer_group_enabled(self) -> bool:
"""检查群消息缓冲是否启用"""
config = self.get_config()
return config.message_buffer_enable_group
def is_message_buffer_private_enabled(self) -> bool:
"""检查私聊消息缓冲是否启用"""
config = self.get_config()
return config.message_buffer_enable_private
def get_message_buffer_interval(self) -> float:
"""获取消息缓冲间隔时间"""
config = self.get_config()
return config.message_buffer_interval
def get_message_buffer_initial_delay(self) -> float:
"""获取消息缓冲初始延迟"""
config = self.get_config()
return config.message_buffer_initial_delay
def get_message_buffer_max_components(self) -> int:
"""获取消息缓冲最大组件数量"""
config = self.get_config()
return config.message_buffer_max_components
def is_message_buffer_group_enabled(self) -> bool:
"""检查是否启用群聊消息缓冲"""
config = self.get_config()
return config.message_buffer_enable_group
def is_message_buffer_private_enabled(self) -> bool:
"""检查是否启用私聊消息缓冲"""
config = self.get_config()
return config.message_buffer_enable_private
def get_message_buffer_block_prefixes(self) -> list[str]:
"""获取消息缓冲屏蔽前缀列表"""
config = self.get_config()
return config.message_buffer_block_prefixes
# 全局功能管理器实例
features_manager = FeaturesManager()

View File

@@ -1,215 +0,0 @@
"""
功能配置迁移脚本
用于将旧的配置文件中的聊天、权限、视频处理等设置迁移到新的独立功能配置文件
"""
import os
import shutil
from pathlib import Path
import tomlkit
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
def migrate_features_from_config(
old_config_path: str = "plugins/napcat_adapter_plugin/config/config.toml",
new_features_path: str = "plugins/napcat_adapter_plugin/config/features.toml",
template_path: str = "plugins/napcat_adapter_plugin/template/features_template.toml",
):
"""
从旧配置文件迁移功能设置到新的功能配置文件
Args:
old_config_path: 旧配置文件路径
new_features_path: 新功能配置文件路径
template_path: 功能配置模板路径
"""
try:
# 检查旧配置文件是否存在
if not os.path.exists(old_config_path):
logger.warning(f"旧配置文件不存在: {old_config_path}")
return False
# 读取旧配置文件
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 检查是否有chat配置段和video配置段
chat_config = old_config.get("chat", {})
video_config = old_config.get("video", {})
# 检查是否有权限相关配置
permission_keys = [
"group_list_type",
"group_list",
"private_list_type",
"private_list",
"ban_user_id",
"ban_qq_bot",
"enable_poke",
"ignore_non_self_poke",
"poke_debounce_seconds",
]
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
has_permission_config = any(key in chat_config for key in permission_keys)
has_video_config = any(key in video_config for key in video_keys)
if not has_permission_config and not has_video_config:
logger.info("旧配置文件中没有找到功能相关配置,无需迁移")
return False
# 确保新功能配置目录存在
new_features_dir = Path(new_features_path).parent
new_features_dir.mkdir(parents=True, exist_ok=True)
# 如果新功能配置文件已存在,先备份
if os.path.exists(new_features_path):
backup_path = f"{new_features_path}.backup"
shutil.copy2(new_features_path, backup_path)
logger.info(f"已备份现有功能配置文件到: {backup_path}")
# 创建新的功能配置
new_features_config = {
"group_list_type": chat_config.get("group_list_type", "whitelist"),
"group_list": chat_config.get("group_list", []),
"private_list_type": chat_config.get("private_list_type", "whitelist"),
"private_list": chat_config.get("private_list", []),
"ban_user_id": chat_config.get("ban_user_id", []),
"ban_qq_bot": chat_config.get("ban_qq_bot", False),
"enable_poke": chat_config.get("enable_poke", True),
"ignore_non_self_poke": chat_config.get("ignore_non_self_poke", False),
"poke_debounce_seconds": chat_config.get("poke_debounce_seconds", 3),
"enable_video_analysis": video_config.get("enable_video_analysis", True),
"max_video_size_mb": video_config.get("max_video_size_mb", 100),
"download_timeout": video_config.get("download_timeout", 60),
"supported_formats": video_config.get(
"supported_formats", ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"]
),
}
# 写入新的功能配置文件
with open(new_features_path, "w", encoding="utf-8") as f:
tomlkit.dump(new_features_config, f)
logger.info(f"功能配置已成功迁移到: {new_features_path}")
# 显示迁移的配置内容
logger.info("迁移的配置内容:")
for key, value in new_features_config.items():
logger.info(f" {key}: {value}")
return True
except Exception as e:
logger.error(f"功能配置迁移失败: {e}")
return False
def remove_features_from_old_config(config_path: str = "plugins/napcat_adapter_plugin/config/config.toml"):
"""
从旧配置文件中移除功能相关配置,并将旧配置移动到 config/old/ 目录
Args:
config_path: 配置文件路径
"""
try:
if not os.path.exists(config_path):
logger.warning(f"配置文件不存在: {config_path}")
return False
# 确保 config/old 目录存在
old_config_dir = "plugins/napcat_adapter_plugin/config/old"
os.makedirs(old_config_dir, exist_ok=True)
# 备份原配置文件到 config/old 目录
old_config_path = os.path.join(old_config_dir, "config_with_features.toml")
shutil.copy2(config_path, old_config_path)
logger.info(f"已备份包含功能配置的原文件到: {old_config_path}")
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
# 移除chat段中的功能相关配置
removed_keys = []
if "chat" in config:
chat_config = config["chat"]
permission_keys = [
"group_list_type",
"group_list",
"private_list_type",
"private_list",
"ban_user_id",
"ban_qq_bot",
"enable_poke",
"ignore_non_self_poke",
"poke_debounce_seconds",
]
for key in permission_keys:
if key in chat_config:
del chat_config[key]
removed_keys.append(key)
if removed_keys:
logger.info(f"已从chat配置段中移除功能相关配置: {removed_keys}")
# 移除video段中的配置
if "video" in config:
video_config = config["video"]
video_keys = ["enable_video_analysis", "max_video_size_mb", "download_timeout", "supported_formats"]
video_removed_keys = []
for key in video_keys:
if key in video_config:
del video_config[key]
video_removed_keys.append(key)
if video_removed_keys:
logger.info(f"已从video配置段中移除配置: {video_removed_keys}")
removed_keys.extend(video_removed_keys)
# 如果video段为空则删除整个段
if not video_config:
del config["video"]
logger.info("已删除空的video配置段")
if removed_keys:
logger.info(f"总共移除的配置项: {removed_keys}")
# 写回配置文件
with open(config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(config))
logger.info(f"已更新配置文件: {config_path}")
return True
except Exception as e:
logger.error(f"移除功能配置失败: {e}")
return False
def auto_migrate_features():
"""
自动执行功能配置迁移
"""
logger.info("开始自动功能配置迁移...")
# 执行迁移
if migrate_features_from_config():
logger.info("功能配置迁移成功")
# 询问是否要从旧配置文件中移除功能配置
logger.info("功能配置已迁移到独立文件,建议从主配置文件中移除相关配置")
# 在实际使用中,这里可以添加用户确认逻辑
# 为了自动化,这里直接执行移除
remove_features_from_old_config()
else:
logger.info("功能配置迁移跳过或失败")
if __name__ == "__main__":
auto_migrate_features()

View File

@@ -1,72 +0,0 @@
from dataclasses import dataclass, field
from typing import Literal
from .config_base import ConfigBase
"""
须知:
1. 本文件中记录了所有的配置项
2. 所有新增的class都需要继承自ConfigBase
3. 所有新增的class都应在config.py中的Config类中添加字段
4. 对于新增的字段若为可选项则应在其后添加field()并设置default_factory或default
"""
ADAPTER_PLATFORM = "qq"
@dataclass
class NicknameConfig(ConfigBase):
nickname: str
"""机器人昵称"""
@dataclass
class NapcatServerConfig(ConfigBase):
mode: Literal["reverse", "forward"] = "reverse"
"""连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)"""
host: str = "localhost"
"""主机地址"""
port: int = 8095
"""端口号"""
url: str = ""
"""正向连接时的完整WebSocket URL如 ws://localhost:8080/ws"""
access_token: str = ""
"""WebSocket 连接的访问令牌,用于身份验证"""
heartbeat_interval: int = 30
"""心跳间隔时间,单位为秒"""
@dataclass
class MaiBotServerConfig(ConfigBase):
platform_name: str = field(default=ADAPTER_PLATFORM, init=False)
"""平台名称“qq”"""
host: str = "localhost"
"""MaiMCore的主机地址"""
port: int = 8000
"""MaiMCore的端口号"""
@dataclass
class VoiceConfig(ConfigBase):
use_tts: bool = False
"""是否启用TTS功能"""
@dataclass
class SlicingConfig(ConfigBase):
max_frame_size: int = 64
"""WebSocket帧的最大大小单位为字节默认64KB"""
delay_ms: int = 10
"""切片发送间隔时间,单位为毫秒"""
@dataclass
class DebugConfig(ConfigBase):
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
"""日志级别默认为INFO"""

View File

@@ -1,26 +0,0 @@
from maim_message import Router, RouteConfig, TargetConfig
from .config import global_config
from src.common.logger import get_logger
from .send_handler import send_handler
logger = get_logger("napcat_adapter")
route_config = RouteConfig(
route_config={
global_config.maibot_server.platform_name: TargetConfig(
url=f"ws://{global_config.maibot_server.host}:{global_config.maibot_server.port}/ws",
token=None,
)
}
)
router = Router(route_config)
async def mmc_start_com():
logger.info("正在连接MaiBot")
router.register_class_handler(send_handler.handle_message)
await router.run()
async def mmc_stop_com():
await router.stop()

View File

@@ -1,76 +0,0 @@
import json
import asyncio
from src.common.logger import get_logger
from ..message_chunker import chunker
from ..config import global_config
logger = get_logger("napcat_adapter")
from maim_message import MessageBase, Router
class MessageSending:
"""
负责把消息发送到麦麦
"""
maibot_router: Router = None
def __init__(self):
pass
async def message_send(self, message_base: MessageBase) -> bool:
"""
发送消息Ada -> MMC 方向,需要实现切片)
Parameters:
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
"""
try:
# 检查是否需要切片发送
message_dict = message_base.to_dict()
if chunker.should_chunk_message(message_dict):
logger.info(f"消息过大,进行切片发送到 MaiBot")
# 切片消息
chunks = chunker.chunk_message(message_dict)
# 逐个发送切片
for i, chunk in enumerate(chunks):
logger.debug(f"发送切片 {i+1}/{len(chunks)} 到 MaiBot")
# 获取对应的客户端并发送切片
platform = message_base.message_info.platform
if platform not in self.maibot_router.clients:
logger.error(f"平台 {platform} 未连接")
return False
client = self.maibot_router.clients[platform]
send_status = await client.send_message(chunk)
if not send_status:
logger.error(f"发送切片 {i+1}/{len(chunks)} 失败")
return False
# 使用配置中的延迟时间
if i < len(chunks) - 1:
delay_seconds = global_config.slicing.delay_ms / 1000.0
logger.debug(f"切片发送延迟: {global_config.slicing.delay_ms}毫秒")
await asyncio.sleep(delay_seconds)
logger.debug("所有切片发送完成")
return True
else:
# 直接发送小消息
send_status = await self.maibot_router.send_message(message_base)
if not send_status:
raise RuntimeError("可能是路由未正确配置或连接异常")
return send_status
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
logger.error("请检查与MaiBot之间的连接")
return False
message_send_instance = MessageSending()

View File

@@ -0,0 +1,59 @@
"""
更新Prompt类导入脚本
将旧的prompt_builder.Prompt导入更新为unified_prompt.Prompt
"""
import os
# 需要更新的文件列表
files_to_update = [
"src/person_info/relationship_fetcher.py",
"src/mood/mood_manager.py",
"src/mais4u/mais4u_chat/body_emotion_action_manager.py",
"src/chat/express/expression_learner.py",
"src/chat/planner_actions/planner.py",
"src/mais4u/mais4u_chat/s4u_prompt.py",
"src/chat/message_receive/bot.py",
"src/chat/replyer/default_generator.py",
"src/chat/express/expression_selector.py",
"src/mais4u/mai_think.py",
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
"src/plugin_system/core/tool_use.py",
"src/chat/memory_system/memory_activator.py",
"src/chat/utils/smart_prompt.py"
]
def update_prompt_imports(file_path):
"""更新文件中的Prompt导入"""
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return False
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 替换导入语句
old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager"
new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager"
if old_import in content:
new_content = content.replace(old_import, new_import)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(new_content)
print(f"已更新: {file_path}")
return True
else:
print(f"无需更新: {file_path}")
return False
def main():
"""主函数"""
updated_count = 0
for file_path in files_to_update:
if update_prompt_imports(file_path):
updated_count += 1
print(f"\n更新完成!共更新了 {updated_count} 个文件")
if __name__ == "__main__":
main()

View File

@@ -3,9 +3,8 @@ import time
import traceback
import math
import random
from typing import Optional, Dict, Any, Tuple
from typing import Dict, Any, Tuple
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.timer_calculator import Timer
from src.common.logger import get_logger
from src.config.config import global_config
@@ -20,10 +19,14 @@ from .hfc_context import HfcContext
from .response_handler import ResponseHandler
from .cycle_tracker import CycleTracker
# 日志记录器
logger = get_logger("hfc.processor")
class CycleProcessor:
"""
循环处理器类,负责处理单次思考循环的逻辑。
"""
def __init__(self, context: HfcContext, response_handler: ResponseHandler, cycle_tracker: CycleTracker):
"""
初始化循环处理器
@@ -52,6 +55,21 @@ class CycleProcessor:
thinking_id,
actions,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
"""
发送并存储回复信息
Args:
response_set: 回复内容集合
loop_start_time: 循环开始时间
action_message: 动作消息
cycle_timers: 循环计时器
thinking_id: 思考ID
actions: 动作列表
Returns:
Tuple[Dict[str, Any], str, Dict[str, float]]: 循环信息, 回复文本, 循环计时器
"""
# 发送回复
with Timer("回复发送", cycle_timers):
reply_text = await self.response_handler.send_response(response_set, loop_start_time, action_message)
@@ -63,6 +81,7 @@ class CycleProcessor:
if platform is None:
platform = getattr(self.context.chat_stream, "platform", "unknown")
# 获取用户信息并生成回复提示
person_id = person_info_manager.get_person_id(
platform,
action_message.get("user_id", ""),
@@ -70,6 +89,7 @@ class CycleProcessor:
person_name = await person_info_manager.get_value(person_id, "person_name")
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
# 存储动作信息到数据库
await database_api.store_action_info(
chat_stream=self.context.chat_stream,
action_build_into_prompt=False,
@@ -95,7 +115,7 @@ class CycleProcessor:
return loop_info, reply_text, cycle_timers
async def observe(self,interest_value:float = 0.0) -> bool:
async def observe(self, interest_value: float = 0.0) -> str:
"""
观察和处理单次思考循环的核心方法
@@ -103,7 +123,7 @@ class CycleProcessor:
interest_value: 兴趣值
Returns:
bool: 处理是否成功
str: 动作类型
功能说明:
- 开始新的思考循环并记录计时
@@ -119,6 +139,15 @@ class CycleProcessor:
# 当interest_value为0时概率接近0使用Focus模式
# 当interest_value很高时概率接近1使用Normal模式
def calculate_normal_mode_probability(interest_val: float) -> float:
"""
计算普通模式的概率
Args:
interest_val: 兴趣值
Returns:
float: 概率
"""
# 使用sigmoid函数调整参数使概率分布更合理
# 当interest_value = 0时概率约为0.1
# 当interest_value = 1时概率约为0.5
@@ -128,21 +157,31 @@ class CycleProcessor:
x0 = 1.0 # 控制曲线中心点
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 0.5 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
# 计算普通模式概率
normal_mode_probability = (
calculate_normal_mode_probability(interest_value)
* 0.5
/ global_config.chat.get_current_talk_frequency(self.context.stream_id)
)
# 根据概率决定使用哪种模式
if random.random() < normal_mode_probability:
mode = ChatMode.NORMAL
logger.info(f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f}选择Normal planner模式")
logger.info(
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f}选择Normal planner模式"
)
else:
mode = ChatMode.FOCUS
logger.info(f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f}选择Focus planner模式")
logger.info(
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f}选择Focus planner模式"
)
# 开始新的思考循环
cycle_timers, thinking_id = self.cycle_tracker.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考")
if ENABLE_S4U:
await send_typing()
await send_typing(self.context.chat_stream.user_info.user_id)
loop_start_time = time.time()
@@ -165,10 +204,14 @@ class CycleProcessor:
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system import EventType
result = await event_manager.trigger_event(EventType.ON_PLAN,plugin_name="SYSTEM", stream_id=self.context.chat_stream)
# 触发规划前事件
result = await event_manager.trigger_event(
EventType.ON_PLAN, plugin_name="SYSTEM", stream_id=self.context.chat_stream
)
if not result.all_continue_process():
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成")
# 规划动作
with Timer("规划器", cycle_timers):
actions, _ = await self.action_planner.plan(
mode=mode,
@@ -179,6 +222,8 @@ class CycleProcessor:
async def execute_action(action_info):
"""执行单个动作的通用函数"""
try:
if action_info["action_type"] == "no_action":
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
if action_info["action_type"] == "no_reply":
# 直接处理no_reply逻辑不再通过动作系统
reason = action_info.get("reasoning", "选择不回复")
@@ -195,13 +240,8 @@ class CycleProcessor:
action_name="no_reply",
)
return {
"action_type": "no_reply",
"success": True,
"reply_text": "",
"command": ""
}
elif action_info["action_type"] != "reply":
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
elif action_info["action_type"] != "reply" and action_info["action_type"] != "no_action":
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
@@ -210,15 +250,16 @@ class CycleProcessor:
action_info["action_data"],
cycle_timers,
thinking_id,
action_info["action_message"]
action_info["action_message"],
)
return {
"action_type": action_info["action_type"],
"success": success,
"reply_text": reply_text,
"command": command
"command": command,
}
else:
# 生成回复
try:
success, response_set, _ = await generator_api.generate_reply(
chat_stream=self.context.chat_stream,
@@ -229,22 +270,15 @@ class CycleProcessor:
from_plugin=False,
)
if not success or not response_set:
logger.info(f"{action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
return {
"action_type": "reply",
"success": False,
"reply_text": "",
"loop_info": None
}
logger.info(
f"{action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
)
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {
"action_type": "reply",
"success": False,
"reply_text": "",
"loop_info": None
}
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
# 发送并存储回复
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
response_set,
loop_start_time,
@@ -253,12 +287,7 @@ class CycleProcessor:
thinking_id,
actions,
)
return {
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info
}
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
@@ -267,7 +296,7 @@ class CycleProcessor:
"success": False,
"reply_text": "",
"loop_info": None,
"error": str(e)
"error": str(e),
}
# 创建所有动作的后台任务
@@ -328,233 +357,16 @@ class CycleProcessor:
}
reply_text = action_reply_text
# 停止正在输入状态
if ENABLE_S4U:
await stop_typing()
# 结束循环
self.context.chat_instance.cycle_tracker.end_cycle(loop_info, cycle_timers)
self.context.chat_instance.cycle_tracker.print_cycle_info(cycle_timers)
action_type = actions[0]["action_type"] if actions else "no_action"
# 管理no_reply计数器当执行了非no_reply动作时重置计数器
if action_type != "no_reply":
# no_reply逻辑已集成到heartFC_chat.py中直接重置计数器
self.context.chat_instance.recent_interest_records.clear()
self.context.no_reply_consecutive = 0
logger.debug(f"{self.log_prefix} 执行了{action_type}动作重置no_reply计数器")
return True
if action_type == "no_reply":
self.context.no_reply_consecutive += 1
self.context.chat_instance._determine_form_type()
# 在一轮动作执行完毕后,增加睡眠压力
if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system:
if action_type not in ["no_reply", "no_action"]:
self.context.energy_manager.increase_sleep_pressure()
return True
async def execute_plan(self, action_result: Dict[str, Any], target_message: Optional[Dict[str, Any]]):
"""
执行一个已经制定好的计划
"""
action_type = action_result.get("action_type", "error")
# 这里我们需要为执行计划创建一个新的循环追踪
cycle_timers, thinking_id = self.cycle_tracker.start_cycle(is_proactive=True)
loop_start_time = time.time()
if action_type == "reply":
# 主动思考不应该直接触发简单回复但为了逻辑完整性我们假设它会调用response_handler
# 注意:这里的 available_actions 和 plan_result 是缺失的,需要根据实际情况处理
await self._handle_reply_action(
target_message, {}, None, loop_start_time, cycle_timers, thinking_id, {"action_result": action_result}
)
else:
await self._handle_other_actions(
action_type,
action_result.get("reasoning", ""),
action_result.get("action_data", {}),
action_result.get("is_parallel", False),
None,
target_message,
cycle_timers,
thinking_id,
{"action_result": action_result},
loop_start_time,
)
async def _handle_reply_action(
self, message_data, available_actions, gen_task, loop_start_time, cycle_timers, thinking_id, plan_result
):
"""
处理回复类型的动作
Args:
message_data: 消息数据
available_actions: 可用动作列表
gen_task: 预先创建的生成任务可能为None
loop_start_time: 循环开始时间
cycle_timers: 循环计时器
thinking_id: 思考ID
plan_result: 规划结果
功能说明:
- 根据聊天模式决定是否使用预生成的回复或实时生成
- 在NORMAL模式下使用异步生成提高效率
- 在FOCUS模式下同步生成确保及时响应
- 发送生成的回复并结束循环
"""
# 初始化reply_to_str以避免UnboundLocalError
reply_to_str = None
if self.context.loop_mode == ChatMode.NORMAL:
if not gen_task:
reply_to_str = await self._build_reply_to_str(message_data)
gen_task = asyncio.create_task(
self.response_handler.generate_response(
message_data=message_data,
available_actions=available_actions,
reply_to=reply_to_str,
request_type="chat.replyer.normal",
)
)
else:
# 如果gen_task已存在但reply_to_str还未构建需要构建它
if reply_to_str is None:
reply_to_str = await self._build_reply_to_str(message_data)
try:
response_set = await asyncio.wait_for(gen_task, timeout=global_config.chat.thinking_timeout)
except asyncio.TimeoutError:
response_set = None
else:
reply_to_str = await self._build_reply_to_str(message_data)
response_set = await self.response_handler.generate_response(
message_data=message_data,
available_actions=available_actions,
reply_to=reply_to_str,
request_type="chat.replyer.focus",
)
if response_set:
loop_info, _, _ = await self.response_handler.generate_and_send_reply(
response_set, reply_to_str, loop_start_time, message_data, cycle_timers, thinking_id, plan_result
)
self.cycle_tracker.end_cycle(loop_info, cycle_timers)
async def _handle_other_actions(
self,
action_type,
reasoning,
action_data,
is_parallel,
gen_task,
action_message,
cycle_timers,
thinking_id,
plan_result,
loop_start_time,
):
"""
处理非回复类型的动作如no_reply、自定义动作等
Args:
action_type: 动作类型
reasoning: 动作理由
action_data: 动作数据
is_parallel: 是否并行执行
gen_task: 生成任务
action_message: 动作消息
cycle_timers: 循环计时器
thinking_id: 思考ID
plan_result: 规划结果
loop_start_time: 循环开始时间
功能说明:
- 在NORMAL模式下可能并行执行回复生成和动作处理
- 等待所有异步任务完成
- 整合回复和动作的执行结果
- 构建最终循环信息并结束循环
"""
background_reply_task = None
if self.context.loop_mode == ChatMode.NORMAL and is_parallel and gen_task:
background_reply_task = asyncio.create_task(
self._handle_parallel_reply(
gen_task, loop_start_time, action_message, cycle_timers, thinking_id, plan_result
)
)
background_action_task = asyncio.create_task(
self._handle_action(action_type, reasoning, action_data, cycle_timers, thinking_id, action_message)
)
reply_loop_info, action_success, action_reply_text, action_command = None, False, "", ""
if background_reply_task:
results = await asyncio.gather(background_reply_task, background_action_task, return_exceptions=True)
reply_result, action_result_val = results
if not isinstance(reply_result, BaseException) and reply_result is not None:
reply_loop_info, _, _ = reply_result
else:
reply_loop_info = None
if not isinstance(action_result_val, BaseException) and action_result_val is not None:
action_success, action_reply_text, action_command = action_result_val
else:
action_success, action_reply_text, action_command = False, "", ""
else:
results = await asyncio.gather(background_action_task, return_exceptions=True)
if results and len(results) > 0:
action_result_val = results[0] # Get the actual result from the tuple
else:
action_result_val = (False, "", "")
if not isinstance(action_result_val, BaseException) and action_result_val is not None:
action_success, action_reply_text, action_command = action_result_val
else:
action_success, action_reply_text, action_command = False, "", ""
loop_info = self._build_final_loop_info(
reply_loop_info, action_success, action_reply_text, action_command, plan_result
)
self.cycle_tracker.end_cycle(loop_info, cycle_timers)
async def _handle_parallel_reply(
self, gen_task, loop_start_time, action_message, cycle_timers, thinking_id, plan_result
):
"""
处理并行回复生成
Args:
gen_task: 回复生成任务
loop_start_time: 循环开始时间
action_message: 动作消息
cycle_timers: 循环计时器
thinking_id: 思考ID
plan_result: 规划结果
Returns:
tuple: (循环信息, 回复文本, 计时器信息) 或 None
功能说明:
- 等待并行回复生成任务完成(带超时)
- 构建回复目标字符串
- 发送生成的回复
- 返回循环信息供上级方法使用
"""
try:
response_set = await asyncio.wait_for(gen_task, timeout=global_config.chat.thinking_timeout)
except asyncio.TimeoutError:
return None, "", {}
if not response_set:
return None, "", {}
reply_to_str = await self._build_reply_to_str(action_message)
return await self.response_handler.generate_and_send_reply(
response_set, reply_to_str, loop_start_time, action_message, cycle_timers, thinking_id, plan_result
)
return action_type
async def _handle_action(
self, action, reasoning, action_data, cycle_timers, thinking_id, action_message
@@ -581,6 +393,7 @@ class CycleProcessor:
if not self.context.chat_stream:
return False, "", ""
try:
# 创建动作处理器
action_handler = self.context.action_manager.create_action(
action_name=action,
action_data=action_data,
@@ -608,7 +421,7 @@ class CycleProcessor:
if fallback_action and fallback_action != action:
logger.info(f"{self.context.log_prefix} 使用回退动作: {fallback_action}")
action_handler = self.context.action_manager.create_action(
action_name=fallback_action,
action_name=fallback_action if isinstance(fallback_action, list) else fallback_action,
action_data=action_data,
reasoning=f"原动作'{action}'不可用,自动回退。{reasoning}",
cycle_timers=cycle_timers,
@@ -622,74 +435,10 @@ class CycleProcessor:
logger.error(f"{self.context.log_prefix} 回退方案也失败,无法创建任何动作处理器")
return False, "", ""
# 执行动作
success, reply_text = await action_handler.handle_action()
return success, reply_text, ""
except Exception as e:
logger.error(f"{self.context.log_prefix} 处理{action}时出错: {e}")
traceback.print_exc()
return False, "", ""
def _get_direct_reply_plan(self, loop_start_time):
"""
获取直接回复的规划结果
Args:
loop_start_time: 循环开始时间
Returns:
dict: 包含直接回复动作的规划结果
功能说明:
- 在某些情况下跳过复杂规划,直接返回回复动作
- 主要用于NORMAL模式下没有其他可用动作时的简化处理
"""
return {
"action_result": {
"action_type": "reply",
"action_data": {"loop_start_time": loop_start_time},
"reasoning": "",
"timestamp": time.time(),
"is_parallel": False,
},
"action_prompt": "",
}
def _build_final_loop_info(self, reply_loop_info, action_success, action_reply_text, action_command, plan_result):
"""
构建最终的循环信息
Args:
reply_loop_info: 回复循环信息可能为None
action_success: 动作执行是否成功
action_reply_text: 动作回复文本
action_command: 动作命令
plan_result: 规划结果
Returns:
dict: 完整的循环信息,包含规划信息和动作信息
功能说明:
- 如果有回复循环信息,则在其基础上添加动作信息
- 如果没有回复信息,则创建新的循环信息结构
- 整合所有执行结果供循环跟踪器记录
"""
if reply_loop_info:
loop_info = reply_loop_info
loop_info["loop_action_info"].update(
{
"action_taken": action_success,
"command": action_command,
"taken_time": time.time(),
}
)
else:
loop_info = {
"loop_plan_info": {"action_result": plan_result.get("action_result", {})},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"command": action_command,
"taken_time": time.time(),
},
}
return loop_info

View File

@@ -91,25 +91,24 @@ class CycleTracker:
# 获取动作类型,兼容新旧格式
action_type = "未知动作"
if hasattr(self, '_current_cycle_detail') and self._current_cycle_detail:
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
loop_plan_info = self._current_cycle_detail.loop_plan_info
if isinstance(loop_plan_info, dict):
action_result = loop_plan_info.get('action_result', {})
action_result = loop_plan_info.get("action_result", {})
if isinstance(action_result, dict):
# 旧格式action_result是字典
action_type = action_result.get('action_type', '未知动作')
action_type = action_result.get("action_type", "未知动作")
elif isinstance(action_result, list) and action_result:
# 新格式action_result是actions列表
action_type = action_result[0].get('action_type', '未知动作')
action_type = action_result[0].get("action_type", "未知动作")
elif isinstance(loop_plan_info, list) and loop_plan_info:
# 直接是actions列表的情况
action_type = loop_plan_info[0].get('action_type', '未知动作')
action_type = loop_plan_info[0].get("action_type", "未知动作")
if self.context.current_cycle_detail.end_time and self.context.current_cycle_detail.start_time:
duration = self.context.current_cycle_detail.end_time - self.context.current_cycle_detail.start_time
logger.info(
f"{self.context.log_prefix}{self.context.current_cycle_detail.cycle_id}次思考,"
f"耗时: {duration:.1f}秒, "
f"选择动作: {action_type}"
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)

View File

@@ -3,10 +3,8 @@ import time
from typing import Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.component_types import ChatMode
from .hfc_context import HfcContext
from src.schedule.schedule_manager import schedule_manager
from src.chat.chat_loop.sleep_manager import sleep_manager
logger = get_logger("hfc")
@@ -75,7 +73,7 @@ class EnergyManager:
continue
# 判断当前是否为睡眠时间
is_sleeping = schedule_manager.is_sleeping()
is_sleeping = sleep_manager.SleepManager().is_sleeping()
if is_sleeping:
# 睡眠中:减少睡眠压力

View File

@@ -2,24 +2,24 @@ import asyncio
import time
import traceback
import random
from typing import Optional, List, Dict, Any, Tuple
from typing import Optional, List, Dict, Any
from collections import deque
from src.common.logger import get_logger
from src.config.config import global_config
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.chat.express.expression_learner import expression_learner_manager
from src.plugin_system.base.component_types import ChatMode
from src.schedule.schedule_manager import schedule_manager, SleepState
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager, SleepState
from src.plugin_system.apis import message_api
from .hfc_context import HfcContext
from .energy_manager import EnergyManager
from .proactive_thinker import ProactiveThinker
from .proactive.proactive_thinker import ProactiveThinker
from .cycle_processor import CycleProcessor
from .response_handler import ResponseHandler
from .cycle_tracker import CycleTracker
from .wakeup_manager import WakeUpManager
from .sleep_manager.wakeup_manager import WakeUpManager
from .proactive.events import ProactiveTriggerEvent
logger = get_logger("hfc")
@@ -46,14 +46,17 @@ class HeartFChatting:
self.energy_manager = EnergyManager(self.context)
self.proactive_thinker = ProactiveThinker(self.context, self.cycle_processor)
self.wakeup_manager = WakeUpManager(self.context)
self.sleep_manager = SleepManager()
# 将唤醒度管理器设置到上下文中
self.context.wakeup_manager = self.wakeup_manager
self.context.energy_manager = self.energy_manager
self.context.sleep_manager = self.sleep_manager
# 将HeartFChatting实例设置到上下文中以便其他组件可以调用其方法
self.context.chat_instance = self
self._loop_task: Optional[asyncio.Task] = None
self._proactive_monitor_task: Optional[asyncio.Task] = None
# 记录最近3次的兴趣度
self.recent_interest_records: deque = deque(maxlen=3)
@@ -93,8 +96,12 @@ class HeartFChatting:
self.context.relationship_builder = relationship_builder_manager.get_or_create_builder(self.context.stream_id)
self.context.expression_learner = expression_learner_manager.get_expression_learner(self.context.stream_id)
#await self.energy_manager.start()
await self.proactive_thinker.start()
# 启动主动思考监视器
if global_config.chat.enable_proactive_thinking:
self._proactive_monitor_task = asyncio.create_task(self._proactive_monitor_loop())
self._proactive_monitor_task.add_done_callback(self._handle_proactive_monitor_completion)
logger.info(f"{self.context.log_prefix} 主动思考监视器已启动")
await self.wakeup_manager.start()
self._loop_task = asyncio.create_task(self._main_chat_loop())
@@ -116,8 +123,12 @@ class HeartFChatting:
return
self.context.running = False
#await self.energy_manager.stop()
await self.proactive_thinker.stop()
# 停止主动思考监视器
if self._proactive_monitor_task and not self._proactive_monitor_task.done():
self._proactive_monitor_task.cancel()
await asyncio.sleep(0)
logger.info(f"{self.context.log_prefix} 主动思考监视器已停止")
await self.wakeup_manager.stop()
if self._loop_task and not self._loop_task.done():
@@ -147,6 +158,151 @@ class HeartFChatting:
except asyncio.CancelledError:
logger.info(f"{self.context.log_prefix} HeartFChatting: 结束了聊天")
def _handle_proactive_monitor_completion(self, task: asyncio.Task):
"""
处理主动思考监视器任务完成
Args:
task: 完成的异步任务对象
功能说明:
- 处理任务异常完成的情况
- 记录任务正常结束或被取消的日志
"""
try:
if exception := task.exception():
logger.error(f"{self.context.log_prefix} 主动思考监视器异常: {exception}")
else:
logger.info(f"{self.context.log_prefix} 主动思考监视器正常结束")
except asyncio.CancelledError:
logger.info(f"{self.context.log_prefix} 主动思考监视器被取消")
async def _proactive_monitor_loop(self):
"""
主动思考监视器循环
功能说明:
- 定期检查是否需要进行主动思考
- 计算聊天沉默时间,并与动态思考间隔比较
- 当沉默时间超过阈值时,触发主动思考
- 处理思考过程中的异常
"""
while self.context.running:
await asyncio.sleep(15)
if not self._should_enable_proactive_thinking():
continue
current_time = time.time()
silence_duration = current_time - self.context.last_message_time
target_interval = self._get_dynamic_thinking_interval()
if silence_duration >= target_interval:
try:
formatted_time = self._format_duration(silence_duration)
event = ProactiveTriggerEvent(
source="silence_monitor",
reason=f"聊天已沉默 {formatted_time}",
metadata={"silence_duration": silence_duration},
)
await self.proactive_thinker.think(event)
self.context.last_message_time = current_time
except Exception as e:
logger.error(f"{self.context.log_prefix} 主动思考触发执行出错: {e}")
logger.error(traceback.format_exc())
def _should_enable_proactive_thinking(self) -> bool:
"""
判断是否应启用主动思考
Returns:
bool: 如果应启用主动思考则返回True否则返回False
功能说明:
- 检查全局配置和特定聊天设置
- 支持按群聊和私聊分别配置
- 支持白名单模式,只在特定聊天中启用
"""
if not self.context.chat_stream:
return False
is_group_chat = self.context.chat_stream.group_info is not None
if is_group_chat and not global_config.chat.proactive_thinking_in_group:
return False
if not is_group_chat and not global_config.chat.proactive_thinking_in_private:
return False
stream_parts = self.context.stream_id.split(":")
current_chat_identifier = f"{stream_parts}:{stream_parts}" if len(stream_parts) >= 2 else self.context.stream_id
enable_list = getattr(
global_config.chat,
"proactive_thinking_enable_in_groups" if is_group_chat else "proactive_thinking_enable_in_private",
[],
)
return not enable_list or current_chat_identifier in enable_list
def _get_dynamic_thinking_interval(self) -> float:
"""
获取动态思考间隔时间
Returns:
float: 思考间隔秒数
功能说明:
- 尝试从timing_utils导入正态分布间隔函数
- 根据配置计算动态间隔,增加随机性
- 在无法导入或计算出错时,回退到固定的间隔
"""
try:
from src.utils.timing_utils import get_normal_distributed_interval
base_interval = global_config.chat.proactive_thinking_interval
delta_sigma = getattr(global_config.chat, "delta_sigma", 120)
if base_interval <= 0:
base_interval = abs(base_interval)
if delta_sigma < 0:
delta_sigma = abs(delta_sigma)
if base_interval == 0 and delta_sigma == 0:
return 300
if delta_sigma == 0:
return base_interval
sigma_percentage = delta_sigma / base_interval if base_interval > 0 else delta_sigma / 1000
return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True)
except ImportError:
logger.warning(f"{self.context.log_prefix} timing_utils不可用使用固定间隔")
return max(300, abs(global_config.chat.proactive_thinking_interval))
except Exception as e:
logger.error(f"{self.context.log_prefix} 动态间隔计算出错: {e},使用固定间隔")
return max(300, abs(global_config.chat.proactive_thinking_interval))
def _format_duration(self, seconds: float) -> str:
"""
格式化时长为可读字符串
Args:
seconds: 时长秒数
Returns:
str: 格式化后的字符串 (例如 "1小时2分3秒")
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
parts = []
if hours > 0:
parts.append(f"{hours}小时")
if minutes > 0:
parts.append(f"{minutes}")
if secs > 0 or not parts:
parts.append(f"{secs}")
return "".join(parts)
async def _main_chat_loop(self):
"""
主聊天循环
@@ -197,8 +353,8 @@ class HeartFChatting:
- NORMAL模式检查进入FOCUS模式的条件并通过normal_mode_handler处理消息
"""
# --- 核心状态更新 ---
await schedule_manager.update_sleep_state(self.wakeup_manager)
current_sleep_state = schedule_manager.get_current_sleep_state()
await self.sleep_manager.update_sleep_state(self.wakeup_manager)
current_sleep_state = self.sleep_manager.get_current_sleep_state()
is_sleeping = current_sleep_state == SleepState.SLEEPING
is_in_insomnia = current_sleep_state == SleepState.INSOMNIA
@@ -228,7 +384,7 @@ class HeartFChatting:
self._handle_wakeup_messages(recent_messages)
# 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP
current_sleep_state = schedule_manager.get_current_sleep_state()
current_sleep_state = self.sleep_manager.get_current_sleep_state()
if current_sleep_state == SleepState.SLEEPING:
# 只有在纯粹的 SLEEPING 状态下才跳过消息处理
@@ -238,113 +394,56 @@ class HeartFChatting:
logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。")
# 根据聊天模式处理新消息
# 统一使用 _should_process_messages 判断是否应该处理
should_process,interest_value = await self._should_process_messages(recent_messages if has_new_messages else None)
if should_process:
self.context.last_read_time = time.time()
await self.cycle_processor.observe(interest_value = interest_value)
else:
# Normal模式消息数量不足等待
should_process, interest_value = await self._should_process_messages(recent_messages)
if not should_process:
# 消息数量不足或兴趣不够,等待
await asyncio.sleep(0.5)
return True
return True # Skip rest of the logic for this iteration
if not await self._should_process_messages(recent_messages if has_new_messages else None):
return has_new_messages
# Messages should be processed
action_type = await self.cycle_processor.observe(interest_value=interest_value)
# 处理新消息
for message in recent_messages:
await self.cycle_processor.observe(interest_value = interest_value)
# 管理no_reply计数器
if action_type != "no_reply":
self.recent_interest_records.clear()
self.context.no_reply_consecutive = 0
logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作重置no_reply计数器")
else: # action_type == "no_reply"
self.context.no_reply_consecutive += 1
self._determine_form_type()
# 在一轮动作执行完毕后,增加睡眠压力
if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system:
if action_type not in ["no_reply", "no_action"]:
self.context.energy_manager.increase_sleep_pressure()
# 如果成功观察,增加能量值并重置累积兴趣值
if has_new_messages:
self.context.energy_value += 1 / global_config.chat.focus_value
# 重置累积兴趣值,因为消息已经被成功处理
self.context.breaking_accumulated_interest = 0.0
logger.info(f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值")
self._check_focus_exit()
else:
# 无新消息时,只进行模式检查,不进行思考循环
self._check_focus_exit()
logger.info(
f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值"
)
# 更新上一帧的睡眠状态
self.context.was_sleeping = is_sleeping
# --- 重新入睡逻辑 ---
# 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡
if schedule_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages:
if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages:
re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60
# 使用 last_message_time 来判断空闲时间
if time.time() - self.context.last_message_time > re_sleep_delay:
logger.info(
f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。"
)
schedule_manager.reset_sleep_state_after_wakeup()
self.sleep_manager.reset_sleep_state_after_wakeup()
# 保存HFC上下文状态
self.context.save_context_state()
return has_new_messages
def _check_focus_exit(self):
"""
检查是否应该退出FOCUS模式
功能说明:
- 区分私聊和群聊环境
- 在强制私聊focus模式下能量值低于1时重置为5但不退出
- 在群聊focus模式下如果配置为focus则不退出
- 其他情况下能量值低于1时退出到NORMAL模式
"""
is_private_chat = self.context.chat_stream.group_info is None if self.context.chat_stream else False
is_group_chat = not is_private_chat
if global_config.chat.force_focus_private and is_private_chat:
if self.context.energy_value <= 1:
self.context.energy_value = 5
return
if is_group_chat and global_config.chat.group_chat_mode == "focus":
return
if self.context.energy_value <= 1: # 如果能量值小于等于1非强制情况
self.context.energy_value = 1 # 将能量值设置为1
def _check_focus_entry(self, new_message_count: int):
"""
检查是否应该进入FOCUS模式
Args:
new_message_count: 新消息数量
功能说明:
- 区分私聊和群聊环境
- 强制私聊focus模式直接进入FOCUS模式并设置能量值为10
- 群聊normal模式不进入FOCUS模式
- 根据focus_value配置和消息数量决定是否进入FOCUS模式
- 当消息数量超过阈值或能量值达到30时进入FOCUS模式
"""
is_private_chat = self.context.chat_stream.group_info is None if self.context.chat_stream else False
is_group_chat = not is_private_chat
if global_config.chat.force_focus_private and is_private_chat:
self.context.energy_value = 10
return
if is_group_chat and global_config.chat.group_chat_mode == "normal":
return
if global_config.chat.focus_value != 0: # 如果专注值配置不为0启用自动专注
if new_message_count > 3 / pow(
global_config.chat.focus_value, 0.5
): # 如果新消息数超过阈值(基于专注值计算)
self.context.energy_value = (
10 + (new_message_count / (3 / pow(global_config.chat.focus_value, 0.5))) * 10
) # 根据消息数量计算能量值
return # 返回,不再检查其他条件
def _handle_wakeup_messages(self, messages):
"""
处理休眠状态下的消息,累积唤醒度
@@ -382,9 +481,16 @@ class HeartFChatting:
def _determine_form_type(self) -> str:
"""判断使用哪种形式的no_reply"""
# 检查是否启用breaking模式
if not getattr(global_config.chat, "enable_breaking_mode", False):
logger.info(f"{self.context.log_prefix} breaking模式已禁用使用waiting形式")
self.context.focus_energy = 1
return "waiting"
# 如果连续no_reply次数少于3次使用waiting形式
if self.context.no_reply_consecutive <= 3:
self.context.focus_energy = 1
return "waiting"
else:
# 使用累积兴趣值而不是最近3次的记录
total_interest = self.context.breaking_accumulated_interest
@@ -392,24 +498,31 @@ class HeartFChatting:
# 计算调整后的阈值
adjusted_threshold = 1 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
logger.info(f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
logger.info(
f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
)
# 如果累积兴趣值小于阈值进入breaking形式
if total_interest < adjusted_threshold:
logger.info(f"{self.context.log_prefix} 累积兴趣度不足进入breaking形式")
self.context.focus_energy = random.randint(3, 6)
return "breaking"
else:
logger.info(f"{self.context.log_prefix} 累积兴趣度充足使用waiting形式")
self.context.focus_energy = 1
return "waiting"
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool, float]:
"""
统一判断是否应该处理消息的函数
根据当前循环模式和消息内容决定是否继续处理
"""
if not new_message:
return False, 0.0
new_message_count = len(new_message)
talk_frequency = global_config.chat.get_current_talk_frequency(self.context.chat_stream.stream_id)
talk_frequency = global_config.chat.get_current_talk_frequency(self.context.stream_id)
modified_exit_count_threshold = self.context.focus_energy * 0.5 / talk_frequency
modified_exit_interest_threshold = 1.5 / talk_frequency
@@ -443,7 +556,9 @@ class HeartFChatting:
if new_message_count > 0:
# 只在兴趣值变化时输出log
if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest:
logger.info(f"{self.context.log_prefix} breaking形式当前累积兴趣值: {total_interest:.2f}, 专注度: {global_config.chat.focus_value:.1f}")
logger.info(
f"{self.context.log_prefix} breaking形式当前累积兴趣值: {total_interest:.2f}, 专注度: {global_config.chat.focus_value:.1f}"
)
self._last_accumulated_interest = total_interest
if total_interest >= modified_exit_interest_threshold:
# 记录兴趣度到列表
@@ -456,64 +571,13 @@ class HeartFChatting:
return True, total_interest / new_message_count
# 每10秒输出一次等待状态
if int(time.time() - self.context.last_read_time) > 0 and int(time.time() - self.context.last_read_time) % 10 == 0:
if (
int(time.time() - self.context.last_read_time) > 0
and int(time.time() - self.context.last_read_time) % 10 == 0
):
logger.info(
f"{self.context.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,累积兴趣{total_interest:.1f},继续等待..."
f"{self.context.log_prefix} 已等待{time.time() - self.context.last_read_time:.0f}秒,累计{new_message_count}条消息,累积兴趣{total_interest:.1f},继续等待..."
)
await asyncio.sleep(0.5)
return False, 0.0
async def _execute_no_reply(self, new_message: List[Dict[str, Any]]) -> bool:
"""执行breaking形式的no_reply原有逻辑"""
new_message_count = len(new_message)
# 检查消息数量是否达到阈值
talk_frequency = global_config.chat.get_current_talk_frequency(self.context.stream_id)
modified_exit_count_threshold = self.context.focus_energy / talk_frequency
if new_message_count >= modified_exit_count_threshold:
# 记录兴趣度到列表
total_interest = 0.0
for msg_dict in new_message:
interest_value = msg_dict.get("interest_value", 0.0)
if msg_dict.get("processed_plain_text", ""):
total_interest += interest_value
self.recent_interest_records.append(total_interest)
logger.info(
f"{self.context.log_prefix} 累计消息数量达到{new_message_count}条(>{modified_exit_count_threshold}),结束等待"
)
return True
# 检查累计兴趣值
if new_message_count > 0:
accumulated_interest = 0.0
for msg_dict in new_message:
text = msg_dict.get("processed_plain_text", "")
interest_value = msg_dict.get("interest_value", 0.0)
if text:
accumulated_interest += interest_value
# 只在兴趣值变化时输出log
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
logger.info(f"{self.context.log_prefix} breaking形式当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
self._last_accumulated_interest = accumulated_interest
if accumulated_interest >= 3 / talk_frequency:
# 记录兴趣度到列表
self.recent_interest_records.append(accumulated_interest)
logger.info(
f"{self.context.log_prefix} 累计兴趣值达到{accumulated_interest:.2f}(>{3 / talk_frequency}),结束等待"
)
return True
# 每10秒输出一次等待状态
if int(time.time() - self.context.last_read_time) > 0 and int(time.time() - self.context.last_read_time) % 10 == 0:
logger.info(
f"{self.context.log_prefix} 已等待{time.time() - self.context.last_read_time:.0f}秒,累计{new_message_count}条消息,继续等待..."
)
return False

View File

@@ -1,16 +1,16 @@
from typing import List, Optional, TYPE_CHECKING
import time
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.common.logger import get_logger
from src.person_info.relationship_builder_manager import RelationshipBuilder
from src.chat.express.expression_learner import ExpressionLearner
from src.plugin_system.base.component_types import ChatMode
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.chat_loop.hfc_utils import CycleDetail
if TYPE_CHECKING:
from .wakeup_manager import WakeUpManager
from .sleep_manager.wakeup_manager import WakeUpManager
from .energy_manager import EnergyManager
from .heartFC_chat import HeartFChatting
from .sleep_manager.sleep_manager import SleepManager
class HfcContext:
@@ -49,7 +49,7 @@ class HfcContext:
self.last_read_time = time.time() - 10
# 从聊天流恢复breaking累积兴趣值
self.breaking_accumulated_interest = getattr(self.chat_stream, 'breaking_accumulated_interest', 0.0)
self.breaking_accumulated_interest = getattr(self.chat_stream, "breaking_accumulated_interest", 0.0)
self.action_manager = ActionManager()
@@ -62,6 +62,7 @@ class HfcContext:
# 唤醒度管理器 - 延迟初始化以避免循环导入
self.wakeup_manager: Optional["WakeUpManager"] = None
self.energy_manager: Optional["EnergyManager"] = None
self.sleep_manager: Optional["SleepManager"] = None
self.focus_energy = 1
self.no_reply_consecutive = 0
@@ -69,7 +70,7 @@ class HfcContext:
# breaking形式下的累积兴趣值
self.breaking_accumulated_interest = 0.0
# 引用HeartFChatting实例以便其他组件可以调用其方法
self.chat_instance = None
self.chat_instance: "HeartFChatting"
def save_context_state(self):
"""将当前状态保存到聊天流"""

View File

@@ -1,13 +1,11 @@
import time
from typing import Optional, Dict, Any, Union
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.plugin_system.apis import send_api
from maim_message.message_base import GroupInfo
from src.common.message_repository import count_messages
logger = get_logger("hfc")
@@ -123,43 +121,7 @@ class CycleDetail:
self.loop_action_info = loop_info["loop_action_info"]
def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None) -> dict:
"""
获取最近消息统计信息
Args:
minutes: 检索的分钟数默认30分钟
chat_id: 指定的chat_id仅统计该chat下的消息。为None时统计全部
Returns:
dict: {"bot_reply_count": int, "total_message_count": int}
功能说明:
- 统计指定时间范围内的消息数量
- 区分机器人回复和总消息数
- 可以针对特定聊天或全局统计
- 用于分析聊天活跃度和机器人参与度
"""
now = time.time()
start_time = now - minutes * 60
bot_id = global_config.bot.qq_account
filter_base: Dict[str, Any] = {"time": {"$gte": start_time}}
if chat_id is not None:
filter_base["chat_id"] = chat_id
# 总消息数
total_message_count = count_messages(filter_base)
# bot自身回复数
bot_filter = filter_base.copy()
bot_filter["user_id"] = bot_id
bot_reply_count = count_messages(bot_filter)
return {"bot_reply_count": bot_reply_count, "total_message_count": total_message_count}
async def send_typing():
async def send_typing(user_id):
"""
发送打字状态指示
@@ -177,6 +139,11 @@ async def send_typing():
group_info=group_info,
)
from plugin_system.core.event_manager import event_manager
from src.plugins.built_in.napcat_adapter_plugin.event_types import NapcatEvent
# 设置正在输入状态
await event_manager.trigger_event(NapcatEvent.PERSONAL.SET_INPUT_STATUS,user_id=user_id,event_type=1)
await send_api.custom_to_stream(
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
)

View File

@@ -0,0 +1,13 @@
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
@dataclass
class ProactiveTriggerEvent:
"""
主动思考触发事件的数据类
"""
source: str # 触发源的标识,例如 "silence_monitor", "insomnia_manager"
reason: str # 触发的具体原因,例如 "聊天已沉默10分钟", "深夜emo"
metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # 可选的元数据,用于传递额外信息

View File

@@ -0,0 +1,253 @@
import time
import traceback
import orjson
from typing import TYPE_CHECKING, Dict, Any
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ChatMode
from ..hfc_context import HfcContext
from .events import ProactiveTriggerEvent
from src.plugin_system.apis import generator_api
from src.plugin_system.apis.generator_api import process_human_text
from src.schedule.schedule_manager import schedule_manager
from src.plugin_system import tool_api
from src.plugin_system.base.component_types import ComponentType
from src.config.config import global_config
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages_with_id
from src.mood.mood_manager import mood_manager
from src.common.database.sqlalchemy_database_api import store_action_info
if TYPE_CHECKING:
from ..cycle_processor import CycleProcessor
logger = get_logger("hfc")
class ProactiveThinker:
"""
主动思考器,负责处理和执行主动思考事件。
当接收到 ProactiveTriggerEvent 时,它会根据事件内容进行一系列决策和操作,
例如调整情绪、调用规划器生成行动,并最终可能产生一个主动的回复。
"""
def __init__(self, context: HfcContext, cycle_processor: "CycleProcessor"):
"""
初始化主动思考器。
Args:
context (HfcContext): HFC聊天上下文对象提供了当前聊天会话的所有背景信息。
cycle_processor (CycleProcessor): 循环处理器,用于执行主动思考后产生的动作。
功能说明:
- 接收并处理主动思考事件 (ProactiveTriggerEvent)。
- 在思考前根据事件类型执行预处理操作,如修改当前情绪状态。
- 调用行动规划器 (Action Planner) 来决定下一步应该做什么。
- 如果规划结果是发送消息则调用生成器API生成回复并发送。
"""
self.context = context
self.cycle_processor = cycle_processor
async def think(self, trigger_event: ProactiveTriggerEvent):
"""
主动思考的统一入口API。
这是外部触发主动思考时调用的主要方法。
Args:
trigger_event (ProactiveTriggerEvent): 描述触发上下文的事件对象,包含了思考的来源和原因。
"""
logger.info(
f"{self.context.log_prefix} 接收到主动思考事件: "
f"来源='{trigger_event.source}', 原因='{trigger_event.reason}'"
)
try:
# 步骤 1: 根据事件类型执行思考前的准备工作,例如调整情绪。
await self._prepare_for_thinking(trigger_event)
# 步骤 2: 执行核心的思考和决策逻辑。
await self._execute_proactive_thinking(trigger_event)
except Exception as e:
# 捕获并记录在思考过程中发生的任何异常。
logger.error(f"{self.context.log_prefix} 主动思考 think 方法执行异常: {e}")
logger.error(traceback.format_exc())
async def _prepare_for_thinking(self, trigger_event: ProactiveTriggerEvent):
"""
根据事件类型,在正式思考前执行准备工作。
目前主要是处理来自失眠管理器的事件,并据此调整情绪。
Args:
trigger_event (ProactiveTriggerEvent): 触发事件。
"""
# 目前只处理来自失眠管理器(insomnia_manager)的事件
if trigger_event.source != "insomnia_manager":
return
try:
# 获取当前聊天的情绪对象
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
new_mood = None
# 根据失眠的不同原因设置对应的情绪
if trigger_event.reason == "low_pressure":
new_mood = "精力过剩,毫无睡意"
elif trigger_event.reason == "random":
new_mood = "深夜emo胡思乱想"
elif trigger_event.reason == "goodnight":
new_mood = "有点困了,准备睡觉了"
# 如果成功匹配到了新的情绪,则更新情绪状态
if new_mood:
mood_obj.mood_state = new_mood
mood_obj.last_change_time = time.time()
logger.info(
f"{self.context.log_prefix}'{trigger_event.reason}'"
f"情绪状态被强制更新为: {mood_obj.mood_state}"
)
except Exception as e:
logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}")
async def _execute_proactive_thinking(self, trigger_event: ProactiveTriggerEvent):
"""
执行主动思考的核心逻辑。
它会调用规划器来决定是否要采取行动,以及采取什么行动。
Args:
trigger_event (ProactiveTriggerEvent): 触发事件。
"""
try:
# 调用规划器的 PROACTIVE 模式,让其决定下一步的行动
actions, _ = await self.cycle_processor.action_planner.plan(mode=ChatMode.PROACTIVE)
# 通常只关心规划出的第一个动作
action_result = actions[0] if actions else {}
action_type = action_result.get("action_type")
if action_type == "proactive_reply":
await self._generate_proactive_content_and_send(action_result)
elif action_type != "do_nothing":
logger.warning(f"{self.context.log_prefix} 主动思考返回了未知的动作类型: {action_type}")
else:
# 如果规划结果是“什么都不做”,则记录日志
logger.info(f"{self.context.log_prefix} 主动思考决策: 保持沉默")
except Exception as e:
logger.error(f"{self.context.log_prefix} 主动思考执行异常: {e}")
logger.error(traceback.format_exc())
async def _generate_proactive_content_and_send(self, action_result: Dict[str, Any]):
"""
获取实时信息,构建最终的生成提示词,并生成和发送主动回复。
Args:
action_result (Dict[str, Any]): 规划器返回的动作结果。
"""
try:
topic = action_result.get("action_data", {}).get("topic", "随便聊聊")
logger.info(f"{self.context.log_prefix} 主动思考确定主题: '{topic}'")
# 1. 获取日程信息
schedule_block = "你今天没有日程安排。"
if global_config.planning_system.schedule_enable:
if current_activity := schedule_manager.get_current_activity():
schedule_block = f"你当前正在:{current_activity}"
# 2. 网络搜索
news_block = "暂时没有获取到最新资讯。"
try:
web_search_tool = tool_api.get_tool_instance("web_search")
if web_search_tool:
tool_args = {"query": topic, "max_results": 10}
# 调用工具,并传递参数
search_result_dict = await web_search_tool.execute(**tool_args)
if search_result_dict and not search_result_dict.get("error"):
news_block = search_result_dict.get("content", "未能提取有效资讯。")
else:
logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}")
else:
logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。")
except Exception as e:
logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}")
# 3. 获取最新的聊天上下文
message_list = get_raw_msg_before_timestamp_with_chat(
chat_id=self.context.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.3),
)
chat_context_block, _ = build_readable_messages_with_id(messages=message_list)
# 4. 构建最终的生成提示词
bot_name = global_config.bot.nickname
personality = global_config.personality
identity_block = (
f"你的名字是{bot_name}\n"
f"关于你:{personality.personality_core},并且{personality.personality_side}\n"
f"你的身份是{personality.identity},平时说话风格是{personality.reply_style}"
)
mood_block = f"你现在的心情是:{mood_manager.get_mood_by_chat_id(self.context.stream_id).mood_state}"
final_prompt = f"""
## 你的角色
{identity_block}
## 你的心情
{mood_block}
## 你今天的日程安排
{schedule_block}
## 关于你准备讨论的话题“{topic}”的最新信息
{news_block}
## 最近的聊天内容
{chat_context_block}
## 任务
你之前决定要发起一个关于“{topic}”的对话。现在,请结合以上所有信息,自然地开启这个话题。
## 要求
- 你的发言要听起来像是自发的,而不是在念报告。
- 巧妙地将日程安排或最新信息融入到你的开场白中。
- 风格要符合你的角色设定。
- 直接输出你想要说的内容,不要包含其他额外信息。
你的回复应该:
1. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。
2. 目的是让对话更有趣、更深入。
3. 不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。
最终请输出一条简短、完整且口语化的回复。
"""
# 5. 调用生成器API并发送
response_text = await generator_api.generate_response_custom(
chat_stream=self.context.chat_stream,
prompt=final_prompt,
request_type="chat.replyer.proactive",
)
if response_text:
response_set = process_human_text(
content=response_text,
enable_splitter=global_config.response_splitter.enable,
enable_chinese_typo=global_config.chinese_typo.enable,
)
await self.cycle_processor.response_handler.send_response(
response_set, time.time(), action_result.get("action_message")
)
await store_action_info(
chat_stream=self.context.chat_stream,
action_name="proactive_reply",
action_data={"topic": topic, "response": response_text},
action_prompt_display=f"主动发起对话: {topic}",
action_done=True,
)
else:
logger.error(f"{self.context.log_prefix} 主动思考生成回复失败。")
except Exception as e:
logger.error(f"{self.context.log_prefix} 生成主动回复内容时异常: {e}")
logger.error(traceback.format_exc())

View File

@@ -1,353 +0,0 @@
import asyncio
import time
import traceback
from typing import Optional, TYPE_CHECKING
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.base.component_types import ChatMode
from .hfc_context import HfcContext
if TYPE_CHECKING:
from .cycle_processor import CycleProcessor
logger = get_logger("hfc")
class ProactiveThinker:
def __init__(self, context: HfcContext, cycle_processor: "CycleProcessor"):
"""
初始化主动思考器
Args:
context: HFC聊天上下文对象
cycle_processor: 循环处理器,用于执行主动思考的结果
功能说明:
- 管理机器人的主动发言功能
- 根据沉默时间和配置触发主动思考
- 提供私聊和群聊不同的思考提示模板
- 使用3-sigma规则计算动态思考间隔
"""
self.context = context
self.cycle_processor = cycle_processor
self._proactive_thinking_task: Optional[asyncio.Task] = None
self.proactive_thinking_prompts = {
"private": """现在你和你朋友的私聊里面已经隔了{time}没有发送消息了,请你结合上下文以及你和你朋友之前聊过的话题和你的人设来决定要不要主动发送消息,你可以选择:
1. 继续保持沉默(当{time}以前已经结束了一个话题并且你不想挑起新话题时)
2. 选择回复(当{time}以前你发送了一条消息且没有人回复你时、你想主动挑起一个话题时)
请根据当前情况做出选择。如果选择回复,请直接发送你想说的内容;如果选择保持沉默,请只回复"沉默"(注意:这个词不会被发送到群聊中)。""",
"group": """现在群里面已经隔了{time}没有人发送消息了,请你结合上下文以及群聊里面之前聊过的话题和你的人设来决定要不要主动发送消息,你可以选择:
1. 继续保持沉默(当{time}以前已经结束了一个话题并且你不想挑起新话题时)
2. 选择回复(当{time}以前你发送了一条消息且没有人回复你时、你想主动挑起一个话题时)
请根据当前情况做出选择。如果选择回复,请直接发送你想说的内容;如果选择保持沉默,请只回复"沉默"(注意:这个词不会被发送到群聊中)。""",
}
async def start(self):
"""
启动主动思考器
功能说明:
- 检查运行状态和配置,避免重复启动
- 只有在启用主动思考功能时才启动
- 创建主动思考循环异步任务
- 设置任务完成回调处理
- 记录启动日志
"""
if self.context.running and not self._proactive_thinking_task and global_config.chat.enable_proactive_thinking:
self._proactive_thinking_task = asyncio.create_task(self._proactive_thinking_loop())
self._proactive_thinking_task.add_done_callback(self._handle_proactive_thinking_completion)
logger.info(f"{self.context.log_prefix} 主动思考器已启动")
async def stop(self):
"""
停止主动思考器
功能说明:
- 取消正在运行的主动思考任务
- 等待任务完全停止
- 记录停止日志
"""
if self._proactive_thinking_task and not self._proactive_thinking_task.done():
self._proactive_thinking_task.cancel()
await asyncio.sleep(0)
logger.info(f"{self.context.log_prefix} 主动思考器已停止")
def _handle_proactive_thinking_completion(self, task: asyncio.Task):
"""
处理主动思考任务完成
Args:
task: 完成的异步任务对象
功能说明:
- 处理任务正常完成或异常情况
- 记录相应的日志信息
- 区分取消和异常终止的情况
"""
try:
if exception := task.exception():
logger.error(f"{self.context.log_prefix} 主动思考循环异常: {exception}")
else:
logger.info(f"{self.context.log_prefix} 主动思考循环正常结束")
except asyncio.CancelledError:
logger.info(f"{self.context.log_prefix} 主动思考循环被取消")
async def _proactive_thinking_loop(self):
"""
主动思考的主循环
功能说明:
- 每15秒检查一次是否需要主动思考
- 只在FOCUS模式下进行主动思考
- 检查是否启用主动思考功能
- 计算沉默时间并与动态间隔比较
- 达到条件时执行主动思考并更新最后消息时间
- 处理执行过程中的异常
"""
while self.context.running:
await asyncio.sleep(15)
if self.context.loop_mode != ChatMode.FOCUS:
continue
if not self._should_enable_proactive_thinking():
continue
current_time = time.time()
silence_duration = current_time - self.context.last_message_time
target_interval = self._get_dynamic_thinking_interval()
if silence_duration >= target_interval:
try:
await self._execute_proactive_thinking(silence_duration)
self.context.last_message_time = current_time
except Exception as e:
logger.error(f"{self.context.log_prefix} 主动思考执行出错: {e}")
logger.error(traceback.format_exc())
def _should_enable_proactive_thinking(self) -> bool:
"""
检查是否应该启用主动思考
Returns:
bool: 如果应该启用主动思考则返回True
功能说明:
- 检查聊天流是否存在
- 检查当前聊天是否在启用列表中(按平台和类型分别检查)
- 根据聊天类型(群聊/私聊)和配置决定是否启用
- 群聊需要proactive_thinking_in_group为True
- 私聊需要proactive_thinking_in_private为True
"""
if not self.context.chat_stream:
return False
is_group_chat = self.context.chat_stream.group_info is not None
# 检查基础开关
if is_group_chat and not global_config.chat.proactive_thinking_in_group:
return False
if not is_group_chat and not global_config.chat.proactive_thinking_in_private:
return False
# 获取当前聊天的完整标识 (platform:chat_id)
stream_parts = self.context.stream_id.split(":")
if len(stream_parts) >= 2:
platform = stream_parts[0]
chat_id = stream_parts[1]
current_chat_identifier = f"{platform}:{chat_id}"
else:
# 如果无法解析则使用原始stream_id
current_chat_identifier = self.context.stream_id
# 检查是否在启用列表中
if is_group_chat:
# 群聊检查
enable_list = getattr(global_config.chat, "proactive_thinking_enable_in_groups", [])
if enable_list and current_chat_identifier not in enable_list:
return False
else:
# 私聊检查
enable_list = getattr(global_config.chat, "proactive_thinking_enable_in_private", [])
if enable_list and current_chat_identifier not in enable_list:
return False
return True
def _get_dynamic_thinking_interval(self) -> float:
"""
获取动态思考间隔
Returns:
float: 计算得出的思考间隔时间(秒)
功能说明:
- 使用3-sigma规则计算正态分布的思考间隔
- 基于base_interval和delta_sigma配置计算
- 处理特殊情况为0或负数的配置
- 如果timing_utils不可用则使用固定间隔
- 间隔范围被限制在1秒到86400秒1天之间
"""
try:
from src.utils.timing_utils import get_normal_distributed_interval
base_interval = global_config.chat.proactive_thinking_interval
delta_sigma = getattr(global_config.chat, "delta_sigma", 120)
if base_interval < 0:
base_interval = abs(base_interval)
if delta_sigma < 0:
delta_sigma = abs(delta_sigma)
if base_interval == 0 and delta_sigma == 0:
return 300
elif base_interval == 0:
sigma_percentage = delta_sigma / 1000
return get_normal_distributed_interval(0, sigma_percentage, 1, 86400, use_3sigma_rule=True)
elif delta_sigma == 0:
return base_interval
sigma_percentage = delta_sigma / base_interval
return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True)
except ImportError:
logger.warning(f"{self.context.log_prefix} timing_utils不可用使用固定间隔")
return max(300, abs(global_config.chat.proactive_thinking_interval))
except Exception as e:
logger.error(f"{self.context.log_prefix} 动态间隔计算出错: {e},使用固定间隔")
return max(300, abs(global_config.chat.proactive_thinking_interval))
def _format_duration(self, seconds: float) -> str:
"""
格式化持续时间为中文描述
Args:
seconds: 持续时间(秒)
Returns:
str: 格式化后的时间字符串,如"1小时30分45秒"
功能说明:
- 将秒数转换为小时、分钟、秒的组合
- 只显示非零的时间单位
- 如果所有单位都为0则显示"0秒"
- 用于主动思考日志的时间显示
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
parts = []
if hours > 0:
parts.append(f"{hours}小时")
if minutes > 0:
parts.append(f"{minutes}")
if secs > 0 or not parts:
parts.append(f"{secs}")
return "".join(parts)
async def _execute_proactive_thinking(self, silence_duration: float):
"""
执行主动思考
Args:
silence_duration: 沉默持续时间(秒)
"""
formatted_time = self._format_duration(silence_duration)
logger.info(f"{self.context.log_prefix} 触发主动思考,已沉默{formatted_time}")
try:
# 直接调用 planner 的 PROACTIVE 模式
action_result_tuple, target_message = await self.cycle_processor.action_planner.plan(
mode=ChatMode.PROACTIVE
)
action_result = action_result_tuple.get("action_result")
# 如果决策不是 do_nothing则执行
if action_result and action_result.get("action_type") != "do_nothing":
logger.info(f"{self.context.log_prefix} 主动思考决策: {action_result.get('action_type')}, 原因: {action_result.get('reasoning')}")
# 在主动思考时,如果 target_message 为 None则默认选取最新 message 作为 target_message
if target_message is None and self.context.chat_stream and self.context.chat_stream.context:
from src.chat.message_receive.message import MessageRecv
latest_message = self.context.chat_stream.context.get_last_message()
if isinstance(latest_message, MessageRecv):
user_info = latest_message.message_info.user_info
target_message = {
"chat_info_platform": latest_message.message_info.platform,
"user_platform": user_info.platform if user_info else None,
"user_id": user_info.user_id if user_info else None,
"processed_plain_text": latest_message.processed_plain_text,
"is_mentioned": latest_message.is_mentioned,
}
# 将决策结果交给 cycle_processor 的后续流程处理
await self.cycle_processor.execute_plan(action_result, target_message)
else:
logger.info(f"{self.context.log_prefix} 主动思考决策: 保持沉默")
except Exception as e:
logger.error(f"{self.context.log_prefix} 主动思考执行异常: {e}")
logger.error(traceback.format_exc())
async def trigger_insomnia_thinking(self, reason: str):
"""
由外部事件(如失眠)触发的一次性主动思考
Args:
reason: 触发的原因 (e.g., "low_pressure", "random")
"""
logger.info(f"{self.context.log_prefix} 因“{reason}”触发失眠,开始深夜思考...")
# 1. 根据原因修改情绪
try:
from src.mood.mood_manager import mood_manager
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
if reason == "low_pressure":
mood_obj.mood_state = "精力过剩,毫无睡意"
elif reason == "random":
mood_obj.mood_state = "深夜emo胡思乱想"
mood_obj.last_change_time = time.time() # 更新时间戳以允许后续的情绪回归
logger.info(f"{self.context.log_prefix} 因失眠,情绪状态被强制更新为: {mood_obj.mood_state}")
except Exception as e:
logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}")
# 2. 直接执行主动思考逻辑
try:
# 传入一个象征性的silence_duration因为它在这里不重要
await self._execute_proactive_thinking(silence_duration=1)
except Exception as e:
logger.error(f"{self.context.log_prefix} 失眠思考执行出错: {e}")
logger.error(traceback.format_exc())
async def trigger_goodnight_thinking(self):
"""
在失眠状态结束后,触发一次准备睡觉的主动思考
"""
logger.info(f"{self.context.log_prefix} 失眠状态结束,准备睡觉,触发告别思考...")
# 1. 设置一个准备睡觉的特定情绪
try:
from src.mood.mood_manager import mood_manager
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
mood_obj.mood_state = "有点困了,准备睡觉了"
mood_obj.last_change_time = time.time()
logger.info(f"{self.context.log_prefix} 情绪状态更新为: {mood_obj.mood_state}")
except Exception as e:
logger.error(f"{self.context.log_prefix} 设置睡前情绪时出错: {e}")
# 2. 直接执行主动思考逻辑
try:
await self._execute_proactive_thinking(silence_duration=1)
except Exception as e:
logger.error(f"{self.context.log_prefix} 睡前告别思考执行出错: {e}")
logger.error(traceback.format_exc())

View File

@@ -1,24 +1,24 @@
import time
import orjson
import random
import traceback
from typing import Optional, Dict, Any, Tuple
from typing import Dict, Any, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.plugin_system.apis import send_api, message_api, database_api
from src.person_info.person_info import get_person_info_manager
from .hfc_context import HfcContext
# 导入反注入系统
from src.chat.antipromptinjector import get_anti_injector
from src.chat.antipromptinjector.types import ProcessResult
from src.chat.utils.prompt_builder import Prompt
# 日志记录器
logger = get_logger("hfc")
anti_injector_logger = get_logger("anti_injector")
class ResponseHandler:
"""
响应处理器类,负责生成和发送机器人的回复。
"""
def __init__(self, context: HfcContext):
"""
初始化响应处理器
@@ -68,6 +68,7 @@ class ResponseHandler:
person_info_manager = get_person_info_manager()
# 获取平台信息
platform = "default"
if self.context.chat_stream:
platform = (
@@ -76,11 +77,13 @@ class ResponseHandler:
or self.context.chat_stream.platform
)
# 获取用户信息并生成回复提示
user_id = action_message.get("user_id", "")
person_id = person_info_manager.get_person_id(platform, user_id)
person_name = await person_info_manager.get_value(person_id, "person_name")
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
# 存储动作信息到数据库
await database_api.store_action_info(
chat_stream=self.context.chat_stream,
action_build_into_prompt=False,
@@ -91,6 +94,7 @@ class ResponseHandler:
action_name="reply",
)
# 构建循环信息
loop_info: Dict[str, Any] = {
"loop_plan_info": {
"action_result": plan_result.get("action_result", {}),
@@ -126,10 +130,12 @@ class ResponseHandler:
- 正确处理元组格式的回复段
"""
current_time = time.time()
# 计算新消息数量
new_message_count = message_api.count_new_messages(
chat_id=self.context.stream_id, start_time=thinking_start_time, end_time=current_time
)
# 根据新消息数量决定是否需要引用回复
need_reply = new_message_count >= random.randint(2, 4)
reply_text = ""
@@ -147,12 +153,16 @@ class ResponseHandler:
# 向下兼容:如果已经是字符串,则直接使用
data = str(reply_seg)
if isinstance(data, list):
data = "".join(map(str, data))
reply_text += data
# 如果是主动思考且内容为“沉默”,则不发送
if is_proactive_thinking and data.strip() == "沉默":
logger.info(f"{self.context.log_prefix} 主动思考决定保持沉默,不发送消息")
continue
# 发送第一段回复
if not first_replied:
await send_api.text_to_stream(
text=data,
@@ -163,7 +173,8 @@ class ResponseHandler:
)
first_replied = True
else:
await send_api.text_to_stream(
# 发送后续回复
sent_message = await send_api.text_to_stream(
text=data,
stream_id=self.context.stream_id,
reply_to_message=None,
@@ -172,101 +183,3 @@ class ResponseHandler:
)
return reply_text
# TODO: 已废弃
async def generate_response(
self,
message_data: dict,
available_actions: Optional[Dict[str, Any]],
reply_to: str,
request_type: str = "chat.replyer.normal",
) -> Optional[list]:
"""
生成回复内容
Args:
message_data: 消息数据
available_actions: 可用动作列表
reply_to: 回复目标
request_type: 请求类型,默认为普通回复
Returns:
list: 生成的回复内容列表失败时返回None
功能说明:
- 在生成回复前进行反注入检测(提高效率)
- 调用生成器API生成回复
- 根据配置启用或禁用工具功能
- 处理生成失败的情况
- 记录生成过程中的错误和异常
"""
try:
# === 反注入检测(仅在需要生成回复时) ===
# 执行反注入检测(直接使用字典格式)
anti_injector = get_anti_injector()
result, modified_content, reason = await anti_injector.process_message(
message_data, self.context.chat_stream
)
# 根据反注入结果处理消息数据
await anti_injector.handle_message_storage(result, modified_content, reason or "", message_data)
if result == ProcessResult.BLOCKED_BAN:
# 用户被封禁 - 直接阻止回复生成
anti_injector_logger.warning(f"用户被反注入系统封禁,阻止回复生成: {reason}")
return None
elif result == ProcessResult.BLOCKED_INJECTION:
# 消息被阻止(危险内容等) - 直接阻止回复生成
anti_injector_logger.warning(f"消息被反注入系统阻止,阻止回复生成: {reason}")
return None
elif result == ProcessResult.COUNTER_ATTACK:
# 反击模式:生成反击消息作为回复
anti_injector_logger.info(f"反击模式启动,生成反击回复: {reason}")
if modified_content:
# 返回反击消息作为回复内容
return [("text", modified_content)]
else:
# 没有反击内容时阻止回复生成
return None
# 检查是否需要加盾处理
safety_prompt = None
if result == ProcessResult.SHIELDED:
# 获取安全系统提示词并注入
shield = anti_injector.shield
safety_prompt = shield.get_safety_system_prompt()
await Prompt.create_async(safety_prompt, "anti_injection_safety_prompt")
anti_injector_logger.info(f"消息已被反注入系统加盾处理,已注入安全提示词: {reason}")
# 处理被修改的消息内容(用于生成回复)
modified_reply_to = reply_to
if modified_content:
# 更新消息内容用于生成回复
anti_injector_logger.info(f"消息内容已被反注入系统修改,使用修改后内容生成回复: {reason}")
# 解析原始reply_to格式"发送者:消息内容"
if ":" in reply_to:
sender_part, _ = reply_to.split(":", 1)
modified_reply_to = f"{sender_part}:{modified_content}"
else:
# 如果格式不标准,直接使用修改后的内容
modified_reply_to = modified_content
# === 正常的回复生成流程 ===
success, reply_set, _ = await generator_api.generate_reply(
chat_stream=self.context.chat_stream,
reply_to=modified_reply_to, # 使用可能被修改的内容
available_actions=available_actions,
enable_tool=global_config.tool.enable_tool,
request_type=request_type,
from_plugin=False,
)
if not success or not reply_set:
logger.info(f"{message_data.get('processed_plain_text')} 的回复生成失败")
return None
return reply_set
except Exception as e:
logger.error(f"{self.context.log_prefix}回复生成出现错误:{str(e)} {traceback.format_exc()}")
return None

View File

@@ -0,0 +1,33 @@
import asyncio
from src.common.logger import get_logger
from ..hfc_context import HfcContext
logger = get_logger("notification_sender")
class NotificationSender:
@staticmethod
async def send_goodnight_notification(context: HfcContext):
"""发送晚安通知"""
try:
from ..proactive.events import ProactiveTriggerEvent
from ..proactive.proactive_thinker import ProactiveThinker
event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
await proactive_thinker.think(event)
except Exception as e:
logger.error(f"发送晚安通知失败: {e}")
@staticmethod
async def send_insomnia_notification(context: HfcContext, reason: str):
"""发送失眠通知"""
try:
from ..proactive.events import ProactiveTriggerEvent
from ..proactive.proactive_thinker import ProactiveThinker
event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
await proactive_thinker.think(event)
except Exception as e:
logger.error(f"发送失眠通知失败: {e}")

View File

@@ -0,0 +1,304 @@
import asyncio
import random
from datetime import datetime, timedelta, date
from typing import Optional, TYPE_CHECKING, List, Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
from .sleep_state import SleepState, SleepStateSerializer
from .time_checker import TimeChecker
from .notification_sender import NotificationSender
if TYPE_CHECKING:
from .wakeup_manager import WakeUpManager
logger = get_logger("sleep_manager")
class SleepManager:
"""
睡眠管理器,核心组件之一,负责管理角色的睡眠周期和状态转换。
它实现了一个状态机,根据预设的时间表、睡眠压力和随机因素,
在不同的睡眠状态(如清醒、准备入睡、睡眠、失眠)之间进行切换。
"""
def __init__(self):
"""
初始化睡眠管理器。
"""
self.time_checker = TimeChecker() # 时间检查器,用于判断当前是否处于理论睡眠时间
self.last_sleep_log_time = 0 # 上次记录睡眠日志的时间戳
self.sleep_log_interval = 35 # 睡眠日志记录间隔(秒)
# --- 统一睡眠状态管理 ---
self._current_state: SleepState = SleepState.AWAKE # 当前睡眠状态
self._sleep_buffer_end_time: Optional[datetime] = None # 睡眠缓冲结束时间,用于状态转换
self._total_delayed_minutes_today: float = 0.0 # 今天总共延迟入睡的分钟数
self._last_sleep_check_date: Optional[date] = None # 上次检查睡眠状态的日期
self._last_fully_slept_log_time: float = 0 # 上次完全进入睡眠状态的时间戳
self._re_sleep_attempt_time: Optional[datetime] = None # 被吵醒后,尝试重新入睡的时间点
# 从本地存储加载上一次的睡眠状态
self._load_sleep_state()
def get_current_sleep_state(self) -> SleepState:
"""获取当前的睡眠状态。"""
return self._current_state
def is_sleeping(self) -> bool:
"""判断当前是否处于正在睡觉的状态。"""
return self._current_state == SleepState.SLEEPING
async def update_sleep_state(self, wakeup_manager: Optional["WakeUpManager"] = None):
"""
更新睡眠状态的核心方法,实现状态机的主要逻辑。
该方法会被周期性调用,以检查并更新当前的睡眠状态。
Args:
wakeup_manager (Optional["WakeUpManager"]): 唤醒管理器,用于获取睡眠压力等上下文信息。
"""
# 如果全局禁用了睡眠系统,则强制设置为清醒状态并返回
if not global_config.sleep_system.enable:
if self._current_state != SleepState.AWAKE:
logger.debug("睡眠系统禁用,强制设为 AWAKE")
self._current_state = SleepState.AWAKE
return
now = datetime.now()
today = now.date()
# 跨天处理:如果日期变化,重置每日相关的睡眠状态
if self._last_sleep_check_date != today:
logger.info(f"新的一天 ({today}),重置睡眠状态。")
self._total_delayed_minutes_today = 0
self._current_state = SleepState.AWAKE
self._sleep_buffer_end_time = None
self._last_sleep_check_date = today
self._save_sleep_state()
# 检查当前是否处于理论上的睡眠时间段
is_in_theoretical_sleep, activity = self.time_checker.is_in_theoretical_sleep_time(now.time())
# --- 状态机核心处理逻辑 ---
if self._current_state == SleepState.AWAKE:
if is_in_theoretical_sleep:
self._handle_awake_to_sleep(now, activity, wakeup_manager)
elif self._current_state == SleepState.PREPARING_SLEEP:
self._handle_preparing_sleep(now, is_in_theoretical_sleep, wakeup_manager)
elif self._current_state == SleepState.SLEEPING:
self._handle_sleeping(now, is_in_theoretical_sleep, activity, wakeup_manager)
elif self._current_state == SleepState.INSOMNIA:
self._handle_insomnia(now, is_in_theoretical_sleep)
elif self._current_state == SleepState.WOKEN_UP:
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
"""处理从“清醒”到“准备入睡”的状态转换。"""
if activity:
logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...")
else:
logger.info("进入理论休眠时间,开始进行睡眠决策...")
if global_config.sleep_system.enable_flexible_sleep:
# --- 新的弹性睡眠逻辑 ---
if wakeup_manager:
sleep_pressure = wakeup_manager.context.sleep_pressure
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
max_delay_minutes = global_config.sleep_system.max_sleep_delay_minutes
buffer_seconds = 0
# 如果睡眠压力低于阈值,则计算延迟时间
if sleep_pressure <= pressure_threshold:
# 压力差,归一化到 (0, 1]
pressure_diff = (pressure_threshold - sleep_pressure) / pressure_threshold
# 延迟分钟数,压力越低,延迟越长
delay_minutes = int(pressure_diff * max_delay_minutes)
# 确保总延迟不超过当日最大值
remaining_delay = max_delay_minutes - self._total_delayed_minutes_today
delay_minutes = min(delay_minutes, remaining_delay)
if delay_minutes > 0:
# 增加一些随机性
buffer_seconds = random.randint(int(delay_minutes * 0.8 * 60), int(delay_minutes * 1.2 * 60))
self._total_delayed_minutes_today += buffer_seconds / 60.0
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 较低,延迟 {buffer_seconds / 60:.1f} 分钟入睡。")
else:
# 延迟额度已用完,设置一个较短的准备时间
buffer_seconds = random.randint(1 * 60, 2 * 60)
logger.info("今日延迟入睡额度已用完,进入短暂准备后入睡。")
else:
# 睡眠压力较高,设置一个较短的准备时间
buffer_seconds = random.randint(1 * 60, 2 * 60)
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 较高,将在短暂准备后入睡。")
# 发送睡前通知
if global_config.sleep_system.enable_pre_sleep_notification:
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
self._current_state = SleepState.PREPARING_SLEEP
logger.info(f"进入准备入睡状态,将在 {buffer_seconds / 60:.1f} 分钟内入睡。")
self._save_sleep_state()
else:
# 无法获取 wakeup_manager退回旧逻辑
buffer_seconds = random.randint(1 * 60, 3 * 60)
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
self._current_state = SleepState.PREPARING_SLEEP
logger.warning("无法获取 WakeUpManager弹性睡眠采用默认1-3分钟延迟。")
self._save_sleep_state()
else:
# 非弹性睡眠模式
if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification:
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
self._current_state = SleepState.SLEEPING
def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
"""处理“准备入睡”状态下的逻辑。"""
# 如果在准备期间离开了理论睡眠时间,则取消入睡
if not is_in_theoretical_sleep:
logger.info("准备入睡期间离开理论休眠时间,取消入睡,恢复清醒。")
self._current_state = SleepState.AWAKE
self._sleep_buffer_end_time = None
self._save_sleep_state()
# 如果缓冲时间结束,则正式进入睡眠状态
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
logger.info("睡眠缓冲期结束,正式进入休眠状态。")
self._current_state = SleepState.SLEEPING
self._last_fully_slept_log_time = now.timestamp()
# 设置一个随机的延迟,用于触发“睡后失眠”检查
delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes
delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1])
self._sleep_buffer_end_time = now + timedelta(minutes=delay_minutes)
logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。")
self._save_sleep_state()
def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
"""处理“正在睡觉”状态下的逻辑。"""
# 如果理论睡眠时间结束,则自然醒来
if not is_in_theoretical_sleep:
logger.info("理论休眠时间结束,自然醒来。")
self._current_state = SleepState.AWAKE
self._save_sleep_state()
# 检查是否到了触发“睡后失眠”的时间点
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
if wakeup_manager:
sleep_pressure = wakeup_manager.context.sleep_pressure
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
# 检查是否触发失眠
insomnia_reason = None
if sleep_pressure < pressure_threshold:
insomnia_reason = "low_pressure"
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 低于阈值 ({pressure_threshold}),触发睡后失眠。")
elif random.random() < getattr(global_config.sleep_system, "random_insomnia_chance", 0.1):
insomnia_reason = "random"
logger.info("随机触发失眠。")
if insomnia_reason:
self._current_state = SleepState.INSOMNIA
# 设置失眠的持续时间
duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes
duration_minutes = random.randint(*duration_minutes_range)
self._sleep_buffer_end_time = now + timedelta(minutes=duration_minutes)
# 发送失眠通知
asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason))
logger.info(f"进入失眠状态 (原因: {insomnia_reason}),将持续 {duration_minutes} 分钟。")
else:
# 睡眠压力正常,不触发失眠,清除检查时间点
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 正常,未触发睡后失眠。")
self._sleep_buffer_end_time = None
self._save_sleep_state()
else:
# 定期记录睡眠日志
current_timestamp = now.timestamp()
if current_timestamp - self.last_sleep_log_time > self.sleep_log_interval and activity:
logger.info(f"当前处于休眠活动 '{activity}' 中。")
self.last_sleep_log_time = current_timestamp
def _handle_insomnia(self, now: datetime, is_in_theoretical_sleep: bool):
"""处理“失眠”状态下的逻辑。"""
# 如果离开理论睡眠时间,则失眠结束
if not is_in_theoretical_sleep:
logger.info("已离开理论休眠时间,失眠结束,恢复清醒。")
self._current_state = SleepState.AWAKE
self._sleep_buffer_end_time = None
self._save_sleep_state()
# 如果失眠持续时间已过,则恢复睡眠
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
logger.info("失眠状态持续时间已过,恢复睡眠。")
self._current_state = SleepState.SLEEPING
self._sleep_buffer_end_time = None
self._save_sleep_state()
def _handle_woken_up(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
"""处理“被吵醒”状态下的逻辑。"""
# 如果理论睡眠时间结束,则状态自动结束
if not is_in_theoretical_sleep:
logger.info("理论休眠时间结束,被吵醒的状态自动结束。")
self._current_state = SleepState.AWAKE
self._re_sleep_attempt_time = None
self._save_sleep_state()
# 到了尝试重新入睡的时间点
elif self._re_sleep_attempt_time and now >= self._re_sleep_attempt_time:
logger.info("被吵醒后经过一段时间,尝试重新入睡...")
if wakeup_manager:
sleep_pressure = wakeup_manager.context.sleep_pressure
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
# 如果睡眠压力足够,则尝试重新入睡
if sleep_pressure >= pressure_threshold:
logger.info("睡眠压力足够,从被吵醒状态转换到准备入睡。")
buffer_seconds = random.randint(3 * 60, 8 * 60)
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
self._current_state = SleepState.PREPARING_SLEEP
self._re_sleep_attempt_time = None
else:
# 睡眠压力不足,延迟一段时间后再次尝试
delay_minutes = 15
self._re_sleep_attempt_time = now + timedelta(minutes=delay_minutes)
logger.info(
f"睡眠压力({sleep_pressure:.1f})仍然较低,暂时保持清醒,在 {delay_minutes} 分钟后再次尝试。"
)
self._save_sleep_state()
def reset_sleep_state_after_wakeup(self):
"""
当角色被用户消息等外部因素唤醒时调用此方法。
将状态强制转换为 WOKEN_UP并设置一个延迟之后会尝试重新入睡。
"""
if self._current_state in [SleepState.PREPARING_SLEEP, SleepState.SLEEPING, SleepState.INSOMNIA]:
logger.info("被唤醒,进入 WOKEN_UP 状态!")
self._current_state = SleepState.WOKEN_UP
self._sleep_buffer_end_time = None
re_sleep_delay_minutes = getattr(global_config.sleep_system, "re_sleep_delay_minutes", 10)
self._re_sleep_attempt_time = datetime.now() + timedelta(minutes=re_sleep_delay_minutes)
logger.info(f"将在 {re_sleep_delay_minutes} 分钟后尝试重新入睡。")
self._save_sleep_state()
def _save_sleep_state(self):
"""将当前所有睡眠相关的状态打包并保存到本地存储。"""
state_data = {
"_current_state": self._current_state,
"_sleep_buffer_end_time": self._sleep_buffer_end_time,
"_total_delayed_minutes_today": self._total_delayed_minutes_today,
"_last_sleep_check_date": self._last_sleep_check_date,
"_re_sleep_attempt_time": self._re_sleep_attempt_time,
}
SleepStateSerializer.save(state_data)
def _load_sleep_state(self):
"""从本地存储加载并恢复所有睡眠相关的状态。"""
state_data = SleepStateSerializer.load()
self._current_state = state_data["_current_state"]
self._sleep_buffer_end_time = state_data["_sleep_buffer_end_time"]
self._total_delayed_minutes_today = state_data["_total_delayed_minutes_today"]
self._last_sleep_check_date = state_data["_last_sleep_check_date"]
self._re_sleep_attempt_time = state_data["_re_sleep_attempt_time"]

View File

@@ -0,0 +1,110 @@
from enum import Enum, auto
from datetime import datetime
from src.common.logger import get_logger
from src.manager.local_store_manager import local_storage
logger = get_logger("sleep_state")
class SleepState(Enum):
"""
定义了角色可能处于的几种睡眠状态。
这是一个状态机,用于管理角色的睡眠周期。
"""
AWAKE = auto() # 清醒状态
INSOMNIA = auto() # 失眠状态
PREPARING_SLEEP = auto() # 准备入睡状态,一个短暂的过渡期
SLEEPING = auto() # 正在睡觉状态
WOKEN_UP = auto() # 被吵醒状态
class SleepStateSerializer:
"""
睡眠状态序列化器。
负责将内存中的睡眠状态对象持久化到本地存储如JSON文件
以及在程序启动时从本地存储中恢复状态。
这样可以确保即使程序重启,角色的睡眠状态也能得以保留。
"""
@staticmethod
def save(state_data: dict):
"""
将当前的睡眠状态数据保存到本地存储。
Args:
state_data (dict): 包含睡眠状态信息的字典。
datetime对象会被转换为时间戳Enum成员会被转换为其名称字符串。
"""
try:
# 准备要序列化的数据字典
state = {
# 保存当前状态的枚举名称
"current_state": state_data["_current_state"].name,
# 将datetime对象转换为Unix时间戳以便序列化
"sleep_buffer_end_time_ts": state_data["_sleep_buffer_end_time"].timestamp()
if state_data["_sleep_buffer_end_time"]
else None,
"total_delayed_minutes_today": state_data["_total_delayed_minutes_today"],
# 将date对象转换为ISO格式的字符串
"last_sleep_check_date_str": state_data["_last_sleep_check_date"].isoformat()
if state_data["_last_sleep_check_date"]
else None,
"re_sleep_attempt_time_ts": state_data["_re_sleep_attempt_time"].timestamp()
if state_data["_re_sleep_attempt_time"]
else None,
}
# 写入本地存储
local_storage["schedule_sleep_state"] = state
logger.debug(f"已保存睡眠状态: {state}")
except Exception as e:
logger.error(f"保存睡眠状态失败: {e}")
@staticmethod
def load() -> dict:
"""
从本地存储加载并解析睡眠状态。
Returns:
dict: 包含恢复后睡眠状态信息的字典。
如果加载失败或没有找到数据,则返回一个默认的清醒状态。
"""
# 定义一个默认的状态,以防加载失败
state_data = {
"_current_state": SleepState.AWAKE,
"_sleep_buffer_end_time": None,
"_total_delayed_minutes_today": 0,
"_last_sleep_check_date": None,
"_re_sleep_attempt_time": None,
}
try:
# 从本地存储读取数据
state = local_storage["schedule_sleep_state"]
if state and isinstance(state, dict):
# 恢复当前状态枚举
state_name = state.get("current_state")
if state_name and hasattr(SleepState, state_name):
state_data["_current_state"] = SleepState[state_name]
# 从时间戳恢复datetime对象
end_time_ts = state.get("sleep_buffer_end_time_ts")
if end_time_ts:
state_data["_sleep_buffer_end_time"] = datetime.fromtimestamp(end_time_ts)
# 恢复重新入睡尝试时间
re_sleep_ts = state.get("re_sleep_attempt_time_ts")
if re_sleep_ts:
state_data["_re_sleep_attempt_time"] = datetime.fromtimestamp(re_sleep_ts)
# 恢复今日延迟睡眠总分钟数
state_data["_total_delayed_minutes_today"] = state.get("total_delayed_minutes_today", 0)
# 从ISO格式字符串恢复date对象
date_str = state.get("last_sleep_check_date_str")
if date_str:
state_data["_last_sleep_check_date"] = datetime.fromisoformat(date_str).date()
logger.info(f"成功从本地存储加载睡眠状态: {state}")
except Exception as e:
# 如果加载过程中出现任何问题,记录警告并返回默认状态
logger.warning(f"加载睡眠状态失败,将使用默认值: {e}")
return state_data

View File

@@ -0,0 +1,108 @@
from datetime import datetime, time, timedelta
from typing import Optional, List, Dict, Any
import random
from src.common.logger import get_logger
from src.config.config import global_config
from src.schedule.schedule_manager import schedule_manager
logger = get_logger("time_checker")
class TimeChecker:
def __init__(self):
# 缓存当天的偏移量,确保一天内使用相同的偏移量
self._daily_sleep_offset: int = 0
self._daily_wake_offset: int = 0
self._offset_date = None
def _get_daily_offsets(self):
"""获取当天的睡眠和起床时间偏移量,每天生成一次"""
today = datetime.now().date()
# 如果是新的一天,重新生成偏移量
if self._offset_date != today:
sleep_offset_range = global_config.sleep_system.sleep_time_offset_minutes
wake_offset_range = global_config.sleep_system.wake_up_time_offset_minutes
# 生成 ±offset_range 范围内的随机偏移量
self._daily_sleep_offset = random.randint(-sleep_offset_range, sleep_offset_range)
self._daily_wake_offset = random.randint(-wake_offset_range, wake_offset_range)
self._offset_date = today
logger.debug(f"生成新的每日偏移量 - 睡觉时间偏移: {self._daily_sleep_offset}分钟, 起床时间偏移: {self._daily_wake_offset}分钟")
return self._daily_sleep_offset, self._daily_wake_offset
def get_today_schedule(self) -> Optional[List[Dict[str, Any]]]:
"""从全局 ScheduleManager 获取今天的日程安排。"""
return schedule_manager.today_schedule
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
if global_config.sleep_system.sleep_by_schedule:
if self.get_today_schedule():
return self._is_in_schedule_sleep_time(now_time)
else:
return self._is_in_sleep_time(now_time)
else:
return self._is_in_sleep_time(now_time)
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
"""检查当前时间是否落在日程表的任何一个睡眠活动中"""
sleep_keywords = ["休眠", "睡觉", "梦乡"]
today_schedule = self.get_today_schedule()
if today_schedule:
for event in today_schedule:
try:
activity = event.get("activity", "").strip()
time_range = event.get("time_range")
if not activity or not time_range:
continue
if any(keyword in activity for keyword in sleep_keywords):
start_str, end_str = time_range.split("-")
start_time = datetime.strptime(start_str.strip(), "%H:%M").time()
end_time = datetime.strptime(end_str.strip(), "%H:%M").time()
if start_time <= end_time: # 同一天
if start_time <= now_time < end_time:
return True, activity
else: # 跨天
if now_time >= start_time or now_time < end_time:
return True, activity
except (ValueError, KeyError, AttributeError) as e:
logger.warning(f"解析日程事件时出错: {event}, 错误: {e}")
continue
return False, None
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
"""检查当前时间是否在固定的睡眠时间内(应用偏移量)"""
try:
start_time_str = global_config.sleep_system.fixed_sleep_time
end_time_str = global_config.sleep_system.fixed_wake_up_time
# 获取当天的偏移量
sleep_offset, wake_offset = self._get_daily_offsets()
# 解析基础时间
base_start_time = datetime.strptime(start_time_str, "%H:%M")
base_end_time = datetime.strptime(end_time_str, "%H:%M")
# 应用偏移量
actual_start_time = (base_start_time + timedelta(minutes=sleep_offset)).time()
actual_end_time = (base_end_time + timedelta(minutes=wake_offset)).time()
logger.debug(f"固定睡眠时间检查 - 基础时间: {start_time_str}-{end_time_str}, "
f"偏移后时间: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')}, "
f"当前时间: {now_time.strftime('%H:%M')}")
if actual_start_time <= actual_end_time:
if actual_start_time <= now_time < actual_end_time:
return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})"
else:
if now_time >= actual_start_time or now_time < actual_end_time:
return True, f"固定睡眠时间(偏移后: {actual_start_time.strftime('%H:%M')}-{actual_end_time.strftime('%H:%M')})"
except ValueError as e:
logger.error(f"固定的睡眠时间格式不正确,请使用 HH:MM 格式: {e}")
return False, None

View File

@@ -4,7 +4,7 @@ from typing import Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.manager.local_store_manager import local_storage
from .hfc_context import HfcContext
from ..hfc_context import HfcContext
logger = get_logger("wakeup")
@@ -138,10 +138,13 @@ class WakeUpManager:
return False
# 只有在休眠且非失眠状态下才累积唤醒度
from src.schedule.schedule_manager import schedule_manager
from src.schedule.sleep_manager import SleepState
from .sleep_state import SleepState
current_sleep_state = schedule_manager.get_current_sleep_state()
sleep_manager = self.context.sleep_manager
if not sleep_manager:
return False
current_sleep_state = sleep_manager.get_current_sleep_state()
if current_sleep_state != SleepState.SLEEPING:
return False
@@ -191,10 +194,9 @@ class WakeUpManager:
mood_manager.set_angry_from_wakeup(self.context.stream_id)
# 通知日程管理器重置睡眠状态
from src.schedule.schedule_manager import schedule_manager
schedule_manager.reset_sleep_state_after_wakeup()
# 通知SleepManager重置睡眠状态
if self.context.sleep_manager:
self.context.sleep_manager.reset_sleep_state_after_wakeup()
logger.info(f"{self.context.log_prefix} 唤醒度达到阈值({self.wakeup_threshold}),被吵醒进入愤怒状态!")

View File

@@ -13,7 +13,7 @@ from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -11,7 +11,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
logger = get_logger("expression_selector")

View File

@@ -0,0 +1,144 @@
"""
Chat Frequency Analyzer
=======================
本模块负责分析用户的聊天时间戳,以识别出他们最活跃的聊天时段(高峰时段)。
核心功能:
- 使用滑动窗口算法来检测时间戳集中的区域。
- 提供接口查询指定用户当前是否处于其聊天高峰时段内。
- 结果会被缓存以提高性能。
可配置参数:
- ANALYSIS_WINDOW_HOURS: 用于分析的时间窗口大小(小时)。
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
"""
import time as time_module
from datetime import datetime, timedelta, time
from typing import List, Tuple, Optional
from .tracker import chat_frequency_tracker
# --- 可配置参数 ---
# 用于分析的时间窗口大小(小时)
ANALYSIS_WINDOW_HOURS = 2
# 触发高峰时段所需的最小聊天次数
MIN_CHATS_FOR_PEAK = 4
# 两个独立高峰时段之间的最小间隔(小时)
MIN_GAP_BETWEEN_PEAKS_HOURS = 1
class ChatFrequencyAnalyzer:
"""
分析聊天时间戳,以识别用户的高频聊天时段。
"""
def __init__(self):
# 缓存分析结果,避免重复计算
# 格式: { "chat_id": (timestamp_of_analysis, [peak_windows]) }
self._analysis_cache: dict[str, tuple[float, list[tuple[time, time]]]] = {}
self._cache_ttl_seconds = 60 * 30 # 缓存30分钟
def _find_peak_windows(self, timestamps: List[float]) -> List[Tuple[datetime, datetime]]:
"""
使用滑动窗口算法来识别时间戳列表中的高峰时段。
Args:
timestamps (List[float]): 按时间排序的聊天时间戳。
Returns:
List[Tuple[datetime, datetime]]: 识别出的高峰时段列表,每个元组代表一个时间窗口的开始和结束。
"""
if len(timestamps) < MIN_CHATS_FOR_PEAK:
return []
# 将时间戳转换为 datetime 对象
datetimes = [datetime.fromtimestamp(ts) for ts in timestamps]
datetimes.sort()
peak_windows: List[Tuple[datetime, datetime]] = []
window_start_idx = 0
for i in range(len(datetimes)):
# 移动窗口的起始点
while datetimes[i] - datetimes[window_start_idx] > timedelta(hours=ANALYSIS_WINDOW_HOURS):
window_start_idx += 1
# 检查当前窗口是否满足高峰条件
if i - window_start_idx + 1 >= MIN_CHATS_FOR_PEAK:
current_window_start = datetimes[window_start_idx]
current_window_end = datetimes[i]
# 合并重叠或相邻的高峰时段
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS):
# 扩展上一个窗口的结束时间
peak_windows[-1] = (peak_windows[-1][0], current_window_end)
else:
peak_windows.append((current_window_start, current_window_end))
return peak_windows
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
"""
获取指定用户的高峰聊天时间段。
Args:
chat_id (str): 聊天标识符。
Returns:
List[Tuple[time, time]]: 高峰时段的列表,每个元组包含开始和结束时间 (time 对象)。
"""
# 检查缓存
cached_timestamp, cached_windows = self._analysis_cache.get(chat_id, (0, []))
if time_module.time() - cached_timestamp < self._cache_ttl_seconds:
return cached_windows
timestamps = chat_frequency_tracker.get_timestamps_for_chat(chat_id)
if not timestamps:
return []
peak_datetime_windows = self._find_peak_windows(timestamps)
# 将 datetime 窗口转换为 time 窗口,并进行归一化处理
peak_time_windows = []
for start_dt, end_dt in peak_datetime_windows:
# TODO:这里可以添加更复杂的逻辑来处理跨天的平均时间
# 为简化,我们直接使用窗口的起止时间
peak_time_windows.append((start_dt.time(), end_dt.time()))
# 更新缓存
self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows)
return peak_time_windows
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
"""
检查当前时间是否处于用户的高峰聊天时段内。
Args:
chat_id (str): 聊天标识符。
now (Optional[datetime]): 要检查的时间,默认为当前时间。
Returns:
bool: 如果处于高峰时段则返回 True否则返回 False。
"""
if now is None:
now = datetime.now()
now_time = now.time()
peak_times = self.get_peak_chat_times(chat_id)
for start_time, end_time in peak_times:
if start_time <= end_time: # 同一天
if start_time <= now_time <= end_time:
return True
else: # 跨天
if now_time >= start_time or now_time <= end_time:
return True
return False
# 创建一个全局单例
chat_frequency_analyzer = ChatFrequencyAnalyzer()

View File

@@ -0,0 +1,77 @@
import orjson
import time
from typing import Dict, List, Optional
from pathlib import Path
from src.common.logger import get_logger
# 数据存储路径
DATA_DIR = Path("data/frequency_analyzer")
DATA_DIR.mkdir(parents=True, exist_ok=True)
TRACKER_FILE = DATA_DIR / "chat_timestamps.json"
logger = get_logger("ChatFrequencyTracker")
class ChatFrequencyTracker:
"""
负责跟踪和存储用户聊天启动时间戳。
"""
def __init__(self):
self._timestamps: Dict[str, List[float]] = self._load_timestamps()
def _load_timestamps(self) -> Dict[str, List[float]]:
"""从本地文件加载时间戳数据。"""
if not TRACKER_FILE.exists():
return {}
try:
with open(TRACKER_FILE, "rb") as f:
data = orjson.loads(f.read())
logger.info(f"成功从 {TRACKER_FILE} 加载了聊天时间戳数据。")
return data
except orjson.JSONDecodeError:
logger.warning(f"无法解析 {TRACKER_FILE},将创建一个新的空数据文件。")
return {}
except Exception as e:
logger.error(f"加载聊天时间戳数据时发生未知错误: {e}")
return {}
def _save_timestamps(self):
"""将当前的时间戳数据保存到本地文件。"""
try:
with open(TRACKER_FILE, "wb") as f:
f.write(orjson.dumps(self._timestamps))
except Exception as e:
logger.error(f"保存聊天时间戳数据到 {TRACKER_FILE} 时失败: {e}")
def record_chat_start(self, chat_id: str):
"""
记录一次聊天会话的开始。
Args:
chat_id (str): 唯一的聊天标识符 (例如用户ID)。
"""
now = time.time()
if chat_id not in self._timestamps:
self._timestamps[chat_id] = []
self._timestamps[chat_id].append(now)
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
self._save_timestamps()
def get_timestamps_for_chat(self, chat_id: str) -> Optional[List[float]]:
"""
获取指定聊天的所有时间戳记录。
Args:
chat_id (str): 聊天标识符。
Returns:
Optional[List[float]]: 时间戳列表,如果不存在则返回 None。
"""
return self._timestamps.get(chat_id)
# 创建一个全局单例
chat_frequency_tracker = ChatFrequencyTracker()

View File

@@ -0,0 +1,119 @@
"""
Frequency-Based Proactive Trigger
=================================
本模块实现了一个周期性任务,用于根据用户的聊天频率来智能地触发主动思考。
核心功能:
- 定期运行,检查所有已知的私聊用户。
- 调用 ChatFrequencyAnalyzer 判断当前是否处于用户的高峰聊天时段。
- 如果满足条件(高峰时段、角色清醒、聊天循环空闲),则触发一次主动思考。
- 包含冷却机制,以避免在同一个高峰时段内重复打扰用户。
可配置参数:
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
"""
import asyncio
import time
from datetime import datetime
from typing import Dict, Optional
from src.common.logger import get_logger
from src.chat.chat_loop.proactive.events import ProactiveTriggerEvent
from src.chat.heart_flow.heartflow import heartflow
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager
from .analyzer import chat_frequency_analyzer
logger = get_logger("FrequencyBasedTrigger")
# --- 可配置参数 ---
# 触发器检查周期(秒)
TRIGGER_CHECK_INTERVAL_SECONDS = 60 * 5 # 5分钟
# 冷却时间(小时),确保在一个高峰时段只触发一次
COOLDOWN_HOURS = 3
class FrequencyBasedTrigger:
"""
一个周期性任务,根据聊天频率分析结果来触发主动思考。
"""
def __init__(self, sleep_manager: SleepManager):
self._sleep_manager = sleep_manager
self._task: Optional[asyncio.Task] = None
# 记录上次为用户触发的时间,用于冷却控制
# 格式: { "chat_id": timestamp }
self._last_triggered: Dict[str, float] = {}
async def _run_trigger_cycle(self):
"""触发器的主要循环逻辑。"""
while True:
try:
await asyncio.sleep(TRIGGER_CHECK_INTERVAL_SECONDS)
logger.debug("开始执行频率触发器检查...")
# 1. 检查角色是否清醒
if self._sleep_manager.is_sleeping():
logger.debug("角色正在睡眠,跳过本次频率触发检查。")
continue
# 2. 获取所有已知的聊天ID
# 【注意】这里我们假设所有 subheartflow 的 ID 就是 chat_id
all_chat_ids = list(heartflow.subheartflows.keys())
if not all_chat_ids:
continue
now = datetime.now()
for chat_id in all_chat_ids:
# 3. 检查是否处于冷却时间内
last_triggered_time = self._last_triggered.get(chat_id, 0)
if time.time() - last_triggered_time < COOLDOWN_HOURS * 3600:
continue
# 4. 检查当前是否是该用户的高峰聊天时间
if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
sub_heartflow = await heartflow.get_or_create_subheartflow(chat_id)
if not sub_heartflow:
logger.warning(f"无法为 {chat_id} 获取或创建 sub_heartflow。")
continue
# 5. 检查用户当前是否已有活跃的思考或回复任务
cycle_detail = sub_heartflow.heart_fc_instance.context.current_cycle_detail
if cycle_detail and not cycle_detail.end_time:
logger.debug(f"用户 {chat_id} 的聊天循环正忙(仍在周期 {cycle_detail.cycle_id} 中),本次不触发。")
continue
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且聊天循环空闲,准备触发主动思考。")
# 6. 直接调用 proactive_thinker
event = ProactiveTriggerEvent(
source="frequency_analyzer",
reason=f"User is in a high-frequency chat period."
)
await sub_heartflow.heart_fc_instance.proactive_thinker.think(event)
# 7. 更新触发时间,进入冷却
self._last_triggered[chat_id] = time.time()
except asyncio.CancelledError:
logger.info("频率触发器任务被取消。")
break
except Exception as e:
logger.error(f"频率触发器循环发生未知错误: {e}", exc_info=True)
# 发生错误后,等待更长时间再重试,避免刷屏
await asyncio.sleep(TRIGGER_CHECK_INTERVAL_SECONDS * 2)
def start(self):
"""启动触发器任务。"""
if self._task is None or self._task.done():
self._task = asyncio.create_task(self._run_trigger_cycle())
logger.info("基于聊天频率的主动思考触发器已启动。")
def stop(self):
"""停止触发器任务。"""
if self._task and not self._task.done():
self._task.cancel()
logger.info("基于聊天频率的主动思考触发器已停止。")

View File

@@ -928,6 +928,7 @@ class EntorhinalCortex:
"concept": concept,
"memory_items": memory_items_json,
"hash": memory_hash,
"weight": 1.0, # 默认权重为1.0
"created_time": created_time,
"last_modified": last_modified,
}
@@ -1084,6 +1085,7 @@ class EntorhinalCortex:
"concept": concept,
"memory_items": memory_items_json,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"weight": 1.0, # 默认权重为1.0
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
}

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager

View File

@@ -12,7 +12,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor

View File

@@ -85,6 +85,7 @@ class ChatStream:
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
self.focus_energy = 1
self.no_reply_consecutive = 0
self.breaking_accumulated_interest = 0.0
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -97,6 +98,7 @@ class ChatStream:
"last_active_time": self.last_active_time,
"energy_value": self.energy_value,
"sleep_pressure": self.sleep_pressure,
"breaking_accumulated_interest": self.breaking_accumulated_interest,
}
@classmethod
@@ -257,7 +259,7 @@ class ChatManager:
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
if model_instance and getattr(model_instance, "group_id", None):
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
@@ -403,7 +405,7 @@ class ChatManager:
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
if model_instance and getattr(model_instance, "group_id", None):
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,

View File

@@ -162,9 +162,7 @@ class ActionModifier:
available_actions = list(self.action_manager.get_using_actions().keys())
available_actions_text = "".join(available_actions) if available_actions else ""
logger.info(
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
)
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
type_mismatched_actions: List[Tuple[str, str]] = []

View File

@@ -1,7 +1,11 @@
import orjson
import time
import traceback
from typing import Dict, Any, Optional, Tuple, List
import asyncio
import math
import random
import json
from typing import Dict, Any, Optional, Tuple, List, TYPE_CHECKING
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
@@ -9,7 +13,7 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
get_actions_by_timestamp_with_chat,
@@ -19,12 +23,20 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
from src.plugin_system.base.component_types import (
ActionInfo,
ChatMode,
ComponentType,
ActionActivationType,
)
from src.plugin_system.core.component_registry import component_registry
from src.schedule.schedule_manager import schedule_manager
from src.mood.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager
if TYPE_CHECKING:
pass
logger = get_logger("planner")
install(extra_lines=3)
@@ -83,12 +95,34 @@ def init_prompt():
## 长期记忆摘要
{long_term_memory_block}
## 最近的聊天内容
{chat_content_block}
## 最近的动作历史
{actions_before_now_block}
## 任务
基于以上所有信息,分析当前情况,决定是否需要主动做些什么
如果你认为不需要,就选择 'do_nothing'
基于以上所有信息(特别是最近的聊天内容),分析当前情况,决定是否适合主动开启一个**新的、但又与当前氛围相关**的话题
## 可用动作
{action_options_text}
动作proactive_reply
动作描述:在当前对话的基础上,主动发起一个新的对话,分享一个有趣的想法、见闻或者对未来的计划。
- 当你觉得可以说些什么来活跃气氛,并且内容与当前聊天氛围不冲突时
- 当你有一些新的想法或计划想要分享,并且可以自然地衔接当前话题时
{{
"action": "proactive_reply",
"reason": "决定主动发起对话的具体原因",
"topic": "你想要发起对话的主题或内容(需要简洁)"
}}
动作do_nothing
动作描述:保持沉默,不主动发起任何动作或对话。
- 当你分析了所有信息后,觉得当前不是一个发起互动的好时机时
- 当最近的聊天内容很连贯,你的插入会打断别人时
{{
"action": "do_nothing",
"reason":"决定保持沉默的具体原因"
}}
你必须从上面列出的可用action中选择一个。
请以严格的 JSON 格式输出,且仅包含 JSON 内容:
@@ -110,6 +144,37 @@ def init_prompt():
"action_prompt",
)
Prompt(
"""
{name_block}
{chat_context_description}{time_block}现在请你根据以下聊天内容选择一个或多个合适的action。如果没有合适的action请选择no_action。,
{chat_content_block}
**要求**
1.action必须符合使用条件如果符合条件就选择
2.如果聊天内容不适合使用action即使符合条件也不要使用
3.{moderation_prompt}
4.请注意如果相同的内容已经被执行,请不要重复执行
这是你最近执行过的动作:
{actions_before_now_block}
**可用的action**
no_action不选择任何动作
{{
"action": "no_action",
"reason":"不动作的原因"
}}
{action_options_text}
请选择并说明触发action的消息id和选择该action的原因。消息id格式:m+数字
请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容:
""",
"sub_planner_prompt",
)
class ActionPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager):
@@ -117,14 +182,16 @@ class ActionPlanner:
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
self.action_manager = action_manager
# LLM规划器配置
# --- 大脑 ---
self.planner_llm = LLMRequest(
model_set=model_config.model_task_config.planner, request_type="planner"
) # 用于动作规划
)
# --- 小脑 (新增) ---
self.planner_small_llm = LLMRequest(
model_set=model_config.model_task_config.planner_small, request_type="planner_small"
)
self.last_obs_time_mark = 0.0
# 添加重试计数器
self.plan_retry_count = 0
self.max_plan_retries = 3
async def _get_long_term_memory_context(self) -> str:
"""
@@ -171,32 +238,18 @@ class ActionPlanner:
构建动作选项
"""
action_options_block = ""
if mode == ChatMode.PROACTIVE:
action_options_block += """动作do_nothing
动作描述:保持沉默,不主动发起任何动作或对话。
- 当你分析了所有信息后,觉得当前不是一个发起互动的好时机时
{{
"action": "do_nothing",
"reason":"决定保持沉默的具体原因"
}}
"""
for action_name, action_info in current_available_actions.items():
# TODO: 增加一个字段来判断action是否支持在PROACTIVE模式下使用
param_text = ""
if action_info.action_parameters:
param_text = "\n" + "\n".join(
f' "{p_name}":"{p_desc}"'
for p_name, p_desc in action_info.action_parameters.items()
f' "{p_name}":"{p_desc}"' for p_name, p_desc in action_info.action_parameters.items()
)
require_text = "\n".join(f"- {req}" for req in action_info.action_require)
using_action_prompt = await global_prompt_manager.get_prompt_async(
"action_prompt"
)
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
action_options_block += using_action_prompt.format(
action_name=action_name,
action_description=action_info.description,
@@ -205,9 +258,7 @@ class ActionPlanner:
)
return action_options_block
def find_message_by_id(
self, message_id: str, message_id_list: list
) -> Optional[Dict[str, Any]]:
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
# sourcery skip: use-next
"""
根据message_id从message_id_list中查找对应的原始消息
@@ -242,6 +293,168 @@ class ActionPlanner:
# 假设消息列表是按时间顺序排列的,最后一个是最新的
return message_id_list[-1].get("message")
def _parse_single_action(
self,
action_json: dict,
message_id_list: list, # 使用 planner.py 的 list of dict
current_available_actions: list, # 使用 planner.py 的 list of tuple
) -> List[Dict[str, Any]]:
"""
[注释] 解析单个小脑LLM返回的action JSON并将其转换为标准化的字典。
"""
parsed_actions = []
try:
action = action_json.get("action", "no_action")
reasoning = action_json.get("reason", "未提供原因")
action_data = {k: v for k, v in action_json.items() if k not in ["action", "reason"]}
target_message = None
if action != "no_action":
if target_message_id := action_json.get("target_message_id"):
target_message = self.find_message_by_id(target_message_id, message_id_list)
if target_message is None:
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}'")
target_message = self.get_latest_message(message_id_list)
else:
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
available_action_names = [name for name, _ in current_available_actions]
if action not in ["no_action", "reply"] and action not in available_action_names:
logger.warning(
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_action'"
)
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
action = "no_action"
# 将列表转换为字典格式以供将来使用
available_actions_dict = dict(current_available_actions)
parsed_actions.append(
{
"action_type": action,
"reasoning": reasoning,
"action_data": action_data,
"action_message": target_message,
"available_actions": available_actions_dict,
}
)
except Exception as e:
logger.error(f"{self.log_prefix}解析单个action时出错: {e}")
parsed_actions.append(
{
"action_type": "no_action",
"reasoning": f"解析action时出错: {e}",
"action_data": {},
"action_message": None,
"available_actions": dict(current_available_actions),
}
)
return parsed_actions
def _filter_no_actions(self, action_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
[注释] 从一个action字典列表中过滤掉所有的 'no_action'
如果过滤后列表为空, 则返回一个空的列表, 或者根据需要返回一个默认的no_action字典。
"""
non_no_actions = [a for a in action_list if a.get("action_type") not in ["no_action", "no_reply"]]
if non_no_actions:
return non_no_actions
# 如果都是 no_action则返回一个包含第一个 no_action 的列表,以保留 reason
return action_list[:1] if action_list else []
async def sub_plan(
self,
action_list: list, # 使用 planner.py 的 list of tuple
chat_content_block: str,
message_id_list: list, # 使用 planner.py 的 list of dict
is_group_chat: bool = False,
chat_target_info: Optional[dict] = None,
) -> List[Dict[str, Any]]:
"""
[注释] "小脑"规划器。接收一小组actions使用轻量级LLM判断其中哪些应该被触发。
这是一个独立的、并行的思考单元。返回一个包含action字典的列表。
"""
try:
actions_before_now = get_actions_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=time.time() - 1200,
timestamp_end=time.time(),
limit=20,
)
action_names_in_list = [name for name, _ in action_list]
filtered_actions = [
record for record in actions_before_now if record.get("action_name") in action_names_in_list
]
actions_before_now_block = build_readable_actions(actions=filtered_actions)
chat_context_description = "你现在正在一个群聊中"
if not is_group_chat and chat_target_info:
chat_target_name = chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
chat_context_description = f"你正在和 {chat_target_name} 私聊"
action_options_block = ""
for using_actions_name, using_actions_info in action_list:
param_text = ""
if using_actions_info.action_parameters:
param_text = "\n" + "\n".join(
f' "{p_name}":"{p_desc}"'
for p_name, p_desc in using_actions_info.action_parameters.items()
)
require_text = "\n".join(f"- {req}" for req in using_actions_info.action_require)
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
action_options_block += using_action_prompt.format(
action_name=using_actions_name,
action_description=using_actions_info.description,
action_parameters=param_text,
action_require=require_text,
)
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
bot_name = global_config.bot.nickname
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
planner_prompt_template = await global_prompt_manager.get_prompt_async("sub_planner_prompt")
prompt = planner_prompt_template.format(
time_block=time_block,
chat_context_description=chat_context_description,
chat_content_block=chat_content_block,
actions_before_now_block=actions_before_now_block,
action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block,
name_block=name_block,
)
except Exception as e:
logger.error(f"构建小脑提示词时出错: {e}\n{traceback.format_exc()}")
return [{"action_type": "no_action", "reasoning": f"构建小脑Prompt时出错: {e}"}]
action_dicts: List[Dict[str, Any]] = []
try:
llm_content, (reasoning_content, _, _) = await self.planner_small_llm.generate_response_async(prompt=prompt)
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix}小脑原始提示词: {prompt}")
logger.info(f"{self.log_prefix}小脑原始响应: {llm_content}")
else:
logger.debug(f"{self.log_prefix}小脑原始响应: {llm_content}")
if llm_content:
parsed_json = orjson.loads(repair_json(llm_content))
if isinstance(parsed_json, list):
for item in parsed_json:
if isinstance(item, dict):
action_dicts.extend(self._parse_single_action(item, message_id_list, action_list))
elif isinstance(parsed_json, dict):
action_dicts.extend(self._parse_single_action(parsed_json, message_id_list, action_list))
except Exception as e:
logger.warning(f"{self.log_prefix}解析小脑响应JSON失败: {e}. LLM原始输出: '{llm_content}'")
action_dicts.append({"action_type": "no_action", "reasoning": f"解析小脑响应失败: {e}"})
if not action_dicts:
action_dicts.append({"action_type": "no_action", "reasoning": "小脑未返回有效action"})
return action_dicts
async def plan(
self,
mode: ChatMode = ChatMode.FOCUS,
@@ -249,172 +462,201 @@ class ActionPlanner:
available_actions: Optional[Dict[str, ActionInfo]] = None,
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作
[注释] "大脑"规划器
1. 启动多个并行的"小脑"(sub_plan)来决定是否执行具体的actions。
2. 自己(大脑)则专注于决定是否进行聊天回复(reply)。
3. 整合大脑和小脑的决策,返回最终要执行的动作列表。
"""
action = "no_reply" # 默认动作
reasoning = "规划器初始化默认"
action_data = {}
current_available_actions: Dict[str, ActionInfo] = {}
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
prompt: str = ""
message_id_list: list = []
try:
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
prompt, message_id_list = await self.build_planner_prompt(
is_group_chat=is_group_chat, # <-- Pass HFC state
chat_target_info=chat_target_info, # <-- 传递获取到的聊天目标信息
current_available_actions=current_available_actions, # <-- Pass determined actions
mode=mode,
# --- 1. 准备上下文信息 ---
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
)
# 大脑使用较长的上下文
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal",
read_mark=self.last_obs_time_mark,
truncate=True,
show_actions=True,
)
# 小脑使用较短、较新的上下文
message_list_before_now_short = message_list_before_now[-int(global_config.chat.max_context_size * 0.3) :]
chat_content_block_short, message_id_list_short = build_readable_messages_with_id(
messages=message_list_before_now_short,
timestamp_mode="normal",
truncate=False,
show_actions=False,
)
self.last_obs_time_mark = time.time()
# --- 调用 LLM (普通文本生成) ---
llm_content = None
is_group_chat, chat_target_info, current_available_actions = self.get_necessary_info()
if available_actions is None:
available_actions = current_available_actions
# --- 2. 启动小脑并行思考 ---
all_sub_planner_results: List[Dict[str, Any]] = []
try:
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
sub_planner_actions: Dict[str, ActionInfo] = {}
for action_name, action_info in available_actions.items():
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
if reasoning_content:
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
else:
logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.debug(f"{self.log_prefix}规划器原始响应: {llm_content}")
if reasoning_content:
logger.debug(f"{self.log_prefix}规划器推理: {reasoning_content}")
if action_info.activation_type in [ActionActivationType.LLM_JUDGE, ActionActivationType.ALWAYS]:
sub_planner_actions[action_name] = action_info
elif action_info.activation_type == ActionActivationType.RANDOM:
if random.random() < action_info.random_activation_probability:
sub_planner_actions[action_name] = action_info
elif action_info.activation_type == ActionActivationType.KEYWORD:
if any(keyword in chat_content_block_short for keyword in action_info.activation_keywords):
sub_planner_actions[action_name] = action_info
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
action = "no_reply"
if sub_planner_actions:
sub_planner_actions_num = len(sub_planner_actions)
planner_size_config = global_config.chat.planner_size
sub_planner_size = int(planner_size_config) + (
1 if random.random() < planner_size_config - int(planner_size_config) else 0
)
sub_planner_num = math.ceil(sub_planner_actions_num / sub_planner_size)
logger.info(f"{self.log_prefix}使用{sub_planner_num}个小脑进行思考 (尺寸: {sub_planner_size})")
action_items = list(sub_planner_actions.items())
random.shuffle(action_items)
sub_planner_lists = [action_items[i::sub_planner_num] for i in range(sub_planner_num)]
sub_plan_tasks = [
self.sub_plan(
action_list=action_group,
chat_content_block=chat_content_block_short,
message_id_list=message_id_list_short,
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
)
for action_group in sub_planner_lists
]
sub_plan_results = await asyncio.gather(*sub_plan_tasks)
for sub_result in sub_plan_results:
all_sub_planner_results.extend(sub_result)
sub_actions_str = ", ".join(
a["action_type"] for a in all_sub_planner_results if a["action_type"] != "no_action"
) or "no_action"
logger.info(f"{self.log_prefix}小脑决策: [{sub_actions_str}]")
except Exception as e:
logger.error(f"{self.log_prefix}小脑调度过程中出错: {e}\n{traceback.format_exc()}")
# --- 3. 大脑独立思考是否回复 ---
action, reasoning, action_data, target_message = "no_reply", "大脑初始化默认", {}, None
try:
prompt, _ = await self.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions={},
mode=mode,
chat_content_block_override=chat_content_block,
message_id_list_override=message_id_list,
)
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
if llm_content:
try:
parsed_json = orjson.loads(repair_json(llm_content))
if isinstance(parsed_json, list):
if parsed_json:
parsed_json = parsed_json[-1]
logger.warning(f"{self.log_prefix}LLM返回了多个JSON对象使用最后一个: {parsed_json}")
else:
parsed_json = {}
if not isinstance(parsed_json, dict):
logger.error(f"{self.log_prefix}解析后的JSON不是字典类型: {type(parsed_json)}")
parsed_json = {}
parsed_json = parsed_json[-1] if isinstance(parsed_json, list) and parsed_json else parsed_json
if isinstance(parsed_json, dict):
action = parsed_json.get("action", "no_reply")
reasoning = parsed_json.get("reason", "未提供原因")
# 将所有其他属性添加到action_data
for key, value in parsed_json.items():
if key not in ["action", "reason"]:
action_data[key] = value
# 非no_reply动作需要target_message_id
action_data = {k: v for k, v in parsed_json.items() if k not in ["action", "reason"]}
if action != "no_reply":
if target_message_id := parsed_json.get("target_message_id"):
# 根据target_message_id查找原始消息
target_message = self.find_message_by_id(target_message_id, message_id_list)
# 如果获取的target_message为None输出warning并重新plan
if target_message is None:
self.plan_retry_count += 1
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
# 如果连续三次plan均为None输出error并选取最新消息
if self.plan_retry_count >= self.max_plan_retries:
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败选择最新消息作为target_message")
if target_id := parsed_json.get("target_message_id"):
target_message = self.find_message_by_id(target_id, message_id_list)
if not target_message:
target_message = self.get_latest_message(message_id_list)
self.plan_retry_count = 0 # 重置计数器
else:
# 递归重新plan
return await self.plan(mode, loop_start_time, available_actions)
else:
# 成功获取到target_message重置计数器
self.plan_retry_count = 0
else:
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
logger.info(f"{self.log_prefix}大脑决策: [{action}]")
except Exception as e:
logger.error(f"{self.log_prefix}大脑处理过程中发生意外错误: {e}\n{traceback.format_exc()}")
action, reasoning = "no_reply", f"大脑处理错误: {e}"
if action != "no_reply" and action != "reply" and action not in current_available_actions:
logger.warning(
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
)
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
action = "no_reply"
except Exception as json_e:
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
traceback.print_exc()
reasoning = f"解析LLM响应JSON失败: {json_e}. 将使用默认动作 'no_reply'."
action = "no_reply"
except Exception as outer_e:
logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_reply: {outer_e}")
traceback.print_exc()
action = "no_reply"
reasoning = f"Planner 内部处理错误: {outer_e}"
# --- 4. 整合大脑和小脑的决策 ---
# 如果是私聊且开启了强制回复则将no_reply强制改为reply
if not is_group_chat and global_config.chat.force_reply_private and action == "no_reply":
action = "reply"
reasoning = "私聊强制回复"
logger.info(f"{self.log_prefix}私聊强制回复已触发,将动作从 'no_reply' 修改为 'reply'")
is_parallel = True
for info in all_sub_planner_results:
action_type = info.get("action_type")
if action_type and action_type not in ["no_action", "no_reply"]:
action_info = available_actions.get(action_type)
if action_info and not action_info.parallel_action:
is_parallel = False
if mode == ChatMode.NORMAL and action in current_available_actions:
is_parallel = current_available_actions[action].parallel_action
break
action_data["loop_start_time"] = loop_start_time
final_actions: List[Dict[str, Any]] = []
actions = []
# 1. 添加Planner取得的动作
actions.append({
if is_parallel:
logger.info(f"{self.log_prefix}决策模式: 大脑与小脑并行")
if action not in ["no_action", "no_reply"]:
final_actions.append(
{
"action_type": action,
"reasoning": reasoning,
"action_data": action_data,
"action_message": target_message,
"available_actions": available_actions # 添加这个字段
})
"available_actions": available_actions,
}
)
final_actions.extend(all_sub_planner_results)
else:
logger.info(f"{self.log_prefix}决策模式: 小脑优先 (检测到非并行action)")
final_actions.extend(all_sub_planner_results)
if action != "reply" and is_parallel:
actions.append({
"action_type": "reply",
"action_message": target_message,
"available_actions": available_actions
})
final_actions = self._filter_no_actions(final_actions)
return actions,target_message
if not final_actions:
final_actions = [
{
"action_type": "no_action",
"reasoning": "所有规划器都选择不执行动作",
"action_data": {}, "action_message": None, "available_actions": available_actions
}
]
final_target_message = target_message
if not final_target_message and final_actions:
final_target_message = next((act.get("action_message") for act in final_actions if act.get("action_message")), None)
actions_str = ", ".join([a.get('action_type', 'N/A') for a in final_actions])
logger.info(f"{self.log_prefix}最终执行动作 ({len(final_actions)}): [{actions_str}]")
return final_actions, final_target_message
async def build_planner_prompt(
self,
is_group_chat: bool, # Now passed as argument
chat_target_info: Optional[dict], # Now passed as argument
is_group_chat: bool,
chat_target_info: Optional[dict],
current_available_actions: Dict[str, ActionInfo],
refresh_time :bool = False,
mode: ChatMode = ChatMode.FOCUS,
) -> tuple[str, list]: # sourcery skip: use-join
chat_content_block_override: Optional[str] = None,
message_id_list_override: Optional[List] = None,
refresh_time: bool = False, # 添加缺失的参数
) -> tuple[str, list]:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
# --- 通用信息获取 ---
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
bot_name = global_config.bot.nickname
bot_nickname = (
f",也有人叫你{','.join(global_config.bot.alias_names)}"
if global_config.bot.alias_names
else ""
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
)
bot_core_personality = global_config.personality.personality_core
identity_block = (
f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}"
)
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}"
schedule_block = ""
if global_config.schedule.enable:
if global_config.planning_system.schedule_enable:
if current_activity := schedule_manager.get_current_activity():
schedule_block = (
f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
)
schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
mood_block = ""
if global_config.mood.enable_mood:
@@ -424,20 +666,38 @@ class ActionPlanner:
# --- 根据模式构建不同的Prompt ---
if mode == ChatMode.PROACTIVE:
long_term_memory_block = await self._get_long_term_memory_context()
action_options_text = await self._build_action_options(
current_available_actions, mode
# 获取最近的聊天记录用于主动思考决策
message_list_short = get_raw_msg_before_timestamp_with_chat(
chat_id=self.chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.2), # 主动思考时只看少量最近消息
)
chat_content_block, _ = build_readable_messages_with_id(
messages=message_list_short,
timestamp_mode="normal",
truncate=False,
show_actions=False,
)
prompt_template = await global_prompt_manager.get_prompt_async(
"proactive_planner_prompt"
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
actions_before_now = get_actions_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=time.time() - 3600,
timestamp_end=time.time(),
limit=5,
)
actions_before_now_block = build_readable_actions(actions=actions_before_now)
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
prompt = prompt_template.format(
time_block=time_block,
identity_block=identity_block,
schedule_block=schedule_block,
mood_block=mood_block,
long_term_memory_block=long_term_memory_block,
action_options_text=action_options_text,
chat_content_block=chat_content_block or "最近没有聊天内容。",
actions_before_now_block=actions_before_now_block,
)
return prompt, []
@@ -463,12 +723,8 @@ class ActionPlanner:
limit=5,
)
actions_before_now_block = build_readable_actions(
actions=actions_before_now
)
actions_before_now_block = (
f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
)
actions_before_now_block = build_readable_actions(actions=actions_before_now)
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
if refresh_time:
self.last_obs_time_mark = time.time()
@@ -507,27 +763,19 @@ class ActionPlanner:
chat_target_name = None
if not is_group_chat and chat_target_info:
chat_target_name = (
chat_target_info.get("person_name")
or chat_target_info.get("user_nickname")
or "对方"
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
)
chat_context_description = f"你正在和 {chat_target_name} 私聊"
action_options_block = await self._build_action_options(
current_available_actions, mode
)
action_options_block = await self._build_action_options(current_available_actions, mode)
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
custom_prompt_block = ""
if global_config.custom_prompt.planner_custom_prompt_content:
custom_prompt_block = (
global_config.custom_prompt.planner_custom_prompt_content
)
custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content
planner_prompt_template = await global_prompt_manager.get_prompt_async(
"planner_prompt"
)
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
schedule_block=schedule_block,
mood_block=mood_block,
@@ -555,9 +803,7 @@ class ActionPlanner:
"""
is_group_chat = True
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
logger.debug(
f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}"
)
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
current_available_actions_dict = self.action_manager.get_using_actions()
@@ -568,13 +814,9 @@ class ActionPlanner:
current_available_actions = {}
for action_name in current_available_actions_dict:
if action_name in all_registered_actions:
current_available_actions[action_name] = all_registered_actions[
action_name
]
current_available_actions[action_name] = all_registered_actions[action_name]
else:
logger.warning(
f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到"
)
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
# 将no_reply作为系统级特殊动作添加到可用动作中
# no_reply虽然是系统级决策但需要让规划器认为它是可用的

View File

@@ -1,6 +1,6 @@
"""
默认回复生成器 - 集成SmartPrompt系统
使用重构后的SmartPrompt系统替换原有的复杂提示词构建逻辑
默认回复生成器 - 集成统一Prompt系统
使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑
"""
import traceback
@@ -11,11 +11,9 @@ import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.chat.utils.prompt_utils import PromptUtils
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.config.api_ada_configs import TaskConfig
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
@@ -23,7 +21,7 @@ from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
@@ -37,10 +35,9 @@ from src.person_info.relationship_fetcher import relationship_fetcher_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
from src.schedule.schedule_manager import schedule_manager
# 导入新的智能Prompt系统
from src.chat.utils.smart_prompt import SmartPrompt, SmartPromptParameters
# 导入新的统一Prompt系统
from src.chat.utils.prompt import PromptParameters
logger = get_logger("replyer")
@@ -286,6 +283,7 @@ class DefaultReplyer:
return False, None, None
from src.plugin_system.core.event_manager import event_manager
# 触发 POST_LLM 事件(请求 LLM 之前)
if not from_plugin:
result = await event_manager.trigger_event(
EventType.POST_LLM, plugin_name="SYSTEM", prompt=prompt, stream_id=stream_id
@@ -307,6 +305,7 @@ class DefaultReplyer:
"model": model_name,
"tool_calls": tool_call,
}
# 触发 AFTER_LLM 事件
if not from_plugin:
result = await event_manager.trigger_event(
@@ -600,7 +599,8 @@ class DefaultReplyer:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
return PromptUtils.parse_reply_target(target_message)
from src.chat.utils.prompt import Prompt
return Prompt.parse_reply_target(target_message)
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
@@ -830,7 +830,7 @@ class DefaultReplyer:
)
person_name = await person_info_manager.get_value(person_id, "person_name")
sender = person_name
target = reply_message.get('processed_plain_text')
target = reply_message.get("processed_plain_text")
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
@@ -876,6 +876,7 @@ class DefaultReplyer:
if sender:
target_user_info = await person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务
task_results = await asyncio.gather(
self._time_and_run_task(
@@ -888,7 +889,7 @@ class DefaultReplyer:
),
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(
PromptUtils.build_cross_context(chat_id, target_user_info, global_config.personality.prompt_mode),
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
"cross_context",
),
)
@@ -939,7 +940,8 @@ class DefaultReplyer:
identity_block = await get_individuality().get_personality_block()
schedule_block = ""
if global_config.schedule.enable:
if global_config.planning_system.schedule_enable:
from src.schedule.schedule_manager import schedule_manager
current_activity = schedule_manager.get_current_activity()
if current_activity:
schedule_block = f"你当前正在:{current_activity}"
@@ -971,8 +973,8 @@ class DefaultReplyer:
# 根据配置选择模板
current_prompt_mode = global_config.personality.prompt_mode
# 使用重构后的SmartPrompt系统
prompt_params = SmartPromptParameters(
# 使用新的统一Prompt系统 - 创建PromptParameters
prompt_parameters = PromptParameters(
chat_id=chat_id,
is_group_chat=is_group_chat,
sender=sender,
@@ -1005,12 +1007,19 @@ class DefaultReplyer:
action_descriptions=action_descriptions,
)
# 使用重构后的SmartPrompt系统
smart_prompt = SmartPrompt(
template_name=None, # 由current_prompt_mode自动选择
parameters=prompt_params,
)
prompt_text = await smart_prompt.build_prompt()
# 使用新的统一Prompt系统 - 使用正确的模板名称
template_name = None
if current_prompt_mode == "s4u":
template_name = "s4u_style_prompt"
elif current_prompt_mode == "normal":
template_name = "normal_style_prompt"
elif current_prompt_mode == "minimal":
template_name = "default_expressor_prompt"
# 获取模板内容
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
prompt_text = await prompt.build()
return prompt_text
@@ -1111,8 +1120,8 @@ class DefaultReplyer:
template_name = "default_expressor_prompt"
# 使用重构后的SmartPrompt系统 - Expressor模式
prompt_params = SmartPromptParameters(
# 使用新的统一Prompt系统 - Expressor模式创建PromptParameters
prompt_parameters = PromptParameters(
chat_id=chat_id,
is_group_chat=is_group_chat,
sender=sender,
@@ -1132,8 +1141,10 @@ class DefaultReplyer:
relation_info_block=relation_info,
)
smart_prompt = SmartPrompt(parameters=prompt_params)
prompt_text = await smart_prompt.build_prompt()
# 使用新的统一Prompt系统 - Expressor模式
template_prompt = await global_prompt_manager.get_prompt_async("default_expressor_prompt")
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
prompt_text = await prompt.build()
return prompt_text
@@ -1181,7 +1192,9 @@ class DefaultReplyer:
else:
logger.debug(f"\n{prompt}\n")
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(prompt)
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
)
logger.debug(f"replyer生成内容: {content}")
return content, reasoning_content, model_name, tool_calls

View File

@@ -1,7 +1,6 @@
from typing import Dict, Optional, List, Tuple
from typing import Dict, Optional
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer

823
src/chat/utils/prompt.py Normal file
View File

@@ -0,0 +1,823 @@
"""
统一提示词系统 - 合并模板管理和智能构建功能
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
"""
import re
import asyncio
import time
import contextvars
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal, Tuple
from contextlib import asynccontextmanager
from rich.traceback import install
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
install(extra_lines=3)
logger = get_logger("unified_prompt")
@dataclass
class PromptParameters:
"""统一提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
def validate(self) -> List[str]:
"""参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
class PromptContext:
"""提示词上下文管理器"""
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock()
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域"""
if context_id is not None:
try:
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
try:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
finally:
self._context_lock.release()
except asyncio.TimeoutError:
logger.warning(f"获取上下文锁超时context_id: {context_id}")
context_id = None
previous_context = self._current_context
token = self._current_context_var.set(context_id) if context_id else None
else:
previous_context = self._current_context
token = None
try:
yield self
finally:
if context_id is not None and token is not None:
try:
self._current_context_var.reset(token)
except Exception as e:
logger.warning(f"恢复上下文时出错: {e}")
try:
self._current_context = previous_context
except Exception:
...
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
current_context = self._current_context
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
if prompt.name:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
class PromptManager:
"""统一提示词管理器"""
def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str) -> "Prompt":
"""异步获取提示模板"""
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
async with self._lock:
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
"""添加新提示模板"""
prompt = Prompt(fstr, name=name)
if prompt.name:
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
"""格式化提示模板"""
prompt = await self.get_prompt_async(name)
result = prompt.format(**kwargs)
return result
# 全局单例
global_prompt_manager = PromptManager()
class Prompt:
"""
统一提示词类 - 合并模板管理和智能构建功能
真正的Prompt类支持模板管理和智能上下文构建
"""
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
def __init__(
self,
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
should_register: bool = True
):
"""
初始化统一提示词
Args:
template: 提示词模板字符串
name: 提示词名称
parameters: 构建参数
should_register: 是否自动注册到全局管理器
"""
self.template = template
self.name = name
self.parameters = parameters or PromptParameters()
self.args = self._parse_template_args(template)
self._formatted_result = ""
# 预处理模板中的转义花括号
self._processed_template = self._process_escaped_braces(template)
# 自动注册
if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(self)
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号"""
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
elif not isinstance(template, str):
template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def _parse_template_args(self, template: str) -> List[str]:
"""解析模板参数"""
template_args = []
processed_template = self._process_escaped_braces(template)
result = re.findall(r"\{(.*?)}", processed_template)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
return template_args
async def build(self) -> str:
"""
构建完整的提示词,包含智能上下文
Returns:
str: 构建完成的提示词文本
"""
# 参数验证
errors = self.parameters.validate()
if errors:
logger.error(f"参数验证失败: {', '.join(errors)}")
raise ValueError(f"参数验证失败: {', '.join(errors)}")
start_time = time.time()
try:
# 构建上下文数据
context_data = await self._build_context_data()
# 格式化模板
result = await self._format_with_context(context_data)
total_time = time.time() - start_time
logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
self._formatted_result = result
return result
except asyncio.TimeoutError as e:
logger.error(f"构建Prompt超时: {e}")
raise TimeoutError(f"构建Prompt超时: {e}")
except Exception as e:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}")
async def _build_context_data(self) -> Dict[str, Any]:
"""构建智能上下文数据"""
# 并行执行所有构建任务
start_time = time.time()
timing_logs = {}
try:
# 准备构建任务
tasks = []
task_names = []
# 初始化预构建参数
pre_built_params = {}
if self.parameters.expression_habits_block:
pre_built_params["expression_habits_block"] = self.parameters.expression_habits_block
if self.parameters.relation_info_block:
pre_built_params["relation_info_block"] = self.parameters.relation_info_block
if self.parameters.memory_block:
pre_built_params["memory_block"] = self.parameters.memory_block
if self.parameters.tool_info_block:
pre_built_params["tool_info_block"] = self.parameters.tool_info_block
if self.parameters.knowledge_prompt:
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
if self.parameters.cross_context_block:
pre_built_params["cross_context_block"] = self.parameters.cross_context_block
# 根据参数确定要构建的项
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
tasks.append(self._build_expression_habits())
task_names.append("expression_habits")
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
tasks.append(self._build_memory_block())
task_names.append("memory_block")
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info())
task_names.append("relation_info")
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
tasks.append(self._build_tool_info())
task_names.append("tool_info")
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
tasks.append(self._build_knowledge_info())
task_names.append("knowledge_info")
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
tasks.append(self._build_cross_context())
task_names.append("cross_context")
# 性能优化
base_timeout = 10.0
task_timeout = 2.0
timeout_seconds = min(
max(base_timeout, len(tasks) * task_timeout),
30.0,
)
max_concurrent_tasks = 5
if len(tasks) > max_concurrent_tasks:
results = []
for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks]
batch_names = task_names[i : i + max_concurrent_tasks]
batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
)
results.extend(batch_results)
else:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
)
# 处理结果
context_data = {}
for i, result in enumerate(results):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
elif isinstance(result, dict):
context_data.update(result)
# 添加预构建的参数
for key, value in pre_built_params.items():
if value:
context_data[key] = value
except asyncio.TimeoutError:
logger.error(f"构建超时 ({timeout_seconds}s)")
context_data = {}
for key, value in pre_built_params.items():
if value:
context_data[key] = value
# 构建聊天历史
if self.parameters.prompt_mode == "s4u":
await self._build_s4u_chat_context(context_data)
else:
await self._build_normal_chat_context(context_data)
# 补充基础信息
context_data.update({
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"extra_info_block": self.parameters.extra_info_block,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"identity": self.parameters.identity_block,
"schedule_block": self.parameters.schedule_block,
"moderation_prompt": self.parameters.moderation_prompt_block,
"reply_target_block": self.parameters.reply_target_block,
"mood_state": self.parameters.mood_prompt,
"action_descriptions": self.parameters.action_descriptions,
})
total_time = time.time() - start_time
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建S4U模式的聊天上下文"""
if not self.parameters.message_list_before_now_long:
return
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
self.parameters.message_list_before_now_long,
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
self.parameters.sender
)
context_data["core_dialogue_prompt"] = core_dialogue
context_data["background_dialogue_prompt"] = background_dialogue
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建normal模式的聊天上下文"""
if not self.parameters.chat_talking_prompt_short:
return
context_data["chat_info"] = f"""群里的聊天内容:
{self.parameters.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""构建S4U风格的分离对话prompt"""
# 实现逻辑与原有SmartPromptBuilder相同
core_dialogue_list = []
bot_id = str(global_config.bot.qq_account)
for msg_dict in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
reply_to = msg_dict.get("reply_to", "")
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
core_dialogue_list.append(msg_dict)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True,
)
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
# 构建核心对话 prompt
core_dialogue_prompt = ""
if core_dialogue_list:
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
if not has_bot_message:
core_dialogue_prompt = ""
else:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :]
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话,你们正在交流中:
{core_dialogue_prompt_str}
--------------------------------
"""
return core_dialogue_prompt, all_dialogue_prompt
async def _build_expression_habits(self) -> Dict[str, Any]:
"""构建表达习惯"""
# 简化的实现,完整实现需要导入相关模块
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> Dict[str, Any]:
"""构建记忆块"""
# 简化的实现
return {"memory_block": ""}
async def _build_relation_info(self) -> Dict[str, Any]:
"""构建关系信息"""
try:
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
return {"relation_info_block": relation_info}
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self) -> Dict[str, Any]:
"""构建工具信息"""
# 简化的实现
return {"tool_info_block": ""}
async def _build_knowledge_info(self) -> Dict[str, Any]:
"""构建知识信息"""
# 简化的实现
return {"knowledge_prompt": ""}
async def _build_cross_context(self) -> Dict[str, Any]:
"""构建跨群上下文"""
try:
cross_context = await Prompt.build_cross_context(
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
)
return {"cross_context_block": cross_context}
except Exception as e:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
"""使用上下文数据格式化模板"""
if self.parameters.prompt_mode == "s4u":
params = self._prepare_s4u_params(context_data)
elif self.parameters.prompt_mode == "normal":
params = self._prepare_normal_params(context_data)
else:
params = self._prepare_default_params(context_data)
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备S4U模式的参数"""
return {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"sender_name": self.parameters.sender,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
"time_block": context_data.get("time_block", ""),
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备Normal模式的参数"""
return {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备默认模式的参数"""
return {
"expression_habits_block": context_data.get("expression_habits_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"chat_target": "",
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"chat_target_2": "",
"reply_target_block": context_data.get("reply_target_block", ""),
"raw_reply": self.parameters.target,
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
def format(self, *args, **kwargs) -> str:
"""格式化模板,支持位置参数和关键字参数"""
try:
# 先用位置参数格式化
if args:
formatted_args = {}
for i in range(len(args)):
if i < len(self.args):
formatted_args[self.args[i]] = args[i]
processed_template = self._processed_template.format(**formatted_args)
else:
processed_template = self._processed_template
# 再用关键字参数格式化
if kwargs:
processed_template = processed_template.format(**kwargs)
# 将临时标记还原为实际的花括号
result = self._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
def __str__(self) -> str:
"""返回格式化后的结果或原始模板"""
return self._formatted_result if self._formatted_result else self.template
def __repr__(self) -> str:
"""返回提示词的表示形式"""
return f"Prompt(template='{self.template}', name='{self.name}')"
# =============================================================================
# PromptUtils功能迁移 - 静态工具方法
# 这些方法原来在PromptUtils类中现在作为Prompt类的静态方法
# 解决循环导入问题
# =============================================================================
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
"""
解析回复目标消息 - 统一实现
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = ""
target = ""
# 添加None检查防止NoneType错误
if target_message is None:
return sender, target
if ":" in target_message or "" in target_message:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r"[:]", string=target_message, maxsplit=1)
if len(parts) == 2:
sender = parts[0].strip()
target = parts[1].strip()
return sender, target
@staticmethod
async def build_relation_info(chat_id: str, reply_to: str) -> str:
"""
构建关系信息 - 统一实现
Args:
chat_id: 聊天ID
reply_to: 回复目标字符串
Returns:
str: 关系信息字符串
"""
if not global_config.relationship.enable_relationship:
return ""
from src.person_info.relationship_fetcher import relationship_fetcher_manager
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
if not reply_to:
return ""
sender, text = Prompt.parse_reply_target(reply_to)
if not sender or not text:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
) -> str:
"""
构建跨群聊上下文 - 统一实现
Args:
chat_id: 聊天ID
prompt_mode: 当前提示词模式
target_user_info: 目标用户信息
Returns:
str: 跨群聊上下文字符串
"""
if not global_config.cross_context.enable:
return ""
from src.plugin_system.apis import cross_context_api
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""
if prompt_mode == "normal":
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
elif prompt_mode == "s4u":
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
return ""
@staticmethod
def parse_reply_target_id(reply_to: str) -> str:
"""
解析回复目标中的用户ID
Args:
reply_to: 回复目标字符串
Returns:
str: 用户ID
"""
if not reply_to:
return ""
# 复用parse_reply_target方法的逻辑
sender, _ = Prompt.parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
return ""
# 工厂函数
def create_prompt(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt:
"""快速创建Prompt实例的工厂函数"""
if parameters is None:
parameters = PromptParameters(**kwargs)
return Prompt(template, name, parameters)
async def create_prompt_async(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt:
"""异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt

View File

@@ -1,299 +0,0 @@
import re
import asyncio
import contextvars
from rich.traceback import install
from contextlib import asynccontextmanager
from typing import Dict, Any, Optional, List, Union
from src.common.logger import get_logger
install(extra_lines=3)
logger = get_logger("prompt_build")
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
# 使用contextvars创建协程上下文变量
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock() # 保留锁用于其他操作
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
"""创建一个异步的临时提示模板作用域"""
# 保存当前上下文并设置新上下文
if context_id is not None:
try:
# 添加超时保护,避免长时间等待锁
await asyncio.wait_for(self._context_lock.acquire(), timeout=5.0)
try:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
finally:
self._context_lock.release()
except asyncio.TimeoutError:
logger.warning(f"获取上下文锁超时context_id: {context_id}")
# 超时时直接进入,不设置上下文
context_id = None
# 保存当前协程的上下文值,不影响其他协程
previous_context = self._current_context
# 设置当前协程的新上下文
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
else:
# 如果没有提供新上下文,保持当前上下文不变
previous_context = self._current_context
token = None
try:
yield self
finally:
# 恢复之前的上下文,添加异常保护
if context_id is not None and token is not None:
try:
self._current_context_var.reset(token)
except Exception as e:
logger.warning(f"恢复上下文时出错: {e}")
# 如果reset失败尝试直接设置
try:
self._current_context = previous_context
except Exception:
...
# 静默忽略恢复失败
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
current_context = self._current_context
logger.debug(f"获取提示词: {name} 当前上下文: {current_context}")
if (
current_context
and current_context in self._context_prompts
and name in self._context_prompts[current_context]
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
if prompt.name:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
class PromptManager:
def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域,支持 message_id 为 None 的情况"""
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str) -> "Prompt":
# 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
# 如果上下文中不存在,则使用全局提示模板
async with self._lock:
# logger.debug(f"从全局获取提示词: {name}")
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
prompt = Prompt(fstr, name=name)
if prompt.name:
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
# 获取当前提示词
prompt = await self.get_prompt_async(name)
# 获取基本格式化结果
result = prompt.format(**kwargs)
return result
# 全局单例
global_prompt_manager = PromptManager()
class Prompt(str):
template: str
name: Optional[str]
args: List[str]
_args: List[Any]
_kwargs: Dict[str, Any]
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号,将 \\{\\} 替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
elif not isinstance(template, str):
template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def __new__(
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
):
# 如果传入的是元组,转换为列表
if isinstance(args, tuple):
args = list(args)
should_register = kwargs.pop("_should_register", True)
# 预处理模板中的转义花括号
processed_fstr = cls._process_escaped_braces(fstr)
# 解析模板
template_args = []
result = re.findall(r"\{(.*?)}", processed_fstr)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
# 如果提供了初始参数,立即格式化
if kwargs or args:
formatted = cls._format_template(fstr, args=args, kwargs=kwargs)
obj = super().__new__(cls, formatted)
else:
obj = super().__new__(cls, "")
obj.template = fstr
obj.name = name
obj.args = template_args
obj._args = args or []
obj._kwargs = kwargs
# 修改自动注册逻辑
if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(obj)
return obj
@classmethod
async def create_async(
cls, fstr, name: Optional[str] = None, args: Optional[Union[List[Any], tuple[Any, ...]]] = None, **kwargs
):
"""异步创建Prompt实例"""
prompt = cls(fstr, name, args, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt
@classmethod
def _format_template(
cls, template, args: Optional[List[Any]] = None, kwargs: Optional[Dict[str, Any]] = None
) -> str:
if kwargs is None:
kwargs = {}
# 预处理模板中的转义花括号
processed_template = cls._process_escaped_braces(template)
template_args = []
result = re.findall(r"\{(.*?)}", processed_template)
for expr in result:
if expr and expr not in template_args:
template_args.append(expr)
formatted_args = {}
formatted_kwargs = {}
# 处理位置参数
if args:
# print(len(template_args), len(args), template_args, args)
for i in range(len(args)):
if i < len(template_args):
arg = args[i]
if isinstance(arg, Prompt):
formatted_args[template_args[i]] = arg.format(**kwargs)
else:
formatted_args[template_args[i]] = arg
else:
logger.error(
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
)
raise ValueError("格式化模板失败")
# 处理关键字参数
if kwargs:
for key, value in kwargs.items():
if isinstance(value, Prompt):
remaining_kwargs = {k: v for k, v in kwargs.items() if k != key}
formatted_kwargs[key] = value.format(**remaining_kwargs)
else:
formatted_kwargs[key] = value
try:
# 先用位置参数格式化
if args:
processed_template = processed_template.format(**formatted_args)
# 再用关键字参数格式化
if kwargs:
processed_template = processed_template.format(**formatted_kwargs)
# 将临时标记还原为实际的花括号
result = cls._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
) from e
def format(self, *args, **kwargs) -> "str":
"""支持位置参数和关键字参数的格式化,使用"""
ret = type(self)(
self.template,
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs or self._kwargs,
)
# print(f"prompt build result: {ret} name: {ret.name} ")
return str(ret)
def __str__(self) -> str:
return super().__str__() if self._kwargs or self._args else self.template
def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')"

View File

@@ -1,156 +0,0 @@
"""
智能提示词参数模块 - 优化参数结构
简化SmartPromptParameters减少冗余和重复
"""
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal
@dataclass
class SmartPromptParameters:
"""简化的智能提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
def validate(self) -> List[str]:
"""统一的参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
def get_needed_build_tasks(self) -> List[str]:
"""获取需要执行的任务列表"""
tasks = []
if self.enable_expression and not self.expression_habits_block:
tasks.append("expression_habits")
if self.enable_memory and not self.memory_block:
tasks.append("memory_block")
if self.enable_relation and not self.relation_info_block:
tasks.append("relation_info")
if self.enable_tool and not self.tool_info_block:
tasks.append("tool_info")
if self.enable_knowledge and not self.knowledge_prompt:
tasks.append("knowledge_info")
if self.enable_cross_context and not self.cross_context_block:
tasks.append("cross_context")
return tasks
@classmethod
def from_legacy_params(cls, **kwargs) -> "SmartPromptParameters":
"""
从旧版参数创建新参数对象
Args:
**kwargs: 旧版参数
Returns:
SmartPromptParameters: 新参数对象
"""
return cls(
# 基础参数
chat_id=kwargs.get("chat_id", ""),
is_group_chat=kwargs.get("is_group_chat", False),
sender=kwargs.get("sender", ""),
target=kwargs.get("target", ""),
reply_to=kwargs.get("reply_to", ""),
extra_info=kwargs.get("extra_info", ""),
prompt_mode=kwargs.get("current_prompt_mode", "s4u"),
# 功能开关
enable_tool=kwargs.get("enable_tool", True),
enable_memory=kwargs.get("enable_memory", True),
enable_expression=kwargs.get("enable_expression", True),
enable_relation=kwargs.get("enable_relation", True),
enable_cross_context=kwargs.get("enable_cross_context", True),
enable_knowledge=kwargs.get("enable_knowledge", True),
# 性能控制
max_context_messages=kwargs.get("max_context_messages", 50),
debug_mode=kwargs.get("debug_mode", False),
# 聊天历史和上下文
chat_target_info=kwargs.get("chat_target_info"),
message_list_before_now_long=kwargs.get("message_list_before_now_long", []),
message_list_before_short=kwargs.get("message_list_before_short", []),
chat_talking_prompt_short=kwargs.get("chat_talking_prompt_short", ""),
target_user_info=kwargs.get("target_user_info"),
# 已构建的内容块
expression_habits_block=kwargs.get("expression_habits_block", ""),
relation_info_block=kwargs.get("relation_info", ""),
memory_block=kwargs.get("memory_block", ""),
tool_info_block=kwargs.get("tool_info", ""),
knowledge_prompt=kwargs.get("knowledge_prompt", ""),
cross_context_block=kwargs.get("cross_context_block", ""),
# 其他内容块
keywords_reaction_prompt=kwargs.get("keywords_reaction_prompt", ""),
extra_info_block=kwargs.get("extra_info_block", ""),
time_block=kwargs.get("time_block", ""),
identity_block=kwargs.get("identity_block", ""),
schedule_block=kwargs.get("schedule_block", ""),
moderation_prompt_block=kwargs.get("moderation_prompt_block", ""),
reply_target_block=kwargs.get("reply_target_block", ""),
mood_prompt=kwargs.get("mood_prompt", ""),
action_descriptions=kwargs.get("action_descriptions", ""),
# 可用动作信息
available_actions=kwargs.get("available_actions", None),
)

View File

@@ -1,132 +0,0 @@
"""
共享提示词工具模块 - 消除重复代码
提供统一的工具函数供DefaultReplyer和SmartPrompt使用
"""
import re
import time
from typing import Dict, Any, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import cross_context_api
logger = get_logger("prompt_utils")
class PromptUtils:
"""提示词工具类 - 提供共享功能,移除缓存相关功能和依赖检查"""
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
"""
解析回复目标消息 - 统一实现
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = ""
target = ""
# 添加None检查防止NoneType错误
if target_message is None:
return sender, target
if ":" in target_message or "" in target_message:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r"[:]", string=target_message, maxsplit=1)
if len(parts) == 2:
sender = parts[0].strip()
target = parts[1].strip()
return sender, target
@staticmethod
async def build_relation_info(chat_id: str, reply_to: str) -> str:
"""
构建关系信息 - 统一实现
Args:
chat_id: 聊天ID
reply_to: 回复目标字符串
Returns:
str: 关系信息字符串
"""
if not global_config.relationship.enable_relationship:
return ""
from src.person_info.relationship_fetcher import relationship_fetcher_manager
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
if not reply_to:
return ""
sender, text = PromptUtils.parse_reply_target(reply_to)
if not sender or not text:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if not person_id:
logger.warning(f"未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
) -> str:
"""
构建跨群聊上下文 - 统一实现完全继承DefaultReplyer功能
"""
if not global_config.cross_context.enable:
return ""
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return ""
if current_prompt_mode == "normal":
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
elif current_prompt_mode == "s4u":
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
return ""
@staticmethod
def parse_reply_target_id(reply_to: str) -> str:
"""
解析回复目标中的用户ID
Args:
reply_to: 回复目标字符串
Returns:
str: 用户ID
"""
if not reply_to:
return ""
# 复用parse_reply_target方法的逻辑
sender, _ = PromptUtils.parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
return ""

View File

@@ -1,938 +0,0 @@
"""
智能Prompt系统 - 完全重构版本
基于原有DefaultReplyer的完整功能集成使用新的参数结构
解决实现质量不高、功能集成不完整和错误处理不足的问题
"""
import asyncio
import time
from datetime import datetime
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Tuple
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
build_readable_messages,
)
from src.person_info.person_info import get_person_info_manager
from src.chat.utils.prompt_utils import PromptUtils
from src.chat.utils.prompt_parameters import SmartPromptParameters
logger = get_logger("smart_prompt")
@dataclass
class ChatContext:
"""聊天上下文信息"""
chat_id: str = ""
platform: str = ""
is_group: bool = False
user_id: str = ""
user_nickname: str = ""
group_id: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.now)
class SmartPromptBuilder:
"""重构的智能提示词构建器 - 统一错误处理和功能集成,移除缓存机制和依赖检查"""
def __init__(self):
# 移除缓存相关初始化
pass
async def build_context_data(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""并行构建完整的上下文数据 - 移除缓存机制和依赖检查"""
# 并行执行所有构建任务
start_time = time.time()
timing_logs = {}
try:
# 准备构建任务
tasks = []
task_names = []
# 初始化预构建参数,使用新的结构
pre_built_params = {}
if params.expression_habits_block:
pre_built_params["expression_habits_block"] = params.expression_habits_block
if params.relation_info_block:
pre_built_params["relation_info_block"] = params.relation_info_block
if params.memory_block:
pre_built_params["memory_block"] = params.memory_block
if params.tool_info_block:
pre_built_params["tool_info_block"] = params.tool_info_block
if params.knowledge_prompt:
pre_built_params["knowledge_prompt"] = params.knowledge_prompt
if params.cross_context_block:
pre_built_params["cross_context_block"] = params.cross_context_block
# 根据新的参数结构确定要构建的项
if params.enable_expression and not pre_built_params.get("expression_habits_block"):
tasks.append(self._build_expression_habits(params))
task_names.append("expression_habits")
if params.enable_memory and not pre_built_params.get("memory_block"):
tasks.append(self._build_memory_block(params))
task_names.append("memory_block")
if params.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info(params))
task_names.append("relation_info")
# 添加mai_think上下文构建任务
if not pre_built_params.get("mai_think"):
tasks.append(self._build_mai_think_context(params))
task_names.append("mai_think_context")
if params.enable_tool and not pre_built_params.get("tool_info_block"):
tasks.append(self._build_tool_info(params))
task_names.append("tool_info")
if params.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
tasks.append(self._build_knowledge_info(params))
task_names.append("knowledge_info")
if params.enable_cross_context and not pre_built_params.get("cross_context_block"):
tasks.append(self._build_cross_context(params))
task_names.append("cross_context")
# 性能优化:根据任务数量动态调整超时时间
base_timeout = 10.0 # 基础超时时间
task_timeout = 2.0 # 每个任务的超时时间
timeout_seconds = min(
max(base_timeout, len(tasks) * task_timeout), # 根据任务数量计算超时
30.0, # 最大超时时间
)
# 性能优化:限制并发任务数量,避免资源耗尽
max_concurrent_tasks = 5 # 最大并发任务数
if len(tasks) > max_concurrent_tasks:
# 分批执行任务
results = []
for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks]
batch_names = task_names[i : i + max_concurrent_tasks]
batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
)
results.extend(batch_results)
else:
# 一次性执行所有任务
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
)
# 处理结果并收集性能数据
context_data = {}
for i, result in enumerate(results):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
elif isinstance(result, dict):
# 结果格式: {component_name: value}
context_data.update(result)
# 记录耗时过长的任务
if task_name in timing_logs and timing_logs[task_name] > 8.0:
logger.warning(f"构建任务{task_name}耗时过长: {timing_logs[task_name]:.2f}s")
# 添加预构建的参数
for key, value in pre_built_params.items():
if value:
context_data[key] = value
except asyncio.TimeoutError:
logger.error(f"构建超时 ({timeout_seconds}s)")
context_data = {}
# 添加预构建的参数,即使在超时情况下
for key, value in pre_built_params.items():
if value:
context_data[key] = value
# 构建聊天历史 - 根据模式不同
if params.prompt_mode == "s4u":
await self._build_s4u_chat_context(context_data, params)
else:
await self._build_normal_chat_context(context_data, params)
# 补充基础信息
context_data.update(
{
"keywords_reaction_prompt": params.keywords_reaction_prompt,
"extra_info_block": params.extra_info_block,
"time_block": params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"identity": params.identity_block,
"schedule_block": params.schedule_block,
"moderation_prompt": params.moderation_prompt_block,
"reply_target_block": params.reply_target_block,
"mood_state": params.mood_prompt,
"action_descriptions": params.action_descriptions,
}
)
total_time = time.time() - start_time
if timing_logs:
timing_str = "; ".join([f"{name}: {time:.2f}s" for name, time in timing_logs.items()])
logger.info(f"构建任务耗时: {timing_str}")
logger.debug(f"构建完成,总耗时: {total_time:.2f}s")
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None:
"""构建S4U模式的聊天上下文 - 使用新参数结构"""
if not params.message_list_before_now_long:
return
# 使用共享工具构建分离历史
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
params.message_list_before_now_long,
params.target_user_info.get("user_id") if params.target_user_info else "",
params.sender
)
context_data["core_dialogue_prompt"] = core_dialogue
context_data["background_dialogue_prompt"] = background_dialogue
async def _build_normal_chat_context(self, context_data: Dict[str, Any], params: SmartPromptParameters) -> None:
"""构建normal模式的聊天上下文 - 使用新参数结构"""
if not params.chat_talking_prompt_short:
return
context_data["chat_info"] = f"""群里的聊天内容:
{params.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""构建S4U风格的分离对话prompt - 完整实现"""
core_dialogue_list = []
bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
for msg_dict in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
reply_to = msg_dict.get("reply_to", "")
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
# bot 和目标用户的对话
core_dialogue_list.append(msg_dict)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
if message_list_before_now:
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
all_dialogue_prompt_str = build_readable_messages(
latest_25_msgs,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True,
)
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
# 构建核心对话 prompt
core_dialogue_prompt = ""
if core_dialogue_list:
# 检查最新五条消息中是否包含bot自己说的消息
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
# logger.info(f"最新五条消息:{latest_5_messages}")
# logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}")
# 如果最新五条消息中不包含bot的消息则返回空字符串
if not has_bot_message:
core_dialogue_prompt = ""
else:
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
core_dialogue_prompt = f"""--------------------------------
这是你和{sender}的对话,你们正在交流中:
{core_dialogue_prompt_str}
--------------------------------
"""
return core_dialogue_prompt, all_dialogue_prompt
async def _build_mai_think_context(self, params: SmartPromptParameters) -> Any:
"""构建mai_think上下文 - 完全继承DefaultReplyer功能"""
from src.mais4u.mai_think import mai_thinking_manager
# 获取mai_think实例
mai_think = mai_thinking_manager.get_mai_think(params.chat_id)
# 设置mai_think的上下文信息
mai_think.memory_block = params.memory_block or ""
mai_think.relation_info_block = params.relation_info_block or ""
mai_think.time_block = params.time_block or f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 设置聊天目标信息
if params.is_group_chat:
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
else:
chat_target_name = "对方"
if params.chat_target_info:
chat_target_name = (
params.chat_target_info.get("person_name") or params.chat_target_info.get("user_nickname") or "对方"
)
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
mai_think.chat_target = chat_target_1
mai_think.chat_target_2 = chat_target_2
mai_think.chat_info = params.chat_talking_prompt_short or ""
mai_think.mood_state = params.mood_prompt or ""
mai_think.identity = params.identity_block or ""
mai_think.sender = params.sender
mai_think.target = params.target
# 返回mai_think实例以便后续使用
return mai_think
def _parse_reply_target_id(self, reply_to: str) -> str:
"""解析回复目标中的用户ID"""
if not reply_to:
return ""
# 复用_parse_reply_target方法的逻辑
sender, _ = self._parse_reply_target(reply_to)
if not sender:
return ""
# 获取用户ID
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
if person_id:
user_id = person_info_manager.get_value_sync(person_id, "user_id")
return str(user_id) if user_id else ""
async def _build_expression_habits(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建表达习惯 - 使用共享工具类完全继承DefaultReplyer功能"""
# 检查是否允许在此聊天流中使用表达
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(params.chat_id)
if not use_expression:
return {"expression_habits_block": ""}
from src.chat.express.expression_selector import expression_selector
style_habits = []
grammar_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
try:
selected_expressions = await expression_selector.select_suitable_expressions_llm(
params.chat_id, params.chat_talking_prompt_short, max_num=8, min_num=2, target_message=params.target
)
except Exception as e:
logger.error(f"选择表达方式失败: {e}")
selected_expressions = []
if selected_expressions:
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_type = expr.get("type", "style")
if expr_type == "grammar":
grammar_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
style_habits_str = "\n".join(style_habits)
grammar_habits_str = "\n".join(grammar_habits)
# 动态构建expression habits块
expression_habits_block = ""
expression_habits_title = ""
if style_habits_str.strip():
expression_habits_title = (
"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
)
expression_habits_block += f"{style_habits_str}\n"
if grammar_habits_str.strip():
expression_habits_title = (
"你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
)
expression_habits_block += f"{grammar_habits_str}\n"
if style_habits_str.strip() and grammar_habits_str.strip():
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中。"
return {"expression_habits_block": f"{expression_habits_title}\n{expression_habits_block}"}
async def _build_memory_block(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建记忆块 - 使用共享工具类完全继承DefaultReplyer功能"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
instant_memory = None
# 初始化记忆激活器
try:
memory_activator = MemoryActivator()
# 获取长期记忆
running_memories = await memory_activator.activate_memory_with_chat_history(
target_message=params.target, chat_history_prompt=params.chat_talking_prompt_short
)
except Exception as e:
logger.error(f"激活记忆失败: {e}")
running_memories = []
# 处理瞬时记忆
if global_config.memory.enable_instant_memory:
# 使用异步记忆包装器(最优化的非阻塞模式)
try:
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
# 获取异步记忆包装器
async_memory = get_async_instant_memory(params.chat_id)
# 后台存储聊天历史(完全非阻塞)
async_memory.store_memory_background(params.chat_talking_prompt_short)
# 快速检索记忆最大超时2秒
instant_memory = await async_memory.get_memory_with_fallback(params.target, max_timeout=2.0)
logger.info(f"异步瞬时记忆:{instant_memory}")
except ImportError:
# 如果异步包装器不可用,尝试使用异步记忆管理器
try:
from src.chat.memory_system.async_memory_optimizer import (
retrieve_memory_nonblocking,
store_memory_nonblocking,
)
# 异步存储聊天历史(非阻塞)
asyncio.create_task(
store_memory_nonblocking(chat_id=params.chat_id, content=params.chat_talking_prompt_short)
)
# 尝试从缓存获取瞬时记忆
instant_memory = await retrieve_memory_nonblocking(chat_id=params.chat_id, query=params.target)
# 如果没有缓存结果,快速检索一次
if instant_memory is None:
try:
# 使用VectorInstantMemoryV2实例
instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1)
instant_memory = await asyncio.wait_for(
instant_memory_system.get_memory_for_context(params.target), timeout=1.5
)
except asyncio.TimeoutError:
logger.warning("瞬时记忆检索超时,使用空结果")
instant_memory = ""
logger.info(f"向量瞬时记忆:{instant_memory}")
except ImportError:
# 最后的fallback使用原有逻辑但加上超时控制
logger.warning("异步记忆系统不可用,使用带超时的同步方式")
# 使用VectorInstantMemoryV2实例
instant_memory_system = VectorInstantMemoryV2(chat_id=params.chat_id, retention_hours=1)
# 异步存储聊天历史
asyncio.create_task(instant_memory_system.store_message(params.chat_talking_prompt_short))
# 带超时的记忆检索
try:
instant_memory = await asyncio.wait_for(
instant_memory_system.get_memory_for_context(params.target),
timeout=1.0, # 最保守的1秒超时
)
except asyncio.TimeoutError:
logger.warning("瞬时记忆检索超时,跳过记忆获取")
instant_memory = ""
except Exception as e:
logger.error(f"瞬时记忆检索失败: {e}")
instant_memory = ""
logger.info(f"同步瞬时记忆:{instant_memory}")
except Exception as e:
logger.error(f"瞬时记忆系统异常: {e}")
instant_memory = ""
# 构建记忆字符串,即使某种记忆为空也要继续
memory_str = ""
has_any_memory = False
# 添加长期记忆
if running_memories:
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n"
has_any_memory = True
# 添加瞬时记忆
if instant_memory:
if not memory_str:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
memory_str += f"- {instant_memory}\n"
has_any_memory = True
# 注入视频分析结果引导语
memory_str = self._inject_video_prompt_if_needed(params.target, memory_str)
# 只有当完全没有任何记忆时才返回空字符串
return {"memory_block": memory_str if has_any_memory else ""}
def _inject_video_prompt_if_needed(self, target: str, memory_str: str) -> str:
"""统一视频分析结果注入逻辑"""
if target and ("[视频内容]" in target or "好的,我将根据您提供的" in target):
video_prompt_injection = (
"\n请注意,以上内容是你刚刚观看的视频,请以第一人称分享你的观后感,而不是在分析一份报告。"
)
return memory_str + video_prompt_injection
return memory_str
async def _build_relation_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建关系信息 - 使用共享工具类"""
try:
relation_info = await PromptUtils.build_relation_info(params.chat_id, params.reply_to)
return {"relation_info_block": relation_info}
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建工具信息 - 使用共享工具类完全继承DefaultReplyer功能"""
if not params.enable_tool:
return {"tool_info_block": ""}
if not params.reply_to:
return {"tool_info_block": ""}
sender, text = PromptUtils.parse_reply_target(params.reply_to)
if not text:
return {"tool_info_block": ""}
from src.plugin_system.core.tool_use import ToolExecutor
# 使用工具执行器获取信息
try:
tool_executor = ToolExecutor(chat_id=params.chat_id)
tool_results, _, _ = await tool_executor.execute_from_chat_message(
sender=sender, target_message=text, chat_history=params.chat_talking_prompt_short, return_details=False
)
if tool_results:
tool_info_str = "以下是你通过工具获取到的实时信息:\n"
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
logger.info(f"获取到 {len(tool_results)} 个工具结果")
return {"tool_info_block": tool_info_str}
else:
logger.debug("未获取到任何工具结果")
return {"tool_info_block": ""}
except Exception as e:
logger.error(f"工具信息获取失败: {e}")
return {"tool_info_block": ""}
async def _build_knowledge_info(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建知识信息 - 使用共享工具类完全继承DefaultReplyer功能"""
if not params.reply_to:
logger.debug("没有回复对象,跳过获取知识库内容")
return {"knowledge_prompt": ""}
sender, content = PromptUtils.parse_reply_target(params.reply_to)
if not content:
logger.debug("回复对象内容为空,跳过获取知识库内容")
return {"knowledge_prompt": ""}
logger.debug(
f"获取知识库内容,元消息:{params.chat_talking_prompt_short[:30]}...,消息长度: {len(params.chat_talking_prompt_short)}"
)
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return {"knowledge_prompt": ""}
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
from src.plugin_system.apis import llm_api
from src.config.config import model_config
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
prompt = await global_prompt_manager.format_prompt(
"lpmm_get_knowledge_prompt",
bot_name=bot_name,
time_now=time_now,
chat_history=params.chat_talking_prompt_short,
sender=sender,
target_message=content,
)
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
if tool_calls:
from src.plugin_system.core.tool_use import ToolExecutor
tool_executor = ToolExecutor(chat_id=params.chat_id)
result = await tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return {"knowledge_prompt": ""}
found_knowledge_from_lpmm = result.get("content", "")
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
return {
"knowledge_prompt": f"你有以下这些**知识**\n{found_knowledge_from_lpmm}\n请你**记住上面的知识**,之后可能会用到。\n"
}
else:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return {"knowledge_prompt": ""}
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return {"knowledge_prompt": ""}
async def _build_cross_context(self, params: SmartPromptParameters) -> Dict[str, Any]:
"""构建跨群上下文 - 使用共享工具类"""
try:
cross_context = await PromptUtils.build_cross_context(
params.chat_id, params.prompt_mode, params.target_user_info
)
return {"cross_context_block": cross_context}
except Exception as e:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具类"""
return PromptUtils.parse_reply_target(target_message)
class SmartPrompt:
"""重构的智能提示词核心类 - 移除缓存机制和依赖检查,简化架构"""
def __init__(
self,
template_name: Optional[str] = None,
parameters: Optional[SmartPromptParameters] = None,
):
self.parameters = parameters or SmartPromptParameters()
self.template_name = template_name or self._get_default_template()
self.builder = SmartPromptBuilder()
def _get_default_template(self) -> str:
"""根据模式选择默认模板"""
if self.parameters.prompt_mode == "s4u":
return "s4u_style_prompt"
elif self.parameters.prompt_mode == "normal":
return "normal_style_prompt"
else:
return "default_expressor_prompt"
async def build_prompt(self) -> str:
"""构建最终的Prompt文本 - 移除缓存机制和依赖检查"""
# 参数验证
errors = self.parameters.validate()
if errors:
logger.error(f"参数验证失败: {', '.join(errors)}")
raise ValueError(f"参数验证失败: {', '.join(errors)}")
start_time = time.time()
try:
# 构建基础上下文的完整映射
context_data = await self.builder.build_context_data(self.parameters)
# 检查关键上下文数据
if not context_data or not isinstance(context_data, dict):
logger.error("构建的上下文数据无效")
raise ValueError("构建的上下文数据无效")
# 获取模板
template = await self._get_template()
if template is None:
logger.error("无法获取模板")
raise ValueError("无法获取模板")
# 根据模式传递不同的参数
if self.parameters.prompt_mode == "s4u":
result = await self._build_s4u_prompt(template, context_data)
elif self.parameters.prompt_mode == "normal":
result = await self._build_normal_prompt(template, context_data)
else:
result = await self._build_default_prompt(template, context_data)
# 记录性能数据
total_time = time.time() - start_time
logger.debug(f"SmartPrompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
return result
except asyncio.TimeoutError as e:
logger.error(f"构建Prompt超时: {e}")
raise TimeoutError(f"构建Prompt超时: {e}")
except Exception as e:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}")
async def _get_template(self) -> Optional[Prompt]:
"""获取模板"""
try:
return await global_prompt_manager.get_prompt_async(self.template_name)
except Exception as e:
logger.error(f"获取模板 {self.template_name} 失败: {e}")
raise RuntimeError(f"获取模板 {self.template_name} 失败: {e}")
async def _build_s4u_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建S4U模式的完整Prompt - 使用新参数结构"""
params = {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"sender_name": self.parameters.sender,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
"time_block": context_data.get("time_block", ""),
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
async def _build_normal_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建Normal模式的完整Prompt - 使用新参数结构"""
params = {
**context_data,
"expression_habits_block": context_data.get("expression_habits_block", ""),
"tool_info_block": context_data.get("tool_info_block", ""),
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
async def _build_default_prompt(self, template: Prompt, context_data: Dict[str, Any]) -> str:
"""构建默认模式的Prompt - 使用新参数结构"""
params = {
"expression_habits_block": context_data.get("expression_habits_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"chat_target": "",
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"chat_target_2": "",
"reply_target_block": context_data.get("reply_target_block", ""),
"raw_reply": self.parameters.target,
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
}
return await global_prompt_manager.format_prompt(self.template_name, **params)
# 工厂函数 - 简化创建 - 更新参数结构
def create_smart_prompt(
chat_id: str = "", sender_name: str = "", target_message: str = "", reply_to: str = "", **kwargs
) -> SmartPrompt:
"""快速创建智能Prompt实例的工厂函数 - 使用新参数结构"""
# 使用新的参数结构
parameters = SmartPromptParameters(
chat_id=chat_id, sender=sender_name, target=target_message, reply_to=reply_to, **kwargs
)
return SmartPrompt(parameters=parameters)
class SmartPromptHealthChecker:
"""SmartPrompt健康检查器 - 移除依赖检查"""
@staticmethod
async def check_system_health() -> Dict[str, Any]:
"""检查系统健康状态 - 移除依赖检查"""
health_status = {"status": "healthy", "components": {}, "issues": []}
try:
# 检查配置
try:
from src.config.config import global_config
health_status["components"]["config"] = "ok"
# 检查关键配置项
if not hasattr(global_config, "personality") or not hasattr(global_config.personality, "prompt_mode"):
health_status["issues"].append("缺少personality.prompt_mode配置")
health_status["status"] = "degraded"
if not hasattr(global_config, "memory") or not hasattr(global_config.memory, "enable_memory"):
health_status["issues"].append("缺少memory.enable_memory配置")
except Exception as e:
health_status["components"]["config"] = f"failed: {str(e)}"
health_status["issues"].append("配置加载失败")
health_status["status"] = "unhealthy"
# 检查Prompt模板
try:
required_templates = ["s4u_style_prompt", "normal_style_prompt", "default_expressor_prompt"]
for template_name in required_templates:
try:
await global_prompt_manager.get_prompt_async(template_name)
health_status["components"][f"template_{template_name}"] = "ok"
except Exception as e:
health_status["components"][f"template_{template_name}"] = f"failed: {str(e)}"
health_status["issues"].append(f"模板{template_name}加载失败")
health_status["status"] = "degraded"
except Exception as e:
health_status["components"]["prompt_templates"] = f"failed: {str(e)}"
health_status["issues"].append("Prompt模板检查失败")
health_status["status"] = "unhealthy"
return health_status
except Exception as e:
return {"status": "unhealthy", "components": {}, "issues": [f"健康检查异常: {str(e)}"]}
@staticmethod
async def run_performance_test() -> Dict[str, Any]:
"""运行性能测试"""
test_results = {"status": "completed", "tests": {}, "summary": {}}
try:
# 创建测试参数
test_params = SmartPromptParameters(
chat_id="test_chat",
sender="test_user",
target="test_message",
reply_to="test_user:test_message",
prompt_mode="s4u",
)
# 测试不同模式下的构建性能
modes = ["s4u", "normal", "minimal"]
for mode in modes:
test_params.prompt_mode = mode
smart_prompt = SmartPrompt(parameters=test_params)
# 运行多次测试取平均值
times = []
for _ in range(3):
start_time = time.time()
try:
await smart_prompt.build_prompt()
end_time = time.time()
times.append(end_time - start_time)
except Exception as e:
times.append(float("inf"))
logger.error(f"性能测试失败 (模式: {mode}): {e}")
# 计算统计信息
valid_times = [t for t in times if t != float("inf")]
if valid_times:
avg_time = sum(valid_times) / len(valid_times)
min_time = min(valid_times)
max_time = max(valid_times)
test_results["tests"][mode] = {
"avg_time": avg_time,
"min_time": min_time,
"max_time": max_time,
"success_rate": len(valid_times) / len(times),
}
else:
test_results["tests"][mode] = {
"avg_time": float("inf"),
"min_time": float("inf"),
"max_time": float("inf"),
"success_rate": 0,
}
# 计算总体统计
all_avg_times = [
test["avg_time"] for test in test_results["tests"].values() if test["avg_time"] != float("inf")
]
if all_avg_times:
test_results["summary"] = {
"overall_avg_time": sum(all_avg_times) / len(all_avg_times),
"fastest_mode": min(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0],
"slowest_mode": max(test_results["tests"].items(), key=lambda x: x[1]["avg_time"])[0],
}
return test_results
except Exception as e:
return {"status": "failed", "tests": {}, "summary": {}, "error": str(e)}

View File

@@ -4,7 +4,6 @@ import time
import hashlib
import uuid
import io
import asyncio
import numpy as np
from typing import Optional, Tuple, Dict, Any
@@ -35,8 +34,7 @@ def is_image_message(message: Dict[str, Any]) -> bool:
bool: 是否为图片消息
"""
return message.get("type") == "image" or (
isinstance(message.get("content"), dict) and
message["content"].get("type") == "image"
isinstance(message.get("content"), dict) and message["content"].get("type") == "image"
)
@@ -596,7 +594,6 @@ class ImageManager:
return "", "[图片]"
# 创建全局单例
image_manager = None

View File

@@ -361,6 +361,7 @@ class GraphNodes(Base):
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
memory_items = Column(Text, nullable=False)
hash = Column(Text, nullable=False)
weight = Column(Float, nullable=False, default=1.0)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)

View File

@@ -443,6 +443,12 @@ MODULE_COLORS = {
"manifest_utils": "\033[38;5;39m", # 蓝色
"schedule_manager": "\033[38;5;27m", # 深蓝色
"monthly_plan_manager": "\033[38;5;171m",
"plan_manager": "\033[38;5;171m",
"llm_generator": "\033[38;5;171m",
"schedule_bridge": "\033[38;5;171m",
"sleep_manager": "\033[38;5;171m",
"official_configs": "\033[38;5;171m",
"mmc_com_layer": "\033[38;5;67m",
# 聊天和多媒体扩展
"chat_voice": "\033[38;5;87m", # 浅青色
"typo_gen": "\033[38;5;123m", # 天蓝色
@@ -564,8 +570,14 @@ MODULE_ALIASES = {
"dependency_config": "依赖配置",
"dependency_manager": "依赖管理",
"manifest_utils": "清单工具",
"schedule_manager": "计划管理",
"monthly_plan_manager": "月度计划",
"schedule_manager": "规划系统-日程表管理",
"monthly_plan_manager": "规划系统-月度计划",
"plan_manager": "规划系统-计划管理",
"llm_generator": "规划系统-LLM生成",
"schedule_bridge": "计划桥接",
"sleep_manager": "睡眠管理",
"official_configs": "官方配置",
"mmc_com_layer": "MMC通信层",
# 聊天和多媒体扩展
"chat_voice": "语音处理",
"typo_gen": "错字生成",

View File

@@ -1,5 +1,6 @@
from typing import List, Dict, Any, Literal
from typing import List, Dict, Any, Literal, Union
from pydantic import Field, field_validator
from threading import Lock
from src.config.config_base import ValidatedConfigBase
@@ -9,7 +10,7 @@ class APIProvider(ValidatedConfigBase):
name: str = Field(..., min_length=1, description="API提供商名称")
base_url: str = Field(..., description="API基础URL")
api_key: str = Field(..., min_length=1, description="API密钥")
api_key: Union[str, List[str]] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询")
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
default="openai", description="客户端类型如openai/google等默认为openai"
)
@@ -33,12 +34,33 @@ class APIProvider(ValidatedConfigBase):
@classmethod
def validate_api_key(cls, v):
"""验证API密钥不能为空"""
if not v or not v.strip():
if isinstance(v, str):
if not v.strip():
raise ValueError("API密钥不能为空")
elif isinstance(v, list):
if not v:
raise ValueError("API密钥列表不能为空")
for key in v:
if not isinstance(key, str) or not key.strip():
raise ValueError("API密钥列表中的密钥不能为空")
else:
raise ValueError("API密钥必须是字符串或字符串列表")
return v
def __init__(self, **data):
super().__init__(**data)
self._api_key_lock = Lock()
self._api_key_index = 0
def get_api_key(self) -> str:
with self._api_key_lock:
if isinstance(self.api_key, str):
return self.api_key
if not self.api_key:
raise ValueError("API密钥列表为空")
key = self.api_key[self._api_key_index]
self._api_key_index = (self._api_key_index + 1) % len(self.api_key)
return key
class ModelInfo(ValidatedConfigBase):
@@ -113,6 +135,7 @@ class ModelTaskConfig(ValidatedConfigBase):
voice: TaskConfig = Field(..., description="语音识别模型配置")
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
planner: TaskConfig = Field(..., description="规划模型配置")
planner_small: TaskConfig = Field(..., description="小脑sub-planner规划模型配置")
embedding: TaskConfig = Field(..., description="嵌入模型配置")
lpmm_entity_extract: TaskConfig = Field(..., description="LPMM实体提取模型配置")
lpmm_rdf_build: TaskConfig = Field(..., description="LPMM RDF构建模型配置")
@@ -147,9 +170,9 @@ class ModelTaskConfig(ValidatedConfigBase):
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo] = Field(..., min_items=1, description="模型列表")
models: List[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表")
api_providers: List[APIProvider] = Field(..., min_length=1, description="API提供商列表")
def __init__(self, **data):
super().__init__(**data)

View File

@@ -35,17 +35,16 @@ from src.config.official_configs import (
VoiceConfig,
DebugConfig,
CustomPromptConfig,
ScheduleConfig,
VideoAnalysisConfig,
DependencyManagementConfig,
WebSearchConfig,
AntiPromptInjectionConfig,
SleepSystemConfig,
MonthlyPlanSystemConfig,
CrossContextConfig,
PermissionConfig,
CommandConfig,
MaizoneIntercomConfig,
PlanningSystemConfig,
)
from .api_ada_configs import (
@@ -379,7 +378,6 @@ class Config(ValidatedConfigBase):
debug: DebugConfig = Field(..., description="调试配置")
custom_prompt: CustomPromptConfig = Field(..., description="自定义提示配置")
voice: VoiceConfig = Field(..., description="语音配置")
schedule: ScheduleConfig = Field(..., description="调度配置")
permission: PermissionConfig = Field(..., description="权限配置")
command: CommandConfig = Field(..., description="命令系统配置")
@@ -395,8 +393,8 @@ class Config(ValidatedConfigBase):
)
web_search: WebSearchConfig = Field(default_factory=lambda: WebSearchConfig(), description="网络搜索配置")
sleep_system: SleepSystemConfig = Field(default_factory=lambda: SleepSystemConfig(), description="睡眠系统配置")
monthly_plan_system: MonthlyPlanSystemConfig = Field(
default_factory=lambda: MonthlyPlanSystemConfig(), description="月层计划系统配置"
planning_system: PlanningSystemConfig = Field(
default_factory=lambda: PlanningSystemConfig(), description="划系统配置"
)
cross_context: CrossContextConfig = Field(
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"

View File

@@ -75,7 +75,7 @@ class ChatConfig(ValidatedConfigBase):
at_bot_inevitable_reply: bool = Field(default=False, description="@机器人的必然回复")
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
focus_value: float = Field(default=1.0, description="专注值")
force_focus_private: bool = Field(default=False, description="强制专注私聊")
force_reply_private: bool = Field(default=False, description="强制回复私聊")
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
timestamp_display_mode: Literal["normal", "normal_no_YMD", "relative"] = Field(
default="normal_no_YMD", description="时间戳显示模式"
@@ -92,6 +92,7 @@ class ChatConfig(ValidatedConfigBase):
default_factory=list, description="启用主动思考的群聊范围格式platform:group_id为空则不限制"
)
delta_sigma: int = Field(default=120, description="采用正态分布随机时间间隔")
planner_size: float = Field(default=5.0, ge=1.0, description="小脑sub-planner的尺寸决定每个小脑处理多少个action")
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
"""
@@ -259,7 +260,6 @@ class NormalChatConfig(ValidatedConfigBase):
"""普通聊天配置类"""
class ExpressionRule(ValidatedConfigBase):
"""表达学习规则"""
@@ -519,11 +519,19 @@ class LPMMKnowledgeConfig(ValidatedConfigBase):
embedding_dimension: int = Field(default=1024, description="嵌入维度")
class ScheduleConfig(ValidatedConfigBase):
"""日程配置类"""
class PlanningSystemConfig(ValidatedConfigBase):
"""规划系统配置 (日程与月度计划)"""
enable: bool = Field(default=True, description="启用")
guidelines: Optional[str] = Field(default=None, description="指导方针")
# --- 日程生成 (原 ScheduleConfig) ---
schedule_enable: bool = Field(True, description="是否启用每日日程生成功能")
schedule_guidelines: str = Field("", description="日程生成指导原则")
# --- 月度计划 (原 MonthlyPlanSystemConfig) ---
monthly_plan_enable: bool = Field(True, description="是否启用月度计划系统")
monthly_plan_guidelines: str = Field("", description="月度计划生成指导原则")
max_plans_per_month: int = Field(10, description="每月最多生成的计划数量")
avoid_repetition_days: int = Field(7, description="避免在多少天内重复使用同一个月度计划")
completion_threshold: int = Field(3, description="一个月度计划被使用多少次后算作完成")
class DependencyManagementConfig(ValidatedConfigBase):
@@ -602,6 +610,11 @@ class SleepSystemConfig(ValidatedConfigBase):
"""睡眠系统配置类"""
enable: bool = Field(default=True, description="是否启用睡眠系统")
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机")
wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机")
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
@@ -615,7 +628,12 @@ class SleepSystemConfig(ValidatedConfigBase):
# --- 失眠机制相关参数 ---
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
insomnia_duration_minutes: int = Field(default=30, ge=1, description="单次失眠状态的持续时间(分钟)")
insomnia_trigger_delay_minutes: List[int] = Field(
default_factory=lambda:[30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
)
insomnia_duration_minutes: List[int] = Field(
default_factory=lambda:[15, 45], description="单次失眠状态的持续时间范围(分钟)"
)
sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值")
deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值")
insomnia_chance_low_pressure: float = Field(default=0.6, ge=0.0, le=1.0, description="压力不足时的失眠基础概率")
@@ -638,22 +656,13 @@ class SleepSystemConfig(ValidatedConfigBase):
)
class MonthlyPlanSystemConfig(ValidatedConfigBase):
"""月度计划系统配置类"""
enable: bool = Field(default=True, description="是否启用本功能")
max_plans_per_month: int = Field(default=20, ge=1, description="每个月允许存在的最大计划数量")
completion_threshold: int = Field(default=3, ge=1, description="计划使用多少次后自动标记为已完成")
avoid_repetition_days: int = Field(default=7, ge=1, description="多少天内不重复抽取同一个计划")
guidelines: Optional[str] = Field(default=None, description="月度计划生成的指导原则")
class ContextGroup(ValidatedConfigBase):
"""上下文共享组配置"""
name: str = Field(..., description="共享组的名称")
chat_ids: List[List[str]] = Field(
..., description='属于该组的聊天ID列表格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]'
...,
description='属于该组的聊天ID列表格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]',
)

View File

@@ -8,6 +8,7 @@ from maim_message import MessageServer
from src.common.remote import TelemetryHeartBeatTask
from src.manager.async_task_manager import async_task_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from src.common.remote import TelemetryHeartBeatTask
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.config.config import global_config
@@ -31,28 +32,49 @@ from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
from src.common.message import get_global_api
from src.chat.memory_system.Hippocampus import hippocampus_manager
if not global_config.memory.enable_memory:
import src.chat.memory_system.Hippocampus as hippocampus_module
class MockHippocampusManager:
def initialize(self):
pass
def get_hippocampus(self):
return None
async def build_memory(self):
pass
async def forget_memory(self, percentage: float = 0.005):
pass
async def consolidate_memory(self):
pass
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3, fast_retrieval: bool = False) -> list:
async def get_memory_from_text(
self,
text: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
max_depth: int = 3,
fast_retrieval: bool = False,
) -> list:
return []
async def get_memory_from_topic(self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3) -> list:
async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> list:
return []
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str]]:
return 0.0, []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
return []
def get_all_node_names(self) -> list:
return []
@@ -93,6 +115,9 @@ class MainSystem:
"""清理资源"""
try:
# 停止消息重组器
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system import EventType
asyncio.run(event_manager.trigger_event(EventType.ON_STOP,plugin_name="SYSTEM"))
from src.utils.message_chunker import reassembler
import asyncio
@@ -211,7 +236,6 @@ MoFox_Bot(第三方修改版)
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")
# 启动情绪管理器
await mood_manager.start()
logger.info("情绪管理器初始化成功")
@@ -251,7 +275,7 @@ MoFox_Bot(第三方修改版)
await self.individuality.initialize()
# 初始化月度计划管理器
if global_config.monthly_plan_system.enable:
if global_config.planning_system.monthly_plan_enable:
logger.info("正在初始化月度计划管理器...")
try:
await monthly_plan_manager.start_monthly_plan_generation()
@@ -260,7 +284,7 @@ MoFox_Bot(第三方修改版)
logger.error(f"月度计划管理器初始化失败: {e}")
# 初始化日程管理器
if global_config.schedule.enable:
if global_config.planning_system.schedule_enable:
logger.info("日程表功能已启用,正在初始化管理器...")
await schedule_manager.load_or_generate_today_schedule()
await schedule_manager.start_daily_schedule_generation()

View File

@@ -1,6 +1,6 @@
from src.chat.message_receive.chat_stream import get_chat_manager
import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U

View File

@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api

View File

@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
from src.mais4u.constant_s4u import ENABLE_S4U

View File

@@ -1,6 +1,6 @@
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
from src.chat.utils.utils import get_recent_group_speaker

View File

@@ -6,7 +6,7 @@ from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager

View File

@@ -150,6 +150,18 @@ class PersonInfoManager:
# Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id
# 你们的英文注释是何意味?
# 检查并修复关键字段为None的情况喵
if final_data.get("user_id") is None:
logger.warning(f"user_id为None使用'unknown'作为默认值 person_id={person_id}")
final_data["user_id"] = "unknown"
if final_data.get("platform") is None:
logger.warning(f"platform为None使用'unknown'作为默认值 person_id={person_id}")
final_data["platform"] = "unknown"
# 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
@@ -200,6 +212,15 @@ class PersonInfoManager:
# Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id
# 检查并修复关键字段为None的情况
if final_data.get("user_id") is None:
logger.warning(f"user_id为None使用'unknown'作为默认值 person_id={person_id}")
final_data["user_id"] = "unknown"
if final_data.get("platform") is None:
logger.warning(f"platform为None使用'unknown'作为默认值 person_id={person_id}")
final_data["platform"] = "unknown"
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
@@ -296,6 +317,15 @@ class PersonInfoManager:
if data and "user_id" in data:
creation_data["user_id"] = data["user_id"]
# 额外检查关键字段如果为None则使用默认值
if creation_data.get("user_id") is None:
logger.warning(f"创建用户时user_id为None使用'unknown'作为默认值 person_id={person_id}")
creation_data["user_id"] = "unknown"
if creation_data.get("platform") is None:
logger.warning(f"创建用户时platform为None使用'unknown'作为默认值 person_id={person_id}")
creation_data["platform"] = "unknown"
# 使用安全的创建方法,处理竞态条件
await self._safe_create_person_info(person_id, creation_data)

View File

@@ -9,7 +9,7 @@ from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager

View File

@@ -25,36 +25,30 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
return None
is_group = current_stream.group_info is not None
current_chat_raw_id = (
current_stream.group_info.group_id if is_group else current_stream.user_info.user_id
)
if is_group:
assert current_stream.group_info is not None
current_chat_raw_id = current_stream.group_info.group_id
else:
current_chat_raw_id = current_stream.user_info.user_id
current_type = "group" if is_group else "private"
for group in global_config.cross_context.groups:
# 检查当前聊天的ID和类型是否在组的chat_ids中
if [current_type, str(current_chat_raw_id)] in group.chat_ids:
# 返回组内其他聊天的 [type, id] 列表
return [
chat_info
for chat_info in group.chat_ids
if chat_info != [current_type, str(current_chat_raw_id)]
]
return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]]
return None
async def build_cross_context_normal(
chat_stream: ChatStream, other_chat_infos: List[List[str]]
) -> str:
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str:
"""
构建跨群聊/私聊上下文 (Normal模式)
"""
cross_context_messages = []
for chat_type, chat_raw_id in other_chat_infos:
is_group = chat_type == "group"
stream_id = get_chat_manager().get_stream_id(
chat_stream.platform, chat_raw_id, is_group=is_group
)
stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
if not stream_id:
continue
@@ -66,9 +60,7 @@ async def build_cross_context_normal(
)
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = build_readable_messages_with_id(
messages, timestamp_mode="relative"
)
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
@@ -95,9 +87,7 @@ async def build_cross_context_s4u(
if user_id:
for chat_type, chat_raw_id in other_chat_infos:
is_group = chat_type == "group"
stream_id = get_chat_manager().get_stream_id(
chat_stream.platform, chat_raw_id, is_group=is_group
)
stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
if not stream_id:
continue
@@ -112,9 +102,7 @@ async def build_cross_context_s4u(
if user_messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
user_name = (
target_user_info.get("person_name")
or target_user_info.get("user_nickname")
or user_id
target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
)
formatted_messages, _ = build_readable_messages_with_id(
user_messages, timestamp_mode="relative"
@@ -130,3 +118,63 @@ async def build_cross_context_s4u(
return ""
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
async def get_chat_history_by_group_name(group_name: str) -> str:
"""
根据互通组名字获取聊天记录
"""
target_group = None
for group in global_config.cross_context.groups:
if group.name == group_name:
target_group = group
break
if not target_group:
return f"找不到名为 {group_name} 的互通组。"
if not target_group.chat_ids:
return f"互通组 {group_name} 中没有配置任何聊天。"
chat_infos = target_group.chat_ids
chat_manager = get_chat_manager()
cross_context_messages = []
for chat_type, chat_raw_id in chat_infos:
is_group = chat_type == "group"
found_stream = None
for stream in chat_manager.streams.values():
if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id:
found_stream = stream
break
else: # private
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream
break
if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
continue
stream_id = found_stream.stream_id
try:
messages = get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=5, # 可配置
)
if messages:
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
continue
if not cross_context_messages:
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"

View File

@@ -12,7 +12,6 @@ import traceback
from typing import Tuple, Any, Dict, List, Optional
from rich.traceback import install
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
@@ -21,6 +20,7 @@ from src.plugin_system.base.component_types import ActionInfo
install(extra_lines=3)
# 日志记录器
logger = get_logger("generator_api")
@@ -107,15 +107,14 @@ async def generate_reply(
"""
try:
# 获取回复器
replyer = get_replyer(
chat_stream, chat_id, request_type=request_type
)
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
logger.debug("[GeneratorAPI] 开始生成回复")
# 向下兼容从action_data中获取reply_to和extra_info
if not reply_to and action_data:
reply_to = action_data.get("reply_to", "")
if not extra_info and action_data:
@@ -136,6 +135,7 @@ async def generate_reply(
return False, [], None
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
if content := llm_response_dict.get("content", ""):
# 处理为拟人化文本
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
else:
reply_set = []
@@ -211,6 +211,7 @@ async def rewrite_reply(
)
reply_set = []
if content:
# 处理为拟人化文本
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
if success:
@@ -236,9 +237,12 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
"""
if isinstance(content, list):
content = "".join(map(str, content))
if not isinstance(content, str):
raise ValueError("content 必须是字符串类型")
try:
# 处理LLM响应
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
reply_set = []
@@ -259,6 +263,18 @@ async def generate_response_custom(
request_type: str = "generator_api",
prompt: str = "",
) -> Optional[str]:
"""
使用自定义提示生成回复
Args:
chat_stream: 聊天流对象
chat_id: 聊天ID
request_type: 请求类型
prompt: 自定义提示
Returns:
Optional[str]: 生成的回复内容
"""
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")

View File

@@ -30,7 +30,6 @@
import traceback
import time
import difflib
import asyncio
from typing import Optional, Union, Dict, Any
from src.common.logger import get_logger
@@ -41,16 +40,16 @@ from maim_message import UserInfo
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, replace_user_references_async
from src.person_info.person_info import get_person_info_manager
from maim_message import Seg
from src.config.config import global_config
# 日志记录器
logger = get_logger("send_api")
# 适配器命令响应等待池
_adapter_response_pool: Dict[str, asyncio.Future] = {}
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
"""查找要回复的消息
@@ -101,6 +100,7 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
return message_recv
def put_adapter_response(request_id: str, response_data: dict) -> None:
"""将适配器响应放入响应池"""
if request_id in _adapter_response_pool:
@@ -187,6 +187,7 @@ async def _send_to_target(
# 创建消息段
message_segment = Seg(type=message_type, data=content) # type: ignore
# 处理回复消息
if reply_to_message:
anchor_message = message_dict_to_message_recv(message_dict=reply_to_message)
anchor_message.update_chat_stream(target_stream)
@@ -195,6 +196,7 @@ async def _send_to_target(
)
else:
anchor_message = None
reply_to_platform_id = None
# 构建发送消息对象
bot_message = MessageSending(
@@ -233,7 +235,6 @@ async def _send_to_target(
return False
# =============================================================================
# 公共API函数 - 预定义类型的发送函数
# =============================================================================
@@ -273,7 +274,9 @@ async def text_to_stream(
)
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False) -> bool:
async def emoji_to_stream(
emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
) -> bool:
"""向指定流发送表情包
Args:
@@ -284,10 +287,14 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo
Returns:
bool: 是否发送成功
"""
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply)
return await _send_to_target(
"emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply
)
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False) -> bool:
async def image_to_stream(
image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
) -> bool:
"""向指定流发送图片
Args:
@@ -298,11 +305,17 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo
Returns:
bool: 是否发送成功
"""
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply)
return await _send_to_target(
"image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply
)
async def command_to_stream(
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False
command: Union[str, dict],
stream_id: str,
storage_message: bool = True,
display_message: str = "",
set_reply: bool = False,
) -> bool:
"""向指定流发送命令

View File

@@ -93,6 +93,7 @@ class BaseAction(ABC):
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# =============================================================================

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, List, Union
from typing import Tuple, Optional, List, Union
from src.common.logger import get_logger
from .component_types import EventType, EventHandlerInfo, ComponentType
@@ -23,17 +23,26 @@ class BaseEventHandler(ABC):
"""是否拦截消息,默认为否"""
init_subscribe: List[Union[EventType, str]] = [EventType.UNKNOWN]
"""初始化时订阅的事件名称"""
plugin_name = None
def __init__(self):
self.log_prefix = "[EventHandler]"
"""对应插件名"""
self.plugin_config: Optional[Dict] = None
"""插件配置字典"""
self.subscribed_events = []
"""订阅的事件列表"""
if EventType.UNKNOWN in self.init_subscribe:
raise NotImplementedError("事件处理器必须指定 event_type")
# 优先使用实例级别的 plugin_config如果没有则使用类级别的配置
# 事件管理器会在注册时通过 set_plugin_config 设置实例级别的配置
instance_config = getattr(self, "plugin_config", None)
if instance_config is not None:
self.plugin_config = instance_config
else:
# 如果实例级别没有配置,则使用类级别的配置(向后兼容)
self.plugin_config = getattr(self.__class__, "plugin_config", {})
@abstractmethod
async def execute(self, kwargs: dict | None) -> Tuple[bool, bool, Optional[str]]:
"""执行事件处理的抽象方法,子类必须实现
@@ -90,14 +99,6 @@ class BaseEventHandler(ABC):
intercept_message=cls.intercept_message,
)
def set_plugin_config(self, plugin_config: Dict) -> None:
"""设置插件配置
Args:
plugin_config (dict): 插件配置字典
"""
self.plugin_config = plugin_config
def set_plugin_name(self, plugin_name: str) -> None:
"""设置插件名称
@@ -106,6 +107,9 @@ class BaseEventHandler(ABC):
"""
self.plugin_name = plugin_name
def set_plugin_config(self,plugin_config) -> None:
self.plugin_config = plugin_config
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问

View File

@@ -69,6 +69,7 @@ class EventType(Enum):
"""
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_STOP ="on_stop"
ON_MESSAGE = "on_message"
ON_PLAN = "on_plan"
POST_LLM = "post_llm"
@@ -215,27 +216,7 @@ class EventInfo(ComponentInfo):
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.EVENT
# 事件类型枚举
class EventType(Enum):
"""
事件类型枚举类
"""
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
ON_MESSAGE = "on_message"
ON_PLAN = "on_plan"
POST_LLM = "post_llm"
AFTER_LLM = "after_llm"
POST_SEND = "post_send"
AFTER_SEND = "after_send"
UNKNOWN = "unknown" # 未知事件类型
def __str__(self) -> str:
return self.value
self.component_type = ComponentType.EVENT_HANDLER
@dataclass

View File

@@ -1,3 +1,4 @@
from pathlib import Path
import re
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
@@ -170,6 +171,8 @@ class ComponentRegistry:
return False
action_class.plugin_name = action_info.plugin_name
# 设置插件配置
action_class.plugin_config = self.get_plugin_config(action_info.plugin_name) or {}
self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
@@ -188,6 +191,8 @@ class ComponentRegistry:
return False
command_class.plugin_name = command_info.plugin_name
# 设置插件配置
command_class.plugin_config = self.get_plugin_config(command_info.plugin_name) or {}
self._command_registry[command_name] = command_class
# 如果启用了且有匹配模式
@@ -220,6 +225,8 @@ class ComponentRegistry:
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
plus_command_class.plugin_name = plus_command_info.plugin_name
# 设置插件配置
plus_command_class.plugin_config = self.get_plugin_config(plus_command_info.plugin_name) or {}
self._plus_command_registry[plus_command_name] = plus_command_class
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
@@ -230,6 +237,8 @@ class ComponentRegistry:
tool_name = tool_info.name
tool_class.plugin_name = tool_info.plugin_name
# 设置插件配置
tool_class.plugin_config = self.get_plugin_config(tool_info.plugin_name) or {}
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
@@ -248,6 +257,9 @@ class ComponentRegistry:
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
return False
handler_class.plugin_name = handler_info.plugin_name
# 设置插件配置
handler_class.plugin_config = self.get_plugin_config(handler_info.plugin_name) or {}
self._event_handler_registry[handler_name] = handler_class
if not handler_info.enabled:
@@ -258,7 +270,7 @@ class ComponentRegistry:
# 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler(handler_class)
return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {})
# === 组件移除相关 ===
@@ -655,20 +667,35 @@ class ComponentRegistry:
plugin_info = self.get_plugin_info(plugin_name)
return plugin_info.components if plugin_info else []
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
def get_plugin_config(self, plugin_name: str) -> dict:
"""获取插件配置
Args:
plugin_name: 插件名称
Returns:
Optional[dict]: 插件配置字典或None
dict: 插件配置字典,如果插件实例不存在或配置为空,返回空字典
"""
# 从插件管理器获取插件实例的配置
from src.plugin_system.core.plugin_manager import plugin_manager
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
return plugin_instance.config if plugin_instance else None
if plugin_instance and plugin_instance.config:
return plugin_instance.config
# 如果插件实例不存在,尝试从配置文件读取
try:
import toml
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as f:
config_data = toml.load(f)
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
return config_data
except Exception as e:
logger.debug(f"读取插件 {plugin_name} 配置文件失败: {e}")
return {}
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""

View File

@@ -145,11 +145,12 @@ class EventManager:
logger.info(f"事件 {event_name} 已禁用")
return True
def register_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool:
"""注册事件处理器
Args:
handler_class (Type[BaseEventHandler]): 事件处理器类
plugin_config (Optional[dict]): 插件配置字典默认为None
Returns:
bool: 注册成功返回True已存在返回False
@@ -163,7 +164,13 @@ class EventManager:
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
return False
self._event_handlers[handler_name] = handler_class()
# 创建事件处理器实例,传递插件配置
handler_instance = handler_class()
handler_instance.plugin_config = plugin_config
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'):
handler_instance.set_plugin_config(plugin_config)
self._event_handlers[handler_name] = handler_instance
# 处理init_subscribe缓存失败的订阅
if self._event_handlers[handler_name].init_subscribe:

View File

@@ -200,7 +200,7 @@ class PluginManager:
# 检查并调用 on_plugin_loaded 钩子(如果存在)
if hasattr(plugin_instance, "on_plugin_loaded") and callable(
getattr(plugin_instance, "on_plugin_loaded")
plugin_instance.on_plugin_loaded
):
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
try:

View File

@@ -6,7 +6,7 @@ from src.plugin_system.core.global_announcement_manager import global_announceme
from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
import inspect
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger

View File

@@ -59,19 +59,79 @@ class AtAction(BaseAction):
if not user_info or not user_info.get("user_id"):
logger.info(f"找不到名为 '{user_name}' 的用户。")
return False, "用户不存在"
try:
# 使用回复器生成艾特回复,而不是直接发送命令
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import get_chat_manager
# 获取当前聊天流
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(self.chat_id)
if not chat_stream:
logger.error(f"找不到聊天流: {self.stream_id}")
return False, "聊天流不存在"
# 创建回复器实例
replyer = DefaultReplyer(chat_stream)
# 构建回复对象,将艾特消息作为回复目标
reply_to = f"{user_name}:{at_message}"
extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}"
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system import EventType
# 触发post_llm
result = await event_manager.trigger_event(EventType.POST_LLM,plugin_name="SYSTEM")
if not result.all_continue_process():
return False, f"被组件{result.get_summary().get("stopped_handlers","")}打断"
# 使用回复器生成回复
success, llm_response, prompt = await replyer.generate_reply_with_context(
reply_to=reply_to,
extra_info=extra_info,
enable_tool=False, # 艾特回复通常不需要工具调用
from_plugin=True # 标识来自插件
)
if success and llm_response:
# 获取生成的回复内容
reply_content = llm_response.get("content", "")
if reply_content:
# 获取用户QQ号发送真正的艾特消息
user_id = user_info.get("user_id")
# 发送真正的艾特命令,使用回复器生成的智能内容
await self.send_command(
"SEND_AT_MESSAGE",
args={"qq_id": user_info.get("user_id"), "text": at_message},
display_message=f"艾特用户 {user_name} 并发送消息: {at_message}",
args={"qq_id": user_id, "text": reply_content},
display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}",
)
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送消息: {at_message}",
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}",
action_done=True,
)
logger.info("艾特用户的动作已触发,但具体实现待完成。")
return True, "艾特用户的动作已触发,但具体实现待完成。"
logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}")
return True, "智能艾特消息发送成功"
else:
logger.warning("回复器生成了空内容")
return False, "回复内容为空"
else:
logger.error("回复器生成回复失败")
return False, "回复生成失败"
except Exception as e:
logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True)
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"执行艾特用户动作失败:{str(e)}",
action_done=False,
)
return False, f"执行失败: {str(e)}"
class AtCommand(BaseCommand):

View File

@@ -39,8 +39,9 @@ class EmojiAction(BaseAction):
llm_judge_prompt = """
判定是否需要使用表情动作的条件:
1. 用户明确要求使用表情包
2. 这是一个适合表达强烈情绪的场合
3. 不要发送太多表情包,如果你已经发送过多个表情包则回答""
2. 这是一个适合表达情绪的场合
3. 发表情包能使当前对话更有趣
4. 不要发送太多表情包,如果你已经发送过多个表情包则回答""
请回答""""
"""

View File

@@ -53,7 +53,9 @@ class MaiZoneRefactoredPlugin(BasePlugin):
"enable_reply": ConfigField(type=bool, default=True, description="完成后是否回复"),
"ai_image_number": ConfigField(type=int, default=1, description="AI生成图片数量"),
"image_number": ConfigField(type=int, default=1, description="本地配图数量1-9张"),
"image_directory": ConfigField(type=str, default=(Path(__file__).parent / "images").as_posix(), description="图片存储目录")
"image_directory": ConfigField(
type=str, default=(Path(__file__).parent / "images").as_posix(), description="图片存储目录"
),
},
"read": {
"permission": ConfigField(type=list, default=[], description="阅读权限QQ号列表"),
@@ -75,7 +77,9 @@ class MaiZoneRefactoredPlugin(BasePlugin):
"forbidden_hours_end": ConfigField(type=int, default=6, description="禁止发送的结束小时(24小时制)"),
},
"cookie": {
"http_fallback_host": ConfigField(type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"),
"http_fallback_host": ConfigField(
type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"
),
"http_fallback_port": ConfigField(type=int, default=9999, description="备用Cookie获取服务的端口"),
"napcat_token": ConfigField(type=str, default="", description="Napcat服务的认证Token可选"),
},
@@ -102,7 +106,7 @@ class MaiZoneRefactoredPlugin(BasePlugin):
content_service,
image_service,
cookie_service,
reply_tracker_service # 传入已创建的实例
reply_tracker_service, # 传入已创建的实例
)
scheduler_service = SchedulerService(self.get_config, qzone_service)
monitor_service = MonitorService(self.get_config, qzone_service)

View File

@@ -13,6 +13,7 @@ from src.common.logger import get_logger
import imghdr
import asyncio
from src.plugin_system.apis import llm_api, config_api, generator_api
from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name
from src.chat.message_receive.chat_stream import get_chat_manager
from maim_message import UserInfo
from src.llm_models.utils_model import LLMRequest
@@ -87,6 +88,11 @@ class ContentService:
if context:
prompt += f"\n作为参考,这里有一些最近的聊天记录:\n---\n{context}\n---"
# 添加跨群聊上下文
cross_context = await get_chat_history_by_group_name("maizone_context_group")
if cross_context and "找不到名为" not in cross_context:
prompt += f"\n\n---跨群聊参考---\n{cross_context}\n---"
# 添加历史记录以避免重复
prompt += "\n\n---历史说说记录---\n"
history_block = await get_send_history(qq_account)
@@ -232,7 +238,7 @@ class ContentService:
for i in range(3): # 重试3次
try:
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=30) as resp:
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status != 200:
logger.error(f"下载图片失败: {image_url}, status: {resp.status}")
await asyncio.sleep(2)

View File

@@ -272,8 +272,10 @@ class QZoneService:
# 检查是否已经在持久化记录中标记为已回复
if not self.reply_tracker.has_replied(fid, comment_tid):
# 记录日志以便追踪
logger.debug(f"发现新评论需要回复 - 说说ID: {fid}, 评论ID: {comment_tid}, "
f"评论人: {comment.get('nickname', '')}, 内容: {comment.get('content', '')}")
logger.debug(
f"发现新评论需要回复 - 说说ID: {fid}, 评论ID: {comment_tid}, "
f"评论人: {comment.get('nickname', '')}, 内容: {comment.get('content', '')}"
)
comments_to_reply.append(comment)
if not comments_to_reply:
@@ -791,10 +793,11 @@ class QZoneService:
try:
# 修复回复逻辑:确保能正确提醒被回复的人
data = {
"topicId": f"{host_qq}_{fid}__1", # 使用标准评论格式,而不是针对特定评论
"topicId": f"{host_qq}_{fid}__1",
"parent_tid": comment_tid,
"uin": uin,
"hostUin": host_qq,
"content": f"回复@{target_name}{content}", # 内容中明确标示回复对象
"content": content,
"format": "fs",
"plat": "qzone",
"source": "ic",
@@ -802,11 +805,13 @@ class QZoneService:
"ref": "feeds",
"richtype": "",
"richval": "",
"paramstr": f"@{target_name}", # 确保触发@提醒机制
"paramstr": "",
}
# 记录详细的请求参数用于调试
logger.info(f"子回复请求参数: topicId={data['topicId']}, parent_tid={data['parent_tid']}, content='{content[:50]}...'")
logger.info(
f"子回复请求参数: topicId={data['topicId']}, parent_tid={data['parent_tid']}, content='{content[:50]}...'"
)
await _request("POST", self.REPLY_URL, params={"g_tk": gtk}, data=data)
return True

View File

@@ -74,8 +74,10 @@ class ReplyTrackerService:
data = json.loads(file_content)
if self._validate_data(data):
self.replied_comments = data
logger.info(f"已加载 {len(self.replied_comments)} 条说说的回复记录,"
f"总计 {sum(len(comments) for comments in self.replied_comments.values())}评论")
logger.info(
f"已加载 {len(self.replied_comments)}说说的回复记录,"
f"总计 {sum(len(comments) for comments in self.replied_comments.values())} 条评论"
)
else:
logger.error("加载的数据格式无效,将创建新的记录")
self.replied_comments = {}
@@ -112,7 +114,7 @@ class ReplyTrackerService:
self._cleanup_old_records()
# 创建临时文件
temp_file = self.reply_record_file.with_suffix('.tmp')
temp_file = self.reply_record_file.with_suffix(".tmp")
# 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f:

View File

@@ -1746,3 +1746,32 @@ class SetGroupSignHandler(BaseEventHandler):
else:
logger.error("事件 napcat_set_group_sign 请求失败!")
return HandlerResult(False, False, {"status": "error"})
# ===PERSONAL===
class SetInputStatusHandler(BaseEventHandler):
handler_name: str = "napcat_set_input_status_handler"
handler_description: str = "设置输入状态"
weight: int = 100
intercept_message: bool = False
init_subscribe = [NapcatEvent.PERSONAL.SET_INPUT_STATUS]
async def execute(self, params: dict):
raw = params.get("raw", {})
user_id = params.get("user_id", "")
event_type = params.get("event_type", 0)
if params.get("raw", ""):
user_id = raw.get("user_id", "")
event_type = raw.get("event_type", 0)
if not user_id or event_type is None:
logger.error("事件 napcat_set_input_status 缺少必要参数: user_id 或 event_type")
return HandlerResult(False, False, {"status": "error"})
payload = {"user_id": str(user_id), "event_type": int(event_type)}
response = await send_handler.send_message_to_napcat(action="set_input_status", params=payload)
if response.get("status", "") == "ok":
return HandlerResult(True, True, response)
else:
logger.error("事件 napcat_set_input_status 请求失败!")
return HandlerResult(False, False, {"status": "error"})

View File

@@ -1816,3 +1816,27 @@ class NapcatEvent:
"""
class FILE(Enum): ...
class PERSONAL(Enum):
SET_INPUT_STATUS = "napcat_set_input_status"
"""
设置输入状态
Args:
user_id (Optional[str|int]): 用户id(必需)
event_type (Optional[int]): 输入状态id(必需)
raw (Optional[dict]): 原始请求体
Returns:
dict: {
"status": "ok",
"retcode": 0,
"data": {
"result": 0,
"errMsg": "string"
},
"message": "string",
"wording": "string",
"echo": "string"
}
"""

View File

@@ -7,8 +7,8 @@ from . import event_types, CONSTS, event_handlers
from typing import List
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField
from src.plugin_system.base.base_event import HandlerResult
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
@@ -18,9 +18,6 @@ from .src.recv_handler.meta_event_handler import meta_event_handler
from .src.recv_handler.notice_handler import notice_handler
from .src.recv_handler.message_sending import message_send_instance
from .src.send_handler import send_handler
from .src.config import global_config
from .src.config.features_config import features_manager
from .src.config.migrate_features import auto_migrate_features
from .src.mmc_com_layer import mmc_start_com, router, mmc_stop_com
from .src.response_pool import put_response, check_timeout_response
from .src.websocket_manager import websocket_manager
@@ -37,11 +34,14 @@ def get_classes_in_module(module):
classes.append(member)
return classes
async def message_recv(server_connection: Server.ServerConnection):
await message_handler.set_server_connection(server_connection)
asyncio.create_task(notice_handler.set_server_connection(server_connection))
await send_handler.set_server_connection(server_connection)
async for raw_message in server_connection:
# 只在debug模式下记录原始消息
if logger.level <= 10: # DEBUG level
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
decoded_raw_message: dict = json.loads(raw_message)
try:
@@ -76,6 +76,7 @@ async def message_recv(server_connection: Server.ServerConnection):
logger.error(f"处理消息时出错: {e}")
logger.debug(f"原始消息: {raw_message[:500]}...")
async def message_process():
"""消息处理主循环"""
logger.info("消息处理器已启动")
@@ -132,17 +133,20 @@ async def message_process():
except Exception as e:
logger.debug(f"清理消息队列时出错: {e}")
async def napcat_server():
async def napcat_server(plugin_config: dict):
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
mode = global_config.napcat_server.mode
# 使用插件系统配置API获取配置
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
logger.info(f"正在启动 adapter连接模式: {mode}")
try:
await websocket_manager.start_connection(message_recv)
await websocket_manager.start_connection(message_recv, plugin_config)
except Exception as e:
logger.error(f"启动 WebSocket 连接失败: {e}")
raise
async def graceful_shutdown():
"""优雅关闭所有组件"""
try:
@@ -154,11 +158,7 @@ async def graceful_shutdown():
except Exception as e:
logger.warning(f"停止消息重组器清理任务时出错: {e}")
# 停止功能管理器文件监控
try:
await features_manager.stop_file_watcher()
except Exception as e:
logger.warning(f"停止功能管理器文件监控时出错: {e}")
# 停止功能管理器文件监控(已迁移到插件系统配置,无需操作)
# 关闭消息处理器(包括消息缓冲器)
try:
@@ -189,10 +189,7 @@ async def graceful_shutdown():
# 等待任务取消完成,忽略 CancelledError
try:
await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=10
)
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.warning("部分任务取消超时")
except Exception as e:
@@ -214,6 +211,7 @@ async def graceful_shutdown():
except Exception:
pass
class LauchNapcatAdapterHandler(BaseEventHandler):
"""自动启动Adapter"""
@@ -224,27 +222,44 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
init_subscribe = [EventType.ON_START]
async def execute(self, kwargs):
# 执行功能配置迁移(如果需要)
logger.info("检查功能配置迁移...")
auto_migrate_features()
# 启动消息重组器的清理任务
logger.info("启动消息重组器...")
await reassembler.start_cleanup_task()
# 初始化功能管理器
logger.info("正在初始化功能管理器...")
features_manager.load_config()
await features_manager.start_file_watcher(check_interval=2.0)
logger.info("功能管理器初始化完成")
logger.info("开始启动Napcat Adapter")
message_send_instance.maibot_router = router
# 创建单独的异步任务,防止阻塞主线程
asyncio.create_task(napcat_server())
asyncio.create_task(mmc_start_com())
asyncio.create_task(self._start_maibot_connection())
asyncio.create_task(napcat_server(self.plugin_config))
asyncio.create_task(message_process())
asyncio.create_task(check_timeout_response())
async def _start_maibot_connection(self):
"""非阻塞方式启动MaiBot连接等待主服务启动后再连接"""
# 等待一段时间让MaiBot主服务完全启动
await asyncio.sleep(5)
max_attempts = 10
attempt = 0
while attempt < max_attempts:
try:
logger.info(f"尝试连接MaiBot (第{attempt + 1}次)")
await mmc_start_com(self.plugin_config)
message_send_instance.maibot_router = router
logger.info("MaiBot router连接已建立")
return
except Exception as e:
attempt += 1
if attempt >= max_attempts:
logger.error(f"MaiBot连接失败已达到最大重试次数: {e}")
return
else:
delay = min(2 + attempt, 10) # 逐渐增加延迟最大10秒
logger.warning(f"MaiBot连接失败: {e}{delay}秒后重试")
await asyncio.sleep(delay)
class StopNapcatAdapterHandler(BaseEventHandler):
"""关闭Adapter"""
@@ -262,11 +277,19 @@ class StopNapcatAdapterHandler(BaseEventHandler):
@register_plugin
class NapcatAdapterPlugin(BasePlugin):
plugin_name = CONSTS.PLUGIN_NAME
enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表
config_file_name: str = "config.toml" # 配置文件名
@property
def enable_plugin(self) -> bool:
"""通过配置文件动态控制插件启用状态"""
# 如果已经通过配置加载了状态,使用配置中的值
if hasattr(self, '_is_enabled'):
return self._is_enabled
# 否则使用默认值(禁用状态)
return False
# 配置节描述
config_section_descriptions = {"plugin": "插件基本信息"}
@@ -275,10 +298,84 @@ class NapcatAdapterPlugin(BasePlugin):
"plugin": {
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
"config_version": ConfigField(type=str, default="1.3.0", description="配置文件版本"),
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"inner": {
"version": ConfigField(type=str, default="0.2.1", description="配置版本号,请勿修改"),
},
"nickname": {
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
},
"napcat_server": {
"mode": ConfigField(type=str, default="reverse", description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
"host": ConfigField(type=str, default="localhost", description="主机地址"),
"port": ConfigField(type=int, default=8095, description="端口号"),
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)"),
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
},
"maibot_server": {
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"),
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"),
"platform_name": ConfigField(type=str, default="napcat", description="平台名称,用于消息路由"),
},
"voice": {
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"),
},
"slicing": {
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"),
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
},
"debug": {
"level": ConfigField(type=str, default="INFO", description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
},
"features": {
# 权限设置
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]),
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]),
"private_list": ConfigField(type=list, default=[], description="用户ID列表"),
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"),
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
# 聊天功能设置
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"),
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
# 视频处理设置
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制MB"),
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"),
# 消息缓冲设置
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"),
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"),
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"),
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"),
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "", ".", "", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"),
}
}
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息",
"inner": "内部配置信息(请勿修改)",
"nickname": "昵称配置(目前未使用)",
"napcat_server": "Napcat连接的ws服务设置",
"maibot_server": "连接麦麦的ws服务设置",
"voice": "发送语音设置",
"slicing": "WebSocket消息切片设置",
"debug": "调试设置",
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)"
}
def register_events(self):
# 注册事件
for e in event_types.NapcatEvent.ON_RECEIVED:
@@ -303,3 +400,21 @@ class NapcatAdapterPlugin(BasePlugin):
if issubclass(handler, BaseEventHandler):
components.append((handler.get_handler_info(), handler))
return components
async def on_plugin_loaded(self):
# 设置插件配置
message_send_instance.set_plugin_config(self.config)
# 设置chunker的插件配置
chunker.set_plugin_config(self.config)
# 设置response_pool的插件配置
from .src.response_pool import set_plugin_config as set_response_pool_config
set_response_pool_config(self.config)
# 设置send_handler的插件配置
send_handler.set_plugin_config(self.config)
# 设置message_handler的插件配置
message_handler.set_plugin_config(self.config)
# 设置notice_handler的插件配置
notice_handler.set_plugin_config(self.config)
# 设置meta_event_handler的插件配置
meta_event_handler.set_plugin_config(self.config)
# 设置其他handler的插件配置现在由component_registry在注册时自动设置

View File

@@ -7,7 +7,7 @@ from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from .config.features_config import features_manager
from src.plugin_system.apis import config_api
from .recv_handler import RealMessageType
@@ -43,6 +43,11 @@ class SimpleMessageBuffer:
self.lock = asyncio.Lock()
self.merge_callback = merge_callback
self._shutdown = False
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
def get_session_id(self, event_data: Dict[str, Any]) -> str:
"""根据事件数据生成会话ID"""
@@ -97,8 +102,7 @@ class SimpleMessageBuffer:
return True
# 检查屏蔽前缀
config = features_manager.get_config()
block_prefixes = tuple(config.message_buffer_block_prefixes)
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
text = text.strip()
if text.startswith(block_prefixes):
@@ -124,15 +128,15 @@ class SimpleMessageBuffer:
if self._shutdown:
return False
config = features_manager.get_config()
if not config.enable_message_buffer:
# 检查是否启用消息缓冲
if not config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", False):
return False
# 检查是否启用对应类型的缓冲
message_type = event_data.get("message_type", "")
if message_type == "group" and not config.message_buffer_enable_group:
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
return False
elif message_type == "private" and not config.message_buffer_enable_private:
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
return False
# 提取文本
@@ -154,8 +158,8 @@ class SimpleMessageBuffer:
session = self.buffer_pool[session_id]
# 检查是否超过最大组件数量
if len(session.messages) >= config.message_buffer_max_components:
logger.info(f"会话 {session_id} 消息数量达到上限,强制合并")
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
asyncio.create_task(self._force_merge_session(session_id))
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
session = self.buffer_pool[session_id]
@@ -187,8 +191,8 @@ class SimpleMessageBuffer:
async def _wait_and_start_merge(self, session_id: str):
"""等待初始延迟后开始合并定时器"""
config = features_manager.get_config()
await asyncio.sleep(config.message_buffer_initial_delay)
initial_delay = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_initial_delay", 0.5)
await asyncio.sleep(initial_delay)
async with self.lock:
session = self.buffer_pool.get(session_id)
@@ -206,8 +210,8 @@ class SimpleMessageBuffer:
async def _wait_and_merge(self, session_id: str):
"""等待合并间隔后执行合并"""
config = features_manager.get_config()
await asyncio.sleep(config.message_buffer_interval)
interval = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_interval", 2.0)
await asyncio.sleep(interval)
await self._merge_session(session_id)
async def _force_merge_session(self, session_id: str):
@@ -236,7 +240,7 @@ class SimpleMessageBuffer:
merged_text = "".join(text_parts) # 使用中文逗号连接
message_count = len(session.messages)
logger.info(f"合并会话 {session_id}{message_count} 条文本消息: {merged_text[:100]}...")
logger.debug(f"合并会话 {session_id}{message_count} 条文本消息: {merged_text[:100]}...")
# 调用回调函数
if self.merge_callback:
@@ -290,13 +294,13 @@ class SimpleMessageBuffer:
expired_sessions.append(session_id)
for session_id in expired_sessions:
logger.info(f"清理过期会话: {session_id}")
logger.debug(f"清理过期会话: {session_id}")
await self._force_merge_session(session_id)
async def shutdown(self):
"""关闭消息缓冲器"""
self._shutdown = True
logger.info("正在关闭简化消息缓冲器...")
logger.debug("正在关闭简化消息缓冲器...")
# 刷新所有缓冲区
await self.flush_all()
@@ -307,4 +311,4 @@ class SimpleMessageBuffer:
await self._cancel_session_timers(session)
self.buffer_pool.clear()
logger.info("简化消息缓冲器已关闭")
logger.debug("简化消息缓冲器已关闭")

View File

@@ -3,24 +3,32 @@
用于在 Ada 发送给 MMC 时进行消息切片利用 WebSocket 协议的自动重组特性
仅在 Ada -> MMC 方向进行切片其他方向MMC -> AdaAda <-> Napcat不切片
"""
import json
import uuid
import asyncio
import time
from typing import List, Dict, Any, Optional, Union
from .config import global_config
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
class MessageChunker:
"""消息切片器,用于处理大消息的分片发送"""
def __init__(self):
self.max_chunk_size = global_config.slicing.max_frame_size * 1024
self.max_chunk_size = 64 * 1024 # 默认值,将在设置配置时更新
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
if plugin_config:
max_frame_size = config_api.get_plugin_config(plugin_config, "slicing.max_frame_size", 64)
self.max_chunk_size = max_frame_size * 1024
def should_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
"""判断消息是否需要切片"""
@@ -29,12 +37,14 @@ class MessageChunker:
message_str = json.dumps(message, ensure_ascii=False)
else:
message_str = message
return len(message_str.encode('utf-8')) > self.max_chunk_size
return len(message_str.encode("utf-8")) > self.max_chunk_size
except Exception as e:
logger.error(f"检查消息大小时出错: {e}")
return False
def chunk_message(self, message: Union[str, Dict[str, Any]], chunk_id: Optional[str] = None) -> List[Dict[str, Any]]:
def chunk_message(
self, message: Union[str, Dict[str, Any]], chunk_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
将消息切片
@@ -62,7 +72,7 @@ class MessageChunker:
if chunk_id is None:
chunk_id = str(uuid.uuid4())
message_bytes = message_str.encode('utf-8')
message_bytes = message_str.encode("utf-8")
total_size = len(message_bytes)
# 计算需要多少个切片
@@ -83,10 +93,10 @@ class MessageChunker:
"total_chunks": num_chunks,
"chunk_size": len(chunk_data),
"total_size": total_size,
"timestamp": time.time()
"timestamp": time.time(),
},
"__mmc_chunk_data__": chunk_data.decode('utf-8', errors='ignore'),
"__mmc_is_chunked__": True
"__mmc_chunk_data__": chunk_data.decode("utf-8", errors="ignore"),
"__mmc_is_chunked__": True,
}
chunks.append(chunk_message)
@@ -111,10 +121,10 @@ class MessageChunker:
data = message
return (
isinstance(data, dict) and
"__mmc_chunk_info__" in data and
"__mmc_chunk_data__" in data and
"__mmc_is_chunked__" in data
isinstance(data, dict)
and "__mmc_chunk_info__" in data
and "__mmc_chunk_data__" in data
and "__mmc_is_chunked__" in data
)
except (json.JSONDecodeError, TypeError):
return False
@@ -152,7 +162,7 @@ class MessageReassembler:
expired_chunks = []
for chunk_id, buffer_info in self.chunk_buffers.items():
if current_time - buffer_info['timestamp'] > self.timeout:
if current_time - buffer_info["timestamp"] > self.timeout:
expired_chunks.append(chunk_id)
for chunk_id in expired_chunks:
@@ -207,7 +217,7 @@ class MessageReassembler:
"chunks": {},
"total_chunks": total_chunks,
"received_chunks": 0,
"timestamp": chunk_timestamp
"timestamp": chunk_timestamp,
}
buffer = self.chunk_buffers[chunk_id]
@@ -260,7 +270,7 @@ class MessageReassembler:
"received": buffer["received_chunks"],
"total": buffer["total_chunks"],
"progress": f"{buffer['received_chunks']}/{buffer['total_chunks']}",
"age_seconds": time.time() - buffer["timestamp"]
"age_seconds": time.time() - buffer["timestamp"],
}
return info

View File

@@ -0,0 +1,44 @@
from maim_message import Router, RouteConfig, TargetConfig
from src.common.logger import get_logger
from .send_handler import send_handler
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
router = None
def create_router(plugin_config: dict):
"""创建路由器实例"""
global router
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "napcat")
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
route_config = RouteConfig(
route_config={
platform_name: TargetConfig(
url=f"ws://{host}:{port}/ws",
token=None,
)
}
)
router = Router(route_config)
return router
async def mmc_start_com(plugin_config: dict = None):
"""启动MaiBot连接"""
logger.info("正在连接MaiBot")
if plugin_config:
create_router(plugin_config)
if router:
router.register_class_handler(send_handler.handle_message)
await router.run()
async def mmc_stop_com():
"""停止MaiBot连接"""
if router:
await router.stop()

View File

@@ -5,8 +5,7 @@ from ...CONSTS import PLUGIN_NAME
logger = get_logger("napcat_adapter")
from ..config import global_config
from ..config.features_config import features_manager
from src.plugin_system.apis import config_api
from ..message_buffer import SimpleMessageBuffer
from ..utils import (
get_group_info,
@@ -48,9 +47,17 @@ class MessageHandler:
def __init__(self):
self.server_connection: Server.ServerConnection = None
self.bot_id_list: Dict[int, bool] = {}
self.plugin_config = None
# 初始化简化消息缓冲器,传入回调函数
self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message)
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
# 将配置传递给消息缓冲器
if self.message_buffer:
self.message_buffer.set_plugin_config(plugin_config)
async def shutdown(self):
"""关闭消息处理器,清理资源"""
if self.message_buffer:
@@ -90,21 +97,41 @@ class MessageHandler:
# 使用新的权限管理器检查权限
if group_id:
if not features_manager.is_group_allowed(group_id):
logger.warning("群聊不在聊天权限范围内,消息被丢弃")
# 检查群聊黑白名单
group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist")
group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", [])
if group_list_type == "whitelist":
if group_id not in group_list:
logger.warning("群聊不在白名单中,消息被丢弃")
return False
else: # blacklist
if group_id in group_list:
logger.warning("群聊在黑名单中,消息被丢弃")
return False
else:
if not features_manager.is_private_allowed(user_id):
logger.warning("私聊不在聊天权限范围内,消息被丢弃")
# 检查私聊黑白名单
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist")
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
if private_list_type == "whitelist":
if user_id not in private_list:
logger.warning("私聊不在白名单中,消息被丢弃")
return False
else: # blacklist
if user_id in private_list:
logger.warning("私聊在黑名单中,消息被丢弃")
return False
# 检查全局禁止名单
if not ignore_global_list and features_manager.is_user_banned(user_id):
ban_user_id = config_api.get_plugin_config(self.plugin_config, "features.ban_user_id", [])
if not ignore_global_list and user_id in ban_user_id:
logger.warning("用户在全局黑名单中,消息被丢弃")
return False
# 检查QQ官方机器人
if features_manager.is_qq_bot_banned() and group_id and not ignore_bot:
ban_qq_bot = config_api.get_plugin_config(self.plugin_config, "features.ban_qq_bot", False)
if ban_qq_bot and group_id and not ignore_bot:
logger.debug("开始判断是否为机器人")
member_info = await get_member_info(self.get_server_connection(), group_id, user_id)
if member_info:
@@ -129,6 +156,21 @@ class MessageHandler:
Parameters:
raw_message: dict: 原始消息
"""
# 添加原始消息调试日志特别关注message字段
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}")
logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
# 检查是否包含@或video消息段
message_segments = raw_message.get('message', [])
if message_segments:
for i, seg in enumerate(message_segments):
seg_type = seg.get('type')
if seg_type in ['at', 'video']:
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
elif seg_type not in ['text', 'face', 'image']:
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
message_type: str = raw_message.get("message_type")
message_id: int = raw_message.get("message_id")
# message_time: int = raw_message.get("time")
@@ -149,7 +191,7 @@ class MessageHandler:
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
@@ -175,7 +217,7 @@ class MessageHandler:
nickname = fetched_member_info.get("nickname") if fetched_member_info else None
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
user_id=sender_info.get("user_id"),
user_nickname=nickname,
user_cardname=None,
@@ -192,7 +234,7 @@ class MessageHandler:
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
group_id=raw_message.get("group_id"),
group_name=group_name,
)
@@ -210,7 +252,7 @@ class MessageHandler:
# 发送者用户信息
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
user_id=sender_info.get("user_id"),
user_nickname=sender_info.get("nickname"),
user_cardname=sender_info.get("card"),
@@ -223,7 +265,7 @@ class MessageHandler:
group_name = fetched_group_info.get("group_name")
group_info: GroupInfo = GroupInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
group_id=raw_message.get("group_id"),
group_name=group_name,
)
@@ -233,12 +275,12 @@ class MessageHandler:
return None
additional_config: dict = {}
if global_config.voice.use_tts:
if config_api.get_plugin_config(self.plugin_config, "voice.use_tts"):
additional_config["allow_tts"] = True
# 消息信息
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name"),
message_id=message_id,
time=message_time,
user_info=user_info,
@@ -260,19 +302,19 @@ class MessageHandler:
return None
# 检查是否需要使用消息缓冲
if features_manager.is_message_buffer_enabled():
enable_message_buffer = config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", True)
if enable_message_buffer:
# 检查消息类型是否启用缓冲
message_type = raw_message.get("message_type")
should_use_buffer = False
if message_type == "group" and features_manager.is_message_buffer_group_enabled():
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True):
should_use_buffer = True
elif message_type == "private" and features_manager.is_message_buffer_private_enabled():
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True):
should_use_buffer = True
if should_use_buffer:
logger.debug(f"尝试缓冲消息,消息类型: {message_type}, 用户: {user_info.user_id}")
logger.debug(f"原始消息段: {raw_message.get('message', [])}")
# 尝试添加到缓冲器
buffered = await self.message_buffer.add_text_message(
@@ -286,10 +328,10 @@ class MessageHandler:
)
if buffered:
logger.info(f"✅ 文本消息已成功缓冲: {user_info.user_id}")
logger.debug(f"✅ 文本消息已成功缓冲: {user_info.user_id}")
return None # 缓冲成功,不立即发送
# 如果缓冲失败(消息包含非文本元素),走正常处理流程
logger.info(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}")
logger.debug(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}")
# 缓冲失败时继续执行后面的正常处理流程,不要直接返回
logger.debug(f"准备发送消息到MaiBot消息段数量: {len(seg_message)}")
@@ -307,7 +349,7 @@ class MessageHandler:
raw_message=raw_message.get("raw_message"),
)
logger.info("发送到Maibot处理信息")
logger.debug("发送到Maibot处理信息")
await message_send_instance.message_send(message_base)
async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None:
@@ -326,6 +368,18 @@ class MessageHandler:
for sub_message in real_message:
sub_message: dict
sub_message_type = sub_message.get("type")
# 添加详细的消息类型调试信息
logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}")
# 特别关注 at 和 video 消息的识别
if sub_message_type == "at":
logger.debug(f"检测到@消息: {sub_message}")
elif sub_message_type == "video":
logger.debug(f"检测到VIDEO消息: {sub_message}")
elif sub_message_type not in ["text", "face", "image", "record"]:
logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}")
match sub_message_type:
case RealMessageType.text:
ret_seg = await self.handle_text_message(sub_message)
@@ -379,6 +433,7 @@ class MessageHandler:
else:
logger.warning("record处理失败或不支持")
case RealMessageType.video:
logger.debug(f"开始处理VIDEO消息段: {sub_message}")
ret_seg = await self.handle_video_message(sub_message)
if ret_seg:
await event_manager.trigger_event(
@@ -386,8 +441,9 @@ class MessageHandler:
)
seg_message.append(ret_seg)
else:
logger.warning("video处理失败")
logger.warning(f"video处理失败,原始消息: {sub_message}")
case RealMessageType.at:
logger.debug(f"开始处理AT消息段: {sub_message}")
ret_seg = await self.handle_at_message(
sub_message,
raw_message.get("self_id"),
@@ -399,7 +455,7 @@ class MessageHandler:
)
seg_message.append(ret_seg)
else:
logger.warning("at处理失败")
logger.warning(f"at处理失败,原始消息: {sub_message}")
case RealMessageType.rps:
ret_seg = await self.handle_rps_message(sub_message)
if ret_seg:
@@ -502,9 +558,7 @@ class MessageHandler:
message_data: dict = raw_message.get("data")
image_sub_type = message_data.get("sub_type")
try:
logger.debug(f"开始下载图片: {message_data.get('url')}")
image_base64 = await get_image_base64(message_data.get("url"))
logger.debug(f"图片下载成功,大小: {len(image_base64)} 字符")
except Exception as e:
logger.error(f"图片消息处理失败: {str(e)}")
return None
@@ -595,8 +649,8 @@ class MessageHandler:
video_url = message_data.get("url")
file_path = message_data.get("filePath") or message_data.get("file_path")
logger.info(f"视频URL: {video_url}")
logger.info(f"视频文件路径: {file_path}")
logger.debug(f"视频URL: {video_url}")
logger.debug(f"视频文件路径: {file_path}")
# 优先使用本地文件路径其次使用URL
video_source = file_path if file_path else video_url
@@ -609,14 +663,14 @@ class MessageHandler:
try:
# 检查是否为本地文件路径
if file_path and Path(file_path).exists():
logger.info(f"使用本地视频文件: {file_path}")
logger.debug(f"使用本地视频文件: {file_path}")
# 直接读取本地文件
with open(file_path, "rb") as f:
video_data = f.read()
# 将视频数据编码为base64用于传输
video_base64 = base64.b64encode(video_data).decode("utf-8")
logger.info(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
logger.debug(f"视频文件大小: {len(video_data) / (1024 * 1024):.2f} MB")
# 返回包含详细信息的字典格式
return Seg(
@@ -629,7 +683,7 @@ class MessageHandler:
)
elif video_url:
logger.info(f"使用视频URL下载: {video_url}")
logger.debug(f"使用视频URL下载: {video_url}")
# 使用video_handler下载视频
video_downloader = get_video_downloader()
download_result = await video_downloader.download_video(video_url)
@@ -641,7 +695,7 @@ class MessageHandler:
# 将视频数据编码为base64用于传输
video_base64 = base64.b64encode(download_result["data"]).decode("utf-8")
logger.info(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
logger.debug(f"视频下载成功,大小: {len(download_result['data']) / (1024 * 1024):.2f} MB")
# 返回包含详细信息的字典格式
return Seg(
@@ -710,15 +764,15 @@ class MessageHandler:
processed_message: Seg
if image_count < 5 and image_count > 0:
# 处理图片数量小于5的情况此时解析图片为base64
logger.info("图片数量小于5开始解析图片为base64")
logger.debug("图片数量小于5开始解析图片为base64")
processed_message = await self._recursive_parse_image_seg(handled_message, True)
elif image_count > 0:
logger.info("图片数量大于等于5开始解析图片为占位符")
logger.debug("图片数量大于等于5开始解析图片为占位符")
# 处理图片数量大于等于5的情况此时解析图片为占位符
processed_message = await self._recursive_parse_image_seg(handled_message, False)
else:
# 处理没有图片的情况,此时直接返回
logger.info("没有图片,直接返回")
logger.debug("没有图片,直接返回")
processed_message = handled_message
# 添加转发消息提示
@@ -800,7 +854,10 @@ class MessageHandler:
content_parts.append(f"链接: {extracted_info['short_url']}")
formatted_content = "\n".join(content_parts)
return Seg(type="text", data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}")
return Seg(
type="text",
data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}",
)
# 如果没有提取到关键信息返回None
return None
@@ -849,7 +906,7 @@ class MessageHandler:
return Seg(type="text", data="[表情包]")
return Seg(type="emoji", data=encoded_image)
else:
logger.info(f"不处理类型: {seg_data.type}")
logger.debug(f"不处理类型: {seg_data.type}")
return seg_data
else:
if seg_data.type == "seglist":
@@ -863,7 +920,7 @@ class MessageHandler:
elif seg_data.type == "emoji":
return Seg(type="text", data="[动画表情]")
else:
logger.info(f"不处理类型: {seg_data.type}")
logger.debug(f"不处理类型: {seg_data.type}")
return seg_data
async def _handle_forward_message(self, message_list: list, layer: int) -> Tuple[Seg, int] | Tuple[None, int]:
@@ -1038,7 +1095,7 @@ class MessageHandler:
raw_message=raw_message.get("raw_message", ""),
)
logger.info(f"发送缓冲合并消息到Maibot处理: {session_id}")
logger.debug(f"发送缓冲合并消息到Maibot处理: {session_id}")
await message_send_instance.message_send(message_base)
except Exception as e:

View File

@@ -0,0 +1,117 @@
import asyncio
from src.common.logger import get_logger
from ..message_chunker import chunker
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
from maim_message import MessageBase, Router
class MessageSending:
"""
负责把消息发送到麦麦
"""
maibot_router: Router = None
plugin_config = None
_connection_retries = 0
_max_retries = 3
def __init__(self):
pass
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def _attempt_reconnect(self):
"""尝试重新连接MaiBot router"""
if self._connection_retries < self._max_retries:
self._connection_retries += 1
logger.warning(f"尝试重新连接MaiBot router (第{self._connection_retries}次)")
try:
# 重新导入router
from ..mmc_com_layer import router
self.maibot_router = router
if self.maibot_router is not None:
logger.info("MaiBot router重连成功")
self._connection_retries = 0 # 重置重试计数
return True
except Exception as e:
logger.error(f"重连失败: {e}")
else:
logger.error(f"已达到最大重连次数({self._max_retries}),停止重试")
return False
async def message_send(self, message_base: MessageBase) -> bool:
"""
发送消息Ada -> MMC 方向,需要实现切片)
Parameters:
message_base: MessageBase: 消息基类,包含发送目标和消息内容等信息
"""
try:
# 检查maibot_router是否已初始化
if self.maibot_router is None:
logger.warning("MaiBot router未初始化尝试重新连接")
if not await self._attempt_reconnect():
logger.error("MaiBot router重连失败无法发送消息")
logger.error("请检查与MaiBot之间的连接")
return False
# 检查是否需要切片发送
message_dict = message_base.to_dict()
if chunker.should_chunk_message(message_dict):
logger.info("消息过大,进行切片发送到 MaiBot")
# 切片消息
chunks = chunker.chunk_message(message_dict)
# 逐个发送切片
for i, chunk in enumerate(chunks):
logger.debug(f"发送切片 {i + 1}/{len(chunks)} 到 MaiBot")
# 获取对应的客户端并发送切片
platform = message_base.message_info.platform
# 再次检查router状态防止运行时被重置
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'):
logger.warning("MaiBot router连接已断开尝试重新连接")
if not await self._attempt_reconnect():
logger.error("MaiBot router重连失败切片发送中止")
return False
if platform not in self.maibot_router.clients:
logger.error(f"平台 {platform} 未连接")
return False
client = self.maibot_router.clients[platform]
send_status = await client.send_message(chunk)
if not send_status:
logger.error(f"发送切片 {i + 1}/{len(chunks)} 失败")
return False
# 使用配置中的延迟时间
if i < len(chunks) - 1 and self.plugin_config:
delay_ms = config_api.get_plugin_config(self.plugin_config, "slicing.delay_ms", 10)
delay_seconds = delay_ms / 1000.0
logger.debug(f"切片发送延迟: {delay_ms}毫秒")
await asyncio.sleep(delay_seconds)
logger.debug("所有切片发送完成")
return True
else:
# 直接发送小消息
send_status = await self.maibot_router.send_message(message_base)
if not send_status:
raise RuntimeError("可能是路由未正确配置或连接异常")
return send_status
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
logger.error("请检查与MaiBot之间的连接")
return False
message_send_instance = MessageSending()

View File

@@ -1,7 +1,7 @@
from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from ..config import global_config
from src.plugin_system.apis import config_api
import time
import asyncio
@@ -14,8 +14,15 @@ class MetaEventHandler:
"""
def __init__(self):
self.interval = global_config.napcat_server.heartbeat_interval
self.interval = 5.0 # 默认值稍后通过set_plugin_config设置
self._interval_checking = False
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
# 更新interval值
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type")

View File

@@ -8,8 +8,7 @@ from src.common.logger import get_logger
logger = get_logger("napcat_adapter")
from ..config import global_config
from ..config.features_config import features_manager
from src.plugin_system.apis import config_api
from ..database import BanUser, db_manager, is_identical
from . import NoticeType, ACCEPT_FORMAT
from .message_sending import message_send_instance
@@ -38,6 +37,11 @@ class NoticeHandler:
def __init__(self):
self.server_connection: Server.ServerConnection | None = None
self.last_poke_time: float = 0.0 # 记录最后一次针对机器人的戳一戳时间
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
@@ -112,10 +116,10 @@ class NoticeHandler:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
if features_manager.is_poke_enabled() and await message_handler.check_allow_to_chat(
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat(
user_id, group_id, False, False
):
logger.info("处理戳一戳消息")
logger.debug("处理戳一戳消息")
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
else:
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
@@ -159,13 +163,13 @@ class NoticeHandler:
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
message_id="notice",
time=message_time,
user_info=user_info,
@@ -187,7 +191,7 @@ class NoticeHandler:
if system_notice:
await self.put_notice(message_base)
else:
logger.info("发送到Maibot处理通知信息")
logger.debug("发送到Maibot处理通知信息")
await message_send_instance.message_send(message_base)
async def handle_poke_notify(
@@ -206,12 +210,12 @@ class NoticeHandler:
# 防抖检查:如果是针对机器人的戳一戳,检查防抖时间
if self_id == target_id:
current_time = time.time()
debounce_seconds = features_manager.get_config().poke_debounce_seconds
debounce_seconds = config_api.get_plugin_config(self.plugin_config, "features.poke_debounce_seconds", 2.0)
if self.last_poke_time > 0:
time_diff = current_time - self.last_poke_time
if time_diff < debounce_seconds:
logger.info(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)")
logger.debug(f"戳一戳防抖:用户 {user_id} 的戳一戳被忽略(距离上次戳一戳 {time_diff:.2f} 秒)")
return None, None
# 记录这次戳一戳的时间
@@ -230,7 +234,7 @@ class NoticeHandler:
else:
user_name = "QQ用户"
user_cardname = "QQ用户"
logger.info("无法获取戳一戳对方的用户昵称")
logger.debug("无法获取戳一戳对方的用户昵称")
# 计算Seg
if self_id == target_id:
@@ -243,8 +247,8 @@ class NoticeHandler:
else:
# 如果配置为忽略不是针对自己的戳一戳则直接返回None
if features_manager.is_non_self_poke_ignored():
logger.info("忽略不是针对自己的戳一戳消息")
if config_api.get_plugin_config(self.plugin_config, "features.ignore_non_self_poke", False):
logger.debug("忽略不是针对自己的戳一戳消息")
return None, None
# 老实说这一步判定没啥意义,毕竟私聊是没有其他人之间的戳一戳,但是感觉可以有这个判定来强限制群聊环境
@@ -254,7 +258,7 @@ class NoticeHandler:
target_name = fetched_member_info.get("nickname")
else:
target_name = "QQ用户"
logger.info("无法获取被戳一戳方的用户昵称")
logger.debug("无法获取被戳一戳方的用户昵称")
display_name = user_name
else:
return None, None
@@ -268,7 +272,7 @@ class NoticeHandler:
logger.warning(f"解析戳一戳消息失败: {str(e)},将使用默认文本")
user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_name,
user_cardname=user_cardname,
@@ -299,7 +303,7 @@ class NoticeHandler:
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
@@ -328,7 +332,7 @@ class NoticeHandler:
user_nickname = fetched_member_info.get("nickname")
user_cardname = fetched_member_info.get("card")
banned_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
@@ -367,7 +371,7 @@ class NoticeHandler:
operator_nickname = "QQ用户"
operator_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=operator_id,
user_nickname=operator_nickname,
user_cardname=operator_cardname,
@@ -393,7 +397,7 @@ class NoticeHandler:
else:
logger.warning("无法获取解除禁言消息发送者的昵称,消息可能会无效")
lifted_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
@@ -436,13 +440,13 @@ class NoticeHandler:
else:
logger.warning("无法获取notice消息所在群的名称")
group_info = GroupInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
group_id=group_id,
group_name=group_name,
)
message_info: BaseMessageInfo = BaseMessageInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
message_id="notice",
time=time.time(),
user_info=None, # 自然解除禁言没有操作者
@@ -493,7 +497,7 @@ class NoticeHandler:
user_cardname = fetched_member_info.get("card")
lifted_user_info: UserInfo = UserInfo(
platform=global_config.maibot_server.platform_name,
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
@@ -517,7 +521,7 @@ class NoticeHandler:
continue
if ban_record.lift_time <= int(time.time()):
# 触发自然解除禁言
logger.info(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
logger.debug(f"检测到用户 {ban_record.user_id} 在群 {ban_record.group_id} 的禁言已解除")
self.lifted_list.append(ban_record)
self.banned_list.remove(ban_record)
await asyncio.sleep(5)

View File

@@ -1,19 +1,26 @@
import asyncio
import time
from typing import Dict
from .config import global_config
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
response_dict: Dict = {}
response_time_dict: Dict = {}
plugin_config = None
def set_plugin_config(config: dict):
"""设置插件配置"""
global plugin_config
plugin_config = config
async def get_response(request_id: str, timeout: int = 10) -> dict:
response = await asyncio.wait_for(_get_response(request_id), timeout)
_ = response_time_dict.pop(request_id)
logger.info(f"响应信息id: {request_id} 已从响应字典中取出")
logger.debug(f"响应信息id: {request_id} 已从响应字典中取出")
return response
@@ -31,18 +38,25 @@ async def put_response(response: dict):
now_time = time.time()
response_dict[echo_id] = response
response_time_dict[echo_id] = now_time
logger.info(f"响应信息id: {echo_id} 已存入响应字典")
logger.debug(f"响应信息id: {echo_id} 已存入响应字典")
async def check_timeout_response() -> None:
while True:
cleaned_message_count: int = 0
now_time = time.time()
# 获取心跳间隔配置
heartbeat_interval = 30 # 默认值
if plugin_config:
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
for echo_id, response_time in list(response_time_dict.items()):
if now_time - response_time > global_config.napcat_server.heartbeat_interval:
if now_time - response_time > heartbeat_interval:
cleaned_message_count += 1
response_dict.pop(echo_id)
response_time_dict.pop(echo_id)
logger.warning(f"响应消息 {echo_id} 超时,已删除")
if cleaned_message_count > 0:
logger.info(f"已删除 {cleaned_message_count} 条超时响应消息")
await asyncio.sleep(global_config.napcat_server.heartbeat_interval)
await asyncio.sleep(heartbeat_interval)

View File

@@ -12,9 +12,9 @@ from maim_message import (
MessageBase,
)
from typing import Dict, Any, Tuple, Optional
from src.plugin_system.apis import config_api
from . import CommandType
from .config import global_config
from .response_pool import get_response
from src.common.logger import get_logger
@@ -22,12 +22,16 @@ logger = get_logger("napcat_adapter")
from .utils import get_image_format, convert_image_to_gif
from .recv_handler.message_sending import message_send_instance
from .websocket_manager import websocket_manager
from .config.features_config import features_manager
class SendHandler:
def __init__(self):
self.server_connection: Optional[Server.ServerConnection] = None
self.plugin_config = None
def set_plugin_config(self, plugin_config: dict):
"""设置插件配置"""
self.plugin_config = plugin_config
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
"""设置Napcat连接"""
@@ -287,11 +291,8 @@ class SendHandler:
"""处理回复消息"""
reply_seg = {"type": "reply", "data": {"id": id}}
# 获取功能配置
ft_config = features_manager.get_config()
# 检查是否启用引用艾特功能
if not ft_config.enable_reply_at:
if not config_api.get_plugin_config(self.plugin_config, "features.enable_reply_at", False):
return reply_seg
try:
@@ -310,7 +311,7 @@ class SendHandler:
return reply_seg
# 根据概率决定是否艾特用户
if random.random() < ft_config.reply_at_rate:
if random.random() < config_api.get_plugin_config(self.plugin_config, "features.reply_at_rate", 0.5):
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
# 在艾特后面添加一个空格
text_seg = {"type": "text", "data": {"text": " "}}
@@ -354,7 +355,11 @@ class SendHandler:
def handle_voice_message(self, encoded_voice: str) -> dict:
"""处理语音消息"""
if not global_config.voice.use_tts:
use_tts = False
if self.plugin_config:
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
if not use_tts:
logger.warning("未启用语音消息处理")
return {}
if not encoded_voice:

View File

@@ -2,9 +2,9 @@ import asyncio
import websockets as Server
from typing import Optional, Callable, Any
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
logger = get_logger("napcat_adapter")
from .config import global_config
class WebSocketManager:
@@ -16,10 +16,12 @@ class WebSocketManager:
self.is_running = False
self.reconnect_interval = 5 # 重连间隔(秒)
self.max_reconnect_attempts = 10 # 最大重连次数
self.plugin_config = None
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None:
"""根据配置启动 WebSocket 连接"""
mode = global_config.napcat_server.mode
self.plugin_config = plugin_config
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
if mode == "reverse":
await self._start_reverse_connection(message_handler)
@@ -30,8 +32,8 @@ class WebSocketManager:
async def _start_reverse_connection(self, message_handler: Callable[[Server.ServerConnection], Any]) -> None:
"""启动反向连接(作为服务器)"""
host = global_config.napcat_server.host
port = global_config.napcat_server.port
host = config_api.get_plugin_config(self.plugin_config, "napcat_server.host")
port = config_api.get_plugin_config(self.plugin_config, "napcat_server.port")
logger.info(f"正在启动反向连接模式,监听地址: ws://{host}:{port}")
@@ -68,9 +70,10 @@ class WebSocketManager:
connect_kwargs = {"max_size": 2**26}
# 如果配置了访问令牌,添加到请求头
if global_config.napcat_server.access_token:
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
if access_token:
connect_kwargs["additional_headers"] = {
"Authorization": f"Bearer {global_config.napcat_server.access_token}"
"Authorization": f"Bearer {access_token}"
}
logger.info("已添加访问令牌到连接请求头")
@@ -112,15 +115,14 @@ class WebSocketManager:
def _get_forward_url(self) -> str:
"""获取正向连接的 URL"""
config = global_config.napcat_server
# 如果配置了完整的 URL直接使用
if config.url:
return config.url
url = config_api.get_plugin_config(self.plugin_config, "napcat_server.url")
if url:
return url
# 否则根据 host 和 port 构建 URL
host = config.host
port = config.port
host = config_api.get_plugin_config(self.plugin_config, "napcat_server.host")
port = config_api.get_plugin_config(self.plugin_config, "napcat_server.port")
return f"ws://{host}:{port}"
async def stop_connection(self) -> None:

Some files were not shown because too many files have changed in this diff Show More