18
changes.md
Normal file
18
changes.md
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# 插件API与规范修改
|
||||||
|
|
||||||
|
1. 现在`plugin_system`的`__init__.py`文件中包含了所有插件API的导入,用户可以直接使用`from plugin_system import *`来导入所有API。
|
||||||
|
|
||||||
|
2. register_plugin函数现在转移到了`plugin_system.apis.plugin_register_api`模块中,用户可以通过`from plugin_system.apis.plugin_register_api import register_plugin`来导入。
|
||||||
|
|
||||||
|
3. 现在强制要求的property如下:
|
||||||
|
- `plugin_name`: 插件名称,必须是唯一的。(与文件夹相同)
|
||||||
|
- `enable_plugin`: 是否启用插件,默认为`True`。
|
||||||
|
- `dependencies`: 插件依赖的其他插件列表,默认为空。**现在并不检查(也许)**
|
||||||
|
- `python_dependencies`: 插件依赖的Python包列表,默认为空。**现在并不检查**
|
||||||
|
- `config_file_name`: 插件配置文件名,默认为`config.toml`。
|
||||||
|
- `config_schema`: 插件配置文件的schema,用于自动生成配置文件。
|
||||||
|
|
||||||
|
# 插件系统修改
|
||||||
|
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
|
||||||
|
2. 修复了一下显示插件信息不显示的问题。同时精简了一下显示内容
|
||||||
|
3. 修复了插件系统混用了`plugin_name`和`display_name`的问题。现在所有的插件信息都使用`display_name`来显示,而内部标识仍然使用`plugin_name`。**(可能有遗漏)**
|
||||||
@@ -103,6 +103,8 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name = "hello_world_plugin" # 内部标识符
|
plugin_name = "hello_world_plugin" # 内部标识符
|
||||||
enable_plugin = True
|
enable_plugin = True
|
||||||
|
dependencies = [] # 插件依赖列表
|
||||||
|
python_dependencies = [] # Python包依赖列表
|
||||||
config_file_name = "config.toml" # 配置文件名
|
config_file_name = "config.toml" # 配置文件名
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
|
|||||||
@@ -36,11 +36,12 @@ import urllib.error
|
|||||||
import base64
|
import base64
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
from src.plugin_system.base.base_command import BaseCommand
|
||||||
from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode
|
from src.plugin_system.base.component_types import ComponentInfo, ActionActivationType, ChatMode
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
|
from src.plugin_system import register_plugin
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("take_picture_plugin")
|
logger = get_logger("take_picture_plugin")
|
||||||
@@ -442,6 +443,8 @@ class TakePicturePlugin(BasePlugin):
|
|||||||
|
|
||||||
plugin_name = "take_picture_plugin" # 内部标识符
|
plugin_name = "take_picture_plugin" # 内部标识符
|
||||||
enable_plugin = True
|
enable_plugin = True
|
||||||
|
dependencies = [] # 插件依赖列表
|
||||||
|
python_dependencies = [] # Python包依赖列表
|
||||||
config_file_name = "config.toml"
|
config_file_name = "config.toml"
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
from src.chat.heart_flow.heartflow import heartflow
|
|
||||||
from src.chat.heart_flow.sub_heartflow import ChatState
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("api")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all_subheartflow_ids() -> list:
|
|
||||||
"""获取所有子心流的ID列表"""
|
|
||||||
all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows()
|
|
||||||
return [subheartflow.subheartflow_id for subheartflow in all_subheartflows]
|
|
||||||
|
|
||||||
|
|
||||||
async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool:
|
|
||||||
"""强制改变子心流的状态"""
|
|
||||||
subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id)
|
|
||||||
if subheartflow:
|
|
||||||
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all_states():
|
|
||||||
"""获取所有状态"""
|
|
||||||
all_states = await heartflow.api_get_all_states()
|
|
||||||
logger.debug(f"所有状态: {all_states}")
|
|
||||||
return all_states
|
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
import platform
|
|
||||||
import psutil
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_system_info():
|
|
||||||
"""获取操作系统信息"""
|
|
||||||
return {
|
|
||||||
"system": platform.system(),
|
|
||||||
"release": platform.release(),
|
|
||||||
"version": platform.version(),
|
|
||||||
"machine": platform.machine(),
|
|
||||||
"processor": platform.processor(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_python_version():
|
|
||||||
"""获取 Python 版本信息"""
|
|
||||||
return sys.version
|
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_usage():
|
|
||||||
"""获取系统总CPU使用率"""
|
|
||||||
return psutil.cpu_percent(interval=1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_process_cpu_usage():
|
|
||||||
"""获取当前进程CPU使用率"""
|
|
||||||
process = psutil.Process(os.getpid())
|
|
||||||
return process.cpu_percent(interval=1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
|
||||||
"""获取系统内存使用情况 (单位 MB)"""
|
|
||||||
mem = psutil.virtual_memory()
|
|
||||||
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
|
|
||||||
return {
|
|
||||||
"total_mb": bytes_to_mb(mem.total),
|
|
||||||
"available_mb": bytes_to_mb(mem.available),
|
|
||||||
"percent": mem.percent,
|
|
||||||
"used_mb": bytes_to_mb(mem.used),
|
|
||||||
"free_mb": bytes_to_mb(mem.free),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_process_memory_usage():
|
|
||||||
"""获取当前进程内存使用情况 (单位 MB)"""
|
|
||||||
process = psutil.Process(os.getpid())
|
|
||||||
mem_info = process.memory_info()
|
|
||||||
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
|
|
||||||
return {
|
|
||||||
"rss_mb": bytes_to_mb(mem_info.rss), # Resident Set Size: 实际使用物理内存
|
|
||||||
"vms_mb": bytes_to_mb(mem_info.vms), # Virtual Memory Size: 虚拟内存大小
|
|
||||||
"percent": process.memory_percent(), # 进程内存使用百分比
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_disk_usage(path="/"):
|
|
||||||
"""获取指定路径磁盘使用情况 (单位 GB)"""
|
|
||||||
disk = psutil.disk_usage(path)
|
|
||||||
bytes_to_gb = lambda x: round(x / (1024 * 1024 * 1024), 2) # noqa
|
|
||||||
return {
|
|
||||||
"total_gb": bytes_to_gb(disk.total),
|
|
||||||
"used_gb": bytes_to_gb(disk.used),
|
|
||||||
"free_gb": bytes_to_gb(disk.free),
|
|
||||||
"percent": disk.percent,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_basic_info():
|
|
||||||
"""获取所有基本信息并封装返回"""
|
|
||||||
# 对于进程CPU使用率,需要先初始化
|
|
||||||
process = psutil.Process(os.getpid())
|
|
||||||
process.cpu_percent(interval=None) # 初始化调用
|
|
||||||
process_cpu = process.cpu_percent(interval=0.1) # 短暂间隔获取
|
|
||||||
|
|
||||||
return {
|
|
||||||
"system_info": get_system_info(),
|
|
||||||
"python_version": get_python_version(),
|
|
||||||
"cpu_usage_percent": get_cpu_usage(),
|
|
||||||
"process_cpu_usage_percent": process_cpu,
|
|
||||||
"memory_usage": get_memory_usage(),
|
|
||||||
"process_memory_usage": get_process_memory_usage(),
|
|
||||||
"disk_usage_root": get_disk_usage("/"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_basic_info_string() -> str:
|
|
||||||
"""获取所有基本信息并以带解释的字符串形式返回"""
|
|
||||||
info = get_all_basic_info()
|
|
||||||
|
|
||||||
sys_info = info["system_info"]
|
|
||||||
mem_usage = info["memory_usage"]
|
|
||||||
proc_mem_usage = info["process_memory_usage"]
|
|
||||||
disk_usage = info["disk_usage_root"]
|
|
||||||
|
|
||||||
# 对进程内存使用百分比进行格式化,保留两位小数
|
|
||||||
proc_mem_percent = round(proc_mem_usage["percent"], 2)
|
|
||||||
|
|
||||||
output_string = f"""[系统信息]
|
|
||||||
- 操作系统: {sys_info["system"]} (例如: Windows, Linux)
|
|
||||||
- 发行版本: {sys_info["release"]} (例如: 11, Ubuntu 20.04)
|
|
||||||
- 详细版本: {sys_info["version"]}
|
|
||||||
- 硬件架构: {sys_info["machine"]} (例如: AMD64)
|
|
||||||
- 处理器信息: {sys_info["processor"]}
|
|
||||||
|
|
||||||
[Python 环境]
|
|
||||||
- Python 版本: {info["python_version"]}
|
|
||||||
|
|
||||||
[CPU 状态]
|
|
||||||
- 系统总 CPU 使用率: {info["cpu_usage_percent"]}%
|
|
||||||
- 当前进程 CPU 使用率: {info["process_cpu_usage_percent"]}%
|
|
||||||
|
|
||||||
[系统内存使用情况]
|
|
||||||
- 总物理内存: {mem_usage["total_mb"]} MB
|
|
||||||
- 可用物理内存: {mem_usage["available_mb"]} MB
|
|
||||||
- 物理内存使用率: {mem_usage["percent"]}%
|
|
||||||
- 已用物理内存: {mem_usage["used_mb"]} MB
|
|
||||||
- 空闲物理内存: {mem_usage["free_mb"]} MB
|
|
||||||
|
|
||||||
[当前进程内存使用情况]
|
|
||||||
- 实际使用物理内存 (RSS): {proc_mem_usage["rss_mb"]} MB
|
|
||||||
- 占用虚拟内存 (VMS): {proc_mem_usage["vms_mb"]} MB
|
|
||||||
- 进程内存使用率: {proc_mem_percent}%
|
|
||||||
|
|
||||||
[磁盘使用情况 (根目录)]
|
|
||||||
- 总空间: {disk_usage["total_gb"]} GB
|
|
||||||
- 已用空间: {disk_usage["used_gb"]} GB
|
|
||||||
- 可用空间: {disk_usage["free_gb"]} GB
|
|
||||||
- 磁盘使用率: {disk_usage["percent"]}%
|
|
||||||
"""
|
|
||||||
return output_string
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(f"System Info: {get_system_info()}")
|
|
||||||
print(f"Python Version: {get_python_version()}")
|
|
||||||
print(f"CPU Usage: {get_cpu_usage()}%")
|
|
||||||
# 第一次调用 process.cpu_percent() 会返回0.0或一个无意义的值,需要间隔一段时间再调用
|
|
||||||
# 或者在初始化Process对象后,先调用一次cpu_percent(interval=None),然后再调用cpu_percent(interval=1)
|
|
||||||
current_process = psutil.Process(os.getpid())
|
|
||||||
current_process.cpu_percent(interval=None) # 初始化
|
|
||||||
print(f"Process CPU Usage: {current_process.cpu_percent(interval=1)}%") # 实际获取
|
|
||||||
|
|
||||||
memory_usage_info = get_memory_usage()
|
|
||||||
print(
|
|
||||||
f"Memory Usage: Total={memory_usage_info['total_mb']}MB, Used={memory_usage_info['used_mb']}MB, Percent={memory_usage_info['percent']}%"
|
|
||||||
)
|
|
||||||
|
|
||||||
process_memory_info = get_process_memory_usage()
|
|
||||||
print(
|
|
||||||
f"Process Memory Usage: RSS={process_memory_info['rss_mb']}MB, VMS={process_memory_info['vms_mb']}MB, Percent={process_memory_info['percent']}%"
|
|
||||||
)
|
|
||||||
|
|
||||||
disk_usage_info = get_disk_usage("/")
|
|
||||||
print(
|
|
||||||
f"Disk Usage (Root): Total={disk_usage_info['total_gb']}GB, Used={disk_usage_info['used_gb']}GB, Percent={disk_usage_info['percent']}%"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n--- All Basic Info (JSON) ---")
|
|
||||||
all_info = get_all_basic_info()
|
|
||||||
import json
|
|
||||||
|
|
||||||
print(json.dumps(all_info, indent=4, ensure_ascii=False))
|
|
||||||
|
|
||||||
print("\n--- All Basic Info (String with Explanations) ---")
|
|
||||||
info_string = get_all_basic_info_string()
|
|
||||||
print(info_string)
|
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
from typing import List, Optional, Dict, Any
|
|
||||||
import strawberry
|
|
||||||
|
|
||||||
# from packaging.version import Version
|
|
||||||
import os
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class APIBotConfig:
|
|
||||||
"""机器人配置类"""
|
|
||||||
|
|
||||||
INNER_VERSION: str # 配置文件内部版本号(toml为字符串)
|
|
||||||
MAI_VERSION: str # 硬编码的版本信息
|
|
||||||
|
|
||||||
# bot
|
|
||||||
BOT_QQ: Optional[int] # 机器人QQ号
|
|
||||||
BOT_NICKNAME: Optional[str] # 机器人昵称
|
|
||||||
BOT_ALIAS_NAMES: List[str] # 机器人别名列表
|
|
||||||
|
|
||||||
# group
|
|
||||||
talk_allowed_groups: List[int] # 允许回复消息的群号列表
|
|
||||||
talk_frequency_down_groups: List[int] # 降低回复频率的群号列表
|
|
||||||
ban_user_id: List[int] # 禁止回复和读取消息的QQ号列表
|
|
||||||
|
|
||||||
# personality
|
|
||||||
personality_core: str # 人格核心特点描述
|
|
||||||
personality_sides: List[str] # 人格细节描述列表
|
|
||||||
|
|
||||||
# identity
|
|
||||||
identity_detail: List[str] # 身份特点列表
|
|
||||||
age: int # 年龄(岁)
|
|
||||||
gender: str # 性别
|
|
||||||
appearance: str # 外貌特征描述
|
|
||||||
|
|
||||||
# platforms
|
|
||||||
platforms: Dict[str, str] # 平台信息
|
|
||||||
|
|
||||||
# chat
|
|
||||||
allow_focus_mode: bool # 是否允许专注聊天状态
|
|
||||||
base_normal_chat_num: int # 最多允许多少个群进行普通聊天
|
|
||||||
base_focused_chat_num: int # 最多允许多少个群进行专注聊天
|
|
||||||
observation_context_size: int # 观察到的最长上下文大小
|
|
||||||
message_buffer: bool # 是否启用消息缓冲
|
|
||||||
ban_words: List[str] # 禁止词列表
|
|
||||||
ban_msgs_regex: List[str] # 禁止消息的正则表达式列表
|
|
||||||
|
|
||||||
# normal_chat
|
|
||||||
model_reasoning_probability: float # 推理模型概率
|
|
||||||
model_normal_probability: float # 普通模型概率
|
|
||||||
emoji_chance: float # 表情符号出现概率
|
|
||||||
thinking_timeout: int # 思考超时时间
|
|
||||||
willing_mode: str # 意愿模式
|
|
||||||
response_interested_rate_amplifier: float # 回复兴趣率放大器
|
|
||||||
emoji_response_penalty: float # 表情回复惩罚
|
|
||||||
mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复
|
|
||||||
at_bot_inevitable_reply: bool # @bot 必然回复
|
|
||||||
|
|
||||||
# focus_chat
|
|
||||||
reply_trigger_threshold: float # 回复触发阈值
|
|
||||||
default_decay_rate_per_second: float # 默认每秒衰减率
|
|
||||||
|
|
||||||
# compressed
|
|
||||||
compressed_length: int # 压缩长度
|
|
||||||
compress_length_limit: int # 压缩长度限制
|
|
||||||
|
|
||||||
# emoji
|
|
||||||
max_emoji_num: int # 最大表情符号数量
|
|
||||||
max_reach_deletion: bool # 达到最大数量时是否删除
|
|
||||||
check_interval: int # 检查表情包的时间间隔(分钟)
|
|
||||||
save_emoji: bool # 是否保存表情包
|
|
||||||
steal_emoji: bool # 是否偷取表情包
|
|
||||||
enable_check: bool # 是否启用表情包过滤
|
|
||||||
check_prompt: str # 表情包过滤要求
|
|
||||||
|
|
||||||
# memory
|
|
||||||
build_memory_interval: int # 记忆构建间隔
|
|
||||||
build_memory_distribution: List[float] # 记忆构建分布
|
|
||||||
build_memory_sample_num: int # 采样数量
|
|
||||||
build_memory_sample_length: int # 采样长度
|
|
||||||
memory_compress_rate: float # 记忆压缩率
|
|
||||||
forget_memory_interval: int # 记忆遗忘间隔
|
|
||||||
memory_forget_time: int # 记忆遗忘时间(小时)
|
|
||||||
memory_forget_percentage: float # 记忆遗忘比例
|
|
||||||
consolidate_memory_interval: int # 记忆整合间隔
|
|
||||||
consolidation_similarity_threshold: float # 相似度阈值
|
|
||||||
consolidation_check_percentage: float # 检查节点比例
|
|
||||||
memory_ban_words: List[str] # 记忆禁止词列表
|
|
||||||
|
|
||||||
# mood
|
|
||||||
mood_update_interval: float # 情绪更新间隔
|
|
||||||
mood_decay_rate: float # 情绪衰减率
|
|
||||||
mood_intensity_factor: float # 情绪强度因子
|
|
||||||
|
|
||||||
# keywords_reaction
|
|
||||||
keywords_reaction_enable: bool # 是否启用关键词反应
|
|
||||||
keywords_reaction_rules: List[Dict[str, Any]] # 关键词反应规则
|
|
||||||
|
|
||||||
# chinese_typo
|
|
||||||
chinese_typo_enable: bool # 是否启用中文错别字
|
|
||||||
chinese_typo_error_rate: float # 中文错别字错误率
|
|
||||||
chinese_typo_min_freq: int # 中文错别字最小频率
|
|
||||||
chinese_typo_tone_error_rate: float # 中文错别字声调错误率
|
|
||||||
chinese_typo_word_replace_rate: float # 中文错别字单词替换率
|
|
||||||
|
|
||||||
# response_splitter
|
|
||||||
enable_response_splitter: bool # 是否启用回复分割器
|
|
||||||
response_max_length: int # 回复最大长度
|
|
||||||
response_max_sentence_num: int # 回复最大句子数
|
|
||||||
enable_kaomoji_protection: bool # 是否启用颜文字保护
|
|
||||||
|
|
||||||
model_max_output_length: int # 模型最大输出长度
|
|
||||||
|
|
||||||
# remote
|
|
||||||
remote_enable: bool # 是否启用远程功能
|
|
||||||
|
|
||||||
# experimental
|
|
||||||
enable_friend_chat: bool # 是否启用好友聊天
|
|
||||||
talk_allowed_private: List[int] # 允许私聊的QQ号列表
|
|
||||||
pfc_chatting: bool # 是否启用PFC聊天
|
|
||||||
|
|
||||||
# 模型配置
|
|
||||||
llm_reasoning: Dict[str, Any] # 推理模型配置
|
|
||||||
llm_normal: Dict[str, Any] # 普通模型配置
|
|
||||||
llm_topic_judge: Dict[str, Any] # 主题判断模型配置
|
|
||||||
summary: Dict[str, Any] # 总结模型配置
|
|
||||||
vlm: Dict[str, Any] # VLM模型配置
|
|
||||||
llm_heartflow: Dict[str, Any] # 心流模型配置
|
|
||||||
llm_observation: Dict[str, Any] # 观察模型配置
|
|
||||||
llm_sub_heartflow: Dict[str, Any] # 子心流模型配置
|
|
||||||
llm_plan: Optional[Dict[str, Any]] # 计划模型配置
|
|
||||||
embedding: Dict[str, Any] # 嵌入模型配置
|
|
||||||
llm_PFC_action_planner: Optional[Dict[str, Any]] # PFC行动计划模型配置
|
|
||||||
llm_PFC_chat: Optional[Dict[str, Any]] # PFC聊天模型配置
|
|
||||||
llm_PFC_reply_checker: Optional[Dict[str, Any]] # PFC回复检查模型配置
|
|
||||||
llm_tool_use: Optional[Dict[str, Any]] # 工具使用模型配置
|
|
||||||
|
|
||||||
api_urls: Optional[Dict[str, str]] # API地址配置
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_config(config: dict):
|
|
||||||
"""
|
|
||||||
校验传入的 toml 配置字典是否合法。
|
|
||||||
:param config: toml库load后的配置字典
|
|
||||||
:raises: ValueError, KeyError, TypeError
|
|
||||||
"""
|
|
||||||
# 检查主层级
|
|
||||||
required_sections = [
|
|
||||||
"inner",
|
|
||||||
"bot",
|
|
||||||
"groups",
|
|
||||||
"personality",
|
|
||||||
"identity",
|
|
||||||
"platforms",
|
|
||||||
"chat",
|
|
||||||
"normal_chat",
|
|
||||||
"focus_chat",
|
|
||||||
"emoji",
|
|
||||||
"memory",
|
|
||||||
"mood",
|
|
||||||
"keywords_reaction",
|
|
||||||
"chinese_typo",
|
|
||||||
"response_splitter",
|
|
||||||
"remote",
|
|
||||||
"experimental",
|
|
||||||
"model",
|
|
||||||
]
|
|
||||||
for section in required_sections:
|
|
||||||
if section not in config:
|
|
||||||
raise KeyError(f"缺少配置段: [{section}]")
|
|
||||||
|
|
||||||
# 检查部分关键字段
|
|
||||||
if "version" not in config["inner"]:
|
|
||||||
raise KeyError("缺少 inner.version 字段")
|
|
||||||
if not isinstance(config["inner"]["version"], str):
|
|
||||||
raise TypeError("inner.version 必须为字符串")
|
|
||||||
|
|
||||||
if "qq" not in config["bot"]:
|
|
||||||
raise KeyError("缺少 bot.qq 字段")
|
|
||||||
if not isinstance(config["bot"]["qq"], int):
|
|
||||||
raise TypeError("bot.qq 必须为整数")
|
|
||||||
|
|
||||||
if "personality_core" not in config["personality"]:
|
|
||||||
raise KeyError("缺少 personality.personality_core 字段")
|
|
||||||
if not isinstance(config["personality"]["personality_core"], str):
|
|
||||||
raise TypeError("personality.personality_core 必须为字符串")
|
|
||||||
|
|
||||||
if "identity_detail" not in config["identity"]:
|
|
||||||
raise KeyError("缺少 identity.identity_detail 字段")
|
|
||||||
if not isinstance(config["identity"]["identity_detail"], list):
|
|
||||||
raise TypeError("identity.identity_detail 必须为列表")
|
|
||||||
|
|
||||||
# 可继续添加更多字段的类型和值检查
|
|
||||||
# ...
|
|
||||||
|
|
||||||
# 检查模型配置
|
|
||||||
model_keys = [
|
|
||||||
"llm_reasoning",
|
|
||||||
"llm_normal",
|
|
||||||
"llm_topic_judge",
|
|
||||||
"summary",
|
|
||||||
"vlm",
|
|
||||||
"llm_heartflow",
|
|
||||||
"llm_observation",
|
|
||||||
"llm_sub_heartflow",
|
|
||||||
"embedding",
|
|
||||||
]
|
|
||||||
if "model" not in config:
|
|
||||||
raise KeyError("缺少 [model] 配置段")
|
|
||||||
for key in model_keys:
|
|
||||||
if key not in config["model"]:
|
|
||||||
raise KeyError(f"缺少 model.{key} 配置")
|
|
||||||
|
|
||||||
# 检查通过
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class APIEnvConfig:
|
|
||||||
"""环境变量配置"""
|
|
||||||
|
|
||||||
HOST: str # 服务主机地址
|
|
||||||
PORT: int # 服务端口
|
|
||||||
|
|
||||||
PLUGINS: List[str] # 插件列表
|
|
||||||
|
|
||||||
MONGODB_HOST: str # MongoDB 主机地址
|
|
||||||
MONGODB_PORT: int # MongoDB 端口
|
|
||||||
DATABASE_NAME: str # 数据库名称
|
|
||||||
|
|
||||||
CHAT_ANY_WHERE_BASE_URL: str # ChatAnywhere 基础URL
|
|
||||||
SILICONFLOW_BASE_URL: str # SiliconFlow 基础URL
|
|
||||||
DEEP_SEEK_BASE_URL: str # DeepSeek 基础URL
|
|
||||||
|
|
||||||
DEEP_SEEK_KEY: Optional[str] # DeepSeek API Key
|
|
||||||
CHAT_ANY_WHERE_KEY: Optional[str] # ChatAnywhere API Key
|
|
||||||
SILICONFLOW_KEY: Optional[str] # SiliconFlow API Key
|
|
||||||
|
|
||||||
SIMPLE_OUTPUT: Optional[bool] # 是否简化输出
|
|
||||||
CONSOLE_LOG_LEVEL: Optional[str] # 控制台日志等级
|
|
||||||
FILE_LOG_LEVEL: Optional[str] # 文件日志等级
|
|
||||||
DEFAULT_CONSOLE_LOG_LEVEL: Optional[str] # 默认控制台日志等级
|
|
||||||
DEFAULT_FILE_LOG_LEVEL: Optional[str] # 默认文件日志等级
|
|
||||||
|
|
||||||
@strawberry.field
|
|
||||||
def get_env(self) -> str:
|
|
||||||
return "env"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_config(config: dict):
|
|
||||||
"""
|
|
||||||
校验环境变量配置字典是否合法。
|
|
||||||
:param config: 环境变量配置字典
|
|
||||||
:raises: KeyError, TypeError
|
|
||||||
"""
|
|
||||||
required_fields = [
|
|
||||||
"HOST",
|
|
||||||
"PORT",
|
|
||||||
"PLUGINS",
|
|
||||||
"MONGODB_HOST",
|
|
||||||
"MONGODB_PORT",
|
|
||||||
"DATABASE_NAME",
|
|
||||||
"CHAT_ANY_WHERE_BASE_URL",
|
|
||||||
"SILICONFLOW_BASE_URL",
|
|
||||||
"DEEP_SEEK_BASE_URL",
|
|
||||||
]
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in config:
|
|
||||||
raise KeyError(f"缺少环境变量配置字段: {field}")
|
|
||||||
|
|
||||||
if not isinstance(config["HOST"], str):
|
|
||||||
raise TypeError("HOST 必须为字符串")
|
|
||||||
if not isinstance(config["PORT"], int):
|
|
||||||
raise TypeError("PORT 必须为整数")
|
|
||||||
if not isinstance(config["PLUGINS"], list):
|
|
||||||
raise TypeError("PLUGINS 必须为列表")
|
|
||||||
if not isinstance(config["MONGODB_HOST"], str):
|
|
||||||
raise TypeError("MONGODB_HOST 必须为字符串")
|
|
||||||
if not isinstance(config["MONGODB_PORT"], int):
|
|
||||||
raise TypeError("MONGODB_PORT 必须为整数")
|
|
||||||
if not isinstance(config["DATABASE_NAME"], str):
|
|
||||||
raise TypeError("DATABASE_NAME 必须为字符串")
|
|
||||||
if not isinstance(config["CHAT_ANY_WHERE_BASE_URL"], str):
|
|
||||||
raise TypeError("CHAT_ANY_WHERE_BASE_URL 必须为字符串")
|
|
||||||
if not isinstance(config["SILICONFLOW_BASE_URL"], str):
|
|
||||||
raise TypeError("SILICONFLOW_BASE_URL 必须为字符串")
|
|
||||||
if not isinstance(config["DEEP_SEEK_BASE_URL"], str):
|
|
||||||
raise TypeError("DEEP_SEEK_BASE_URL 必须为字符串")
|
|
||||||
|
|
||||||
# 可选字段类型检查
|
|
||||||
optional_str_fields = [
|
|
||||||
"DEEP_SEEK_KEY",
|
|
||||||
"CHAT_ANY_WHERE_KEY",
|
|
||||||
"SILICONFLOW_KEY",
|
|
||||||
"CONSOLE_LOG_LEVEL",
|
|
||||||
"FILE_LOG_LEVEL",
|
|
||||||
"DEFAULT_CONSOLE_LOG_LEVEL",
|
|
||||||
"DEFAULT_FILE_LOG_LEVEL",
|
|
||||||
]
|
|
||||||
for field in optional_str_fields:
|
|
||||||
if field in config and config[field] is not None and not isinstance(config[field], str):
|
|
||||||
raise TypeError(f"{field} 必须为字符串或None")
|
|
||||||
|
|
||||||
if (
|
|
||||||
"SIMPLE_OUTPUT" in config
|
|
||||||
and config["SIMPLE_OUTPUT"] is not None
|
|
||||||
and not isinstance(config["SIMPLE_OUTPUT"], bool)
|
|
||||||
):
|
|
||||||
raise TypeError("SIMPLE_OUTPUT 必须为布尔值或None")
|
|
||||||
|
|
||||||
# 检查通过
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
print("当前路径:")
|
|
||||||
print(ROOT_PATH)
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
import strawberry
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from strawberry.fastapi import GraphQLRouter
|
|
||||||
|
|
||||||
from src.common.server import get_global_server
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class Query:
|
|
||||||
@strawberry.field
|
|
||||||
def hello(self) -> str:
|
|
||||||
return "Hello World"
|
|
||||||
|
|
||||||
|
|
||||||
schema = strawberry.Schema(Query)
|
|
||||||
|
|
||||||
graphql_app = GraphQLRouter(schema)
|
|
||||||
|
|
||||||
fast_api_app: FastAPI = get_global_server().get_app()
|
|
||||||
|
|
||||||
fast_api_app.include_router(graphql_app, prefix="/graphql")
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
pass
|
|
||||||
112
src/api/main.py
112
src/api/main.py
@@ -1,112 +0,0 @@
|
|||||||
from fastapi import APIRouter
|
|
||||||
from strawberry.fastapi import GraphQLRouter
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# from src.chat.heart_flow.heartflow import heartflow
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
|
||||||
# from src.config.config import BotConfig
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.api.reload_config import reload_config as reload_config_func
|
|
||||||
from src.common.server import get_global_server
|
|
||||||
from src.api.apiforgui import (
|
|
||||||
get_all_subheartflow_ids,
|
|
||||||
forced_change_subheartflow_status,
|
|
||||||
get_subheartflow_cycle_info,
|
|
||||||
get_all_states,
|
|
||||||
)
|
|
||||||
from src.chat.heart_flow.sub_heartflow import ChatState
|
|
||||||
from src.api.basic_info_api import get_all_basic_info # 新增导入
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("api")
|
|
||||||
|
|
||||||
logger.info("麦麦API服务器已启动")
|
|
||||||
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
|
|
||||||
|
|
||||||
router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/config/reload")
|
|
||||||
async def reload_config():
|
|
||||||
return await reload_config_func()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/gui/subheartflow/get/all")
|
|
||||||
async def get_subheartflow_ids():
|
|
||||||
"""获取所有子心流的ID列表"""
|
|
||||||
return await get_all_subheartflow_ids()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/gui/subheartflow/forced_change_status")
|
|
||||||
async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): # noqa
|
|
||||||
"""强制改变子心流的状态"""
|
|
||||||
# 参数检查
|
|
||||||
if not isinstance(status, ChatState):
|
|
||||||
logger.warning(f"无效的状态参数: {status}")
|
|
||||||
return {"status": "failed", "reason": "invalid status"}
|
|
||||||
logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}")
|
|
||||||
success = await forced_change_subheartflow_status(subheartflow_id, status)
|
|
||||||
if success:
|
|
||||||
logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功")
|
|
||||||
return {"status": "success"}
|
|
||||||
else:
|
|
||||||
logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败")
|
|
||||||
return {"status": "failed"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stop")
|
|
||||||
async def force_stop_maibot():
|
|
||||||
"""强制停止MAI Bot"""
|
|
||||||
from bot import request_shutdown
|
|
||||||
|
|
||||||
success = await request_shutdown()
|
|
||||||
if success:
|
|
||||||
logger.info("MAI Bot已强制停止")
|
|
||||||
return {"status": "success"}
|
|
||||||
else:
|
|
||||||
logger.error("MAI Bot强制停止失败")
|
|
||||||
return {"status": "failed"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/gui/subheartflow/cycleinfo")
|
|
||||||
async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int):
|
|
||||||
"""获取子心流的循环信息"""
|
|
||||||
cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len)
|
|
||||||
if cycle_info:
|
|
||||||
return {"status": "success", "data": cycle_info}
|
|
||||||
else:
|
|
||||||
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
|
|
||||||
return {"status": "failed", "reason": "subheartflow not found"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/gui/get_all_states")
|
|
||||||
async def get_all_states_api():
|
|
||||||
"""获取所有状态"""
|
|
||||||
all_states = await get_all_states()
|
|
||||||
if all_states:
|
|
||||||
return {"status": "success", "data": all_states}
|
|
||||||
else:
|
|
||||||
logger.warning("获取所有状态失败")
|
|
||||||
return {"status": "failed", "reason": "failed to get all states"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/info")
|
|
||||||
async def get_system_basic_info():
|
|
||||||
"""获取系统基本信息"""
|
|
||||||
logger.info("请求系统基本信息")
|
|
||||||
try:
|
|
||||||
info = get_all_basic_info()
|
|
||||||
return {"status": "success", "data": info}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取系统基本信息失败: {e}")
|
|
||||||
return {"status": "failed", "reason": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
def start_api_server():
|
|
||||||
"""启动API服务器"""
|
|
||||||
get_global_server().register_router(router, prefix="/api/v1")
|
|
||||||
# pass
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from fastapi import HTTPException
|
|
||||||
from rich.traceback import install
|
|
||||||
from src.config.config import get_config_dir, load_config
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
import os
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
|
||||||
|
|
||||||
logger = get_logger("api")
|
|
||||||
|
|
||||||
|
|
||||||
async def reload_config():
|
|
||||||
try:
|
|
||||||
from src.config import config as config_module
|
|
||||||
|
|
||||||
logger.debug("正在重载配置文件...")
|
|
||||||
bot_config_path = os.path.join(get_config_dir(), "bot_config.toml")
|
|
||||||
config_module.global_config = load_config(config_path=bot_config_path)
|
|
||||||
logger.debug("配置文件重载成功")
|
|
||||||
return {"status": "reloaded"}
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e
|
|
||||||
@@ -5,20 +5,19 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, Tuple, List, Any
|
|
||||||
from PIL import Image
|
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
|
import binascii
|
||||||
# from gradio_client import file
|
from typing import Optional, Tuple, List, Any
|
||||||
|
from PIL import Image
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.database.database_model import Emoji
|
from src.common.database.database_model import Emoji
|
||||||
from src.common.database.database import db as peewee_db
|
from src.common.database.database import db as peewee_db
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger import get_logger
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -26,7 +25,7 @@ logger = get_logger("emoji")
|
|||||||
|
|
||||||
BASE_DIR = os.path.join("data")
|
BASE_DIR = os.path.join("data")
|
||||||
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
||||||
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -85,7 +84,7 @@ class MaiEmoji:
|
|||||||
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
|
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
|
||||||
try:
|
try:
|
||||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||||
self.format = img.format.lower()
|
self.format = img.format.lower() # type: ignore
|
||||||
logger.debug(f"[初始化] 格式获取成功: {self.format}")
|
logger.debug(f"[初始化] 格式获取成功: {self.format}")
|
||||||
except Exception as pil_error:
|
except Exception as pil_error:
|
||||||
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
|
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
|
||||||
@@ -100,7 +99,7 @@ class MaiEmoji:
|
|||||||
logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}")
|
logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}")
|
||||||
self.is_deleted = True
|
self.is_deleted = True
|
||||||
return None
|
return None
|
||||||
except base64.binascii.Error as b64_error:
|
except (binascii.Error, ValueError) as b64_error:
|
||||||
logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}")
|
logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}")
|
||||||
self.is_deleted = True
|
self.is_deleted = True
|
||||||
return None
|
return None
|
||||||
@@ -113,7 +112,7 @@ class MaiEmoji:
|
|||||||
async def register_to_db(self) -> bool:
|
async def register_to_db(self) -> bool:
|
||||||
"""
|
"""
|
||||||
注册表情包
|
注册表情包
|
||||||
将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下
|
将表情包对应的文件,从当前路径移动到EMOJI_REGISTERED_DIR目录下
|
||||||
并修改对应的实例属性,然后将表情包信息保存到数据库中
|
并修改对应的实例属性,然后将表情包信息保存到数据库中
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -122,7 +121,7 @@ class MaiEmoji:
|
|||||||
# 源路径是当前实例的完整路径 self.full_path
|
# 源路径是当前实例的完整路径 self.full_path
|
||||||
source_full_path = self.full_path
|
source_full_path = self.full_path
|
||||||
# 目标完整路径
|
# 目标完整路径
|
||||||
destination_full_path = os.path.join(EMOJI_REGISTED_DIR, self.filename)
|
destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename)
|
||||||
|
|
||||||
# 检查源文件是否存在
|
# 检查源文件是否存在
|
||||||
if not os.path.exists(source_full_path):
|
if not os.path.exists(source_full_path):
|
||||||
@@ -139,7 +138,7 @@ class MaiEmoji:
|
|||||||
logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}")
|
logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}")
|
||||||
# 更新实例的路径属性为新路径
|
# 更新实例的路径属性为新路径
|
||||||
self.full_path = destination_full_path
|
self.full_path = destination_full_path
|
||||||
self.path = EMOJI_REGISTED_DIR
|
self.path = EMOJI_REGISTERED_DIR
|
||||||
# self.filename 保持不变
|
# self.filename 保持不变
|
||||||
except Exception as move_error:
|
except Exception as move_error:
|
||||||
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
|
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
|
||||||
@@ -202,7 +201,7 @@ class MaiEmoji:
|
|||||||
try:
|
try:
|
||||||
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
||||||
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
||||||
except Emoji.DoesNotExist:
|
except Emoji.DoesNotExist: # type: ignore
|
||||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||||
result = 0 # Indicate no DB record was deleted
|
result = 0 # Indicate no DB record was deleted
|
||||||
|
|
||||||
@@ -298,7 +297,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
|||||||
def _ensure_emoji_dir() -> None:
|
def _ensure_emoji_dir() -> None:
|
||||||
"""确保表情存储目录存在"""
|
"""确保表情存储目录存在"""
|
||||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||||
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
|
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
async def clear_temp_emoji() -> None:
|
async def clear_temp_emoji() -> None:
|
||||||
@@ -331,10 +330,10 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
|||||||
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
||||||
return removed_count
|
return removed_count
|
||||||
|
|
||||||
|
cleaned_count = 0
|
||||||
try:
|
try:
|
||||||
# 获取内存中所有有效表情包的完整路径集合
|
# 获取内存中所有有效表情包的完整路径集合
|
||||||
tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted}
|
tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted}
|
||||||
cleaned_count = 0
|
|
||||||
|
|
||||||
# 遍历指定目录中的所有文件
|
# 遍历指定目录中的所有文件
|
||||||
for file_name in os.listdir(emoji_dir):
|
for file_name in os.listdir(emoji_dir):
|
||||||
@@ -358,11 +357,11 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
|||||||
else:
|
else:
|
||||||
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
||||||
|
|
||||||
return removed_count + cleaned_count
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
|
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
|
||||||
|
|
||||||
|
return removed_count + cleaned_count
|
||||||
|
|
||||||
|
|
||||||
class EmojiManager:
|
class EmojiManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -414,7 +413,7 @@ class EmojiManager:
|
|||||||
emoji_update.usage_count += 1
|
emoji_update.usage_count += 1
|
||||||
emoji_update.last_used_time = time.time() # Update last used time
|
emoji_update.last_used_time = time.time() # Update last used time
|
||||||
emoji_update.save() # Persist changes to DB
|
emoji_update.save() # Persist changes to DB
|
||||||
except Emoji.DoesNotExist:
|
except Emoji.DoesNotExist: # type: ignore
|
||||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录表情使用失败: {str(e)}")
|
logger.error(f"记录表情使用失败: {str(e)}")
|
||||||
@@ -570,8 +569,8 @@ class EmojiManager:
|
|||||||
if objects_to_remove:
|
if objects_to_remove:
|
||||||
self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove]
|
self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove]
|
||||||
|
|
||||||
# 清理 EMOJI_REGISTED_DIR 目录中未被追踪的文件
|
# 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件
|
||||||
removed_count = await clean_unused_emojis(EMOJI_REGISTED_DIR, self.emoji_objects, removed_count)
|
removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count)
|
||||||
|
|
||||||
# 输出清理结果
|
# 输出清理结果
|
||||||
if removed_count > 0:
|
if removed_count > 0:
|
||||||
@@ -850,11 +849,13 @@ class EmojiManager:
|
|||||||
if isinstance(image_base64, str):
|
if isinstance(image_base64, str):
|
||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
if image_format == "gif" or image_format == "GIF":
|
if image_format == "gif" or image_format == "GIF":
|
||||||
image_base64 = get_image_manager().transform_gif(image_base64)
|
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||||
|
if not image_base64:
|
||||||
|
raise RuntimeError("GIF表情包转换失败")
|
||||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
import os
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
MAX_EXPRESSION_COUNT = 300
|
MAX_EXPRESSION_COUNT = 300
|
||||||
@@ -74,7 +76,8 @@ class ExpressionLearner:
|
|||||||
)
|
)
|
||||||
self.llm_model = None
|
self.llm_model = None
|
||||||
|
|
||||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||||
|
# sourcery skip: extract-duplicate-method, remove-unnecessary-cast
|
||||||
"""
|
"""
|
||||||
获取指定chat_id的style和grammar表达方式
|
获取指定chat_id的style和grammar表达方式
|
||||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||||
@@ -119,10 +122,10 @@ class ExpressionLearner:
|
|||||||
min_len = min(len(s1), len(s2))
|
min_len = min(len(s1), len(s2))
|
||||||
if min_len < 5:
|
if min_len < 5:
|
||||||
return False
|
return False
|
||||||
same = sum(1 for a, b in zip(s1, s2, strict=False) if a == b)
|
same = sum(a == b for a, b in zip(s1, s2, strict=False))
|
||||||
return same / min_len > 0.8
|
return same / min_len > 0.8
|
||||||
|
|
||||||
async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]:
|
async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]:
|
||||||
"""
|
"""
|
||||||
学习并存储表达方式,分别学习语言风格和句法特点
|
学习并存储表达方式,分别学习语言风格和句法特点
|
||||||
同时对所有已存储的表达方式进行全局衰减
|
同时对所有已存储的表达方式进行全局衰减
|
||||||
@@ -158,12 +161,12 @@ class ExpressionLearner:
|
|||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
||||||
if not learnt_style:
|
if not learnt_style:
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
for _ in range(1):
|
for _ in range(1):
|
||||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
||||||
if not learnt_grammar:
|
if not learnt_grammar:
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
return learnt_style, learnt_grammar
|
return learnt_style, learnt_grammar
|
||||||
|
|
||||||
@@ -214,6 +217,7 @@ class ExpressionLearner:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||||
|
# sourcery skip: use-join
|
||||||
"""
|
"""
|
||||||
选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式
|
选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式
|
||||||
type: "style" or "grammar"
|
type: "style" or "grammar"
|
||||||
@@ -249,7 +253,7 @@ class ExpressionLearner:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# 按chat_id分组
|
# 按chat_id分组
|
||||||
chat_dict: Dict[str, List[Dict[str, str]]] = {}
|
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
for chat_id, situation, style in learnt_expressions:
|
for chat_id, situation, style in learnt_expressions:
|
||||||
if chat_id not in chat_dict:
|
if chat_id not in chat_dict:
|
||||||
chat_dict[chat_id] = []
|
chat_dict[chat_id] = []
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from .expression_learner import get_expression_learner
|
|
||||||
import random
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
from json_repair import repair_json
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
|
|
||||||
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
|
from .expression_learner import get_expression_learner
|
||||||
|
|
||||||
logger = get_logger("expression_selector")
|
logger = get_logger("expression_selector")
|
||||||
|
|
||||||
@@ -82,6 +84,7 @@ class ExpressionSelector:
|
|||||||
def get_random_expressions(
|
def get_random_expressions(
|
||||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||||
|
# sourcery skip: extract-duplicate-method, move-assign
|
||||||
(
|
(
|
||||||
learnt_style_expressions,
|
learnt_style_expressions,
|
||||||
learnt_grammar_expressions,
|
learnt_grammar_expressions,
|
||||||
@@ -165,8 +168,14 @@ class ExpressionSelector:
|
|||||||
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
||||||
|
|
||||||
async def select_suitable_expressions_llm(
|
async def select_suitable_expressions_llm(
|
||||||
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5, target_message: str = None
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
chat_info: str,
|
||||||
|
max_num: int = 10,
|
||||||
|
min_num: int = 5,
|
||||||
|
target_message: Optional[str] = None,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
|
# sourcery skip: inline-variable, list-comprehension
|
||||||
"""使用LLM选择适合的表达方式"""
|
"""使用LLM选择适合的表达方式"""
|
||||||
|
|
||||||
# 1. 获取35个随机表达方式(现在按权重抽取)
|
# 1. 获取35个随机表达方式(现在按权重抽取)
|
||||||
|
|||||||
@@ -1,19 +1,21 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, List
|
from collections import deque
|
||||||
|
from typing import List, Optional, Dict, Any, Deque, Callable, Awaitable
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
|
||||||
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.chat.planner_actions.planner import ActionPlanner
|
from src.chat.planner_actions.planner import ActionPlanner
|
||||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.config.config import global_config
|
|
||||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
|
||||||
from src.chat.focus_chat.hfc_utils import CycleDetail
|
from src.chat.focus_chat.hfc_utils import CycleDetail
|
||||||
|
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||||
|
from src.plugin_system.base.component_types import ChatMode
|
||||||
import random
|
import random
|
||||||
from src.chat.focus_chat.hfc_utils import get_recent_message_stats
|
from src.chat.focus_chat.hfc_utils import get_recent_message_stats
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
@@ -23,7 +25,6 @@ from .priority_manager import PriorityManager
|
|||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ERROR_LOOP_INFO = {
|
ERROR_LOOP_INFO = {
|
||||||
"loop_plan_info": {
|
"loop_plan_info": {
|
||||||
"action_result": {
|
"action_result": {
|
||||||
@@ -79,7 +80,9 @@ class HeartFChatting:
|
|||||||
"""
|
"""
|
||||||
# 基础属性
|
# 基础属性
|
||||||
self.stream_id: str = chat_id # 聊天流ID
|
self.stream_id: str = chat_id # 聊天流ID
|
||||||
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
|
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
|
||||||
|
if not self.chat_stream:
|
||||||
|
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||||
|
|
||||||
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
||||||
@@ -93,7 +96,6 @@ class HeartFChatting:
|
|||||||
self._message_threshold = max(10, int(30 * global_config.chat.exit_focus_threshold))
|
self._message_threshold = max(10, int(30 * global_config.chat.exit_focus_threshold))
|
||||||
self._fatigue_triggered = False # 是否已触发疲惫退出
|
self._fatigue_triggered = False # 是否已触发疲惫退出
|
||||||
|
|
||||||
|
|
||||||
self.action_manager = ActionManager()
|
self.action_manager = ActionManager()
|
||||||
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||||
@@ -112,11 +114,9 @@ class HeartFChatting:
|
|||||||
|
|
||||||
self.last_read_time = time.time() - 1
|
self.last_read_time = time.time() - 1
|
||||||
|
|
||||||
|
|
||||||
self.willing_amplifier = 1
|
self.willing_amplifier = 1
|
||||||
self.willing_manager = get_willing_manager()
|
self.willing_manager = get_willing_manager()
|
||||||
|
|
||||||
|
|
||||||
self.reply_mode = self.chat_stream.context.get_priority_mode()
|
self.reply_mode = self.chat_stream.context.get_priority_mode()
|
||||||
if self.reply_mode == "priority":
|
if self.reply_mode == "priority":
|
||||||
self.priority_manager = PriorityManager(
|
self.priority_manager = PriorityManager(
|
||||||
@@ -126,12 +126,10 @@ class HeartFChatting:
|
|||||||
else:
|
else:
|
||||||
self.priority_manager = None
|
self.priority_manager = None
|
||||||
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} HeartFChatting 初始化完成,消息疲惫阈值: {self._message_threshold}条(基于exit_focus_threshold={global_config.chat.exit_focus_threshold}计算,仅在auto模式下生效)"
|
f"{self.log_prefix} HeartFChatting 初始化完成,消息疲惫阈值: {self._message_threshold}条(基于exit_focus_threshold={global_config.chat.exit_focus_threshold}计算,仅在auto模式下生效)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.energy_value = 100
|
self.energy_value = 100
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
@@ -160,8 +158,7 @@ class HeartFChatting:
|
|||||||
def _handle_loop_completion(self, task: asyncio.Task):
|
def _handle_loop_completion(self, task: asyncio.Task):
|
||||||
"""当 _hfc_loop 任务完成时执行的回调。"""
|
"""当 _hfc_loop 任务完成时执行的回调。"""
|
||||||
try:
|
try:
|
||||||
exception = task.exception()
|
if exception := task.exception():
|
||||||
if exception:
|
|
||||||
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
||||||
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
logger.error(traceback.format_exc()) # Log full traceback for exceptions
|
||||||
else:
|
else:
|
||||||
@@ -191,26 +188,27 @@ class HeartFChatting:
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, "
|
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
|
||||||
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
|
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
|
||||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _loopbody(self):
|
async def _loopbody(self):
|
||||||
if self.loop_mode == "focus":
|
if self.loop_mode == "focus":
|
||||||
|
|
||||||
self.energy_value -= 5 * (1 / global_config.chat.exit_focus_threshold)
|
self.energy_value -= 5 * (1 / global_config.chat.exit_focus_threshold)
|
||||||
if self.energy_value <= 0:
|
if self.energy_value <= 0:
|
||||||
self.loop_mode = "normal"
|
self.loop_mode = "normal"
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
return await self._observe()
|
return await self._observe()
|
||||||
elif self.loop_mode == "normal":
|
elif self.loop_mode == "normal":
|
||||||
new_messages_data = get_raw_msg_by_timestamp_with_chat(
|
new_messages_data = get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id=self.stream_id, timestamp_start=self.last_read_time, timestamp_end=time.time(),limit=10,limit_mode="earliest",fliter_bot=True
|
chat_id=self.stream_id,
|
||||||
|
timestamp_start=self.last_read_time,
|
||||||
|
timestamp_end=time.time(),
|
||||||
|
limit=10,
|
||||||
|
limit_mode="earliest",
|
||||||
|
fliter_bot=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(new_messages_data) > 4 * global_config.chat.auto_focus_threshold:
|
if len(new_messages_data) > 4 * global_config.chat.auto_focus_threshold:
|
||||||
@@ -238,18 +236,13 @@ class HeartFChatting:
|
|||||||
reply_to_str = f"{person_name}:{message_data.get('processed_plain_text')}"
|
reply_to_str = f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||||
return reply_to_str
|
return reply_to_str
|
||||||
|
|
||||||
|
|
||||||
async def _observe(self, message_data: dict = None):
|
async def _observe(self, message_data: dict = None):
|
||||||
# 创建新的循环信息
|
# 创建新的循环信息
|
||||||
cycle_timers, thinking_id = self.start_cycle()
|
cycle_timers, thinking_id = self.start_cycle()
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
|
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
|
||||||
|
|
||||||
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
async with global_prompt_manager.async_message_scope(
|
|
||||||
self.chat_stream.context.get_template_name()
|
|
||||||
):
|
|
||||||
|
|
||||||
loop_start_time = time.time()
|
loop_start_time = time.time()
|
||||||
# await self.loop_info.observe()
|
# await self.loop_info.observe()
|
||||||
await self.relationship_builder.build_relation()
|
await self.relationship_builder.build_relation()
|
||||||
@@ -267,12 +260,9 @@ class HeartFChatting:
|
|||||||
reply_to_str = await self.build_reply_to_str(message_data)
|
reply_to_str = await self.build_reply_to_str(message_data)
|
||||||
gen_task = asyncio.create_task(self._generate_response(message_data, available_actions, reply_to_str))
|
gen_task = asyncio.create_task(self._generate_response(message_data, available_actions, reply_to_str))
|
||||||
|
|
||||||
|
|
||||||
with Timer("规划器", cycle_timers):
|
with Timer("规划器", cycle_timers):
|
||||||
plan_result = await self.action_planner.plan(mode=self.loop_mode)
|
plan_result = await self.action_planner.plan(mode=self.loop_mode)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
action_result = plan_result.get("action_result", {})
|
action_result = plan_result.get("action_result", {})
|
||||||
action_type, action_data, reasoning, is_parallel = (
|
action_type, action_data, reasoning, is_parallel = (
|
||||||
action_result.get("action_type", "error"),
|
action_result.get("action_type", "error"),
|
||||||
@@ -293,8 +283,6 @@ class HeartFChatting:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定执行{action_type}动作")
|
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定执行{action_type}动作")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if action_type == "no_action":
|
if action_type == "no_action":
|
||||||
# 等待回复生成完毕
|
# 等待回复生成完毕
|
||||||
gather_timeout = global_config.chat.thinking_timeout
|
gather_timeout = global_config.chat.thinking_timeout
|
||||||
@@ -307,9 +295,7 @@ class HeartFChatting:
|
|||||||
content = " ".join([item[1] for item in response_set if item[0] == "text"])
|
content = " ".join([item[1] for item in response_set if item[0] == "text"])
|
||||||
|
|
||||||
# 模型炸了,没有回复内容生成
|
# 模型炸了,没有回复内容生成
|
||||||
if not response_set or (
|
if not response_set or (action_type not in ["no_action"] and not is_parallel):
|
||||||
action_type not in ["no_action"] and not is_parallel
|
|
||||||
):
|
|
||||||
if not response_set:
|
if not response_set:
|
||||||
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
|
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
|
||||||
elif action_type not in ["no_action"] and not is_parallel:
|
elif action_type not in ["no_action"] and not is_parallel:
|
||||||
@@ -320,14 +306,11 @@ class HeartFChatting:
|
|||||||
|
|
||||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
|
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
|
||||||
|
|
||||||
|
|
||||||
# 发送回复 (不再需要传入 chat)
|
# 发送回复 (不再需要传入 chat)
|
||||||
await self._send_response(response_set, reply_to_str, loop_start_time)
|
await self._send_response(response_set, reply_to_str, loop_start_time)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 动作执行计时
|
# 动作执行计时
|
||||||
with Timer("动作执行", cycle_timers):
|
with Timer("动作执行", cycle_timers):
|
||||||
@@ -360,8 +343,6 @@ class HeartFChatting:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _main_chat_loop(self):
|
async def _main_chat_loop(self):
|
||||||
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
|
||||||
try:
|
try:
|
||||||
@@ -425,9 +406,6 @@ class HeartFChatting:
|
|||||||
|
|
||||||
# 处理动作并获取结果
|
# 处理动作并获取结果
|
||||||
result = await action_handler.handle_action()
|
result = await action_handler.handle_action()
|
||||||
if len(result) == 3:
|
|
||||||
success, reply_text, command = result
|
|
||||||
else:
|
|
||||||
success, reply_text = result
|
success, reply_text = result
|
||||||
command = ""
|
command = ""
|
||||||
|
|
||||||
@@ -447,8 +425,6 @@ class HeartFChatting:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, "", ""
|
return False, "", ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""优雅关闭HeartFChatting实例,取消活动循环任务"""
|
"""优雅关闭HeartFChatting实例,取消活动循环任务"""
|
||||||
logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
|
logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
|
||||||
@@ -483,7 +459,6 @@ class HeartFChatting:
|
|||||||
|
|
||||||
logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
|
logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
|
||||||
|
|
||||||
|
|
||||||
def adjust_reply_frequency(self):
|
def adjust_reply_frequency(self):
|
||||||
"""
|
"""
|
||||||
根据预设规则动态调整回复意愿(willing_amplifier)。
|
根据预设规则动态调整回复意愿(willing_amplifier)。
|
||||||
@@ -554,8 +529,6 @@ class HeartFChatting:
|
|||||||
f"意愿放大器更新为: {self.willing_amplifier:.2f}"
|
f"意愿放大器更新为: {self.willing_amplifier:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def normal_response(self, message_data: dict) -> None:
|
async def normal_response(self, message_data: dict) -> None:
|
||||||
"""
|
"""
|
||||||
处理接收到的消息。
|
处理接收到的消息。
|
||||||
@@ -587,7 +560,6 @@ class HeartFChatting:
|
|||||||
if message_data.get("is_emoji") or message_data.get("is_picid"):
|
if message_data.get("is_emoji") or message_data.get("is_picid"):
|
||||||
reply_probability = 0
|
reply_probability = 0
|
||||||
|
|
||||||
|
|
||||||
# 打印消息信息
|
# 打印消息信息
|
||||||
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
|
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
|
||||||
if reply_probability > 0.1:
|
if reply_probability > 0.1:
|
||||||
@@ -606,7 +578,6 @@ class HeartFChatting:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, message_data: dict, available_actions: Optional[list], reply_to: str
|
self, message_data: dict, available_actions: Optional[list], reply_to: str
|
||||||
) -> Optional[list]:
|
) -> Optional[list]:
|
||||||
@@ -630,10 +601,7 @@ class HeartFChatting:
|
|||||||
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _send_response(self, reply_set, reply_to, thinking_start_time):
|
||||||
async def _send_response(
|
|
||||||
self, reply_set, reply_to, thinking_start_time
|
|
||||||
):
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
new_message_count = message_api.count_new_messages(
|
new_message_count = message_api.count_new_messages(
|
||||||
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||||
@@ -651,7 +619,9 @@ class HeartFChatting:
|
|||||||
data = reply_seg[1]
|
data = reply_seg[1]
|
||||||
if not first_replyed:
|
if not first_replyed:
|
||||||
if need_reply:
|
if need_reply:
|
||||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False)
|
await send_api.text_to_stream(
|
||||||
|
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False
|
||||||
|
)
|
||||||
first_replyed = True
|
first_replyed = True
|
||||||
else:
|
else:
|
||||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=False)
|
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=False)
|
||||||
@@ -661,5 +631,3 @@ class HeartFChatting:
|
|||||||
reply_text += data
|
reply_text += data
|
||||||
|
|
||||||
return reply_text
|
return reply_text
|
||||||
|
|
||||||
|
|
||||||
@@ -1,11 +1,14 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
import json
|
||||||
from src.common.logger import get_logger
|
|
||||||
from typing import Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.message_repository import count_messages
|
from src.common.message_repository import count_messages
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
from src.chat.message_receive.message import UserInfo
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -82,7 +85,6 @@ class CycleDetail:
|
|||||||
self.loop_action_info = loop_info["loop_action_info"]
|
self.loop_action_info = loop_info["loop_action_info"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
|
def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import heapq
|
|||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("normal_chat")
|
logger = get_logger("normal_chat")
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
from typing import Any, Optional, Dict
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from typing import Any, Optional
|
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||||
from typing import Dict
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
logger = get_logger("heartflow")
|
logger = get_logger("heartflow")
|
||||||
@@ -17,14 +17,11 @@ class Heartflow:
|
|||||||
async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]:
|
async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]:
|
||||||
"""获取或创建一个新的SubHeartflow实例"""
|
"""获取或创建一个新的SubHeartflow实例"""
|
||||||
if subheartflow_id in self.subheartflows:
|
if subheartflow_id in self.subheartflows:
|
||||||
subflow = self.subheartflows.get(subheartflow_id)
|
if subflow := self.subheartflows.get(subheartflow_id):
|
||||||
if subflow:
|
|
||||||
return subflow
|
return subflow
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_subflow = SubHeartflow(
|
new_subflow = SubHeartflow(subheartflow_id)
|
||||||
subheartflow_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
await new_subflow.initialize()
|
await new_subflow.initialize()
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,23 @@
|
|||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
|
||||||
from src.config.config import global_config
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.chat.message_receive.message import MessageRecv
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
|
||||||
from src.chat.heart_flow.heartflow import heartflow
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
|
||||||
from src.chat.utils.timer_calculator import Timer
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
import re
|
import re
|
||||||
import math
|
import math
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
from typing import Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
|
from src.chat.message_receive.message import MessageRecv
|
||||||
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
|
from src.chat.heart_flow.heartflow import heartflow
|
||||||
|
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
|
from src.chat.utils.timer_calculator import Timer
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_manager import get_relationship_manager
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
||||||
|
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
|
|
||||||
@@ -27,16 +29,16 @@ async def _process_relationship(message: MessageRecv) -> None:
|
|||||||
message: 消息对象,包含用户信息
|
message: 消息对象,包含用户信息
|
||||||
"""
|
"""
|
||||||
platform = message.message_info.platform
|
platform = message.message_info.platform
|
||||||
user_id = message.message_info.user_info.user_id
|
user_id = message.message_info.user_info.user_id # type: ignore
|
||||||
nickname = message.message_info.user_info.user_nickname
|
nickname = message.message_info.user_info.user_nickname # type: ignore
|
||||||
cardname = message.message_info.user_info.user_cardname or nickname
|
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
|
||||||
|
|
||||||
relationship_manager = get_relationship_manager()
|
relationship_manager = get_relationship_manager()
|
||||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||||
|
|
||||||
if not is_known:
|
if not is_known:
|
||||||
logger.info(f"首次认识用户: {nickname}")
|
logger.info(f"首次认识用户: {nickname}")
|
||||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
|
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||||
@@ -96,31 +98,24 @@ class HeartFCMessageReceiver:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 1. 消息解析与初始化
|
# 1. 消息解析与初始化
|
||||||
groupinfo = message.message_info.group_info
|
|
||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
messageinfo = message.message_info
|
chat = message.chat_stream
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
|
||||||
platform=messageinfo.platform,
|
|
||||||
user_info=userinfo,
|
|
||||||
group_info=groupinfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# 2. 兴趣度计算与更新
|
||||||
interested_rate, is_mentioned = await _calculate_interest(message)
|
interested_rate, is_mentioned = await _calculate_interest(message)
|
||||||
message.interest_value = interested_rate
|
message.interest_value = interested_rate
|
||||||
message.is_mentioned = is_mentioned
|
message.is_mentioned = is_mentioned
|
||||||
|
|
||||||
await self.storage.store_message(message, chat)
|
await self.storage.store_message(message, chat)
|
||||||
|
|
||||||
subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id)
|
subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore
|
||||||
message.update_chat_stream(chat)
|
|
||||||
|
|
||||||
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||||
|
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore
|
||||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||||
|
|
||||||
# 7. 日志记录
|
# 3. 日志记录
|
||||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||||
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
||||||
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
|
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
|
||||||
@@ -129,11 +124,11 @@ class HeartFCMessageReceiver:
|
|||||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||||
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
||||||
|
|
||||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
|
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
||||||
|
|
||||||
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
||||||
|
|
||||||
# 8. 关系处理
|
# 4. 关系处理
|
||||||
if global_config.relationship.enable_relationship:
|
if global_config.relationship.enable_relationship:
|
||||||
await _process_relationship(message)
|
await _process_relationship(message)
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Tuple
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.focus_chat.heartFC_chat import HeartFChatting
|
from src.chat.focus_chat.heartFC_chat import HeartFChatting
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
from src.config.config import global_config
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
logger = get_logger("sub_heartflow")
|
logger = get_logger("sub_heartflow")
|
||||||
|
|
||||||
@@ -28,7 +31,6 @@ class SubHeartflow:
|
|||||||
self.subheartflow_id = subheartflow_id
|
self.subheartflow_id = subheartflow_id
|
||||||
self.chat_id = subheartflow_id
|
self.chat_id = subheartflow_id
|
||||||
|
|
||||||
|
|
||||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||||
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||||
|
|
||||||
@@ -45,9 +47,6 @@ class SubHeartflow:
|
|||||||
"""异步初始化方法,创建兴趣流并确定聊天类型"""
|
"""异步初始化方法,创建兴趣流并确定聊天类型"""
|
||||||
await self.heart_fc_instance.start()
|
await self.heart_fc_instance.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _stop_heart_fc_chat(self):
|
async def _stop_heart_fc_chat(self):
|
||||||
"""停止并清理 HeartFChatting 实例"""
|
"""停止并清理 HeartFChatting 实例"""
|
||||||
if self.heart_fc_instance.running:
|
if self.heart_fc_instance.running:
|
||||||
@@ -86,7 +85,6 @@ class SubHeartflow:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_in_focus_cooldown(self) -> bool:
|
def is_in_focus_cooldown(self) -> bool:
|
||||||
"""检查是否在focus模式的冷却期内
|
"""检查是否在focus模式的冷却期内
|
||||||
|
|
||||||
@@ -133,6 +131,4 @@ class SubHeartflow:
|
|||||||
if elapsed_since_exit >= cooldown_duration:
|
if elapsed_since_exit >= cooldown_duration:
|
||||||
return 1.0 # 冷却完成
|
return 1.0 # 冷却完成
|
||||||
|
|
||||||
# 计算进度:0表示刚开始冷却,1表示冷却完成
|
return elapsed_since_exit / cooldown_duration
|
||||||
progress = elapsed_since_exit / cooldown_duration
|
|
||||||
return progress
|
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ def calculate_information_content(text):
|
|||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else
|
||||||
"""计算余弦相似度"""
|
"""计算余弦相似度"""
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
norm1 = np.linalg.norm(v1)
|
norm1 = np.linalg.norm(v1)
|
||||||
@@ -89,13 +89,12 @@ class MemoryGraph:
|
|||||||
if not isinstance(self.G.nodes[concept]["memory_items"], list):
|
if not isinstance(self.G.nodes[concept]["memory_items"], list):
|
||||||
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
|
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
|
||||||
self.G.nodes[concept]["memory_items"].append(memory)
|
self.G.nodes[concept]["memory_items"].append(memory)
|
||||||
# 更新最后修改时间
|
|
||||||
self.G.nodes[concept]["last_modified"] = current_time
|
|
||||||
else:
|
else:
|
||||||
self.G.nodes[concept]["memory_items"] = [memory]
|
self.G.nodes[concept]["memory_items"] = [memory]
|
||||||
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
|
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
|
||||||
if "created_time" not in self.G.nodes[concept]:
|
if "created_time" not in self.G.nodes[concept]:
|
||||||
self.G.nodes[concept]["created_time"] = current_time
|
self.G.nodes[concept]["created_time"] = current_time
|
||||||
|
# 更新最后修改时间
|
||||||
self.G.nodes[concept]["last_modified"] = current_time
|
self.G.nodes[concept]["last_modified"] = current_time
|
||||||
else:
|
else:
|
||||||
# 如果是新节点,创建新的记忆列表
|
# 如果是新节点,创建新的记忆列表
|
||||||
@@ -108,11 +107,7 @@ class MemoryGraph:
|
|||||||
|
|
||||||
def get_dot(self, concept):
|
def get_dot(self, concept):
|
||||||
# 检查节点是否存在于图中
|
# 检查节点是否存在于图中
|
||||||
if concept in self.G:
|
return (concept, self.G.nodes[concept]) if concept in self.G else None
|
||||||
# 从图中获取节点数据
|
|
||||||
node_data = self.G.nodes[concept]
|
|
||||||
return concept, node_data
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_related_item(self, topic, depth=1):
|
def get_related_item(self, topic, depth=1):
|
||||||
if topic not in self.G:
|
if topic not in self.G:
|
||||||
@@ -139,8 +134,7 @@ class MemoryGraph:
|
|||||||
if depth >= 2:
|
if depth >= 2:
|
||||||
# 获取相邻节点的记忆项
|
# 获取相邻节点的记忆项
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
node_data = self.get_dot(neighbor)
|
if node_data := self.get_dot(neighbor):
|
||||||
if node_data:
|
|
||||||
concept, data = node_data
|
concept, data = node_data
|
||||||
if "memory_items" in data:
|
if "memory_items" in data:
|
||||||
memory_items = data["memory_items"]
|
memory_items = data["memory_items"]
|
||||||
@@ -194,9 +188,9 @@ class MemoryGraph:
|
|||||||
class Hippocampus:
|
class Hippocampus:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.memory_graph = MemoryGraph()
|
self.memory_graph = MemoryGraph()
|
||||||
self.model_summary = None
|
self.model_summary: LLMRequest = None # type: ignore
|
||||||
self.entorhinal_cortex = None
|
self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
|
||||||
self.parahippocampal_gyrus = None
|
self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
# 初始化子组件
|
# 初始化子组件
|
||||||
@@ -218,7 +212,7 @@ class Hippocampus:
|
|||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
# 使用集合来去重,避免排序
|
# 使用集合来去重,避免排序
|
||||||
unique_items = set(str(item) for item in memory_items)
|
unique_items = {str(item) for item in memory_items}
|
||||||
# 使用frozenset来保证顺序一致性
|
# 使用frozenset来保证顺序一致性
|
||||||
content = f"{concept}:{frozenset(unique_items)}"
|
content = f"{concept}:{frozenset(unique_items)}"
|
||||||
return hash(content)
|
return hash(content)
|
||||||
@@ -231,6 +225,7 @@ class Hippocampus:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_topic_llm(text, topic_num):
|
def find_topic_llm(text, topic_num):
|
||||||
|
# sourcery skip: inline-immediately-returned-variable
|
||||||
prompt = (
|
prompt = (
|
||||||
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
||||||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||||
@@ -240,6 +235,7 @@ class Hippocampus:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def topic_what(text, topic):
|
def topic_what(text, topic):
|
||||||
|
# sourcery skip: inline-immediately-returned-variable
|
||||||
# 不再需要 time_info 参数
|
# 不再需要 time_info 参数
|
||||||
prompt = (
|
prompt = (
|
||||||
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
|
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
|
||||||
@@ -480,9 +476,7 @@ class Hippocampus:
|
|||||||
top_memories = memory_similarities[:max_memory_length]
|
top_memories = memory_similarities[:max_memory_length]
|
||||||
|
|
||||||
# 添加到结果中
|
# 添加到结果中
|
||||||
for memory, similarity in top_memories:
|
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
|
||||||
all_memories.append((node, [memory], similarity))
|
|
||||||
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
|
|
||||||
else:
|
else:
|
||||||
logger.info("节点没有记忆")
|
logger.info("节点没有记忆")
|
||||||
|
|
||||||
@@ -646,9 +640,7 @@ class Hippocampus:
|
|||||||
top_memories = memory_similarities[:max_memory_length]
|
top_memories = memory_similarities[:max_memory_length]
|
||||||
|
|
||||||
# 添加到结果中
|
# 添加到结果中
|
||||||
for memory, similarity in top_memories:
|
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
|
||||||
all_memories.append((node, [memory], similarity))
|
|
||||||
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
|
|
||||||
else:
|
else:
|
||||||
logger.info("节点没有记忆")
|
logger.info("节点没有记忆")
|
||||||
|
|
||||||
@@ -823,11 +815,11 @@ class EntorhinalCortex:
|
|||||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
# 调用修改后的 random_get_msg_snippet
|
if messages := self.random_get_msg_snippet(
|
||||||
messages = self.random_get_msg_snippet(
|
timestamp,
|
||||||
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
|
global_config.memory.memory_build_sample_length,
|
||||||
)
|
max_memorized_time_per_msg,
|
||||||
if messages:
|
):
|
||||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||||
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||||
chat_samples.append(messages)
|
chat_samples.append(messages)
|
||||||
@@ -838,6 +830,7 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
||||||
|
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||||
try_count = 0
|
try_count = 0
|
||||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||||
@@ -847,22 +840,21 @@ class EntorhinalCortex:
|
|||||||
timestamp_start = target_timestamp
|
timestamp_start = target_timestamp
|
||||||
timestamp_end = target_timestamp + time_window_seconds
|
timestamp_end = target_timestamp + time_window_seconds
|
||||||
|
|
||||||
chosen_message = get_raw_msg_by_timestamp(
|
if chosen_message := get_raw_msg_by_timestamp(
|
||||||
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
|
timestamp_start=timestamp_start,
|
||||||
)
|
timestamp_end=timestamp_end,
|
||||||
|
limit=1,
|
||||||
|
limit_mode="earliest",
|
||||||
|
):
|
||||||
|
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
|
||||||
|
|
||||||
if chosen_message:
|
if messages := get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id = chosen_message[0].get("chat_id")
|
|
||||||
|
|
||||||
messages = get_raw_msg_by_timestamp_with_chat(
|
|
||||||
timestamp_start=timestamp_start,
|
timestamp_start=timestamp_start,
|
||||||
timestamp_end=timestamp_end,
|
timestamp_end=timestamp_end,
|
||||||
limit=chat_size,
|
limit=chat_size,
|
||||||
limit_mode="earliest",
|
limit_mode="earliest",
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
)
|
):
|
||||||
|
|
||||||
if messages:
|
|
||||||
# 检查获取到的所有消息是否都未达到最大记忆次数
|
# 检查获取到的所有消息是否都未达到最大记忆次数
|
||||||
all_valid = True
|
all_valid = True
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -975,7 +967,7 @@ class EntorhinalCortex:
|
|||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
if nodes_to_delete:
|
if nodes_to_delete:
|
||||||
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute()
|
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
|
||||||
|
|
||||||
# 处理边的信息
|
# 处理边的信息
|
||||||
db_edges = list(GraphEdges.select())
|
db_edges = list(GraphEdges.select())
|
||||||
@@ -1114,7 +1106,7 @@ class EntorhinalCortex:
|
|||||||
node_start = time.time()
|
node_start = time.time()
|
||||||
if nodes_data:
|
if nodes_data:
|
||||||
batch_size = 500 # 增加批量大小
|
batch_size = 500 # 增加批量大小
|
||||||
with GraphNodes._meta.database.atomic():
|
with GraphNodes._meta.database.atomic(): # type: ignore
|
||||||
for i in range(0, len(nodes_data), batch_size):
|
for i in range(0, len(nodes_data), batch_size):
|
||||||
batch = nodes_data[i : i + batch_size]
|
batch = nodes_data[i : i + batch_size]
|
||||||
GraphNodes.insert_many(batch).execute()
|
GraphNodes.insert_many(batch).execute()
|
||||||
@@ -1125,7 +1117,7 @@ class EntorhinalCortex:
|
|||||||
edge_start = time.time()
|
edge_start = time.time()
|
||||||
if edges_data:
|
if edges_data:
|
||||||
batch_size = 500 # 增加批量大小
|
batch_size = 500 # 增加批量大小
|
||||||
with GraphEdges._meta.database.atomic():
|
with GraphEdges._meta.database.atomic(): # type: ignore
|
||||||
for i in range(0, len(edges_data), batch_size):
|
for i in range(0, len(edges_data), batch_size):
|
||||||
batch = edges_data[i : i + batch_size]
|
batch = edges_data[i : i + batch_size]
|
||||||
GraphEdges.insert_many(batch).execute()
|
GraphEdges.insert_many(batch).execute()
|
||||||
@@ -1489,9 +1481,7 @@ class ParahippocampalGyrus:
|
|||||||
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
||||||
last_modified = node_data.get("last_modified", current_time)
|
last_modified = node_data.get("last_modified", current_time)
|
||||||
# 条件1:检查是否长时间未修改 (超过24小时)
|
# 条件1:检查是否长时间未修改 (超过24小时)
|
||||||
if current_time - last_modified > 3600 * 24:
|
if current_time - last_modified > 3600 * 24 and memory_items:
|
||||||
# 条件2:再次确认节点包含记忆项(理论上已确认,但作为保险)
|
|
||||||
if memory_items:
|
|
||||||
current_count = len(memory_items)
|
current_count = len(memory_items)
|
||||||
# 如果列表非空,才进行随机选择
|
# 如果列表非空,才进行随机选择
|
||||||
if current_count > 0:
|
if current_count > 0:
|
||||||
@@ -1669,7 +1659,7 @@ class ParahippocampalGyrus:
|
|||||||
|
|
||||||
class HippocampusManager:
|
class HippocampusManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._hippocampus = None
|
self._hippocampus: Hippocampus = None # type: ignore
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from json_repair import repair_json
|
|||||||
logger = get_logger("memory_activator")
|
logger = get_logger("memory_activator")
|
||||||
|
|
||||||
|
|
||||||
def get_keywords_from_json(json_str):
|
def get_keywords_from_json(json_str) -> List:
|
||||||
"""
|
"""
|
||||||
从JSON字符串中提取关键词列表
|
从JSON字符串中提取关键词列表
|
||||||
|
|
||||||
@@ -28,15 +28,8 @@ def get_keywords_from_json(json_str):
|
|||||||
fixed_json = repair_json(json_str)
|
fixed_json = repair_json(json_str)
|
||||||
|
|
||||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||||
if isinstance(fixed_json, str):
|
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||||
result = json.loads(fixed_json)
|
return result.get("keywords", [])
|
||||||
else:
|
|
||||||
# 如果repair_json直接返回了字典对象,直接使用
|
|
||||||
result = fixed_json
|
|
||||||
|
|
||||||
# 提取关键词
|
|
||||||
keywords = result.get("keywords", [])
|
|
||||||
return keywords
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解析关键词JSON失败: {e}")
|
logger.error(f"解析关键词JSON失败: {e}")
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -1,52 +1,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import stats
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
class DistributionVisualizer:
|
|
||||||
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
|
|
||||||
"""
|
|
||||||
初始化分布可视化器
|
|
||||||
|
|
||||||
参数:
|
|
||||||
mean (float): 期望均值
|
|
||||||
std (float): 标准差
|
|
||||||
skewness (float): 偏度
|
|
||||||
sample_size (int): 样本大小
|
|
||||||
"""
|
|
||||||
self.mean = mean
|
|
||||||
self.std = std
|
|
||||||
self.skewness = skewness
|
|
||||||
self.sample_size = sample_size
|
|
||||||
self.samples = None
|
|
||||||
|
|
||||||
def generate_samples(self):
|
|
||||||
"""生成具有指定参数的样本"""
|
|
||||||
if self.skewness == 0:
|
|
||||||
# 对于无偏度的情况,直接使用正态分布
|
|
||||||
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
|
|
||||||
else:
|
|
||||||
# 使用 scipy.stats 生成具有偏度的分布
|
|
||||||
self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size)
|
|
||||||
|
|
||||||
def get_weighted_samples(self):
|
|
||||||
"""获取加权后的样本数列"""
|
|
||||||
if self.samples is None:
|
|
||||||
self.generate_samples()
|
|
||||||
# 将样本值乘以样本大小
|
|
||||||
return self.samples * self.sample_size
|
|
||||||
|
|
||||||
def get_statistics(self):
|
|
||||||
"""获取分布的统计信息"""
|
|
||||||
if self.samples is None:
|
|
||||||
self.generate_samples()
|
|
||||||
|
|
||||||
return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)}
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBuildScheduler:
|
class MemoryBuildScheduler:
|
||||||
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,22 +1,25 @@
|
|||||||
import traceback
|
import traceback
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
from maim_message import UserInfo
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.experimental.only_message_process import MessageProcessor
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.config.config import global_config
|
from src.experimental.only_message_process import MessageProcessor
|
||||||
|
from src.experimental.PFC.pfc_manager import PFCManager
|
||||||
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
|
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
from src.plugin_system.base.base_command import BaseCommand
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
from maim_message import UserInfo
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
import re
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||||
@@ -182,8 +185,8 @@ class ChatBot:
|
|||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform=message.message_info.platform,
|
platform=message.message_info.platform, # type: ignore
|
||||||
user_info=user_info,
|
user_info=user_info, # type: ignore
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -193,8 +196,10 @@ class ChatBot:
|
|||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
# 过滤检查
|
# 过滤检查
|
||||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex(
|
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||||
message.raw_message, chat, user_info
|
message.raw_message, # type: ignore
|
||||||
|
chat,
|
||||||
|
user_info, # type: ignore
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,17 @@ import hashlib
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Optional, TYPE_CHECKING
|
from typing import Dict, Optional, TYPE_CHECKING
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
from ...common.database.database import db
|
|
||||||
from ...common.database.database_model import ChatStreams # 新增导入
|
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database import db
|
||||||
|
from src.common.database.database_model import ChatStreams # 新增导入
|
||||||
|
|
||||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .message import MessageRecv
|
from .message import MessageRecv
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -28,7 +27,7 @@ class ChatMessageContext:
|
|||||||
def __init__(self, message: "MessageRecv"):
|
def __init__(self, message: "MessageRecv"):
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
def get_template_name(self) -> str:
|
def get_template_name(self) -> Optional[str]:
|
||||||
"""获取模板名称"""
|
"""获取模板名称"""
|
||||||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||||
return self.message.message_info.template_info.template_name
|
return self.message.message_info.template_info.template_name
|
||||||
@@ -39,11 +38,12 @@ class ChatMessageContext:
|
|||||||
return self.message
|
return self.message
|
||||||
|
|
||||||
def check_types(self, types: list) -> bool:
|
def check_types(self, types: list) -> bool:
|
||||||
|
# sourcery skip: invert-any-all, use-any, use-next
|
||||||
"""检查消息类型"""
|
"""检查消息类型"""
|
||||||
if not self.message.message_info.format_info.accept_format:
|
if not self.message.message_info.format_info.accept_format: # type: ignore
|
||||||
return False
|
return False
|
||||||
for t in types:
|
for t in types:
|
||||||
if t not in self.message.message_info.format_info.accept_format:
|
if t not in self.message.message_info.format_info.accept_format: # type: ignore
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ class ChatStream:
|
|||||||
platform: str,
|
platform: str,
|
||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
group_info: Optional[GroupInfo] = None,
|
group_info: Optional[GroupInfo] = None,
|
||||||
data: dict = None,
|
data: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
@@ -76,7 +76,7 @@ class ChatStream:
|
|||||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||||
self.saved = False
|
self.saved = False
|
||||||
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
|
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
@@ -98,7 +98,7 @@ class ChatStream:
|
|||||||
return cls(
|
return cls(
|
||||||
stream_id=data["stream_id"],
|
stream_id=data["stream_id"],
|
||||||
platform=data["platform"],
|
platform=data["platform"],
|
||||||
user_info=user_info,
|
user_info=user_info, # type: ignore
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
@@ -162,8 +162,8 @@ class ChatManager:
|
|||||||
def register_message(self, message: "MessageRecv"):
|
def register_message(self, message: "MessageRecv"):
|
||||||
"""注册消息到聊天流"""
|
"""注册消息到聊天流"""
|
||||||
stream_id = self._generate_stream_id(
|
stream_id = self._generate_stream_id(
|
||||||
message.message_info.platform,
|
message.message_info.platform, # type: ignore
|
||||||
message.message_info.user_info,
|
message.message_info.user_info, # type: ignore
|
||||||
message.message_info.group_info,
|
message.message_info.group_info,
|
||||||
)
|
)
|
||||||
self.last_messages[stream_id] = message
|
self.last_messages[stream_id] = message
|
||||||
@@ -184,10 +184,7 @@ class ChatManager:
|
|||||||
|
|
||||||
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
|
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
|
||||||
"""获取聊天流ID"""
|
"""获取聊天流ID"""
|
||||||
if is_group:
|
components = [platform, id] if is_group else [platform, id, "private"]
|
||||||
components = [platform, str(id)]
|
|
||||||
else:
|
|
||||||
components = [platform, str(id), "private"]
|
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Any, TYPE_CHECKING
|
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .chat_stream import ChatStream
|
|
||||||
from ..utils.utils_image import get_image_manager
|
|
||||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from typing import Optional, Any
|
||||||
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.utils.utils_image import get_image_manager
|
||||||
|
from .chat_stream import ChatStream
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Message(MessageBase):
|
class Message(MessageBase):
|
||||||
chat_stream: "ChatStream" = None
|
chat_stream: "ChatStream" = None # type: ignore
|
||||||
reply: Optional["Message"] = None
|
reply: Optional["Message"] = None
|
||||||
processed_plain_text: str = ""
|
processed_plain_text: str = ""
|
||||||
memorized_times: int = 0
|
memorized_times: int = 0
|
||||||
@@ -55,7 +53,7 @@ class Message(MessageBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None)
|
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
|
||||||
|
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
# 文本处理相关属性
|
# 文本处理相关属性
|
||||||
@@ -66,6 +64,7 @@ class Message(MessageBase):
|
|||||||
self.reply = reply
|
self.reply = reply
|
||||||
|
|
||||||
async def _process_message_segments(self, segment: Seg) -> str:
|
async def _process_message_segments(self, segment: Seg) -> str:
|
||||||
|
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
|
||||||
"""递归处理消息段,转换为文字描述
|
"""递归处理消息段,转换为文字描述
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -78,13 +77,13 @@ class Message(MessageBase):
|
|||||||
# 处理消息段列表
|
# 处理消息段列表
|
||||||
segments_text = []
|
segments_text = []
|
||||||
for seg in segment.data:
|
for seg in segment.data:
|
||||||
processed = await self._process_message_segments(seg)
|
processed = await self._process_message_segments(seg) # type: ignore
|
||||||
if processed:
|
if processed:
|
||||||
segments_text.append(processed)
|
segments_text.append(processed)
|
||||||
return " ".join(segments_text)
|
return " ".join(segments_text)
|
||||||
else:
|
else:
|
||||||
# 处理单个消息段
|
# 处理单个消息段
|
||||||
return await self._process_single_segment(segment)
|
return await self._process_single_segment(segment) # type: ignore
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _process_single_segment(self, segment):
|
async def _process_single_segment(self, segment):
|
||||||
@@ -113,7 +112,7 @@ class MessageRecv(Message):
|
|||||||
self.is_mentioned = None
|
self.is_mentioned = None
|
||||||
self.priority_mode = "interest"
|
self.priority_mode = "interest"
|
||||||
self.priority_info = None
|
self.priority_info = None
|
||||||
self.interest_value = None
|
self.interest_value: float = None # type: ignore
|
||||||
|
|
||||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
@@ -139,7 +138,7 @@ class MessageRecv(Message):
|
|||||||
if segment.type == "text":
|
if segment.type == "text":
|
||||||
self.is_picid = False
|
self.is_picid = False
|
||||||
self.is_emoji = False
|
self.is_emoji = False
|
||||||
return segment.data
|
return segment.data # type: ignore
|
||||||
elif segment.type == "image":
|
elif segment.type == "image":
|
||||||
# 如果是base64图片数据
|
# 如果是base64图片数据
|
||||||
if isinstance(segment.data, str):
|
if isinstance(segment.data, str):
|
||||||
@@ -161,7 +160,7 @@ class MessageRecv(Message):
|
|||||||
elif segment.type == "mention_bot":
|
elif segment.type == "mention_bot":
|
||||||
self.is_picid = False
|
self.is_picid = False
|
||||||
self.is_emoji = False
|
self.is_emoji = False
|
||||||
self.is_mentioned = float(segment.data)
|
self.is_mentioned = float(segment.data) # type: ignore
|
||||||
return ""
|
return ""
|
||||||
elif segment.type == "priority_info":
|
elif segment.type == "priority_info":
|
||||||
self.is_picid = False
|
self.is_picid = False
|
||||||
@@ -187,7 +186,7 @@ class MessageRecv(Message):
|
|||||||
"""生成详细文本,包含时间和用户信息"""
|
"""生成详细文本,包含时间和用户信息"""
|
||||||
timestamp = self.message_info.time
|
timestamp = self.message_info.time
|
||||||
user_info = self.message_info.user_info
|
user_info = self.message_info.user_info
|
||||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
|
||||||
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
||||||
|
|
||||||
|
|
||||||
@@ -235,7 +234,7 @@ class MessageProcessBase(Message):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if seg.type == "text":
|
if seg.type == "text":
|
||||||
return seg.data
|
return seg.data # type: ignore
|
||||||
elif seg.type == "image":
|
elif seg.type == "image":
|
||||||
# 如果是base64图片数据
|
# 如果是base64图片数据
|
||||||
if isinstance(seg.data, str):
|
if isinstance(seg.data, str):
|
||||||
@@ -251,7 +250,7 @@ class MessageProcessBase(Message):
|
|||||||
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
||||||
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
||||||
# print(f"reply: {self.reply}")
|
# print(f"reply: {self.reply}")
|
||||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]"
|
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return f"[{seg.type}:{str(seg.data)}]"
|
return f"[{seg.type}:{str(seg.data)}]"
|
||||||
@@ -265,7 +264,7 @@ class MessageProcessBase(Message):
|
|||||||
timestamp = self.message_info.time
|
timestamp = self.message_info.time
|
||||||
user_info = self.message_info.user_info
|
user_info = self.message_info.user_info
|
||||||
|
|
||||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
|
||||||
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
||||||
|
|
||||||
|
|
||||||
@@ -314,7 +313,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
is_emoji: bool = False,
|
is_emoji: bool = False,
|
||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
apply_set_reply_logic: bool = False,
|
apply_set_reply_logic: bool = False,
|
||||||
reply_to: str = None,
|
reply_to: str = None, # type: ignore
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -347,7 +346,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
self.message_segment = Seg(
|
self.message_segment = Seg(
|
||||||
type="seglist",
|
type="seglist",
|
||||||
data=[
|
data=[
|
||||||
Seg(type="reply", data=self.reply.message_info.message_id),
|
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
|
||||||
self.message_segment,
|
self.message_segment,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -367,10 +366,10 @@ class MessageSending(MessageProcessBase):
|
|||||||
) -> "MessageSending":
|
) -> "MessageSending":
|
||||||
"""从思考状态消息创建发送状态消息"""
|
"""从思考状态消息创建发送状态消息"""
|
||||||
return cls(
|
return cls(
|
||||||
message_id=thinking.message_info.message_id,
|
message_id=thinking.message_info.message_id, # type: ignore
|
||||||
chat_stream=thinking.chat_stream,
|
chat_stream=thinking.chat_stream,
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
bot_user_info=thinking.message_info.user_info,
|
bot_user_info=thinking.message_info.user_info, # type: ignore
|
||||||
reply=thinking.reply,
|
reply=thinking.reply,
|
||||||
is_head=is_head,
|
is_head=is_head,
|
||||||
is_emoji=is_emoji,
|
is_emoji=is_emoji,
|
||||||
@@ -402,13 +401,11 @@ class MessageSet:
|
|||||||
if not isinstance(message, MessageSending):
|
if not isinstance(message, MessageSending):
|
||||||
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
||||||
self.messages.append(message)
|
self.messages.append(message)
|
||||||
self.messages.sort(key=lambda x: x.message_info.time)
|
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||||
|
|
||||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||||
"""通过索引获取消息"""
|
"""通过索引获取消息"""
|
||||||
if 0 <= index < len(self.messages):
|
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||||
return self.messages[index]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||||
"""获取最接近指定时间的消息"""
|
"""获取最接近指定时间的消息"""
|
||||||
@@ -418,7 +415,7 @@ class MessageSet:
|
|||||||
left, right = 0, len(self.messages) - 1
|
left, right = 0, len(self.messages) - 1
|
||||||
while left < right:
|
while left < right:
|
||||||
mid = (left + right) // 2
|
mid = (left + right) // 2
|
||||||
if self.messages[mid].message_info.time < target_time:
|
if self.messages[mid].message_info.time < target_time: # type: ignore
|
||||||
left = mid + 1
|
left = mid + 1
|
||||||
else:
|
else:
|
||||||
right = mid
|
right = mid
|
||||||
@@ -444,11 +441,8 @@ class MessageSet:
|
|||||||
|
|
||||||
|
|
||||||
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
|
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
|
||||||
return MessageRecv(
|
return MessageRecv(message_dict)
|
||||||
|
|
||||||
message_dict
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
||||||
"""从数据库字典创建MessageRecv实例"""
|
"""从数据库字典创建MessageRecv实例"""
|
||||||
|
|||||||
@@ -2,11 +2,10 @@ import re
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
|
from src.common.database.database_model import Messages, RecalledMessages, Images
|
||||||
from .message import MessageSending, MessageRecv
|
|
||||||
from .chat_stream import ChatStream
|
|
||||||
from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from .chat_stream import ChatStream
|
||||||
|
from .message import MessageSending, MessageRecv
|
||||||
|
|
||||||
logger = get_logger("message_storage")
|
logger = get_logger("message_storage")
|
||||||
|
|
||||||
@@ -55,7 +54,7 @@ class MessageStorage:
|
|||||||
is_picid = message.is_picid
|
is_picid = message.is_picid
|
||||||
|
|
||||||
chat_info_dict = chat_stream.to_dict()
|
chat_info_dict = chat_stream.to_dict()
|
||||||
user_info_dict = message.message_info.user_info.to_dict()
|
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||||
|
|
||||||
# message_id 现在是 TextField,直接使用字符串值
|
# message_id 现在是 TextField,直接使用字符串值
|
||||||
msg_id = message.message_info.message_id
|
msg_id = message.message_info.message_id
|
||||||
@@ -67,7 +66,7 @@ class MessageStorage:
|
|||||||
|
|
||||||
Messages.create(
|
Messages.create(
|
||||||
message_id=msg_id,
|
message_id=msg_id,
|
||||||
time=float(message.message_info.time),
|
time=float(message.message_info.time), # type: ignore
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
# Flattened chat_info
|
# Flattened chat_info
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
@@ -121,7 +120,7 @@ class MessageStorage:
|
|||||||
try:
|
try:
|
||||||
# Assuming input 'time' is a string timestamp that can be converted to float
|
# Assuming input 'time' is a string timestamp that can be converted to float
|
||||||
current_time_float = float(time)
|
current_time_float = float(time)
|
||||||
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
|
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("删除撤回消息失败")
|
logger.exception("删除撤回消息失败")
|
||||||
|
|
||||||
@@ -133,22 +132,19 @@ class MessageStorage:
|
|||||||
"""更新最新一条匹配消息的message_id"""
|
"""更新最新一条匹配消息的message_id"""
|
||||||
try:
|
try:
|
||||||
if message.message_segment.type == "notify":
|
if message.message_segment.type == "notify":
|
||||||
mmc_message_id = message.message_segment.data.get("echo")
|
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
|
||||||
qq_message_id = message.message_segment.data.get("actual_id")
|
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
|
||||||
else:
|
else:
|
||||||
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
||||||
return
|
return
|
||||||
if not qq_message_id:
|
if not qq_message_id:
|
||||||
logger.info("消息不存在message_id,无法更新")
|
logger.info("消息不存在message_id,无法更新")
|
||||||
return
|
return
|
||||||
# 查询最新一条匹配消息
|
if matched_message := (
|
||||||
matched_message = (
|
|
||||||
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
||||||
)
|
):
|
||||||
|
|
||||||
if matched_message:
|
|
||||||
# 更新找到的消息记录
|
# 更新找到的消息记录
|
||||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute()
|
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
|
||||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||||
else:
|
else:
|
||||||
logger.debug("未找到匹配的消息")
|
logger.debug("未找到匹配的消息")
|
||||||
@@ -173,10 +169,7 @@ class MessageStorage:
|
|||||||
image_record = (
|
image_record = (
|
||||||
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
||||||
)
|
)
|
||||||
if image_record:
|
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||||
return f"[picid:{image_record.image_id}]"
|
|
||||||
else:
|
|
||||||
return match.group(0) # 保持原样
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return match.group(0)
|
return match.group(0)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from src.chat.message_receive.message import MessageSending
|
|
||||||
from src.common.message.api import get_global_api
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
|
||||||
from src.chat.utils.utils import truncate_message
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.utils.utils import calculate_typing_time
|
|
||||||
from rich.traceback import install
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
install(extra_lines=3)
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from src.common.message.api import get_global_api
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.message_receive.message import MessageSending
|
||||||
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
|
from src.chat.utils.utils import truncate_message
|
||||||
|
from src.chat.utils.utils import calculate_typing_time
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("sender")
|
logger = get_logger("sender")
|
||||||
|
|
||||||
@@ -49,10 +50,10 @@ class HeartFCSender:
|
|||||||
"""
|
"""
|
||||||
if not message.chat_stream:
|
if not message.chat_stream:
|
||||||
logger.error("消息缺少 chat_stream,无法发送")
|
logger.error("消息缺少 chat_stream,无法发送")
|
||||||
raise Exception("消息缺少 chat_stream,无法发送")
|
raise ValueError("消息缺少 chat_stream,无法发送")
|
||||||
if not message.message_info or not message.message_info.message_id:
|
if not message.message_info or not message.message_info.message_id:
|
||||||
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
||||||
raise Exception("消息缺少 message_info 或 message_id,无法发送")
|
raise ValueError("消息缺少 message_info 或 message_id,无法发送")
|
||||||
|
|
||||||
chat_id = message.chat_stream.stream_id
|
chat_id = message.chat_stream.stream_id
|
||||||
message_id = message.message_info.message_id
|
message_id = message.message_info.message_id
|
||||||
@@ -84,4 +85,3 @@ class HeartFCSender:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
|
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +1,30 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_system.base.component_types import ChatMode
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from ..message_receive.message import MessageThinking
|
from src.chat.message_receive.message import MessageThinking
|
||||||
from src.chat.message_receive.normal_message_sender import message_manager
|
|
||||||
from src.chat.normal_chat.willing.willing_manager import get_willing_manager
|
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||||
from ..focus_chat.priority_manager import PriorityManager
|
from src.chat.focus_chat.priority_manager import PriorityManager
|
||||||
import traceback
|
|
||||||
from src.chat.planner_actions.planner import ActionPlanner
|
from src.chat.planner_actions.planner import ActionPlanner
|
||||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
|
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
|
|
||||||
|
|
||||||
willing_manager = get_willing_manager()
|
willing_manager = get_willing_manager()
|
||||||
|
|
||||||
logger = get_logger("normal_chat")
|
logger = get_logger("normal_chat")
|
||||||
|
|
||||||
LOOP_INTERVAL = 0.3
|
LOOP_INTERVAL = 0.3
|
||||||
|
|
||||||
|
|
||||||
class NormalChat:
|
class NormalChat:
|
||||||
"""
|
"""
|
||||||
普通聊天处理类,负责处理非核心对话的聊天逻辑。
|
普通聊天处理类,负责处理非核心对话的聊天逻辑。
|
||||||
@@ -63,7 +65,7 @@ class NormalChat:
|
|||||||
|
|
||||||
# Planner相关初始化
|
# Planner相关初始化
|
||||||
self.action_manager = ActionManager()
|
self.action_manager = ActionManager()
|
||||||
self.planner = ActionPlanner(self.stream_id, self.action_manager, mode="normal")
|
self.planner = ActionPlanner(self.stream_id, self.action_manager, mode=ChatMode.NORMAL)
|
||||||
self.action_modifier = ActionModifier(self.action_manager, self.stream_id)
|
self.action_modifier = ActionModifier(self.action_manager, self.stream_id)
|
||||||
self.enable_planner = global_config.normal_chat.enable_planner # 从配置中读取是否启用planner
|
self.enable_planner = global_config.normal_chat.enable_planner # 从配置中读取是否启用planner
|
||||||
|
|
||||||
@@ -134,7 +136,6 @@ class NormalChat:
|
|||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logger.error(f"[{self.stream_name}] 处理消息时出错: {e} {traceback.format_exc()}")
|
# logger.error(f"[{self.stream_name}] 处理消息时出错: {e} {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
# except asyncio.CancelledError:
|
# except asyncio.CancelledError:
|
||||||
# logger.info(f"[{self.stream_name}] 兴趣模式轮询任务被取消")
|
# logger.info(f"[{self.stream_name}] 兴趣模式轮询任务被取消")
|
||||||
# return False
|
# return False
|
||||||
@@ -165,7 +166,6 @@ class NormalChat:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.stream_name}] 添加消息到优先级队列时出错: {e} {traceback.format_exc()}")
|
logger.error(f"[{self.stream_name}] 添加消息到优先级队列时出错: {e} {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"[{self.stream_name}] 优先级消息生产者任务被取消")
|
logger.info(f"[{self.stream_name}] 优先级消息生产者任务被取消")
|
||||||
return False
|
return False
|
||||||
@@ -188,9 +188,6 @@ class NormalChat:
|
|||||||
# except asyncio.CancelledError:
|
# except asyncio.CancelledError:
|
||||||
# logger.info(f"[{self.stream_name}] 兴趣模式消息轮询任务被优雅地取消了")
|
# logger.info(f"[{self.stream_name}] 兴趣模式消息轮询任务被优雅地取消了")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _priority_chat_loop(self):
|
async def _priority_chat_loop(self):
|
||||||
"""
|
"""
|
||||||
使用优先级队列的消息处理循环。
|
使用优先级队列的消息处理循环。
|
||||||
@@ -275,7 +272,6 @@ class NormalChat:
|
|||||||
|
|
||||||
# reply = message_from_db_dict(message_data)
|
# reply = message_from_db_dict(message_data)
|
||||||
|
|
||||||
|
|
||||||
# mark_head = False
|
# mark_head = False
|
||||||
# first_bot_msg = None
|
# first_bot_msg = None
|
||||||
# for msg in response_set:
|
# for msg in response_set:
|
||||||
@@ -652,7 +648,9 @@ class NormalChat:
|
|||||||
# Start consumer loop
|
# Start consumer loop
|
||||||
consumer_task = asyncio.create_task(self._priority_chat_loop())
|
consumer_task = asyncio.create_task(self._priority_chat_loop())
|
||||||
self._priority_chat_task = consumer_task
|
self._priority_chat_task = consumer_task
|
||||||
self._priority_chat_task.add_done_callback(lambda t: self._handle_task_completion(t, "priority_consumer"))
|
self._priority_chat_task.add_done_callback(
|
||||||
|
lambda t: self._handle_task_completion(t, "priority_consumer")
|
||||||
|
)
|
||||||
else: # Interest mode
|
else: # Interest mode
|
||||||
polling_task = asyncio.create_task(self._interest_message_polling_loop())
|
polling_task = asyncio.create_task(self._interest_message_polling_loop())
|
||||||
self._chat_task = polling_task
|
self._chat_task = polling_task
|
||||||
@@ -712,7 +710,6 @@ class NormalChat:
|
|||||||
self._chat_task = None
|
self._chat_task = None
|
||||||
self._priority_chat_task = None
|
self._priority_chat_task = None
|
||||||
|
|
||||||
|
|
||||||
# def adjust_reply_frequency(self):
|
# def adjust_reply_frequency(self):
|
||||||
# """
|
# """
|
||||||
# 根据预设规则动态调整回复意愿(willing_amplifier)。
|
# 根据预设规则动态调整回复意愿(willing_amplifier)。
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
from typing import Dict, List, Optional, Type, Any
|
from typing import Dict, List, Optional, Type
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
from src.plugin_system.base.component_types import ComponentType
|
from src.plugin_system.base.component_types import ComponentType, ActionActivationType, ChatMode, ActionInfo
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
# 定义动作信息类型
|
|
||||||
ActionInfo = Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class ActionManager:
|
class ActionManager:
|
||||||
"""
|
"""
|
||||||
@@ -20,8 +17,8 @@ class ActionManager:
|
|||||||
|
|
||||||
# 类常量
|
# 类常量
|
||||||
DEFAULT_RANDOM_PROBABILITY = 0.3
|
DEFAULT_RANDOM_PROBABILITY = 0.3
|
||||||
DEFAULT_MODE = "all"
|
DEFAULT_MODE = ChatMode.ALL
|
||||||
DEFAULT_ACTIVATION_TYPE = "always"
|
DEFAULT_ACTIVATION_TYPE = ActionActivationType.ALWAYS
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化动作管理器"""
|
"""初始化动作管理器"""
|
||||||
@@ -30,14 +27,11 @@ class ActionManager:
|
|||||||
# 当前正在使用的动作集合,默认加载默认动作
|
# 当前正在使用的动作集合,默认加载默认动作
|
||||||
self._using_actions: Dict[str, ActionInfo] = {}
|
self._using_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
# 默认动作集,仅作为快照,用于恢复默认
|
|
||||||
self._default_actions: Dict[str, ActionInfo] = {}
|
|
||||||
|
|
||||||
# 加载插件动作
|
# 加载插件动作
|
||||||
self._load_plugin_actions()
|
self._load_plugin_actions()
|
||||||
|
|
||||||
# 初始化时将默认动作加载到使用中的动作
|
# 初始化时将默认动作加载到使用中的动作
|
||||||
self._using_actions = self._default_actions.copy()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
|
|
||||||
def _load_plugin_actions(self) -> None:
|
def _load_plugin_actions(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -54,43 +48,15 @@ class ActionManager:
|
|||||||
def _load_plugin_system_actions(self) -> None:
|
def _load_plugin_system_actions(self) -> None:
|
||||||
"""从插件系统的component_registry加载Action组件"""
|
"""从插件系统的component_registry加载Action组件"""
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
from src.plugin_system.base.component_types import ComponentType
|
|
||||||
|
|
||||||
# 获取所有Action组件
|
# 获取所有Action组件
|
||||||
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
|
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
|
||||||
|
|
||||||
for action_name, action_info in action_components.items():
|
for action_name, action_info in action_components.items():
|
||||||
if action_name in self._registered_actions:
|
if action_name in self._registered_actions:
|
||||||
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 将插件系统的ActionInfo转换为ActionManager格式
|
self._registered_actions[action_name] = action_info
|
||||||
converted_action_info = {
|
|
||||||
"description": action_info.description,
|
|
||||||
"parameters": getattr(action_info, "action_parameters", {}),
|
|
||||||
"require": getattr(action_info, "action_require", []),
|
|
||||||
"associated_types": getattr(action_info, "associated_types", []),
|
|
||||||
"enable_plugin": action_info.enabled,
|
|
||||||
# 激活类型相关
|
|
||||||
"focus_activation_type": action_info.focus_activation_type.value,
|
|
||||||
"normal_activation_type": action_info.normal_activation_type.value,
|
|
||||||
"random_activation_probability": action_info.random_activation_probability,
|
|
||||||
"llm_judge_prompt": action_info.llm_judge_prompt,
|
|
||||||
"activation_keywords": action_info.activation_keywords,
|
|
||||||
"keyword_case_sensitive": action_info.keyword_case_sensitive,
|
|
||||||
# 模式和并行设置
|
|
||||||
"mode_enable": action_info.mode_enable.value,
|
|
||||||
"parallel_action": action_info.parallel_action,
|
|
||||||
# 插件信息
|
|
||||||
"_plugin_name": getattr(action_info, "plugin_name", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
self._registered_actions[action_name] = converted_action_info
|
|
||||||
|
|
||||||
# 如果启用,也添加到默认动作集
|
|
||||||
if action_info.enabled:
|
|
||||||
self._default_actions[action_name] = converted_action_info
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
||||||
@@ -133,7 +99,9 @@ class ActionManager:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取组件类 - 明确指定查询Action类型
|
# 获取组件类 - 明确指定查询Action类型
|
||||||
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
component_class: Type[BaseAction] = component_registry.get_component_class(
|
||||||
|
action_name, ComponentType.ACTION
|
||||||
|
) # type: ignore
|
||||||
if not component_class:
|
if not component_class:
|
||||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||||
return None
|
return None
|
||||||
@@ -173,10 +141,6 @@ class ActionManager:
|
|||||||
"""获取所有已注册的动作集"""
|
"""获取所有已注册的动作集"""
|
||||||
return self._registered_actions.copy()
|
return self._registered_actions.copy()
|
||||||
|
|
||||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
|
||||||
"""获取默认动作集"""
|
|
||||||
return self._default_actions.copy()
|
|
||||||
|
|
||||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||||
"""获取当前正在使用的动作集合"""
|
"""获取当前正在使用的动作集合"""
|
||||||
return self._using_actions.copy()
|
return self._using_actions.copy()
|
||||||
@@ -221,31 +185,31 @@ class ActionManager:
|
|||||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
||||||
"""
|
# """
|
||||||
添加新的动作到注册集
|
# 添加新的动作到注册集
|
||||||
|
|
||||||
Args:
|
# Args:
|
||||||
action_name: 动作名称
|
# action_name: 动作名称
|
||||||
description: 动作描述
|
# description: 动作描述
|
||||||
parameters: 动作参数定义,默认为空字典
|
# parameters: 动作参数定义,默认为空字典
|
||||||
require: 动作依赖项,默认为空列表
|
# require: 动作依赖项,默认为空列表
|
||||||
|
|
||||||
Returns:
|
# Returns:
|
||||||
bool: 添加是否成功
|
# bool: 添加是否成功
|
||||||
"""
|
# """
|
||||||
if action_name in self._registered_actions:
|
# if action_name in self._registered_actions:
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
if parameters is None:
|
# if parameters is None:
|
||||||
parameters = {}
|
# parameters = {}
|
||||||
if require is None:
|
# if require is None:
|
||||||
require = []
|
# require = []
|
||||||
|
|
||||||
action_info = {"description": description, "parameters": parameters, "require": require}
|
# action_info = {"description": description, "parameters": parameters, "require": require}
|
||||||
|
|
||||||
self._registered_actions[action_name] = action_info
|
# self._registered_actions[action_name] = action_info
|
||||||
return True
|
# return True
|
||||||
|
|
||||||
def remove_action(self, action_name: str) -> bool:
|
def remove_action(self, action_name: str) -> bool:
|
||||||
"""从注册集移除指定动作"""
|
"""从注册集移除指定动作"""
|
||||||
@@ -264,10 +228,9 @@ class ActionManager:
|
|||||||
|
|
||||||
def restore_actions(self) -> None:
|
def restore_actions(self) -> None:
|
||||||
"""恢复到默认动作集"""
|
"""恢复到默认动作集"""
|
||||||
logger.debug(
|
actions_to_restore = list(self._using_actions.keys())
|
||||||
f"恢复动作集: 从 {list(self._using_actions.keys())} 恢复到默认动作集 {list(self._default_actions.keys())}"
|
self._using_actions = component_registry.get_default_actions()
|
||||||
)
|
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||||
self._using_actions = self._default_actions.copy()
|
|
||||||
|
|
||||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -297,4 +260,4 @@ class ActionManager:
|
|||||||
"""
|
"""
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
return component_registry.get_component_class(action_name)
|
return component_registry.get_component_class(action_name) # type: ignore
|
||||||
|
|||||||
@@ -1,15 +1,20 @@
|
|||||||
from typing import List, Any, Dict
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.focus_chat.hfc_utils import CycleDetail
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
|
from typing import List, Any, Dict, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.chat.focus_chat.hfc_utils import CycleDetail
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||||
|
from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
@@ -25,7 +30,7 @@ class ActionModifier:
|
|||||||
def __init__(self, action_manager: ActionManager, chat_id: str):
|
def __init__(self, action_manager: ActionManager, chat_id: str):
|
||||||
"""初始化动作处理器"""
|
"""初始化动作处理器"""
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
|
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
|
||||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||||
|
|
||||||
self.action_manager = action_manager
|
self.action_manager = action_manager
|
||||||
@@ -45,7 +50,7 @@ class ActionModifier:
|
|||||||
self,
|
self,
|
||||||
history_loop=None,
|
history_loop=None,
|
||||||
message_content: str = "",
|
message_content: str = "",
|
||||||
):
|
): # sourcery skip: use-named-expression
|
||||||
"""
|
"""
|
||||||
动作修改流程,整合传统观察处理和新的激活类型判定
|
动作修改流程,整合传统观察处理和新的激活类型判定
|
||||||
|
|
||||||
@@ -125,12 +130,11 @@ class ActionModifier:
|
|||||||
f"{self.log_prefix} 动作修改流程结束,最终可用动作: {list(self.action_manager.get_using_actions().keys())}||移除记录: {removals_summary}"
|
f"{self.log_prefix} 动作修改流程结束,最终可用动作: {list(self.action_manager.get_using_actions().keys())}||移除记录: {removals_summary}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_action_associated_types(self, all_actions, chat_context):
|
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||||
type_mismatched_actions = []
|
type_mismatched_actions = []
|
||||||
for action_name, data in all_actions.items():
|
for action_name, action_info in all_actions.items():
|
||||||
if data.get("associated_types"):
|
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||||
if not chat_context.check_types(data["associated_types"]):
|
associated_types_str = ", ".join(action_info.associated_types)
|
||||||
associated_types_str = ", ".join(data["associated_types"])
|
|
||||||
reason = f"适配器不支持(需要: {associated_types_str})"
|
reason = f"适配器不支持(需要: {associated_types_str})"
|
||||||
type_mismatched_actions.append((action_name, reason))
|
type_mismatched_actions.append((action_name, reason))
|
||||||
logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}")
|
logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}")
|
||||||
@@ -167,21 +171,21 @@ class ActionModifier:
|
|||||||
if activation_type == "always":
|
if activation_type == "always":
|
||||||
continue # 总是激活,无需处理
|
continue # 总是激活,无需处理
|
||||||
|
|
||||||
elif activation_type == "random":
|
elif activation_type == ActionActivationType.RANDOM:
|
||||||
probability = action_info.get("random_activation_probability", ActionManager.DEFAULT_RANDOM_PROBABILITY)
|
probability = action_info.random_activation_probability or ActionManager.DEFAULT_RANDOM_PROBABILITY
|
||||||
if not (random.random() < probability):
|
if random.random() >= probability:
|
||||||
reason = f"RANDOM类型未触发(概率{probability})"
|
reason = f"RANDOM类型未触发(概率{probability})"
|
||||||
deactivated_actions.append((action_name, reason))
|
deactivated_actions.append((action_name, reason))
|
||||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||||
|
|
||||||
elif activation_type == "keyword":
|
elif activation_type == ActionActivationType.KEYWORD:
|
||||||
if not self._check_keyword_activation(action_name, action_info, chat_content):
|
if not self._check_keyword_activation(action_name, action_info, chat_content):
|
||||||
keywords = action_info.get("activation_keywords", [])
|
keywords = action_info.activation_keywords
|
||||||
reason = f"关键词未匹配(关键词: {keywords})"
|
reason = f"关键词未匹配(关键词: {keywords})"
|
||||||
deactivated_actions.append((action_name, reason))
|
deactivated_actions.append((action_name, reason))
|
||||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
|
||||||
|
|
||||||
elif activation_type == "llm_judge":
|
elif activation_type == ActionActivationType.LLM_JUDGE:
|
||||||
llm_judge_actions[action_name] = action_info
|
llm_judge_actions[action_name] = action_info
|
||||||
|
|
||||||
elif activation_type == "never":
|
elif activation_type == "never":
|
||||||
@@ -273,7 +277,7 @@ class ActionModifier:
|
|||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 处理结果并更新缓存
|
# 处理结果并更新缓存
|
||||||
for _, (action_name, result) in enumerate(zip(task_names, task_results, strict=False)):
|
for action_name, result in zip(task_names, task_results, strict=False):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||||
results[action_name] = False
|
results[action_name] = False
|
||||||
@@ -289,7 +293,7 @@ class ActionModifier:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
||||||
# 如果并行执行失败,为所有任务返回False
|
# 如果并行执行失败,为所有任务返回False
|
||||||
for action_name in tasks_to_run.keys():
|
for action_name in tasks_to_run:
|
||||||
results[action_name] = False
|
results[action_name] = False
|
||||||
|
|
||||||
# 清理过期缓存
|
# 清理过期缓存
|
||||||
@@ -300,10 +304,11 @@ class ActionModifier:
|
|||||||
def _cleanup_expired_cache(self, current_time: float):
|
def _cleanup_expired_cache(self, current_time: float):
|
||||||
"""清理过期的缓存条目"""
|
"""清理过期的缓存条目"""
|
||||||
expired_keys = []
|
expired_keys = []
|
||||||
for cache_key, cache_data in self._llm_judge_cache.items():
|
expired_keys.extend(
|
||||||
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
|
cache_key
|
||||||
expired_keys.append(cache_key)
|
for cache_key, cache_data in self._llm_judge_cache.items()
|
||||||
|
if current_time - cache_data["timestamp"] > self._cache_expiry_time
|
||||||
|
)
|
||||||
for key in expired_keys:
|
for key in expired_keys:
|
||||||
del self._llm_judge_cache[key]
|
del self._llm_judge_cache[key]
|
||||||
|
|
||||||
@@ -382,7 +387,7 @@ class ActionModifier:
|
|||||||
def _check_keyword_activation(
|
def _check_keyword_activation(
|
||||||
self,
|
self,
|
||||||
action_name: str,
|
action_name: str,
|
||||||
action_info: Dict[str, Any],
|
action_info: ActionInfo,
|
||||||
chat_content: str = "",
|
chat_content: str = "",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -399,8 +404,8 @@ class ActionModifier:
|
|||||||
bool: 是否应该激活此action
|
bool: 是否应该激活此action
|
||||||
"""
|
"""
|
||||||
|
|
||||||
activation_keywords = action_info.get("activation_keywords", [])
|
activation_keywords = action_info.activation_keywords
|
||||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
case_sensitive = action_info.keyword_case_sensitive
|
||||||
|
|
||||||
if not activation_keywords:
|
if not activation_keywords:
|
||||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||||
|
|||||||
@@ -1,23 +1,26 @@
|
|||||||
import json # <--- 确保导入 json
|
import json
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from datetime import datetime
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
|
||||||
from json_repair import repair_json
|
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
|
||||||
from datetime import datetime
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_actions,
|
build_readable_actions,
|
||||||
build_readable_messages,
|
|
||||||
get_actions_by_timestamp_with_chat,
|
get_actions_by_timestamp_with_chat,
|
||||||
|
build_readable_messages,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
)
|
)
|
||||||
import time
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
from src.plugin_system.base.component_types import ChatMode, ActionInfo
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("planner")
|
logger = get_logger("planner")
|
||||||
|
|
||||||
@@ -28,7 +31,7 @@ def init_prompt():
|
|||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
{time_block}
|
{time_block}
|
||||||
{indentify_block}
|
{identity_block}
|
||||||
你现在需要根据聊天内容,选择的合适的action来参与聊天。
|
你现在需要根据聊天内容,选择的合适的action来参与聊天。
|
||||||
{chat_context_description},以下是具体的聊天内容:
|
{chat_context_description},以下是具体的聊天内容:
|
||||||
{chat_content_block}
|
{chat_content_block}
|
||||||
@@ -76,7 +79,7 @@ class ActionPlanner:
|
|||||||
|
|
||||||
self.last_obs_time_mark = 0.0
|
self.last_obs_time_mark = 0.0
|
||||||
|
|
||||||
async def plan(self,mode:str = "focus") -> Dict[str, Any]:
|
async def plan(self, mode: str = "focus") -> Dict[str, Any]: # sourcery skip: dict-comprehension
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
@@ -84,6 +87,7 @@ class ActionPlanner:
|
|||||||
action = "no_reply" # 默认动作
|
action = "no_reply" # 默认动作
|
||||||
reasoning = "规划器初始化默认"
|
reasoning = "规划器初始化默认"
|
||||||
action_data = {}
|
action_data = {}
|
||||||
|
current_available_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_group_chat = True
|
is_group_chat = True
|
||||||
@@ -95,7 +99,7 @@ class ActionPlanner:
|
|||||||
|
|
||||||
# 获取完整的动作信息
|
# 获取完整的动作信息
|
||||||
all_registered_actions = self.action_manager.get_registered_actions()
|
all_registered_actions = self.action_manager.get_registered_actions()
|
||||||
current_available_actions = {}
|
|
||||||
for action_name in current_available_actions_dict.keys():
|
for action_name in current_available_actions_dict.keys():
|
||||||
if action_name in all_registered_actions:
|
if action_name in all_registered_actions:
|
||||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||||
@@ -107,13 +111,17 @@ class ActionPlanner:
|
|||||||
len(current_available_actions) == 1 and "no_reply" in current_available_actions
|
len(current_available_actions) == 1 and "no_reply" in current_available_actions
|
||||||
):
|
):
|
||||||
action = "no_reply"
|
action = "no_reply"
|
||||||
reasoning = "没有可用的动作" if not current_available_actions else "只有no_reply动作可用,跳过规划"
|
reasoning = "只有no_reply动作可用,跳过规划" if current_available_actions else "没有可用的动作"
|
||||||
logger.info(f"{self.log_prefix}{reasoning}")
|
logger.info(f"{self.log_prefix}{reasoning}")
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{self.log_prefix}[focus]沉默后恢复到默认动作集, 当前可用: {list(self.action_manager.get_using_actions().keys())}"
|
f"{self.log_prefix}[focus]沉默后恢复到默认动作集, 当前可用: {list(self.action_manager.get_using_actions().keys())}"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning},
|
"action_result": {
|
||||||
|
"action_type": action,
|
||||||
|
"action_data": action_data,
|
||||||
|
"reasoning": reasoning,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||||
@@ -142,7 +150,7 @@ class ActionPlanner:
|
|||||||
|
|
||||||
except Exception as req_e:
|
except Exception as req_e:
|
||||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||||
reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}"
|
reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
|
||||||
action = "no_reply"
|
action = "no_reply"
|
||||||
|
|
||||||
if llm_content:
|
if llm_content:
|
||||||
@@ -164,7 +172,6 @@ class ActionPlanner:
|
|||||||
reasoning = parsed_json.get("reasoning", "未提供原因")
|
reasoning = parsed_json.get("reasoning", "未提供原因")
|
||||||
|
|
||||||
# 将所有其他属性添加到action_data
|
# 将所有其他属性添加到action_data
|
||||||
action_data = {}
|
|
||||||
for key, value in parsed_json.items():
|
for key, value in parsed_json.items():
|
||||||
if key not in ["action", "reasoning"]:
|
if key not in ["action", "reasoning"]:
|
||||||
action_data[key] = value
|
action_data[key] = value
|
||||||
@@ -175,8 +182,8 @@ class ActionPlanner:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
|
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
|
||||||
)
|
)
|
||||||
action = "no_reply"
|
|
||||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
|
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
|
||||||
|
action = "no_reply"
|
||||||
|
|
||||||
except Exception as json_e:
|
except Exception as json_e:
|
||||||
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
|
||||||
@@ -192,8 +199,7 @@ class ActionPlanner:
|
|||||||
|
|
||||||
is_parallel = False
|
is_parallel = False
|
||||||
if action in current_available_actions:
|
if action in current_available_actions:
|
||||||
action_info = current_available_actions[action]
|
is_parallel = current_available_actions[action].parallel_action
|
||||||
is_parallel = action_info.get("parallel_action", False)
|
|
||||||
|
|
||||||
action_result = {
|
action_result = {
|
||||||
"action_type": action,
|
"action_type": action,
|
||||||
@@ -203,20 +209,18 @@ class ActionPlanner:
|
|||||||
"is_parallel": is_parallel,
|
"is_parallel": is_parallel,
|
||||||
}
|
}
|
||||||
|
|
||||||
plan_result = {
|
return {
|
||||||
"action_result": action_result,
|
"action_result": action_result,
|
||||||
"action_prompt": prompt,
|
"action_prompt": prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
return plan_result
|
|
||||||
|
|
||||||
async def build_planner_prompt(
|
async def build_planner_prompt(
|
||||||
self,
|
self,
|
||||||
is_group_chat: bool, # Now passed as argument
|
is_group_chat: bool, # Now passed as argument
|
||||||
chat_target_info: Optional[dict], # Now passed as argument
|
chat_target_info: Optional[dict], # Now passed as argument
|
||||||
current_available_actions,
|
current_available_actions: Dict[str, ActionInfo],
|
||||||
mode: str = "focus",
|
mode: str = "focus",
|
||||||
) -> str:
|
) -> str: # sourcery skip: use-join
|
||||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||||
try:
|
try:
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||||
@@ -281,23 +285,23 @@ class ActionPlanner:
|
|||||||
action_options_block = ""
|
action_options_block = ""
|
||||||
|
|
||||||
for using_actions_name, using_actions_info in current_available_actions.items():
|
for using_actions_name, using_actions_info in current_available_actions.items():
|
||||||
if using_actions_info["parameters"]:
|
if using_actions_info.action_parameters:
|
||||||
param_text = "\n"
|
param_text = "\n"
|
||||||
for param_name, param_description in using_actions_info["parameters"].items():
|
for param_name, param_description in using_actions_info.action_parameters.items():
|
||||||
param_text += f' "{param_name}":"{param_description}"\n'
|
param_text += f' "{param_name}":"{param_description}"\n'
|
||||||
param_text = param_text.rstrip("\n")
|
param_text = param_text.rstrip("\n")
|
||||||
else:
|
else:
|
||||||
param_text = ""
|
param_text = ""
|
||||||
|
|
||||||
require_text = ""
|
require_text = ""
|
||||||
for require_item in using_actions_info["require"]:
|
for require_item in using_actions_info.action_require:
|
||||||
require_text += f"- {require_item}\n"
|
require_text += f"- {require_item}\n"
|
||||||
require_text = require_text.rstrip("\n")
|
require_text = require_text.rstrip("\n")
|
||||||
|
|
||||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
||||||
using_action_prompt = using_action_prompt.format(
|
using_action_prompt = using_action_prompt.format(
|
||||||
action_name=using_actions_name,
|
action_name=using_actions_name,
|
||||||
action_description=using_actions_info["description"],
|
action_description=using_actions_info.description,
|
||||||
action_parameters=param_text,
|
action_parameters=param_text,
|
||||||
action_require=require_text,
|
action_require=require_text,
|
||||||
)
|
)
|
||||||
@@ -314,10 +318,10 @@ class ActionPlanner:
|
|||||||
else:
|
else:
|
||||||
bot_nickname = ""
|
bot_nickname = ""
|
||||||
bot_core_personality = global_config.personality.personality_core
|
bot_core_personality = global_config.personality.personality_core
|
||||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||||
|
|
||||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||||
prompt = planner_prompt_template.format(
|
return planner_prompt_template.format(
|
||||||
time_block=time_block,
|
time_block=time_block,
|
||||||
by_what=by_what,
|
by_what=by_what,
|
||||||
chat_context_description=chat_context_description,
|
chat_context_description=chat_context_description,
|
||||||
@@ -326,10 +330,8 @@ class ActionPlanner:
|
|||||||
no_action_block=no_action_block,
|
no_action_block=no_action_block,
|
||||||
action_options_text=action_options_block,
|
action_options_text=action_options_block,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
indentify_block=indentify_block,
|
identity_block=identity_block,
|
||||||
)
|
)
|
||||||
return prompt
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|||||||
@@ -1,31 +1,31 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import List, Optional, Dict, Any, Tuple
|
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv, MessageSending
|
|
||||||
from src.chat.message_receive.message import Seg # Local import needed after move
|
|
||||||
from src.chat.message_receive.message import UserInfo
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
|
||||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.chat.express.expression_selector import expression_selector
|
|
||||||
from src.mood.mood_manager import mood_manager
|
|
||||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
|
||||||
import random
|
import random
|
||||||
import ast
|
import ast
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
from datetime import datetime
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageThinking, MessageSending
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||||
|
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||||
|
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||||
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||||
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||||
|
from src.mood.mood_manager import mood_manager
|
||||||
|
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||||
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from src.tools.tool_executor import ToolExecutor
|
from src.tools.tool_executor import ToolExecutor
|
||||||
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
|
||||||
logger = get_logger("replyer")
|
logger = get_logger("replyer")
|
||||||
|
|
||||||
@@ -132,25 +132,23 @@ class DefaultReplyer:
|
|||||||
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
|
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
|
||||||
weights = [config.get("weight", 1.0) for config in configs]
|
weights = [config.get("weight", 1.0) for config in configs]
|
||||||
|
|
||||||
# random.choices 返回一个列表,我们取第一个元素
|
return random.choices(population=configs, weights=weights, k=1)[0]
|
||||||
selected_config = random.choices(population=configs, weights=weights, k=1)[0]
|
|
||||||
return selected_config
|
|
||||||
|
|
||||||
async def generate_reply_with_context(
|
async def generate_reply_with_context(
|
||||||
self,
|
self,
|
||||||
reply_data: Dict[str, Any] = None,
|
reply_data: Optional[Dict[str, Any]] = None,
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
extra_info: str = "",
|
extra_info: str = "",
|
||||||
available_actions: List[str] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
enable_tool: bool = True,
|
enable_tool: bool = True,
|
||||||
enable_timeout: bool = False,
|
enable_timeout: bool = False,
|
||||||
) -> Tuple[bool, Optional[str]]:
|
) -> Tuple[bool, Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
回复器 (Replier): 核心逻辑,负责生成回复文本。
|
回复器 (Replier): 核心逻辑,负责生成回复文本。
|
||||||
(已整合原 HeartFCGenerator 的功能)
|
(已整合原 HeartFCGenerator 的功能)
|
||||||
"""
|
"""
|
||||||
if available_actions is None:
|
if available_actions is None:
|
||||||
available_actions = []
|
available_actions = {}
|
||||||
if reply_data is None:
|
if reply_data is None:
|
||||||
reply_data = {}
|
reply_data = {}
|
||||||
try:
|
try:
|
||||||
@@ -202,14 +200,14 @@ class DefaultReplyer:
|
|||||||
except Exception as llm_e:
|
except Exception as llm_e:
|
||||||
# 精简报错信息
|
# 精简报错信息
|
||||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||||
return False, None # LLM 调用失败则无法生成回复
|
return False, None, prompt # LLM 调用失败则无法生成回复
|
||||||
|
|
||||||
return True, content, prompt
|
return True, content, prompt
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, None
|
return False, None, prompt
|
||||||
|
|
||||||
async def rewrite_reply_with_context(
|
async def rewrite_reply_with_context(
|
||||||
self,
|
self,
|
||||||
@@ -289,15 +287,14 @@ class DefaultReplyer:
|
|||||||
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取")
|
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取")
|
||||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||||
|
|
||||||
relation_info = await relationship_fetcher.build_relation_info(person_id, text, chat_history)
|
return await relationship_fetcher.build_relation_info(person_id, text, chat_history)
|
||||||
return relation_info
|
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_history, target):
|
async def build_expression_habits(self, chat_history, target):
|
||||||
if not global_config.expression.enable_expression:
|
if not global_config.expression.enable_expression:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
style_habbits = []
|
style_habits = []
|
||||||
grammar_habbits = []
|
grammar_habits = []
|
||||||
|
|
||||||
# 使用从处理器传来的选中表达方式
|
# 使用从处理器传来的选中表达方式
|
||||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||||
@@ -311,22 +308,22 @@ class DefaultReplyer:
|
|||||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||||
expr_type = expr.get("type", "style")
|
expr_type = expr.get("type", "style")
|
||||||
if expr_type == "grammar":
|
if expr_type == "grammar":
|
||||||
grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||||
else:
|
else:
|
||||||
style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{self.log_prefix} 没有从处理器获得表达方式,将使用空的表达方式")
|
logger.debug(f"{self.log_prefix} 没有从处理器获得表达方式,将使用空的表达方式")
|
||||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||||
|
|
||||||
style_habbits_str = "\n".join(style_habbits)
|
style_habits_str = "\n".join(style_habits)
|
||||||
grammar_habbits_str = "\n".join(grammar_habbits)
|
grammar_habits_str = "\n".join(grammar_habits)
|
||||||
|
|
||||||
# 动态构建expression habits块
|
# 动态构建expression habits块
|
||||||
expression_habits_block = ""
|
expression_habits_block = ""
|
||||||
if style_habbits_str.strip():
|
if style_habits_str.strip():
|
||||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habbits_str}\n\n"
|
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||||
if grammar_habbits_str.strip():
|
if grammar_habits_str.strip():
|
||||||
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habbits_str}\n"
|
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n"
|
||||||
|
|
||||||
return expression_habits_block
|
return expression_habits_block
|
||||||
|
|
||||||
@@ -334,21 +331,19 @@ class DefaultReplyer:
|
|||||||
if not global_config.memory.enable_memory:
|
if not global_config.memory.enable_memory:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
running_memorys = await self.memory_activator.activate_memory_with_chat_history(
|
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||||
target_message=target, chat_history_prompt=chat_history
|
target_message=target, chat_history_prompt=chat_history
|
||||||
)
|
)
|
||||||
|
|
||||||
if running_memorys:
|
if not running_memories:
|
||||||
|
return ""
|
||||||
|
|
||||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||||
for running_memory in running_memorys:
|
for running_memory in running_memories:
|
||||||
memory_str += f"- {running_memory['content']}\n"
|
memory_str += f"- {running_memory['content']}\n"
|
||||||
memory_block = memory_str
|
return memory_str
|
||||||
else:
|
|
||||||
memory_block = ""
|
|
||||||
|
|
||||||
return memory_block
|
async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
|
||||||
|
|
||||||
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
|
|
||||||
"""构建工具信息块
|
"""构建工具信息块
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -373,7 +368,7 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用工具执行器获取信息
|
# 使用工具执行器获取信息
|
||||||
tool_results = await self.tool_executor.execute_from_chat_message(
|
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
|
||||||
sender=sender, target_message=text, chat_history=chat_history, return_details=False
|
sender=sender, target_message=text, chat_history=chat_history, return_details=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -428,7 +423,7 @@ class DefaultReplyer:
|
|||||||
for name, content in result.groupdict().items():
|
for name, content in result.groupdict().items():
|
||||||
reaction = reaction.replace(f"[{name}]", content)
|
reaction = reaction.replace(f"[{name}]", content)
|
||||||
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
|
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
|
||||||
keywords_reaction_prompt += reaction + ","
|
keywords_reaction_prompt += f"{reaction},"
|
||||||
break
|
break
|
||||||
except re.error as e:
|
except re.error as e:
|
||||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
|
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
|
||||||
@@ -438,21 +433,21 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
return keywords_reaction_prompt
|
return keywords_reaction_prompt
|
||||||
|
|
||||||
async def _time_and_run_task(self, coro, name: str):
|
async def _time_and_run_task(self, coroutine, name: str):
|
||||||
"""一个简单的帮助函数,用于计时和运行异步任务,返回任务名、结果和耗时"""
|
"""一个简单的帮助函数,用于计时和运行异步任务,返回任务名、结果和耗时"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = await coro
|
result = await coroutine
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
duration = end_time - start_time
|
duration = end_time - start_time
|
||||||
return name, result, duration
|
return name, result, duration
|
||||||
|
|
||||||
async def build_prompt_reply_context(
|
async def build_prompt_reply_context(
|
||||||
self,
|
self,
|
||||||
reply_data=None,
|
reply_data: Dict[str, Any],
|
||||||
available_actions: List[str] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
enable_timeout: bool = False,
|
enable_timeout: bool = False,
|
||||||
enable_tool: bool = True,
|
enable_tool: bool = True,
|
||||||
) -> str:
|
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||||
"""
|
"""
|
||||||
构建回复器上下文
|
构建回复器上下文
|
||||||
|
|
||||||
@@ -468,7 +463,7 @@ class DefaultReplyer:
|
|||||||
str: 构建好的上下文
|
str: 构建好的上下文
|
||||||
"""
|
"""
|
||||||
if available_actions is None:
|
if available_actions is None:
|
||||||
available_actions = []
|
available_actions = {}
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
@@ -487,10 +482,9 @@ class DefaultReplyer:
|
|||||||
if available_actions:
|
if available_actions:
|
||||||
action_descriptions = "你有以下的动作能力,但执行这些动作不由你决定,由另外一个模型同步决定,因此你只需要知道有如下能力即可:\n"
|
action_descriptions = "你有以下的动作能力,但执行这些动作不由你决定,由另外一个模型同步决定,因此你只需要知道有如下能力即可:\n"
|
||||||
for action_name, action_info in available_actions.items():
|
for action_name, action_info in available_actions.items():
|
||||||
action_description = action_info.get("description", "")
|
action_description = action_info.description
|
||||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||||
action_descriptions += "\n"
|
action_descriptions += "\n"
|
||||||
|
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
@@ -506,7 +500,6 @@ class DefaultReplyer:
|
|||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
@@ -531,7 +524,7 @@ class DefaultReplyer:
|
|||||||
),
|
),
|
||||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "build_memory_block"),
|
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "build_memory_block"),
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_tool_info(reply_data, chat_talking_prompt_short, enable_tool=enable_tool), "build_tool_info"
|
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "build_tool_info"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -589,8 +582,8 @@ class DefaultReplyer:
|
|||||||
short_impression = ["友好活泼", "人类"]
|
short_impression = ["友好活泼", "人类"]
|
||||||
personality = short_impression[0]
|
personality = short_impression[0]
|
||||||
identity = short_impression[1]
|
identity = short_impression[1]
|
||||||
prompt_personality = personality + "," + identity
|
prompt_personality = f"{personality},{identity}"
|
||||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
moderation_prompt_block = (
|
moderation_prompt_block = (
|
||||||
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
|
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
|
||||||
@@ -637,7 +630,7 @@ class DefaultReplyer:
|
|||||||
"chat_target_private2", sender_name=chat_target_name
|
"chat_target_private2", sender_name=chat_target_name
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
expression_habits_block=expression_habits_block,
|
expression_habits_block=expression_habits_block,
|
||||||
chat_target=chat_target_1,
|
chat_target=chat_target_1,
|
||||||
@@ -651,7 +644,7 @@ class DefaultReplyer:
|
|||||||
reply_target_block=reply_target_block,
|
reply_target_block=reply_target_block,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||||
identity=indentify_block,
|
identity=identity_block,
|
||||||
target_message=target,
|
target_message=target,
|
||||||
sender_name=sender,
|
sender_name=sender,
|
||||||
config_expression_style=global_config.expression.expression_style,
|
config_expression_style=global_config.expression.expression_style,
|
||||||
@@ -660,8 +653,6 @@ class DefaultReplyer:
|
|||||||
mood_state=mood_prompt,
|
mood_state=mood_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
async def build_prompt_rewrite_context(
|
async def build_prompt_rewrite_context(
|
||||||
self,
|
self,
|
||||||
reply_data: Dict[str, Any],
|
reply_data: Dict[str, Any],
|
||||||
@@ -722,8 +713,8 @@ class DefaultReplyer:
|
|||||||
short_impression = ["友好活泼", "人类"]
|
short_impression = ["友好活泼", "人类"]
|
||||||
personality = short_impression[0]
|
personality = short_impression[0]
|
||||||
identity = short_impression[1]
|
identity = short_impression[1]
|
||||||
prompt_personality = personality + "," + identity
|
prompt_personality = f"{personality},{identity}"
|
||||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
moderation_prompt_block = (
|
moderation_prompt_block = (
|
||||||
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
|
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
|
||||||
@@ -767,14 +758,14 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
template_name = "default_expressor_prompt"
|
template_name = "default_expressor_prompt"
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
expression_habits_block=expression_habits_block,
|
expression_habits_block=expression_habits_block,
|
||||||
relation_info_block=relation_info,
|
relation_info_block=relation_info,
|
||||||
chat_target=chat_target_1,
|
chat_target=chat_target_1,
|
||||||
time_block=time_block,
|
time_block=time_block,
|
||||||
chat_info=chat_talking_prompt_half,
|
chat_info=chat_talking_prompt_half,
|
||||||
identity=indentify_block,
|
identity=identity_block,
|
||||||
chat_target_2=chat_target_2,
|
chat_target_2=chat_target_2,
|
||||||
reply_target_block=reply_target_block,
|
reply_target_block=reply_target_block,
|
||||||
raw_reply=raw_reply,
|
raw_reply=raw_reply,
|
||||||
@@ -784,8 +775,6 @@ class DefaultReplyer:
|
|||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
async def _build_single_sending_message(
|
async def _build_single_sending_message(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
@@ -794,7 +783,7 @@ class DefaultReplyer:
|
|||||||
is_emoji: bool,
|
is_emoji: bool,
|
||||||
thinking_start_time: float,
|
thinking_start_time: float,
|
||||||
display_message: str,
|
display_message: str,
|
||||||
anchor_message: MessageRecv = None,
|
anchor_message: Optional[MessageRecv] = None,
|
||||||
) -> MessageSending:
|
) -> MessageSending:
|
||||||
"""构建单个发送消息"""
|
"""构建单个发送消息"""
|
||||||
|
|
||||||
@@ -805,12 +794,9 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# await anchor_message.process()
|
# await anchor_message.process()
|
||||||
if anchor_message:
|
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
||||||
sender_info = anchor_message.message_info.user_info
|
|
||||||
else:
|
|
||||||
sender_info = None
|
|
||||||
|
|
||||||
bot_message = MessageSending(
|
return MessageSending(
|
||||||
message_id=message_id, # 使用片段的唯一ID
|
message_id=message_id, # 使用片段的唯一ID
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
bot_user_info=bot_user_info,
|
bot_user_info=bot_user_info,
|
||||||
@@ -823,8 +809,6 @@ class DefaultReplyer:
|
|||||||
display_message=display_message,
|
display_message=display_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
return bot_message
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.replyer.default_generator import DefaultReplyer
|
from src.chat.replyer.default_generator import DefaultReplyer
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("ReplyerManager")
|
logger = get_logger("ReplyerManager")
|
||||||
|
|
||||||
|
|
||||||
class ReplyerManager:
|
class ReplyerManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._replyers: Dict[str, DefaultReplyer] = {}
|
self._repliers: Dict[str, DefaultReplyer] = {}
|
||||||
|
|
||||||
def get_replyer(
|
def get_replyer(
|
||||||
self,
|
self,
|
||||||
@@ -29,17 +30,16 @@ class ReplyerManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 如果已有缓存实例,直接返回
|
# 如果已有缓存实例,直接返回
|
||||||
if stream_id in self._replyers:
|
if stream_id in self._repliers:
|
||||||
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
|
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
|
||||||
return self._replyers[stream_id]
|
return self._repliers[stream_id]
|
||||||
|
|
||||||
# 如果没有缓存,则创建新实例(首次初始化)
|
# 如果没有缓存,则创建新实例(首次初始化)
|
||||||
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
|
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
|
||||||
|
|
||||||
target_stream = chat_stream
|
target_stream = chat_stream
|
||||||
if not target_stream:
|
if not target_stream:
|
||||||
chat_manager = get_chat_manager()
|
if chat_manager := get_chat_manager():
|
||||||
if chat_manager:
|
|
||||||
target_stream = chat_manager.get_stream(stream_id)
|
target_stream = chat_manager.get_stream(stream_id)
|
||||||
|
|
||||||
if not target_stream:
|
if not target_stream:
|
||||||
@@ -52,7 +52,7 @@ class ReplyerManager:
|
|||||||
model_configs=model_configs, # 可以是None,此时使用默认模型
|
model_configs=model_configs, # 可以是None,此时使用默认模型
|
||||||
request_type=request_type,
|
request_type=request_type,
|
||||||
)
|
)
|
||||||
self._replyers[stream_id] = replyer
|
self._repliers[stream_id] = replyer
|
||||||
return replyer
|
return replyer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from src.config.config import global_config
|
|
||||||
from typing import List, Dict, Any, Tuple # 确保类型提示被导入
|
|
||||||
import time # 导入 time 模块以获取当前时间
|
import time # 导入 time 模块以获取当前时间
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from src.common.message_repository import find_messages, count_messages
|
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.common.database.database_model import ActionRecords
|
from src.common.database.database_model import ActionRecords
|
||||||
from src.common.database.database_model import Images
|
from src.common.database.database_model import Images
|
||||||
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -28,7 +30,12 @@ def get_raw_msg_by_timestamp(
|
|||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_chat(
|
def get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest", fliter_bot = False
|
chat_id: str,
|
||||||
|
timestamp_start: float,
|
||||||
|
timestamp_end: float,
|
||||||
|
limit: int = 0,
|
||||||
|
limit_mode: str = "latest",
|
||||||
|
fliter_bot=False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
@@ -38,11 +45,18 @@ def get_raw_msg_by_timestamp_with_chat(
|
|||||||
# 只有当 limit 为 0 时才应用外部 sort
|
# 只有当 limit 为 0 时才应用外部 sort
|
||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
# 直接将 limit_mode 传递给 find_messages
|
# 直接将 limit_mode 传递给 find_messages
|
||||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot)
|
return find_messages(
|
||||||
|
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_chat_inclusive(
|
def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest", fliter_bot = False
|
chat_id: str,
|
||||||
|
timestamp_start: float,
|
||||||
|
timestamp_end: float,
|
||||||
|
limit: int = 0,
|
||||||
|
limit_mode: str = "latest",
|
||||||
|
fliter_bot=False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
@@ -53,7 +67,9 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
|||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
# 直接将 limit_mode 传递给 find_messages
|
# 直接将 limit_mode 传递给 find_messages
|
||||||
|
|
||||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot)
|
return find_messages(
|
||||||
|
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_chat_users(
|
def get_raw_msg_by_timestamp_with_chat_users(
|
||||||
@@ -88,8 +104,8 @@ def get_actions_by_timestamp_with_chat(
|
|||||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||||
query = ActionRecords.select().where(
|
query = ActionRecords.select().where(
|
||||||
(ActionRecords.chat_id == chat_id)
|
(ActionRecords.chat_id == chat_id)
|
||||||
& (ActionRecords.time > timestamp_start)
|
& (ActionRecords.time > timestamp_start) # type: ignore
|
||||||
& (ActionRecords.time < timestamp_end)
|
& (ActionRecords.time < timestamp_end) # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
@@ -113,8 +129,8 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
|||||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||||
query = ActionRecords.select().where(
|
query = ActionRecords.select().where(
|
||||||
(ActionRecords.chat_id == chat_id)
|
(ActionRecords.chat_id == chat_id)
|
||||||
& (ActionRecords.time >= timestamp_start)
|
& (ActionRecords.time >= timestamp_start) # type: ignore
|
||||||
& (ActionRecords.time <= timestamp_end)
|
& (ActionRecords.time <= timestamp_end) # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
@@ -190,7 +206,7 @@ def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list,
|
|||||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int:
|
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||||
"""
|
"""
|
||||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||||
@@ -227,7 +243,7 @@ def _build_readable_messages_internal(
|
|||||||
merge_messages: bool = False,
|
merge_messages: bool = False,
|
||||||
timestamp_mode: str = "relative",
|
timestamp_mode: str = "relative",
|
||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
pic_id_mapping: Dict[str, str] = None,
|
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||||
pic_counter: int = 1,
|
pic_counter: int = 1,
|
||||||
show_pic: bool = True,
|
show_pic: bool = True,
|
||||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||||
@@ -249,7 +265,7 @@ def _build_readable_messages_internal(
|
|||||||
if not messages:
|
if not messages:
|
||||||
return "", [], pic_id_mapping or {}, pic_counter
|
return "", [], pic_id_mapping or {}, pic_counter
|
||||||
|
|
||||||
message_details_raw: List[Tuple[float, str, str]] = []
|
message_details_raw: List[Tuple[float, str, str, bool]] = []
|
||||||
|
|
||||||
# 使用传入的映射字典,如果没有则创建新的
|
# 使用传入的映射字典,如果没有则创建新的
|
||||||
if pic_id_mapping is None:
|
if pic_id_mapping is None:
|
||||||
@@ -280,7 +296,7 @@ def _build_readable_messages_internal(
|
|||||||
# 检查是否是动作记录
|
# 检查是否是动作记录
|
||||||
if msg.get("is_action_record", False):
|
if msg.get("is_action_record", False):
|
||||||
is_action = True
|
is_action = True
|
||||||
timestamp = msg.get("time")
|
timestamp: float = msg.get("time") # type: ignore
|
||||||
content = msg.get("display_message", "")
|
content = msg.get("display_message", "")
|
||||||
# 对于动作记录,也处理图片ID
|
# 对于动作记录,也处理图片ID
|
||||||
content = process_pic_ids(content)
|
content = process_pic_ids(content)
|
||||||
@@ -304,9 +320,10 @@ def _build_readable_messages_internal(
|
|||||||
user_nickname = user_info.get("user_nickname")
|
user_nickname = user_info.get("user_nickname")
|
||||||
user_cardname = user_info.get("user_cardname")
|
user_cardname = user_info.get("user_cardname")
|
||||||
|
|
||||||
timestamp = msg.get("time")
|
timestamp: float = msg.get("time") # type: ignore
|
||||||
|
content: str
|
||||||
if msg.get("display_message"):
|
if msg.get("display_message"):
|
||||||
content = msg.get("display_message")
|
content = msg.get("display_message", "")
|
||||||
else:
|
else:
|
||||||
content = msg.get("processed_plain_text", "") # 默认空字符串
|
content = msg.get("processed_plain_text", "") # 默认空字符串
|
||||||
|
|
||||||
@@ -326,10 +343,11 @@ def _build_readable_messages_internal(
|
|||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||||
|
person_name: str
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
person_name = f"{global_config.bot.nickname}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
|
||||||
|
|
||||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||||
if not person_name:
|
if not person_name:
|
||||||
@@ -344,12 +362,10 @@ def _build_readable_messages_internal(
|
|||||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||||
match = re.search(reply_pattern, content)
|
match = re.search(reply_pattern, content)
|
||||||
if match:
|
if match:
|
||||||
aaa = match.group(1)
|
aaa: str = match[1]
|
||||||
bbb = match.group(2)
|
bbb: str = match[2]
|
||||||
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||||
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
|
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa
|
||||||
if not reply_person_name:
|
|
||||||
reply_person_name = aaa
|
|
||||||
# 在内容前加上回复信息
|
# 在内容前加上回复信息
|
||||||
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
|
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
|
||||||
|
|
||||||
@@ -364,17 +380,14 @@ def _build_readable_messages_internal(
|
|||||||
aaa = m.group(1)
|
aaa = m.group(1)
|
||||||
bbb = m.group(2)
|
bbb = m.group(2)
|
||||||
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||||
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
|
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa
|
||||||
if not at_person_name:
|
|
||||||
at_person_name = aaa
|
|
||||||
new_content += f"@{at_person_name}"
|
new_content += f"@{at_person_name}"
|
||||||
last_end = m.end()
|
last_end = m.end()
|
||||||
new_content += content[last_end:]
|
new_content += content[last_end:]
|
||||||
content = new_content
|
content = new_content
|
||||||
|
|
||||||
target_str = "这是QQ的一个功能,用于提及某人,但没那么明显"
|
target_str = "这是QQ的一个功能,用于提及某人,但没那么明显"
|
||||||
if target_str in content:
|
if target_str in content and random.random() < 0.6:
|
||||||
if random.random() < 0.6:
|
|
||||||
content = content.replace(target_str, "")
|
content = content.replace(target_str, "")
|
||||||
|
|
||||||
if content != "":
|
if content != "":
|
||||||
@@ -525,6 +538,7 @@ def _build_readable_messages_internal(
|
|||||||
|
|
||||||
|
|
||||||
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||||
|
# sourcery skip: use-contextlib-suppress
|
||||||
"""
|
"""
|
||||||
构建图片映射信息字符串,显示图片的具体描述内容
|
构建图片映射信息字符串,显示图片的具体描述内容
|
||||||
|
|
||||||
@@ -584,7 +598,6 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
|||||||
if action_name == "no_action" or action_name == "no_reply":
|
if action_name == "no_action" or action_name == "no_reply":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
action_prompt_display = action.get("action_prompt_display", "无具体内容")
|
action_prompt_display = action.get("action_prompt_display", "无具体内容")
|
||||||
|
|
||||||
time_diff_seconds = current_time - action_time
|
time_diff_seconds = current_time - action_time
|
||||||
@@ -616,9 +629,7 @@ async def build_readable_messages_with_list(
|
|||||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成图片映射信息并添加到最前面
|
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
|
||||||
if pic_mapping_info:
|
|
||||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||||
|
|
||||||
return formatted_string, details_list
|
return formatted_string, details_list
|
||||||
@@ -633,7 +644,7 @@ def build_readable_messages(
|
|||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
show_actions: bool = False,
|
show_actions: bool = False,
|
||||||
show_pic: bool = True,
|
show_pic: bool = True,
|
||||||
) -> str:
|
) -> str: # sourcery skip: extract-method
|
||||||
"""
|
"""
|
||||||
将消息列表转换为可读的文本格式。
|
将消息列表转换为可读的文本格式。
|
||||||
如果提供了 read_mark,则在相应位置插入已读标记。
|
如果提供了 read_mark,则在相应位置插入已读标记。
|
||||||
@@ -756,9 +767,7 @@ def build_readable_messages(
|
|||||||
# 组合结果
|
# 组合结果
|
||||||
result_parts = []
|
result_parts = []
|
||||||
if pic_mapping_info:
|
if pic_mapping_info:
|
||||||
result_parts.append(pic_mapping_info)
|
result_parts.extend((pic_mapping_info, "\n"))
|
||||||
result_parts.append("\n")
|
|
||||||
|
|
||||||
if formatted_before and formatted_after:
|
if formatted_before and formatted_after:
|
||||||
result_parts.extend([formatted_before, read_mark_line, formatted_after])
|
result_parts.extend([formatted_before, read_mark_line, formatted_after])
|
||||||
elif formatted_before:
|
elif formatted_before:
|
||||||
@@ -831,8 +840,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
platform = msg.get("chat_info_platform")
|
platform = msg.get("chat_info_platform")
|
||||||
user_id = msg.get("user_id")
|
user_id = msg.get("user_id")
|
||||||
_timestamp = msg.get("time")
|
_timestamp = msg.get("time")
|
||||||
|
content: str = ""
|
||||||
if msg.get("display_message"):
|
if msg.get("display_message"):
|
||||||
content = msg.get("display_message")
|
content = msg.get("display_message", "")
|
||||||
else:
|
else:
|
||||||
content = msg.get("processed_plain_text", "")
|
content = msg.get("processed_plain_text", "")
|
||||||
|
|
||||||
@@ -920,17 +930,14 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
person_ids_set = set() # 使用集合来自动去重
|
person_ids_set = set() # 使用集合来自动去重
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
platform = msg.get("user_platform")
|
platform: str = msg.get("user_platform") # type: ignore
|
||||||
user_id = msg.get("user_id")
|
user_id: str = msg.get("user_id") # type: ignore
|
||||||
|
|
||||||
# 检查必要信息是否存在 且 不是机器人自己
|
# 检查必要信息是否存在 且 不是机器人自己
|
||||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
if person_id := PersonInfoManager.get_person_id(platform, user_id):
|
||||||
|
|
||||||
# 只有当获取到有效 person_id 时才添加
|
|
||||||
if person_id:
|
|
||||||
person_ids_set.add(person_id)
|
person_ids_set.add(person_id)
|
||||||
|
|
||||||
return list(person_ids_set) # 将集合转换为列表返回
|
return list(person_ids_set) # 将集合转换为列表返回
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
import ast
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, TypeVar, List, Union, Tuple
|
|
||||||
import ast
|
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
|
||||||
|
|
||||||
# 定义类型变量用于泛型类型提示
|
# 定义类型变量用于泛型类型提示
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -30,16 +31,12 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
|||||||
# 尝试标准的 JSON 解析
|
# 尝试标准的 JSON 解析
|
||||||
return json.loads(json_str)
|
return json.loads(json_str)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# 如果标准解析失败,尝试将单引号替换为双引号再解析
|
# 如果标准解析失败,尝试用 ast.literal_eval 解析
|
||||||
# (注意:这种替换可能不安全,如果字符串内容本身包含引号)
|
|
||||||
# 更安全的方式是用 ast.literal_eval
|
|
||||||
try:
|
try:
|
||||||
# logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...")
|
# logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...")
|
||||||
result = ast.literal_eval(json_str)
|
result = ast.literal_eval(json_str)
|
||||||
# 确保结果是字典(因为我们通常期望参数是字典)
|
|
||||||
if isinstance(result, dict):
|
if isinstance(result, dict):
|
||||||
return result
|
return result
|
||||||
else:
|
|
||||||
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
|
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
|
||||||
return default_value
|
return default_value
|
||||||
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
|
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
|
||||||
@@ -53,7 +50,9 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
|||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
|
def extract_tool_call_arguments(
|
||||||
|
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
从LLM工具调用对象中提取参数
|
从LLM工具调用对象中提取参数
|
||||||
|
|
||||||
@@ -77,13 +76,11 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s
|
|||||||
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
|
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
|
||||||
return default_result
|
return default_result
|
||||||
|
|
||||||
# 提取arguments
|
if arguments_str := function_data.get("arguments", "{}"):
|
||||||
arguments_str = function_data.get("arguments", "{}")
|
|
||||||
if not arguments_str:
|
|
||||||
return default_result
|
|
||||||
|
|
||||||
# 解析JSON
|
# 解析JSON
|
||||||
return safe_json_loads(arguments_str, default_result)
|
return safe_json_loads(arguments_str, default_result)
|
||||||
|
else:
|
||||||
|
return default_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取工具调用参数时出错: {e}")
|
logger.error(f"提取工具调用参数时出错: {e}")
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from typing import Dict, Any, Optional, List, Union
|
|
||||||
import re
|
import re
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
# import traceback
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Dict, Any, Optional, List, Union
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -32,6 +32,7 @@ class PromptContext:
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def async_scope(self, context_id: Optional[str] = None):
|
async def async_scope(self, context_id: Optional[str] = None):
|
||||||
|
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
|
||||||
"""创建一个异步的临时提示模板作用域"""
|
"""创建一个异步的临时提示模板作用域"""
|
||||||
# 保存当前上下文并设置新上下文
|
# 保存当前上下文并设置新上下文
|
||||||
if context_id is not None:
|
if context_id is not None:
|
||||||
@@ -88,8 +89,7 @@ class PromptContext:
|
|||||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||||
"""异步注册提示模板到指定作用域"""
|
"""异步注册提示模板到指定作用域"""
|
||||||
async with self._context_lock:
|
async with self._context_lock:
|
||||||
target_context = context_id or self._current_context
|
if target_context := context_id or self._current_context:
|
||||||
if target_context:
|
|
||||||
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
|
||||||
|
|
||||||
|
|
||||||
@@ -151,7 +151,7 @@ class Prompt(str):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_escaped_braces(template) -> str:
|
def _process_escaped_braces(template) -> str:
|
||||||
"""处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记"""
|
"""处理模板中的转义花括号,将 \{ 和 \} 替换为临时标记""" # type: ignore
|
||||||
# 如果传入的是列表,将其转换为字符串
|
# 如果传入的是列表,将其转换为字符串
|
||||||
if isinstance(template, list):
|
if isinstance(template, list):
|
||||||
template = "\n".join(str(item) for item in template)
|
template = "\n".join(str(item) for item in template)
|
||||||
@@ -195,13 +195,7 @@ class Prompt(str):
|
|||||||
obj._kwargs = kwargs
|
obj._kwargs = kwargs
|
||||||
|
|
||||||
# 修改自动注册逻辑
|
# 修改自动注册逻辑
|
||||||
if should_register:
|
if should_register and not global_prompt_manager._context._current_context:
|
||||||
if global_prompt_manager._context._current_context:
|
|
||||||
# 如果存在当前上下文,则注册到上下文中
|
|
||||||
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# 否则注册到全局管理器
|
|
||||||
global_prompt_manager.register(obj)
|
global_prompt_manager.register(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@@ -276,15 +270,13 @@ class Prompt(str):
|
|||||||
self.name,
|
self.name,
|
||||||
args=list(args) if args else self._args,
|
args=list(args) if args else self._args,
|
||||||
_should_register=False,
|
_should_register=False,
|
||||||
**kwargs if kwargs else self._kwargs,
|
**kwargs or self._kwargs,
|
||||||
)
|
)
|
||||||
# print(f"prompt build result: {ret} name: {ret.name} ")
|
# print(f"prompt build result: {ret} name: {ret.name} ")
|
||||||
return str(ret)
|
return str(ret)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
if self._kwargs or self._args:
|
return super().__str__() if self._kwargs or self._args else self.template
|
||||||
return super().__str__()
|
|
||||||
return self.template
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Any, Dict, Tuple, List
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, Tuple, List
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database import db
|
||||||
|
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||||
from src.manager.async_task_manager import AsyncTask
|
from src.manager.async_task_manager import AsyncTask
|
||||||
|
|
||||||
from ...common.database.database import db # This db is the Peewee database instance
|
|
||||||
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
|
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
logger = get_logger("maibot_statistic")
|
logger = get_logger("maibot_statistic")
|
||||||
@@ -76,14 +75,14 @@ class OnlineTimeRecordTask(AsyncTask):
|
|||||||
with db.atomic(): # Use atomic operations for schema changes
|
with db.atomic(): # Use atomic operations for schema changes
|
||||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||||
|
|
||||||
async def run(self):
|
async def run(self): # sourcery skip: use-named-expression
|
||||||
try:
|
try:
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
extended_end_time = current_time + timedelta(minutes=1)
|
extended_end_time = current_time + timedelta(minutes=1)
|
||||||
|
|
||||||
if self.record_id:
|
if self.record_id:
|
||||||
# 如果有记录,则更新结束时间
|
# 如果有记录,则更新结束时间
|
||||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
|
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
|
||||||
updated_rows = query.execute()
|
updated_rows = query.execute()
|
||||||
if updated_rows == 0:
|
if updated_rows == 0:
|
||||||
# Record might have been deleted or ID is stale, try to find/create
|
# Record might have been deleted or ID is stale, try to find/create
|
||||||
@@ -94,7 +93,7 @@ class OnlineTimeRecordTask(AsyncTask):
|
|||||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||||
recent_record = (
|
recent_record = (
|
||||||
OnlineTime.select()
|
OnlineTime.select()
|
||||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
|
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
|
||||||
.order_by(OnlineTime.end_timestamp.desc())
|
.order_by(OnlineTime.end_timestamp.desc())
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@@ -123,15 +122,15 @@ def _format_online_time(online_seconds: int) -> str:
|
|||||||
:param online_seconds: 在线时间(秒)
|
:param online_seconds: 在线时间(秒)
|
||||||
:return: 格式化后的在线时间字符串
|
:return: 格式化后的在线时间字符串
|
||||||
"""
|
"""
|
||||||
total_oneline_time = timedelta(seconds=online_seconds)
|
total_online_time = timedelta(seconds=online_seconds)
|
||||||
|
|
||||||
days = total_oneline_time.days
|
days = total_online_time.days
|
||||||
hours = total_oneline_time.seconds // 3600
|
hours = total_online_time.seconds // 3600
|
||||||
minutes = (total_oneline_time.seconds // 60) % 60
|
minutes = (total_online_time.seconds // 60) % 60
|
||||||
seconds = total_oneline_time.seconds % 60
|
seconds = total_online_time.seconds % 60
|
||||||
if days > 0:
|
if days > 0:
|
||||||
# 如果在线时间超过1天,则格式化为"X天X小时X分钟"
|
# 如果在线时间超过1天,则格式化为"X天X小时X分钟"
|
||||||
return f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒"
|
return f"{total_online_time.days}天{hours}小时{minutes}分钟{seconds}秒"
|
||||||
elif hours > 0:
|
elif hours > 0:
|
||||||
# 如果在线时间超过1小时,则格式化为"X小时X分钟X秒"
|
# 如果在线时间超过1小时,则格式化为"X小时X分钟X秒"
|
||||||
return f"{hours}小时{minutes}分钟{seconds}秒"
|
return f"{hours}小时{minutes}分钟{seconds}秒"
|
||||||
@@ -163,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
if "deploy_time" in local_storage:
|
if "deploy_time" in local_storage:
|
||||||
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
|
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
|
||||||
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
|
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
|
||||||
else:
|
else:
|
||||||
# 否则,使用最大时间范围,并记录部署时间为当前时间
|
# 否则,使用最大时间范围,并记录部署时间为当前时间
|
||||||
deploy_time = datetime(2000, 1, 1)
|
deploy_time = datetime(2000, 1, 1)
|
||||||
@@ -252,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 创建后台任务,不等待完成
|
# 创建后台任务,不等待完成
|
||||||
collect_task = asyncio.create_task(
|
collect_task = asyncio.create_task(
|
||||||
loop.run_in_executor(executor, self._collect_all_statistics, now)
|
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
stats = await collect_task
|
stats = await collect_task
|
||||||
@@ -260,8 +259,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 创建并发的输出任务
|
# 创建并发的输出任务
|
||||||
output_tasks = [
|
output_tasks = [
|
||||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
|
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
|
||||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
|
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
|
||||||
]
|
]
|
||||||
|
|
||||||
# 等待所有输出任务完成
|
# 等待所有输出任务完成
|
||||||
@@ -320,7 +319,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
# 以最早的时间戳为起始时间获取记录
|
# 以最早的时间戳为起始时间获取记录
|
||||||
# Assuming LLMUsage.timestamp is a DateTimeField
|
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||||
query_start_time = collect_period[-1][1]
|
query_start_time = collect_period[-1][1]
|
||||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
|
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||||
record_timestamp = record.timestamp # This is already a datetime object
|
record_timestamp = record.timestamp # This is already a datetime object
|
||||||
for idx, (_, period_start) in enumerate(collect_period):
|
for idx, (_, period_start) in enumerate(collect_period):
|
||||||
if record_timestamp >= period_start:
|
if record_timestamp >= period_start:
|
||||||
@@ -388,7 +387,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
query_start_time = collect_period[-1][1]
|
query_start_time = collect_period[-1][1]
|
||||||
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
|
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
|
||||||
# record.end_timestamp and record.start_timestamp are datetime objects
|
# record.end_timestamp and record.start_timestamp are datetime objects
|
||||||
record_end_timestamp = record.end_timestamp
|
record_end_timestamp = record.end_timestamp
|
||||||
record_start_timestamp = record.start_timestamp
|
record_start_timestamp = record.start_timestamp
|
||||||
@@ -428,7 +427,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||||
for message in Messages.select().where(Messages.time >= query_start_timestamp):
|
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||||
message_time_ts = message.time # This is a float timestamp
|
message_time_ts = message.time # This is a float timestamp
|
||||||
|
|
||||||
chat_id = None
|
chat_id = None
|
||||||
@@ -661,7 +660,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
if "last_full_statistics" in local_storage:
|
if "last_full_statistics" in local_storage:
|
||||||
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
||||||
last_stat = local_storage["last_full_statistics"] # 上次完整统计数据
|
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||||
|
|
||||||
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
|
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
|
||||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||||
@@ -727,6 +726,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
return stat
|
return stat
|
||||||
|
|
||||||
def _convert_defaultdict_to_dict(self, data):
|
def _convert_defaultdict_to_dict(self, data):
|
||||||
|
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
|
||||||
"""递归转换defaultdict为普通dict"""
|
"""递归转换defaultdict为普通dict"""
|
||||||
if isinstance(data, defaultdict):
|
if isinstance(data, defaultdict):
|
||||||
# 转换defaultdict为普通dict
|
# 转换defaultdict为普通dict
|
||||||
@@ -812,8 +812,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
# 全局阶段平均时间
|
# 全局阶段平均时间
|
||||||
if stats[FOCUS_AVG_TIMES_BY_STAGE]:
|
if stats[FOCUS_AVG_TIMES_BY_STAGE]:
|
||||||
output.append("全局阶段平均时间:")
|
output.append("全局阶段平均时间:")
|
||||||
for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items():
|
output.extend(f" {stage}: {avg_time:.3f}秒" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items())
|
||||||
output.append(f" {stage}: {avg_time:.3f}秒")
|
|
||||||
output.append("")
|
output.append("")
|
||||||
|
|
||||||
# Action类型比例
|
# Action类型比例
|
||||||
@@ -1050,7 +1049,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
]
|
]
|
||||||
|
|
||||||
tab_content_list.append(
|
tab_content_list.append(
|
||||||
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"]))
|
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加Focus统计内容
|
# 添加Focus统计内容
|
||||||
@@ -1212,6 +1211,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
f.write(html_template)
|
f.write(html_template)
|
||||||
|
|
||||||
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
||||||
|
# sourcery skip: for-append-to-extend, list-comprehension, use-any
|
||||||
"""生成Focus统计独立分页的HTML内容"""
|
"""生成Focus统计独立分页的HTML内容"""
|
||||||
|
|
||||||
# 为每个时间段准备Focus数据
|
# 为每个时间段准备Focus数据
|
||||||
@@ -1313,12 +1313,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
# 聊天流Action选择比例对比表(横向表格)
|
# 聊天流Action选择比例对比表(横向表格)
|
||||||
focus_chat_action_ratios_rows = ""
|
focus_chat_action_ratios_rows = ""
|
||||||
if stat_data.get("focus_action_ratios_by_chat"):
|
if stat_data.get("focus_action_ratios_by_chat"):
|
||||||
# 获取所有action类型(按全局频率排序)
|
if all_action_types_for_ratio := sorted(
|
||||||
all_action_types_for_ratio = sorted(
|
stat_data[FOCUS_ACTION_RATIOS].keys(),
|
||||||
stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True
|
key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x],
|
||||||
)
|
reverse=True,
|
||||||
|
):
|
||||||
if all_action_types_for_ratio:
|
|
||||||
# 为每个聊天流生成数据行(按循环数排序)
|
# 为每个聊天流生成数据行(按循环数排序)
|
||||||
chat_ratio_rows = []
|
chat_ratio_rows = []
|
||||||
for chat_id in sorted(
|
for chat_id in sorted(
|
||||||
@@ -1379,16 +1378,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
if period_name == "all_time":
|
if period_name == "all_time":
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
start_time = datetime.fromtimestamp(local_storage["deploy_time"])
|
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
|
||||||
time_range = (
|
|
||||||
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
start_time = datetime.now() - period_delta
|
start_time = datetime.now() - period_delta
|
||||||
time_range = (
|
|
||||||
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
# 生成该时间段的Focus统计HTML
|
# 生成该时间段的Focus统计HTML
|
||||||
section_html = f"""
|
section_html = f"""
|
||||||
<div class="focus-period-section">
|
<div class="focus-period-section">
|
||||||
@@ -1681,16 +1675,10 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
if period_name == "all_time":
|
if period_name == "all_time":
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
start_time = datetime.fromtimestamp(local_storage["deploy_time"])
|
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
|
||||||
time_range = (
|
|
||||||
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
start_time = datetime.now() - period_delta
|
start_time = datetime.now() - period_delta
|
||||||
time_range = (
|
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 生成该时间段的版本对比HTML
|
# 生成该时间段的版本对比HTML
|
||||||
section_html = f"""
|
section_html = f"""
|
||||||
<div class="version-period-section">
|
<div class="version-period-section">
|
||||||
@@ -1865,7 +1853,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 查询LLM使用记录
|
# 查询LLM使用记录
|
||||||
query_start_time = start_time
|
query_start_time = start_time
|
||||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
|
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||||
record_time = record.timestamp
|
record_time = record.timestamp
|
||||||
|
|
||||||
# 找到对应的时间间隔索引
|
# 找到对应的时间间隔索引
|
||||||
@@ -1875,7 +1863,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
if 0 <= interval_index < len(time_points):
|
if 0 <= interval_index < len(time_points):
|
||||||
# 累加总花费数据
|
# 累加总花费数据
|
||||||
cost = record.cost or 0.0
|
cost = record.cost or 0.0
|
||||||
total_cost_data[interval_index] += cost
|
total_cost_data[interval_index] += cost # type: ignore
|
||||||
|
|
||||||
# 累加按模型分类的花费
|
# 累加按模型分类的花费
|
||||||
model_name = record.model_name or "unknown"
|
model_name = record.model_name or "unknown"
|
||||||
@@ -1892,7 +1880,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 查询消息记录
|
# 查询消息记录
|
||||||
query_start_timestamp = start_time.timestamp()
|
query_start_timestamp = start_time.timestamp()
|
||||||
for message in Messages.select().where(Messages.time >= query_start_timestamp):
|
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||||
message_time_ts = message.time
|
message_time_ts = message.time
|
||||||
|
|
||||||
# 找到对应的时间间隔索引
|
# 找到对应的时间间隔索引
|
||||||
@@ -1982,6 +1970,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _generate_chart_tab(self, chart_data: dict) -> str:
|
def _generate_chart_tab(self, chart_data: dict) -> str:
|
||||||
|
# sourcery skip: extract-duplicate-method, move-assign-in-block
|
||||||
"""生成图表选项卡HTML内容"""
|
"""生成图表选项卡HTML内容"""
|
||||||
|
|
||||||
# 生成不同颜色的调色板
|
# 生成不同颜色的调色板
|
||||||
@@ -2293,7 +2282,7 @@ class AsyncStatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 数据收集任务
|
# 数据收集任务
|
||||||
collect_task = asyncio.create_task(
|
collect_task = asyncio.create_task(
|
||||||
loop.run_in_executor(executor, self._collect_all_statistics, now)
|
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
stats = await collect_task
|
stats = await collect_task
|
||||||
@@ -2301,8 +2290,8 @@ class AsyncStatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 创建并发的输出任务
|
# 创建并发的输出任务
|
||||||
output_tasks = [
|
output_tasks = [
|
||||||
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
|
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
|
||||||
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
|
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
|
||||||
]
|
]
|
||||||
|
|
||||||
# 等待所有输出任务完成
|
# 等待所有输出任务完成
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, Dict, Callable
|
from typing import Optional, Dict, Callable
|
||||||
import asyncio
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -88,10 +89,10 @@ class Timer:
|
|||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.elapsed = None
|
self.elapsed: float = None # type: ignore
|
||||||
|
|
||||||
self.auto_unit = auto_unit
|
self.auto_unit = auto_unit
|
||||||
self.start = None
|
self.start: float = None # type: ignore
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_types(name, storage):
|
def _validate_types(name, storage):
|
||||||
@@ -120,7 +121,7 @@ class Timer:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
wrapper.__timer__ = self # 保留计时器引用
|
wrapper.__timer__ = self # 保留计时器引用 # type: ignore
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import math
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import jieba
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import jieba
|
|
||||||
from pypinyin import Style, pinyin
|
from pypinyin import Style, pinyin
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -104,7 +104,7 @@ class ChineseTypoGenerator:
|
|||||||
try:
|
try:
|
||||||
return "\u4e00" <= char <= "\u9fff"
|
return "\u4e00" <= char <= "\u9fff"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(e)
|
logger.debug(str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_pinyin(self, sentence):
|
def _get_pinyin(self, sentence):
|
||||||
@@ -138,7 +138,7 @@ class ChineseTypoGenerator:
|
|||||||
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||||
if not py[-1].isdigit():
|
if not py[-1].isdigit():
|
||||||
# 为非数字结尾的拼音添加数字声调1
|
# 为非数字结尾的拼音添加数字声调1
|
||||||
return py + "1"
|
return f"{py}1"
|
||||||
|
|
||||||
base = py[:-1] # 去掉声调
|
base = py[:-1] # 去掉声调
|
||||||
tone = int(py[-1]) # 获取声调
|
tone = int(py[-1]) # 获取声调
|
||||||
|
|||||||
@@ -1,23 +1,21 @@
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
import jieba
|
import jieba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
|
from typing import Optional, Tuple, Dict
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.message_repository import find_messages, count_messages
|
||||||
# from src.mood.mood_manager import mood_manager
|
from src.config.config import global_config
|
||||||
from ..message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from .typo_generator import ChineseTypoGenerator
|
|
||||||
from ...config.config import global_config
|
|
||||||
from ...common.message_repository import find_messages, count_messages
|
|
||||||
from typing import Optional, Tuple, Dict
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
from .typo_generator import ChineseTypoGenerator
|
||||||
|
|
||||||
logger = get_logger("chat_utils")
|
logger = get_logger("chat_utils")
|
||||||
|
|
||||||
@@ -31,11 +29,7 @@ def db_message_to_str(message_dict: dict) -> str:
|
|||||||
logger.debug(f"message_dict: {message_dict}")
|
logger.debug(f"message_dict: {message_dict}")
|
||||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||||
try:
|
try:
|
||||||
name = "[(%s)%s]%s" % (
|
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
|
||||||
message_dict["user_id"],
|
|
||||||
message_dict.get("user_nickname", ""),
|
|
||||||
message_dict.get("user_cardname", ""),
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||||
content = message_dict.get("processed_plain_text", "")
|
content = message_dict.get("processed_plain_text", "")
|
||||||
@@ -58,11 +52,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned"))
|
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
return is_mentioned, reply_probability
|
return is_mentioned, reply_probability
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(e)
|
logger.warning(str(e))
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
||||||
)
|
)
|
||||||
@@ -135,17 +129,14 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c
|
|||||||
if not recent_messages:
|
if not recent_messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
message_detailed_plain_text = ""
|
|
||||||
message_detailed_plain_text_list = []
|
|
||||||
|
|
||||||
# 反转消息列表,使最新的消息在最后
|
# 反转消息列表,使最新的消息在最后
|
||||||
recent_messages.reverse()
|
recent_messages.reverse()
|
||||||
|
|
||||||
if combine:
|
if combine:
|
||||||
for msg_db_data in recent_messages:
|
return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
|
||||||
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
|
|
||||||
return message_detailed_plain_text
|
message_detailed_plain_text_list = []
|
||||||
else:
|
|
||||||
for msg_db_data in recent_messages:
|
for msg_db_data in recent_messages:
|
||||||
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
|
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
|
||||||
return message_detailed_plain_text_list
|
return message_detailed_plain_text_list
|
||||||
@@ -204,10 +195,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
|||||||
|
|
||||||
len_text = len(text)
|
len_text = len(text)
|
||||||
if len_text < 3:
|
if len_text < 3:
|
||||||
if random.random() < 0.01:
|
return list(text) if random.random() < 0.01 else [text]
|
||||||
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
|
|
||||||
else:
|
|
||||||
return [text]
|
|
||||||
|
|
||||||
# 定义分隔符
|
# 定义分隔符
|
||||||
separators = {",", ",", " ", "。", ";"}
|
separators = {",", ",", " ", "。", ";"}
|
||||||
@@ -352,8 +340,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
|||||||
max_length = global_config.response_splitter.max_length * 2
|
max_length = global_config.response_splitter.max_length * 2
|
||||||
max_sentence_num = global_config.response_splitter.max_sentence_num
|
max_sentence_num = global_config.response_splitter.max_sentence_num
|
||||||
# 如果基本上是中文,则进行长度过滤
|
# 如果基本上是中文,则进行长度过滤
|
||||||
if get_western_ratio(cleaned_text) < 0.1:
|
if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
|
||||||
if len(cleaned_text) > max_length:
|
|
||||||
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
|
||||||
return ["懒得说"]
|
return ["懒得说"]
|
||||||
|
|
||||||
@@ -420,7 +407,7 @@ def calculate_typing_time(
|
|||||||
# chinese_time *= 1 / typing_speed_multiplier
|
# chinese_time *= 1 / typing_speed_multiplier
|
||||||
# english_time *= 1 / typing_speed_multiplier
|
# english_time *= 1 / typing_speed_multiplier
|
||||||
# 计算中文字符数
|
# 计算中文字符数
|
||||||
chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff")
|
chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
|
||||||
|
|
||||||
# 如果只有一个中文字符,使用3倍时间
|
# 如果只有一个中文字符,使用3倍时间
|
||||||
if chinese_chars == 1 and len(input_string.strip()) == 1:
|
if chinese_chars == 1 and len(input_string.strip()) == 1:
|
||||||
@@ -429,11 +416,7 @@ def calculate_typing_time(
|
|||||||
# 正常计算所有字符的输入时间
|
# 正常计算所有字符的输入时间
|
||||||
total_time = 0.0
|
total_time = 0.0
|
||||||
for char in input_string:
|
for char in input_string:
|
||||||
if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符
|
total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
|
||||||
total_time += chinese_time
|
|
||||||
else: # 其他字符(如英文)
|
|
||||||
total_time += english_time
|
|
||||||
|
|
||||||
if is_emoji:
|
if is_emoji:
|
||||||
total_time = 1
|
total_time = 1
|
||||||
|
|
||||||
@@ -453,18 +436,14 @@ def cosine_similarity(v1, v2):
|
|||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
norm1 = np.linalg.norm(v1)
|
norm1 = np.linalg.norm(v1)
|
||||||
norm2 = np.linalg.norm(v2)
|
norm2 = np.linalg.norm(v2)
|
||||||
if norm1 == 0 or norm2 == 0:
|
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||||
return 0
|
|
||||||
return dot_product / (norm1 * norm2)
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_vector(text):
|
def text_to_vector(text):
|
||||||
"""将文本转换为词频向量"""
|
"""将文本转换为词频向量"""
|
||||||
# 分词
|
# 分词
|
||||||
words = jieba.lcut(text)
|
words = jieba.lcut(text)
|
||||||
# 统计词频
|
return Counter(words)
|
||||||
word_freq = Counter(words)
|
|
||||||
return word_freq
|
|
||||||
|
|
||||||
|
|
||||||
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||||
@@ -491,9 +470,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
|||||||
|
|
||||||
def truncate_message(message: str, max_length=20) -> str:
|
def truncate_message(message: str, max_length=20) -> str:
|
||||||
"""截断消息,使其不超过指定长度"""
|
"""截断消息,使其不超过指定长度"""
|
||||||
if len(message) > max_length:
|
return f"{message[:max_length]}..." if len(message) > max_length else message
|
||||||
return message[:max_length] + "..."
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def protect_kaomoji(sentence):
|
def protect_kaomoji(sentence):
|
||||||
@@ -522,7 +499,7 @@ def protect_kaomoji(sentence):
|
|||||||
placeholder_to_kaomoji = {}
|
placeholder_to_kaomoji = {}
|
||||||
|
|
||||||
for idx, match in enumerate(kaomoji_matches):
|
for idx, match in enumerate(kaomoji_matches):
|
||||||
kaomoji = match[0] if match[0] else match[1]
|
kaomoji = match[0] or match[1]
|
||||||
placeholder = f"__KAOMOJI_{idx}__"
|
placeholder = f"__KAOMOJI_{idx}__"
|
||||||
sentence = sentence.replace(kaomoji, placeholder, 1)
|
sentence = sentence.replace(kaomoji, placeholder, 1)
|
||||||
placeholder_to_kaomoji[placeholder] = kaomoji
|
placeholder_to_kaomoji[placeholder] = kaomoji
|
||||||
@@ -563,7 +540,7 @@ def get_western_ratio(paragraph):
|
|||||||
if not alnum_chars:
|
if not alnum_chars:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
western_count = sum(1 for char in alnum_chars if is_english_letter(char))
|
western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
|
||||||
return western_count / len(alnum_chars)
|
return western_count / len(alnum_chars)
|
||||||
|
|
||||||
|
|
||||||
@@ -610,6 +587,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
|||||||
|
|
||||||
|
|
||||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||||
|
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
||||||
"""将时间戳转换为人类可读的时间格式
|
"""将时间戳转换为人类可读的时间格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -621,7 +599,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
|||||||
"""
|
"""
|
||||||
if mode == "normal":
|
if mode == "normal":
|
||||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
|
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
|
||||||
if mode == "normal_no_YMD":
|
elif mode == "normal_no_YMD":
|
||||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||||
elif mode == "relative":
|
elif mode == "relative":
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@@ -640,7 +618,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
|||||||
else:
|
else:
|
||||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
|
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
|
||||||
else: # mode = "lite" or unknown
|
else: # mode = "lite" or unknown
|
||||||
# 只返回时分秒格式,喵~
|
# 只返回时分秒格式
|
||||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||||
|
|
||||||
|
|
||||||
@@ -670,8 +648,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
elif chat_stream.user_info: # It's a private chat
|
elif chat_stream.user_info: # It's a private chat
|
||||||
is_group_chat = False
|
is_group_chat = False
|
||||||
user_info = chat_stream.user_info
|
user_info = chat_stream.user_info
|
||||||
platform = chat_stream.platform
|
platform: str = chat_stream.platform # type: ignore
|
||||||
user_id = user_info.user_id
|
user_id: str = user_info.user_id # type: ignore
|
||||||
|
|
||||||
# Initialize target_info with basic info
|
# Initialize target_info with basic info
|
||||||
target_info = {
|
target_info = {
|
||||||
|
|||||||
@@ -3,21 +3,20 @@ import os
|
|||||||
import time
|
import time
|
||||||
import hashlib
|
import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
|
import io
|
||||||
|
import asyncio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
from rich.traceback import install
|
||||||
import numpy as np
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database import db
|
from src.common.database.database import db
|
||||||
from src.common.database.database_model import Images, ImageDescriptions
|
from src.common.database.database_model import Images, ImageDescriptions
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("chat_image")
|
logger = get_logger("chat_image")
|
||||||
@@ -103,7 +102,7 @@ class ImageManager:
|
|||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
|
|
||||||
# 查询缓存的描述
|
# 查询缓存的描述
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||||
@@ -111,7 +110,7 @@ class ImageManager:
|
|||||||
return f"[表情包,含义看起来是:{cached_description}]"
|
return f"[表情包,含义看起来是:{cached_description}]"
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
if image_format == "gif" or image_format == "GIF":
|
if image_format in ["gif", "GIF"]:
|
||||||
image_base64_processed = self.transform_gif(image_base64)
|
image_base64_processed = self.transform_gif(image_base64)
|
||||||
if image_base64_processed is None:
|
if image_base64_processed is None:
|
||||||
logger.warning("GIF转换失败,无法获取描述")
|
logger.warning("GIF转换失败,无法获取描述")
|
||||||
@@ -154,7 +153,7 @@ class ImageManager:
|
|||||||
img_obj.description = description
|
img_obj.description = description
|
||||||
img_obj.timestamp = current_timestamp
|
img_obj.timestamp = current_timestamp
|
||||||
img_obj.save()
|
img_obj.save()
|
||||||
except Images.DoesNotExist:
|
except Images.DoesNotExist: # type: ignore
|
||||||
Images.create(
|
Images.create(
|
||||||
emoji_hash=image_hash,
|
emoji_hash=image_hash,
|
||||||
path=file_path,
|
path=file_path,
|
||||||
@@ -204,7 +203,7 @@ class ImageManager:
|
|||||||
return f"[图片:{cached_description}]"
|
return f"[图片:{cached_description}]"
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,输出为一段平文本,最多50字"
|
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,输出为一段平文本,最多50字"
|
||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||||
|
|
||||||
@@ -258,6 +257,7 @@ class ImageManager:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
|
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
|
||||||
|
# sourcery skip: use-contextlib-suppress
|
||||||
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
|
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -351,7 +351,7 @@ class ImageManager:
|
|||||||
# 创建拼接图像
|
# 创建拼接图像
|
||||||
total_width = target_width * len(resized_frames)
|
total_width = target_width * len(resized_frames)
|
||||||
# 防止总宽度为0
|
# 防止总宽度为0
|
||||||
if total_width == 0 and len(resized_frames) > 0:
|
if total_width == 0 and resized_frames:
|
||||||
logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小")
|
logger.warning("计算出的总宽度为0,但有选中帧,可能目标宽度太小")
|
||||||
# 至少给点宽度吧
|
# 至少给点宽度吧
|
||||||
total_width = len(resized_frames)
|
total_width = len(resized_frames)
|
||||||
@@ -368,10 +368,7 @@ class ImageManager:
|
|||||||
# 转换为base64
|
# 转换为base64
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
|
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
|
||||||
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
return result_base64
|
|
||||||
|
|
||||||
except MemoryError:
|
except MemoryError:
|
||||||
logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多")
|
logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多")
|
||||||
return None # 内存不够啦
|
return None # 内存不够啦
|
||||||
@@ -380,6 +377,7 @@ class ImageManager:
|
|||||||
return None # 其他错误也返回None
|
return None # 其他错误也返回None
|
||||||
|
|
||||||
async def process_image(self, image_base64: str) -> Tuple[str, str]:
|
async def process_image(self, image_base64: str) -> Tuple[str, str]:
|
||||||
|
# sourcery skip: hoist-if-from-if
|
||||||
"""处理图片并返回图片ID和描述
|
"""处理图片并返回图片ID和描述
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -422,14 +420,6 @@ class ImageManager:
|
|||||||
existing_image.save()
|
existing_image.save()
|
||||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||||
else:
|
else:
|
||||||
# print(f"图片已存在: {existing_image.image_id}")
|
|
||||||
# print(f"图片描述: {existing_image.description}")
|
|
||||||
# print(f"图片计数: {existing_image.count}")
|
|
||||||
# 更新计数
|
|
||||||
existing_image.count += 1
|
|
||||||
existing_image.save()
|
|
||||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
|
||||||
else:
|
|
||||||
# print(f"图片不存在: {image_hash}")
|
# print(f"图片不存在: {image_hash}")
|
||||||
image_id = str(uuid.uuid4())
|
image_id = str(uuid.uuid4())
|
||||||
|
|
||||||
@@ -491,7 +481,7 @@ class ImageManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 获取图片格式
|
# 获取图片格式
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
prompt = """请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本"""
|
prompt = """请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本"""
|
||||||
|
|||||||
@@ -35,9 +35,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
|||||||
|
|
||||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||||
|
|
||||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
|
return min(max((current_willing - 0.5), 0.01) * 2, 1)
|
||||||
|
|
||||||
return reply_probability
|
|
||||||
|
|
||||||
async def before_generate_reply_handle(self, message_id):
|
async def before_generate_reply_handle(self, message_id):
|
||||||
chat_id = self.ongoing_messages[message_id].chat_id
|
chat_id = self.ongoing_messages[message_id].chat_id
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from src.common.logger import get_logger
|
import importlib
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from rich.traceback import install
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import importlib
|
|
||||||
from typing import Dict, Optional
|
|
||||||
import asyncio
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -92,8 +94,8 @@ class BaseWillingManager(ABC):
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def setup(self, message: dict, chat: ChatStream):
|
def setup(self, message: dict, chat: ChatStream):
|
||||||
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
|
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
|
||||||
self.ongoing_messages[message.get("message_id")] = WillingInfo(
|
self.ongoing_messages[message.message_info.message_id] = WillingInfo( # type: ignore
|
||||||
message=message,
|
message=message,
|
||||||
chat=chat,
|
chat=chat,
|
||||||
person_info_manager=get_person_info_manager(),
|
person_info_manager=get_person_info_manager(),
|
||||||
|
|||||||
@@ -54,11 +54,11 @@ class DBWrapper:
|
|||||||
return getattr(get_db(), name)
|
return getattr(get_db(), name)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return get_db()[key]
|
return get_db()[key] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# 全局数据库访问点
|
# 全局数据库访问点
|
||||||
memory_db: Database = DBWrapper()
|
memory_db: Database = DBWrapper() # type: ignore
|
||||||
|
|
||||||
# 定义数据库文件路径
|
# 定义数据库文件路径
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
|||||||
@@ -414,9 +414,7 @@ def initialize_database():
|
|||||||
existing_columns = {row[1] for row in cursor.fetchall()}
|
existing_columns = {row[1] for row in cursor.fetchall()}
|
||||||
model_fields = set(model._meta.fields.keys())
|
model_fields = set(model._meta.fields.keys())
|
||||||
|
|
||||||
# 检查并添加缺失字段(原有逻辑)
|
if missing_fields := model_fields - existing_columns:
|
||||||
missing_fields = model_fields - existing_columns
|
|
||||||
if missing_fields:
|
|
||||||
logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}")
|
logger.warning(f"表 '{table_name}' 缺失字段: {missing_fields}")
|
||||||
|
|
||||||
for field_name, field_obj in model._meta.fields.items():
|
for field_name, field_obj in model._meta.fields.items():
|
||||||
@@ -432,10 +430,7 @@ def initialize_database():
|
|||||||
"DateTimeField": "DATETIME",
|
"DateTimeField": "DATETIME",
|
||||||
}.get(field_type, "TEXT")
|
}.get(field_type, "TEXT")
|
||||||
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
||||||
if field_obj.null:
|
alter_sql += " NULL" if field_obj.null else " NOT NULL"
|
||||||
alter_sql += " NULL"
|
|
||||||
else:
|
|
||||||
alter_sql += " NOT NULL"
|
|
||||||
if hasattr(field_obj, "default") and field_obj.default is not None:
|
if hasattr(field_obj, "default") and field_obj.default is not None:
|
||||||
# 正确处理不同类型的默认值
|
# 正确处理不同类型的默认值
|
||||||
default_value = field_obj.default
|
default_value = field_obj.default
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Optional
|
import logging
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
# 创建logs目录
|
# 创建logs目录
|
||||||
LOG_DIR = Path("logs")
|
LOG_DIR = Path("logs")
|
||||||
LOG_DIR.mkdir(exist_ok=True)
|
LOG_DIR.mkdir(exist_ok=True)
|
||||||
@@ -160,7 +160,7 @@ def close_handlers():
|
|||||||
_console_handler = None
|
_console_handler = None
|
||||||
|
|
||||||
|
|
||||||
def remove_duplicate_handlers():
|
def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
|
||||||
"""移除重复的handler,特别是文件handler"""
|
"""移除重复的handler,特别是文件handler"""
|
||||||
root_logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
@@ -184,7 +184,7 @@ def remove_duplicate_handlers():
|
|||||||
|
|
||||||
|
|
||||||
# 读取日志配置
|
# 读取日志配置
|
||||||
def load_log_config():
|
def load_log_config(): # sourcery skip: use-contextlib-suppress
|
||||||
"""从配置文件加载日志设置"""
|
"""从配置文件加载日志设置"""
|
||||||
config_path = Path("config/bot_config.toml")
|
config_path = Path("config/bot_config.toml")
|
||||||
default_config = {
|
default_config = {
|
||||||
@@ -365,7 +365,7 @@ MODULE_COLORS = {
|
|||||||
"component_registry": "\033[38;5;214m", # 橙黄色
|
"component_registry": "\033[38;5;214m", # 橙黄色
|
||||||
"stream_api": "\033[38;5;220m", # 黄色
|
"stream_api": "\033[38;5;220m", # 黄色
|
||||||
"config_api": "\033[38;5;226m", # 亮黄色
|
"config_api": "\033[38;5;226m", # 亮黄色
|
||||||
"hearflow_api": "\033[38;5;154m", # 黄绿色
|
"heartflow_api": "\033[38;5;154m", # 黄绿色
|
||||||
"action_apis": "\033[38;5;118m", # 绿色
|
"action_apis": "\033[38;5;118m", # 绿色
|
||||||
"independent_apis": "\033[38;5;82m", # 绿色
|
"independent_apis": "\033[38;5;82m", # 绿色
|
||||||
"llm_api": "\033[38;5;46m", # 亮绿色
|
"llm_api": "\033[38;5;46m", # 亮绿色
|
||||||
@@ -412,6 +412,7 @@ class ModuleColoredConsoleRenderer:
|
|||||||
"""自定义控制台渲染器,为不同模块提供不同颜色"""
|
"""自定义控制台渲染器,为不同模块提供不同颜色"""
|
||||||
|
|
||||||
def __init__(self, colors=True):
|
def __init__(self, colors=True):
|
||||||
|
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
||||||
self._colors = colors
|
self._colors = colors
|
||||||
self._config = LOG_CONFIG
|
self._config = LOG_CONFIG
|
||||||
|
|
||||||
@@ -443,6 +444,7 @@ class ModuleColoredConsoleRenderer:
|
|||||||
self._enable_full_content_colors = False
|
self._enable_full_content_colors = False
|
||||||
|
|
||||||
def __call__(self, logger, method_name, event_dict):
|
def __call__(self, logger, method_name, event_dict):
|
||||||
|
# sourcery skip: merge-duplicate-blocks
|
||||||
"""渲染日志消息"""
|
"""渲染日志消息"""
|
||||||
# 获取基本信息
|
# 获取基本信息
|
||||||
timestamp = event_dict.get("timestamp", "")
|
timestamp = event_dict.get("timestamp", "")
|
||||||
@@ -662,7 +664,7 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
|||||||
"""获取logger实例,支持按名称绑定"""
|
"""获取logger实例,支持按名称绑定"""
|
||||||
if name is None:
|
if name is None:
|
||||||
return raw_logger
|
return raw_logger
|
||||||
logger = binds.get(name)
|
logger = binds.get(name) # type: ignore
|
||||||
if logger is None:
|
if logger is None:
|
||||||
logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name)
|
logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name)
|
||||||
binds[name] = logger
|
binds[name] = logger
|
||||||
@@ -671,8 +673,8 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
|||||||
|
|
||||||
def configure_logging(
|
def configure_logging(
|
||||||
level: str = "INFO",
|
level: str = "INFO",
|
||||||
console_level: str = None,
|
console_level: Optional[str] = None,
|
||||||
file_level: str = None,
|
file_level: Optional[str] = None,
|
||||||
max_bytes: int = 5 * 1024 * 1024,
|
max_bytes: int = 5 * 1024 * 1024,
|
||||||
backup_count: int = 30,
|
backup_count: int = 30,
|
||||||
log_dir: str = "logs",
|
log_dir: str = "logs",
|
||||||
@@ -729,14 +731,11 @@ def reload_log_config():
|
|||||||
global LOG_CONFIG
|
global LOG_CONFIG
|
||||||
LOG_CONFIG = load_log_config()
|
LOG_CONFIG = load_log_config()
|
||||||
|
|
||||||
# 重新设置handler的日志级别
|
if file_handler := get_file_handler():
|
||||||
file_handler = get_file_handler()
|
|
||||||
if file_handler:
|
|
||||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||||
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
|
file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
|
||||||
|
|
||||||
console_handler = get_console_handler()
|
if console_handler := get_console_handler():
|
||||||
if console_handler:
|
|
||||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||||
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
|
console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
|
||||||
|
|
||||||
@@ -780,8 +779,7 @@ def set_console_log_level(level: str):
|
|||||||
global LOG_CONFIG
|
global LOG_CONFIG
|
||||||
LOG_CONFIG["console_log_level"] = level.upper()
|
LOG_CONFIG["console_log_level"] = level.upper()
|
||||||
|
|
||||||
console_handler = get_console_handler()
|
if console_handler := get_console_handler():
|
||||||
if console_handler:
|
|
||||||
console_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
|
console_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||||
|
|
||||||
# 重新设置root logger级别
|
# 重新设置root logger级别
|
||||||
@@ -800,8 +798,7 @@ def set_file_log_level(level: str):
|
|||||||
global LOG_CONFIG
|
global LOG_CONFIG
|
||||||
LOG_CONFIG["file_log_level"] = level.upper()
|
LOG_CONFIG["file_log_level"] = level.upper()
|
||||||
|
|
||||||
file_handler = get_file_handler()
|
if file_handler := get_file_handler():
|
||||||
if file_handler:
|
|
||||||
file_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
|
file_handler.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||||
|
|
||||||
# 重新设置root logger级别
|
# 重新设置root logger级别
|
||||||
@@ -933,13 +930,12 @@ def format_json_for_logging(data, indent=2, ensure_ascii=False):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 格式化后的JSON字符串
|
str: 格式化后的JSON字符串
|
||||||
"""
|
"""
|
||||||
if isinstance(data, str):
|
if not isinstance(data, str):
|
||||||
|
# 如果是对象,直接格式化
|
||||||
|
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
|
||||||
# 如果是JSON字符串,先解析再格式化
|
# 如果是JSON字符串,先解析再格式化
|
||||||
parsed_data = json.loads(data)
|
parsed_data = json.loads(data)
|
||||||
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
|
return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii)
|
||||||
else:
|
|
||||||
# 如果是对象,直接格式化
|
|
||||||
return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_old_logs():
|
def cleanup_old_logs():
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from src.config.config import global_config
|
|||||||
global_api = None
|
global_api = None
|
||||||
|
|
||||||
|
|
||||||
def get_global_api() -> MessageServer:
|
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||||
"""获取全局MessageServer实例"""
|
"""获取全局MessageServer实例"""
|
||||||
global global_api
|
global global_api
|
||||||
if global_api is None:
|
if global_api is None:
|
||||||
@@ -36,8 +36,7 @@ def get_global_api() -> MessageServer:
|
|||||||
kwargs["custom_logger"] = maim_message_logger
|
kwargs["custom_logger"] = maim_message_logger
|
||||||
|
|
||||||
# 添加token认证
|
# 添加token认证
|
||||||
if maim_message_config.auth_token:
|
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
|
||||||
if len(maim_message_config.auth_token) > 0:
|
|
||||||
kwargs["enable_token"] = True
|
kwargs["enable_token"] = True
|
||||||
|
|
||||||
if maim_message_config.use_custom:
|
if maim_message_config.use_custom:
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from src.common.database.database_model import Messages # 更改导入
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from typing import List, Any, Optional
|
from typing import List, Any, Optional
|
||||||
from peewee import Model # 添加 Peewee Model 导入
|
from peewee import Model # 添加 Peewee Model 导入
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
from src.common.database.database_model import Messages
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -20,7 +22,7 @@ def find_messages(
|
|||||||
sort: Optional[List[tuple[str, int]]] = None,
|
sort: Optional[List[tuple[str, int]]] = None,
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
fliter_bot = False
|
fliter_bot=False,
|
||||||
) -> List[dict[str, Any]]:
|
) -> List[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
根据提供的过滤器、排序和限制条件查找消息。
|
根据提供的过滤器、排序和限制条件查找消息。
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
self.server_url = TELEMETRY_SERVER_URL
|
self.server_url = TELEMETRY_SERVER_URL
|
||||||
"""遥测服务地址"""
|
"""遥测服务地址"""
|
||||||
|
|
||||||
self.client_uuid = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None
|
self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore
|
||||||
"""客户端UUID"""
|
"""客户端UUID"""
|
||||||
|
|
||||||
self.info_dict = self._get_sys_info()
|
self.info_dict = self._get_sys_info()
|
||||||
@@ -72,7 +72,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
|
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
|
||||||
) as response:
|
) as response:
|
||||||
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
|
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
|
||||||
logger.debug(local_storage["deploy_time"])
|
logger.debug(local_storage["deploy_time"]) # type: ignore
|
||||||
logger.debug(f"Response status: {response.status}")
|
logger.debug(f"Response status: {response.status}")
|
||||||
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
@@ -93,7 +93,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
error_msg = str(e) if str(e) else "未知错误"
|
error_msg = str(e) or "未知错误"
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"请求UUID出错,不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}"
|
f"请求UUID出错,不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}"
|
||||||
) # 可能是网络问题
|
) # 可能是网络问题
|
||||||
@@ -114,11 +114,11 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
"""向服务器发送心跳"""
|
"""向服务器发送心跳"""
|
||||||
headers = {
|
headers = {
|
||||||
"Client-UUID": self.client_uuid,
|
"Client-UUID": self.client_uuid,
|
||||||
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}",
|
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
||||||
logger.debug(headers)
|
logger.debug(str(headers))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||||
@@ -151,7 +151,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
error_msg = str(e) if str(e) else "未知错误"
|
error_msg = str(e) or "未知错误"
|
||||||
logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}")
|
logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}")
|
||||||
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
from tomlkit.items import Table
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -45,8 +46,8 @@ def update_config():
|
|||||||
|
|
||||||
# 检查version是否相同
|
# 检查version是否相同
|
||||||
if old_config and "inner" in old_config and "inner" in new_config:
|
if old_config and "inner" in old_config and "inner" in new_config:
|
||||||
old_version = old_config["inner"].get("version")
|
old_version = old_config["inner"].get("version") # type: ignore
|
||||||
new_version = new_config["inner"].get("version")
|
new_version = new_config["inner"].get("version") # type: ignore
|
||||||
if old_version and new_version and old_version == new_version:
|
if old_version and new_version and old_version == new_version:
|
||||||
print(f"检测到版本号相同 (v{old_version}),跳过更新")
|
print(f"检测到版本号相同 (v{old_version}),跳过更新")
|
||||||
# 如果version相同,恢复旧配置文件并返回
|
# 如果version相同,恢复旧配置文件并返回
|
||||||
@@ -62,7 +63,7 @@ def update_config():
|
|||||||
if key == "version":
|
if key == "version":
|
||||||
continue
|
continue
|
||||||
if key in target:
|
if key in target:
|
||||||
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
|
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||||
update_dict(target[key], value)
|
update_dict(target[key], value)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@@ -85,10 +86,7 @@ def update_config():
|
|||||||
if value and isinstance(value[0], dict) and "regex" in value[0]:
|
if value and isinstance(value[0], dict) and "regex" in value[0]:
|
||||||
contains_regex = True
|
contains_regex = True
|
||||||
|
|
||||||
if contains_regex:
|
target[key] = value if contains_regex else tomlkit.array(str(value))
|
||||||
target[key] = value
|
|
||||||
else:
|
|
||||||
target[key] = tomlkit.array(value)
|
|
||||||
else:
|
else:
|
||||||
# 其他类型使用item方法创建新值
|
# 其他类型使用item方法创建新值
|
||||||
target[key] = tomlkit.item(value)
|
target[key] = tomlkit.item(value)
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import field, dataclass
|
|
||||||
|
|
||||||
import tomlkit
|
import tomlkit
|
||||||
import shutil
|
import shutil
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
from tomlkit import TOMLDocument
|
from tomlkit import TOMLDocument
|
||||||
from tomlkit.items import Table
|
from tomlkit.items import Table
|
||||||
|
from dataclasses import field, dataclass
|
||||||
from src.common.logger import get_logger
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.config.config_base import ConfigBase
|
from src.config.config_base import ConfigBase
|
||||||
from src.config.official_configs import (
|
from src.config.official_configs import (
|
||||||
BotConfig,
|
BotConfig,
|
||||||
@@ -80,8 +78,8 @@ def update_config():
|
|||||||
|
|
||||||
# 检查version是否相同
|
# 检查version是否相同
|
||||||
if old_config and "inner" in old_config and "inner" in new_config:
|
if old_config and "inner" in old_config and "inner" in new_config:
|
||||||
old_version = old_config["inner"].get("version")
|
old_version = old_config["inner"].get("version") # type: ignore
|
||||||
new_version = new_config["inner"].get("version")
|
new_version = new_config["inner"].get("version") # type: ignore
|
||||||
if old_version and new_version and old_version == new_version:
|
if old_version and new_version and old_version == new_version:
|
||||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||||
return
|
return
|
||||||
@@ -103,7 +101,7 @@ def update_config():
|
|||||||
shutil.copy2(template_path, new_config_path)
|
shutil.copy2(template_path, new_config_path)
|
||||||
logger.info(f"已创建新配置文件: {new_config_path}")
|
logger.info(f"已创建新配置文件: {new_config_path}")
|
||||||
|
|
||||||
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
|
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||||
"""
|
"""
|
||||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||||
"""
|
"""
|
||||||
@@ -112,8 +110,9 @@ def update_config():
|
|||||||
if key == "version":
|
if key == "version":
|
||||||
continue
|
continue
|
||||||
if key in target:
|
if key in target:
|
||||||
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
target_value = target[key]
|
||||||
update_dict(target[key], value)
|
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
||||||
|
update_dict(target_value, value)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# 对数组类型进行特殊处理
|
# 对数组类型进行特殊处理
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class ConfigBase:
|
|||||||
field_type = f.type
|
field_type = f.type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
init_args[field_name] = cls._convert_field(value, field_type)
|
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Literal
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from src.config.config_base import ConfigBase
|
from src.config.config_base import ConfigBase
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -110,7 +111,7 @@ class ChatConfig(ConfigBase):
|
|||||||
exit_focus_threshold: float = 1.0
|
exit_focus_threshold: float = 1.0
|
||||||
"""自动退出专注聊天的阈值,越低越容易退出专注聊天"""
|
"""自动退出专注聊天的阈值,越低越容易退出专注聊天"""
|
||||||
|
|
||||||
def get_current_talk_frequency(self, chat_stream_id: str = None) -> float:
|
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||||
"""
|
"""
|
||||||
根据当前时间和聊天流获取对应的 talk_frequency
|
根据当前时间和聊天流获取对应的 talk_frequency
|
||||||
|
|
||||||
@@ -135,7 +136,7 @@ class ChatConfig(ConfigBase):
|
|||||||
# 如果都没有匹配,返回默认值
|
# 如果都没有匹配,返回默认值
|
||||||
return self.talk_frequency
|
return self.talk_frequency
|
||||||
|
|
||||||
def _get_time_based_frequency(self, time_freq_list: list[str]) -> float:
|
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
根据时间配置列表获取当前时段的频率
|
根据时间配置列表获取当前时段的频率
|
||||||
|
|
||||||
@@ -183,7 +184,7 @@ class ChatConfig(ConfigBase):
|
|||||||
|
|
||||||
return current_frequency
|
return current_frequency
|
||||||
|
|
||||||
def _get_stream_specific_frequency(self, chat_stream_id: str) -> float:
|
def _get_stream_specific_frequency(self, chat_stream_id: str):
|
||||||
"""
|
"""
|
||||||
获取特定聊天流在当前时间的频率
|
获取特定聊天流在当前时间的频率
|
||||||
|
|
||||||
@@ -214,7 +215,7 @@ class ChatConfig(ConfigBase):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> str:
|
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
解析流配置字符串并生成对应的 chat_id
|
解析流配置字符串并生成对应的 chat_id
|
||||||
|
|
||||||
@@ -278,7 +279,6 @@ class NormalChatConfig(ConfigBase):
|
|||||||
"""@bot 必然回复"""
|
"""@bot 必然回复"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FocusChatConfig(ConfigBase):
|
class FocusChatConfig(ConfigBase):
|
||||||
"""专注聊天配置类"""
|
"""专注聊天配置类"""
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from typing import Dict, Any, Optional
|
|||||||
from src.chat.message_receive.message import Message
|
from src.chat.message_receive.message import Message
|
||||||
from .pfc_types import ConversationState
|
from .pfc_types import ConversationState
|
||||||
from .pfc import ChatObserver, GoalAnalyzer
|
from .pfc import ChatObserver, GoalAnalyzer
|
||||||
from .message_sender import DirectMessageSender
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .action_planner import ActionPlanner
|
from .action_planner import ActionPlanner
|
||||||
from .observation_info import ObservationInfo
|
from .observation_info import ObservationInfo
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -8,7 +8,7 @@ class Identity:
|
|||||||
|
|
||||||
identity_detail: List[str] # 身份细节描述
|
identity_detail: List[str] # 身份细节描述
|
||||||
|
|
||||||
def __init__(self, identity_detail: List[str] = None):
|
def __init__(self, identity_detail: Optional[List[str]] = None):
|
||||||
"""初始化身份特征
|
"""初始化身份特征
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
from typing import Optional
|
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from .personality import Personality
|
|
||||||
from .identity import Identity
|
|
||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
from .personality import Personality
|
||||||
|
from .identity import Identity
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ class Individuality:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 正常初始化实例属性
|
# 正常初始化实例属性
|
||||||
self.personality: Optional[Personality] = None
|
self.personality: Personality = None # type: ignore
|
||||||
self.identity: Optional[Identity] = None
|
self.identity: Optional[Identity] = None
|
||||||
|
|
||||||
self.name = ""
|
self.name = ""
|
||||||
@@ -109,7 +110,7 @@ class Individuality:
|
|||||||
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
|
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
|
||||||
if existing_short_impression:
|
if existing_short_impression:
|
||||||
try:
|
try:
|
||||||
existing_data = ast.literal_eval(existing_short_impression)
|
existing_data = ast.literal_eval(existing_short_impression) # type: ignore
|
||||||
if isinstance(existing_data, list) and len(existing_data) >= 1:
|
if isinstance(existing_data, list) and len(existing_data) >= 1:
|
||||||
personality_result = existing_data[0]
|
personality_result = existing_data[0]
|
||||||
except (json.JSONDecodeError, TypeError, IndexError):
|
except (json.JSONDecodeError, TypeError, IndexError):
|
||||||
@@ -128,7 +129,7 @@ class Individuality:
|
|||||||
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
|
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
|
||||||
if existing_short_impression:
|
if existing_short_impression:
|
||||||
try:
|
try:
|
||||||
existing_data = ast.literal_eval(existing_short_impression)
|
existing_data = ast.literal_eval(existing_short_impression) # type: ignore
|
||||||
if isinstance(existing_data, list) and len(existing_data) >= 2:
|
if isinstance(existing_data, list) and len(existing_data) >= 2:
|
||||||
identity_result = existing_data[1]
|
identity_result = existing_data[1]
|
||||||
except (json.JSONDecodeError, TypeError, IndexError):
|
except (json.JSONDecodeError, TypeError, IndexError):
|
||||||
@@ -204,6 +205,7 @@ class Individuality:
|
|||||||
return prompt_personality
|
return prompt_personality
|
||||||
|
|
||||||
def get_identity_prompt(self, level: int, x_person: int = 2) -> str:
|
def get_identity_prompt(self, level: int, x_person: int = 2) -> str:
|
||||||
|
# sourcery skip: assign-if-exp, merge-else-if-into-elif
|
||||||
"""
|
"""
|
||||||
获取身份特征的prompt
|
获取身份特征的prompt
|
||||||
|
|
||||||
@@ -240,13 +242,13 @@ class Individuality:
|
|||||||
|
|
||||||
if identity_parts:
|
if identity_parts:
|
||||||
details_str = ",".join(identity_parts)
|
details_str = ",".join(identity_parts)
|
||||||
if x_person in [1, 2]:
|
if x_person in {1, 2}:
|
||||||
return f"{i_pronoun},{details_str}。"
|
return f"{i_pronoun},{details_str}。"
|
||||||
else: # x_person == 0
|
else: # x_person == 0
|
||||||
# 无人称时,直接返回细节,不加代词和开头的逗号
|
# 无人称时,直接返回细节,不加代词和开头的逗号
|
||||||
return f"{details_str}。"
|
return f"{details_str}。"
|
||||||
else:
|
else:
|
||||||
if x_person in [1, 2]:
|
if x_person in {1, 2}:
|
||||||
return f"{i_pronoun}的身份信息不完整。"
|
return f"{i_pronoun}的身份信息不完整。"
|
||||||
else: # x_person == 0
|
else: # x_person == 0
|
||||||
return "身份信息不完整。"
|
return "身份信息不完整。"
|
||||||
@@ -441,14 +443,15 @@ class Individuality:
|
|||||||
if info_list_json:
|
if info_list_json:
|
||||||
try:
|
try:
|
||||||
info_list = json.loads(info_list_json) if isinstance(info_list_json, str) else info_list_json
|
info_list = json.loads(info_list_json) if isinstance(info_list_json, str) else info_list_json
|
||||||
for item in info_list:
|
keywords.extend(
|
||||||
if isinstance(item, dict) and "info_type" in item:
|
item["info_type"] for item in info_list if isinstance(item, dict) and "info_type" in item
|
||||||
keywords.append(item["info_type"])
|
)
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
logger.error(f"解析info_list失败: {info_list_json}")
|
logger.error(f"解析info_list失败: {info_list_json}")
|
||||||
return keywords
|
return keywords
|
||||||
|
|
||||||
async def _create_personality(self, personality_core: str, personality_sides: list) -> str:
|
async def _create_personality(self, personality_core: str, personality_sides: list) -> str:
|
||||||
|
# sourcery skip: merge-list-append, move-assign
|
||||||
"""使用LLM创建压缩版本的impression
|
"""使用LLM创建压缩版本的impression
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +25,7 @@ class Personality:
|
|||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, personality_core: str = "", personality_sides: List[str] = None):
|
def __init__(self, personality_core: str = "", personality_sides: Optional[List[str]] = None):
|
||||||
if personality_sides is None:
|
if personality_sides is None:
|
||||||
personality_sides = []
|
personality_sides = []
|
||||||
self.personality_core = personality_core
|
self.personality_core = personality_core
|
||||||
@@ -41,7 +42,7 @@ class Personality:
|
|||||||
cls._instance = cls()
|
cls._instance = cls()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def _init_big_five_personality(self):
|
def _init_big_five_personality(self): # sourcery skip: extract-method
|
||||||
"""初始化大五人格特质"""
|
"""初始化大五人格特质"""
|
||||||
# 构建文件路径
|
# 构建文件路径
|
||||||
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
|
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
|
||||||
@@ -63,7 +64,6 @@ class Personality:
|
|||||||
else:
|
else:
|
||||||
self.extraversion = 0.3
|
self.extraversion = 0.3
|
||||||
self.neuroticism = 0.5
|
self.neuroticism = 0.5
|
||||||
|
|
||||||
if "认真" in self.personality_core or "负责" in self.personality_sides:
|
if "认真" in self.personality_core or "负责" in self.personality_sides:
|
||||||
self.conscientiousness = 0.9
|
self.conscientiousness = 0.9
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ class MainSystem:
|
|||||||
logger.info("个体特征初始化成功")
|
logger.info("个体特征初始化成功")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
init_time = int(1000 * (time.time() - init_start_time))
|
init_time = int(1000 * (time.time() - init_start_time))
|
||||||
logger.info(f"初始化完成,神经元放电{init_time}次")
|
logger.info(f"初始化完成,神经元放电{init_time}次")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -83,10 +83,10 @@ class S4UMessageProcessor:
|
|||||||
# 1. 消息解析与初始化
|
# 1. 消息解析与初始化
|
||||||
groupinfo = message.message_info.group_info
|
groupinfo = message.message_info.group_info
|
||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
messageinfo = message.message_info
|
message_info = message.message_info
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform=messageinfo.platform,
|
platform=message_info.platform,
|
||||||
user_info=userinfo,
|
user_info=userinfo,
|
||||||
group_info=groupinfo,
|
group_info=groupinfo,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -120,12 +120,7 @@ class AsyncTaskManager:
|
|||||||
"""
|
"""
|
||||||
获取所有任务的状态
|
获取所有任务的状态
|
||||||
"""
|
"""
|
||||||
tasks_status = {}
|
return {task_name: {"status": "done" if task.done() else "running"} for task_name, task in self.tasks.items()}
|
||||||
for task_name, task in self.tasks.items():
|
|
||||||
tasks_status[task_name] = {
|
|
||||||
"status": "running" if not task.done() else "done",
|
|
||||||
}
|
|
||||||
return tasks_status
|
|
||||||
|
|
||||||
async def stop_and_wait_all_tasks(self):
|
async def stop_and_wait_all_tasks(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ import math
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from ..common.logger import get_logger
|
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||||
|
|
||||||
logger = get_logger("mood")
|
logger = get_logger("mood")
|
||||||
@@ -19,7 +19,7 @@ def init_prompt():
|
|||||||
{chat_talking_prompt}
|
{chat_talking_prompt}
|
||||||
以上是群里正在进行的聊天记录
|
以上是群里正在进行的聊天记录
|
||||||
|
|
||||||
{indentify_block}
|
{identity_block}
|
||||||
你刚刚的情绪状态是:{mood_state}
|
你刚刚的情绪状态是:{mood_state}
|
||||||
|
|
||||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态
|
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态
|
||||||
@@ -32,7 +32,7 @@ def init_prompt():
|
|||||||
{chat_talking_prompt}
|
{chat_talking_prompt}
|
||||||
以上是群里最近的聊天记录
|
以上是群里最近的聊天记录
|
||||||
|
|
||||||
{indentify_block}
|
{identity_block}
|
||||||
你之前的情绪状态是:{mood_state}
|
你之前的情绪状态是:{mood_state}
|
||||||
|
|
||||||
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
|
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
|
||||||
@@ -55,12 +55,12 @@ class ChatMood:
|
|||||||
request_type="mood",
|
request_type="mood",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.last_change_time = 0
|
self.last_change_time: float = 0
|
||||||
|
|
||||||
async def update_mood_by_message(self, message: MessageRecv, interested_rate: float):
|
async def update_mood_by_message(self, message: MessageRecv, interested_rate: float):
|
||||||
self.regression_count = 0
|
self.regression_count = 0
|
||||||
|
|
||||||
during_last_time = message.message_info.time - self.last_change_time
|
during_last_time = message.message_info.time - self.last_change_time # type: ignore
|
||||||
|
|
||||||
base_probability = 0.05
|
base_probability = 0.05
|
||||||
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
|
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
|
||||||
@@ -80,7 +80,7 @@ class ChatMood:
|
|||||||
|
|
||||||
logger.info(f"更新情绪状态,感兴趣度: {interested_rate}, 更新概率: {update_probability}")
|
logger.info(f"更新情绪状态,感兴趣度: {interested_rate}, 更新概率: {update_probability}")
|
||||||
|
|
||||||
message_time = message.message_info.time
|
message_time: float = message.message_info.time # type: ignore
|
||||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_change_time,
|
timestamp_start=self.last_change_time,
|
||||||
@@ -105,12 +105,12 @@ class ChatMood:
|
|||||||
bot_nickname = ""
|
bot_nickname = ""
|
||||||
|
|
||||||
prompt_personality = global_config.personality.personality_core
|
prompt_personality = global_config.personality.personality_core
|
||||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"change_mood_prompt",
|
"change_mood_prompt",
|
||||||
chat_talking_prompt=chat_talking_prompt,
|
chat_talking_prompt=chat_talking_prompt,
|
||||||
indentify_block=indentify_block,
|
identity_block=identity_block,
|
||||||
mood_state=self.mood_state,
|
mood_state=self.mood_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -121,7 +121,7 @@ class ChatMood:
|
|||||||
|
|
||||||
self.mood_state = response
|
self.mood_state = response
|
||||||
|
|
||||||
self.last_change_time = message_time
|
self.last_change_time = message_time # type: ignore
|
||||||
|
|
||||||
async def regress_mood(self):
|
async def regress_mood(self):
|
||||||
message_time = time.time()
|
message_time = time.time()
|
||||||
@@ -149,12 +149,12 @@ class ChatMood:
|
|||||||
bot_nickname = ""
|
bot_nickname = ""
|
||||||
|
|
||||||
prompt_personality = global_config.personality.personality_core
|
prompt_personality = global_config.personality.personality_core
|
||||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"regress_mood_prompt",
|
"regress_mood_prompt",
|
||||||
chat_talking_prompt=chat_talking_prompt,
|
chat_talking_prompt=chat_talking_prompt,
|
||||||
indentify_block=indentify_block,
|
identity_block=identity_block,
|
||||||
mood_state=self.mood_state,
|
mood_state=self.mood_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.database.database import db
|
|
||||||
from src.common.database.database_model import PersonInfo # 新增导入
|
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Any, Callable, Dict
|
|
||||||
import datetime
|
import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from json_repair import repair_json
|
||||||
|
from typing import Any, Callable, Dict, Union, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database import db
|
||||||
|
from src.common.database.database_model import PersonInfo
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
import json # 新增导入
|
|
||||||
from json_repair import repair_json
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
PersonInfoManager 类方法功能摘要:
|
PersonInfoManager 类方法功能摘要:
|
||||||
@@ -42,7 +43,7 @@ person_info_default = {
|
|||||||
"last_know": None,
|
"last_know": None,
|
||||||
# "user_cardname": None, # This field is not in Peewee model PersonInfo
|
# "user_cardname": None, # This field is not in Peewee model PersonInfo
|
||||||
# "user_avatar": None, # This field is not in Peewee model PersonInfo
|
# "user_avatar": None, # This field is not in Peewee model PersonInfo
|
||||||
"impression": None, # Corrected from persion_impression
|
"impression": None, # Corrected from person_impression
|
||||||
"short_impression": None,
|
"short_impression": None,
|
||||||
"info_list": None,
|
"info_list": None,
|
||||||
"points": None,
|
"points": None,
|
||||||
@@ -84,7 +85,7 @@ class PersonInfoManager:
|
|||||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_person_id(platform: str, user_id: int):
|
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||||
"""获取唯一id"""
|
"""获取唯一id"""
|
||||||
if "-" in platform:
|
if "-" in platform:
|
||||||
platform = platform.split("-")[1]
|
platform = platform.split("-")[1]
|
||||||
@@ -106,27 +107,24 @@ class PersonInfoManager:
|
|||||||
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
|
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_person_id_by_person_name(self, person_name: str):
|
def get_person_id_by_person_name(self, person_name: str) -> str:
|
||||||
"""根据用户名获取用户ID"""
|
"""根据用户名获取用户ID"""
|
||||||
try:
|
try:
|
||||||
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
|
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
|
||||||
if record:
|
return record.person_id if record else ""
|
||||||
return record.person_id
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
|
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_person_info(person_id: str, data: dict = None):
|
async def create_person_info(person_id: str, data: Optional[dict] = None):
|
||||||
"""创建一个项"""
|
"""创建一个项"""
|
||||||
if not person_id:
|
if not person_id:
|
||||||
logger.debug("创建失败,personid不存在")
|
logger.debug("创建失败,person_id不存在")
|
||||||
return
|
return
|
||||||
|
|
||||||
_person_info_default = copy.deepcopy(person_info_default)
|
_person_info_default = copy.deepcopy(person_info_default)
|
||||||
model_fields = PersonInfo._meta.fields.keys()
|
model_fields = PersonInfo._meta.fields.keys() # type: ignore
|
||||||
|
|
||||||
final_data = {"person_id": person_id}
|
final_data = {"person_id": person_id}
|
||||||
|
|
||||||
@@ -163,9 +161,9 @@ class PersonInfoManager:
|
|||||||
|
|
||||||
await asyncio.to_thread(_db_create_sync, final_data)
|
await asyncio.to_thread(_db_create_sync, final_data)
|
||||||
|
|
||||||
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
|
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
|
||||||
"""更新某一个字段,会补全"""
|
"""更新某一个字段,会补全"""
|
||||||
if field_name not in PersonInfo._meta.fields:
|
if field_name not in PersonInfo._meta.fields: # type: ignore
|
||||||
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
|
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -228,15 +226,13 @@ class PersonInfoManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def has_one_field(person_id: str, field_name: str):
|
async def has_one_field(person_id: str, field_name: str):
|
||||||
"""判断是否存在某一个字段"""
|
"""判断是否存在某一个字段"""
|
||||||
if field_name not in PersonInfo._meta.fields:
|
if field_name not in PersonInfo._meta.fields: # type: ignore
|
||||||
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
|
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _db_has_field_sync(p_id: str, f_name: str):
|
def _db_has_field_sync(p_id: str, f_name: str):
|
||||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||||
if record:
|
return bool(record)
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
|
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
|
||||||
@@ -435,9 +431,7 @@ class PersonInfoManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
||||||
# Fallback to default in case of any error during DB access
|
# Fallback to default in case of any error during DB access
|
||||||
if field_name in person_info_default:
|
return default_value_for_field if field_name in person_info_default else None
|
||||||
return default_value_for_field
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_value_sync(person_id: str, field_name: str):
|
def get_value_sync(person_id: str, field_name: str):
|
||||||
@@ -446,8 +440,7 @@ class PersonInfoManager:
|
|||||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||||
default_value_for_field = []
|
default_value_for_field = []
|
||||||
|
|
||||||
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
if record := PersonInfo.get_or_none(PersonInfo.person_id == person_id):
|
||||||
if record:
|
|
||||||
val = getattr(record, field_name, None)
|
val = getattr(record, field_name, None)
|
||||||
if field_name in JSON_SERIALIZED_FIELDS:
|
if field_name in JSON_SERIALIZED_FIELDS:
|
||||||
if isinstance(val, str):
|
if isinstance(val, str):
|
||||||
@@ -481,7 +474,7 @@ class PersonInfoManager:
|
|||||||
record = await asyncio.to_thread(_db_get_record_sync, person_id)
|
record = await asyncio.to_thread(_db_get_record_sync, person_id)
|
||||||
|
|
||||||
for field_name in field_names:
|
for field_name in field_names:
|
||||||
if field_name not in PersonInfo._meta.fields:
|
if field_name not in PersonInfo._meta.fields: # type: ignore
|
||||||
if field_name in person_info_default:
|
if field_name in person_info_default:
|
||||||
result[field_name] = copy.deepcopy(person_info_default[field_name])
|
result[field_name] = copy.deepcopy(person_info_default[field_name])
|
||||||
logger.debug(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。")
|
logger.debug(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。")
|
||||||
@@ -509,7 +502,7 @@ class PersonInfoManager:
|
|||||||
"""
|
"""
|
||||||
获取满足条件的字段值字典
|
获取满足条件的字段值字典
|
||||||
"""
|
"""
|
||||||
if field_name not in PersonInfo._meta.fields:
|
if field_name not in PersonInfo._meta.fields: # type: ignore
|
||||||
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
|
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -531,7 +524,7 @@ class PersonInfoManager:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def get_or_create_person(
|
async def get_or_create_person(
|
||||||
self, platform: str, user_id: int, nickname: str = None, user_cardname: str = None, user_avatar: str = None
|
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
根据 platform 和 user_id 获取 person_id。
|
根据 platform 和 user_id 获取 person_id。
|
||||||
@@ -561,7 +554,7 @@ class PersonInfoManager:
|
|||||||
"points": [],
|
"points": [],
|
||||||
"forgotten_points": [],
|
"forgotten_points": [],
|
||||||
}
|
}
|
||||||
model_fields = PersonInfo._meta.fields.keys()
|
model_fields = PersonInfo._meta.fields.keys() # type: ignore
|
||||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||||
|
|
||||||
await self.create_person_info(person_id, data=filtered_initial_data)
|
await self.create_person_info(person_id, data=filtered_initial_data)
|
||||||
@@ -610,7 +603,9 @@ class PersonInfoManager:
|
|||||||
"name_reason",
|
"name_reason",
|
||||||
]
|
]
|
||||||
valid_fields_to_get = [
|
valid_fields_to_get = [
|
||||||
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
|
f
|
||||||
|
for f in required_fields
|
||||||
|
if f in PersonInfo._meta.fields or f in person_info_default # type: ignore
|
||||||
]
|
]
|
||||||
|
|
||||||
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ import traceback
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Any
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_manager import get_relationship_manager
|
||||||
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
|
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_by_timestamp_with_chat,
|
get_raw_msg_by_timestamp_with_chat,
|
||||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||||
@@ -45,7 +45,7 @@ class RelationshipBuilder:
|
|||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
# 新的消息段缓存结构:
|
# 新的消息段缓存结构:
|
||||||
# {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]}
|
# {person_id: [{"start_time": float, "end_time": float, "last_msg_time": float, "message_count": int}, ...]}
|
||||||
self.person_engaged_cache: Dict[str, List[Dict[str, any]]] = {}
|
self.person_engaged_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
# 持久化存储文件路径
|
# 持久化存储文件路径
|
||||||
self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl")
|
self.cache_file_path = os.path.join("data", "relationship", f"relationship_cache_{self.chat_id}.pkl")
|
||||||
@@ -210,11 +210,7 @@ class RelationshipBuilder:
|
|||||||
if person_id not in self.person_engaged_cache:
|
if person_id not in self.person_engaged_cache:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
total_count = 0
|
return sum(segment["message_count"] for segment in self.person_engaged_cache[person_id])
|
||||||
for segment in self.person_engaged_cache[person_id]:
|
|
||||||
total_count += segment["message_count"]
|
|
||||||
|
|
||||||
return total_count
|
|
||||||
|
|
||||||
def _cleanup_old_segments(self) -> bool:
|
def _cleanup_old_segments(self) -> bool:
|
||||||
"""清理老旧的消息段"""
|
"""清理老旧的消息段"""
|
||||||
@@ -289,7 +285,7 @@ class RelationshipBuilder:
|
|||||||
self.last_cleanup_time = current_time
|
self.last_cleanup_time = current_time
|
||||||
|
|
||||||
# 保存缓存
|
# 保存缓存
|
||||||
if cleanup_stats["segments_removed"] > 0 or len(users_to_remove) > 0:
|
if cleanup_stats["segments_removed"] > 0 or users_to_remove:
|
||||||
self._save_cache()
|
self._save_cache()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}"
|
f"{self.log_prefix} 清理完成 - 影响用户: {cleanup_stats['users_cleaned']}, 移除消息段: {cleanup_stats['segments_removed']}, 移除用户: {len(users_to_remove)}"
|
||||||
@@ -313,6 +309,7 @@ class RelationshipBuilder:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_cache_status(self) -> str:
|
def get_cache_status(self) -> str:
|
||||||
|
# sourcery skip: merge-list-append, merge-list-appends-into-extend
|
||||||
"""获取缓存状态信息,用于调试和监控"""
|
"""获取缓存状态信息,用于调试和监控"""
|
||||||
if not self.person_engaged_cache:
|
if not self.person_engaged_cache:
|
||||||
return f"{self.log_prefix} 关系缓存为空"
|
return f"{self.log_prefix} 关系缓存为空"
|
||||||
@@ -357,13 +354,12 @@ class RelationshipBuilder:
|
|||||||
self._cleanup_old_segments()
|
self._cleanup_old_segments()
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
latest_messages = get_raw_msg_by_timestamp_with_chat(
|
if latest_messages := get_raw_msg_by_timestamp_with_chat(
|
||||||
self.chat_id,
|
self.chat_id,
|
||||||
self.last_processed_message_time,
|
self.last_processed_message_time,
|
||||||
current_time,
|
current_time,
|
||||||
limit=50, # 获取自上次处理后的消息
|
limit=50, # 获取自上次处理后的消息
|
||||||
)
|
):
|
||||||
if latest_messages:
|
|
||||||
# 处理所有新的非bot消息
|
# 处理所有新的非bot消息
|
||||||
for latest_msg in latest_messages:
|
for latest_msg in latest_messages:
|
||||||
user_id = latest_msg.get("user_id")
|
user_id = latest_msg.get("user_id")
|
||||||
@@ -414,7 +410,7 @@ class RelationshipBuilder:
|
|||||||
# 负责触发关系构建、整合消息段、更新用户印象
|
# 负责触发关系构建、整合消息段、更新用户印象
|
||||||
# ================================
|
# ================================
|
||||||
|
|
||||||
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]):
|
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
|
||||||
"""基于消息段更新用户印象"""
|
"""基于消息段更新用户印象"""
|
||||||
original_segment_count = len(segments)
|
original_segment_count = len(segments)
|
||||||
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")
|
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Dict, Optional, List
|
from typing import Dict, Optional, List, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .relationship_builder import RelationshipBuilder
|
from .relationship_builder import RelationshipBuilder
|
||||||
|
|
||||||
@@ -63,7 +64,7 @@ class RelationshipBuilderManager:
|
|||||||
"""
|
"""
|
||||||
return list(self.builders.keys())
|
return list(self.builders.keys())
|
||||||
|
|
||||||
def get_status(self) -> Dict[str, any]:
|
def get_status(self) -> Dict[str, Any]:
|
||||||
"""获取管理器状态
|
"""获取管理器状态
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -94,9 +95,7 @@ class RelationshipBuilderManager:
|
|||||||
bool: 是否成功清理
|
bool: 是否成功清理
|
||||||
"""
|
"""
|
||||||
builder = self.get_builder(chat_id)
|
builder = self.get_builder(chat_id)
|
||||||
if builder:
|
return builder.force_cleanup_user_segments(person_id) if builder else False
|
||||||
return builder.force_cleanup_user_segments(person_id)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# 全局管理器实例
|
# 全局管理器实例
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
from src.config.config import global_config
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
from typing import List, Dict
|
|
||||||
from json_repair import repair_json
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("relationship_fetcher")
|
logger = get_logger("relationship_fetcher")
|
||||||
|
|
||||||
|
|
||||||
@@ -62,11 +65,11 @@ class RelationshipFetcher:
|
|||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
|
|
||||||
# 信息获取缓存:记录正在获取的信息请求
|
# 信息获取缓存:记录正在获取的信息请求
|
||||||
self.info_fetching_cache: List[Dict[str, any]] = []
|
self.info_fetching_cache: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
# 信息结果缓存:存储已获取的信息结果,带TTL
|
# 信息结果缓存:存储已获取的信息结果,带TTL
|
||||||
self.info_fetched_cache: Dict[str, Dict[str, any]] = {}
|
self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknow": bool}}}
|
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
|
||||||
|
|
||||||
# LLM模型配置
|
# LLM模型配置
|
||||||
self.llm_model = LLMRequest(
|
self.llm_model = LLMRequest(
|
||||||
@@ -184,7 +187,7 @@ class RelationshipFetcher:
|
|||||||
nickname_str = ",".join(global_config.bot.alias_names)
|
nickname_str = ",".join(global_config.bot.alias_names)
|
||||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore
|
||||||
|
|
||||||
info_cache_block = self._build_info_cache_block()
|
info_cache_block = self._build_info_cache_block()
|
||||||
|
|
||||||
@@ -208,8 +211,7 @@ class RelationshipFetcher:
|
|||||||
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息:{content_json.get('none', '')}")
|
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息:{content_json.get('none', '')}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
info_type = content_json.get("info_type")
|
if info_type := content_json.get("info_type"):
|
||||||
if info_type:
|
|
||||||
# 记录信息获取请求
|
# 记录信息获取请求
|
||||||
self.info_fetching_cache.append(
|
self.info_fetching_cache.append(
|
||||||
{
|
{
|
||||||
@@ -287,7 +289,7 @@ class RelationshipFetcher:
|
|||||||
"ttl": 2,
|
"ttl": 2,
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"person_name": person_name,
|
"person_name": person_name,
|
||||||
"unknow": cached_info == "none",
|
"unknown": cached_info == "none",
|
||||||
}
|
}
|
||||||
logger.info(f"{self.log_prefix} 记得 {person_name} 的 {info_type}: {cached_info}")
|
logger.info(f"{self.log_prefix} 记得 {person_name} 的 {info_type}: {cached_info}")
|
||||||
return
|
return
|
||||||
@@ -321,7 +323,7 @@ class RelationshipFetcher:
|
|||||||
"ttl": 2,
|
"ttl": 2,
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"person_name": person_name,
|
"person_name": person_name,
|
||||||
"unknow": True,
|
"unknown": True,
|
||||||
}
|
}
|
||||||
logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
|
logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
|
||||||
await self._save_info_to_cache(person_id, info_type, "none")
|
await self._save_info_to_cache(person_id, info_type, "none")
|
||||||
@@ -353,15 +355,15 @@ class RelationshipFetcher:
|
|||||||
if person_id not in self.info_fetched_cache:
|
if person_id not in self.info_fetched_cache:
|
||||||
self.info_fetched_cache[person_id] = {}
|
self.info_fetched_cache[person_id] = {}
|
||||||
self.info_fetched_cache[person_id][info_type] = {
|
self.info_fetched_cache[person_id][info_type] = {
|
||||||
"info": "unknow" if is_unknown else info_content,
|
"info": "unknown" if is_unknown else info_content,
|
||||||
"ttl": 3,
|
"ttl": 3,
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"person_name": person_name,
|
"person_name": person_name,
|
||||||
"unknow": is_unknown,
|
"unknown": is_unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存到持久化缓存 (info_list)
|
# 保存到持久化缓存 (info_list)
|
||||||
await self._save_info_to_cache(person_id, info_type, info_content if not is_unknown else "none")
|
await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content)
|
||||||
|
|
||||||
if not is_unknown:
|
if not is_unknown:
|
||||||
logger.info(f"{self.log_prefix} 思考得到,{person_name} 的 {info_type}: {info_content}")
|
logger.info(f"{self.log_prefix} 思考得到,{person_name} 的 {info_type}: {info_content}")
|
||||||
@@ -393,7 +395,7 @@ class RelationshipFetcher:
|
|||||||
|
|
||||||
for info_type in self.info_fetched_cache[person_id]:
|
for info_type in self.info_fetched_cache[person_id]:
|
||||||
person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
|
person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
|
||||||
if not self.info_fetched_cache[person_id][info_type]["unknow"]:
|
if not self.info_fetched_cache[person_id][info_type]["unknown"]:
|
||||||
info_content = self.info_fetched_cache[person_id][info_type]["info"]
|
info_content = self.info_fetched_cache[person_id][info_type]["info"]
|
||||||
person_known_infos.append(f"[{info_type}]:{info_content}")
|
person_known_infos.append(f"[{info_type}]:{info_content}")
|
||||||
else:
|
else:
|
||||||
@@ -430,6 +432,7 @@ class RelationshipFetcher:
|
|||||||
return persons_infos_str
|
return persons_infos_str
|
||||||
|
|
||||||
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
|
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
|
||||||
|
# sourcery skip: use-next
|
||||||
"""将提取到的信息保存到 person_info 的 info_list 字段中
|
"""将提取到的信息保存到 person_info 的 info_list 字段中
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from .person_info import PersonInfoManager, get_person_info_manager
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -12,7 +12,7 @@ from difflib import SequenceMatcher
|
|||||||
import jieba
|
import jieba
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
logger = get_logger("relation")
|
logger = get_logger("relation")
|
||||||
|
|
||||||
@@ -28,8 +28,7 @@ class RelationshipManager:
|
|||||||
async def is_known_some_one(platform, user_id):
|
async def is_known_some_one(platform, user_id):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
is_known = await person_info_manager.is_person_known(platform, user_id)
|
return await person_info_manager.is_person_known(platform, user_id)
|
||||||
return is_known
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||||
@@ -110,7 +109,7 @@ class RelationshipManager:
|
|||||||
|
|
||||||
return relation_prompt
|
return relation_prompt
|
||||||
|
|
||||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages=None):
|
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
|
||||||
"""更新用户印象
|
"""更新用户印象
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -123,7 +122,7 @@ class RelationshipManager:
|
|||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
nickname = await person_info_manager.get_value(person_id, "nickname")
|
nickname = await person_info_manager.get_value(person_id, "nickname")
|
||||||
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
|
||||||
|
|
||||||
alias_str = ", ".join(global_config.bot.alias_names)
|
alias_str = ", ".join(global_config.bot.alias_names)
|
||||||
# personality_block =get_individuality().get_personality_prompt(x_person=2, level=2)
|
# personality_block =get_individuality().get_personality_prompt(x_person=2, level=2)
|
||||||
@@ -142,13 +141,13 @@ class RelationshipManager:
|
|||||||
# 遍历消息,构建映射
|
# 遍历消息,构建映射
|
||||||
for msg in user_messages:
|
for msg in user_messages:
|
||||||
await person_info_manager.get_or_create_person(
|
await person_info_manager.get_or_create_person(
|
||||||
platform=msg.get("chat_info_platform"),
|
platform=msg.get("chat_info_platform"), # type: ignore
|
||||||
user_id=msg.get("user_id"),
|
user_id=msg.get("user_id"), # type: ignore
|
||||||
nickname=msg.get("user_nickname"),
|
nickname=msg.get("user_nickname"), # type: ignore
|
||||||
user_cardname=msg.get("user_cardname"),
|
user_cardname=msg.get("user_cardname"), # type: ignore
|
||||||
)
|
)
|
||||||
replace_user_id = msg.get("user_id")
|
replace_user_id: str = msg.get("user_id") # type: ignore
|
||||||
replace_platform = msg.get("chat_info_platform")
|
replace_platform: str = msg.get("chat_info_platform") # type: ignore
|
||||||
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
|
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
|
||||||
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
||||||
|
|
||||||
@@ -354,8 +353,8 @@ class RelationshipManager:
|
|||||||
|
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
nickname = await person_info_manager.get_value(person_id, "nickname")
|
nickname = await person_info_manager.get_value(person_id, "nickname")
|
||||||
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
|
||||||
attitude = await person_info_manager.get_value(person_id, "attitude") or 50
|
attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
|
||||||
|
|
||||||
# 根据熟悉度,调整印象和简短印象的最大长度
|
# 根据熟悉度,调整印象和简短印象的最大长度
|
||||||
if know_times > 300:
|
if know_times > 300:
|
||||||
@@ -414,9 +413,7 @@ class RelationshipManager:
|
|||||||
if len(remaining_points) < 10:
|
if len(remaining_points) < 10:
|
||||||
# 如果还没达到30条,直接保留
|
# 如果还没达到30条,直接保留
|
||||||
remaining_points.append(point)
|
remaining_points.append(point)
|
||||||
else:
|
elif random.random() < keep_probability:
|
||||||
# 随机决定是否保留
|
|
||||||
if random.random() < keep_probability:
|
|
||||||
# 保留这个点,随机移除一个已保留的点
|
# 保留这个点,随机移除一个已保留的点
|
||||||
idx_to_remove = random.randrange(len(remaining_points))
|
idx_to_remove = random.randrange(len(remaining_points))
|
||||||
points_to_move.append(remaining_points[idx_to_remove])
|
points_to_move.append(remaining_points[idx_to_remove])
|
||||||
@@ -520,7 +517,7 @@ class RelationshipManager:
|
|||||||
new_attitude = int(relation_value_json.get("attitude", 50))
|
new_attitude = int(relation_value_json.get("attitude", 50))
|
||||||
|
|
||||||
# 获取当前的关系值
|
# 获取当前的关系值
|
||||||
old_attitude = await person_info_manager.get_value(person_id, "attitude") or 50
|
old_attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
|
||||||
|
|
||||||
# 更新熟悉度
|
# 更新熟悉度
|
||||||
if new_attitude > 25:
|
if new_attitude > 25:
|
||||||
|
|||||||
@@ -5,11 +5,11 @@ MaiBot 插件系统
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 导出主要的公共接口
|
# 导出主要的公共接口
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
|
from .base import (
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
BasePlugin,
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
BaseAction,
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
BaseCommand,
|
||||||
from src.plugin_system.base.component_types import (
|
ConfigField,
|
||||||
ComponentType,
|
ComponentType,
|
||||||
ActionActivationType,
|
ActionActivationType,
|
||||||
ChatMode,
|
ChatMode,
|
||||||
@@ -19,18 +19,22 @@ from src.plugin_system.base.component_types import (
|
|||||||
PluginInfo,
|
PluginInfo,
|
||||||
PythonDependency,
|
PythonDependency,
|
||||||
)
|
)
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
from .core.plugin_manager import (
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
plugin_manager,
|
||||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
component_registry,
|
||||||
|
dependency_manager,
|
||||||
|
)
|
||||||
|
|
||||||
# 导入工具模块
|
# 导入工具模块
|
||||||
from src.plugin_system.utils import (
|
from .utils import (
|
||||||
ManifestValidator,
|
ManifestValidator,
|
||||||
ManifestGenerator,
|
ManifestGenerator,
|
||||||
validate_plugin_manifest,
|
validate_plugin_manifest,
|
||||||
generate_plugin_manifest,
|
generate_plugin_manifest,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .apis.plugin_register_api import register_plugin
|
||||||
|
|
||||||
|
|
||||||
__version__ = "1.0.0"
|
__version__ = "1.0.0"
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from src.plugin_system.apis import (
|
|||||||
person_api,
|
person_api,
|
||||||
send_api,
|
send_api,
|
||||||
utils_api,
|
utils_api,
|
||||||
|
plugin_register_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
||||||
@@ -30,4 +31,5 @@ __all__ = [
|
|||||||
"person_api",
|
"person_api",
|
||||||
"send_api",
|
"send_api",
|
||||||
"utils_api",
|
"utils_api",
|
||||||
|
"plugin_register_api",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from src.chat.replyer.default_generator import DefaultReplyer
|
|||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.chat.utils.utils import process_llm_response
|
from src.chat.utils.utils import process_llm_response
|
||||||
from src.chat.replyer.replyer_manager import replyer_manager
|
from src.chat.replyer.replyer_manager import replyer_manager
|
||||||
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
|
||||||
logger = get_logger("generator_api")
|
logger = get_logger("generator_api")
|
||||||
|
|
||||||
@@ -64,12 +65,12 @@ def get_replyer(
|
|||||||
|
|
||||||
|
|
||||||
async def generate_reply(
|
async def generate_reply(
|
||||||
chat_stream=None,
|
chat_stream: Optional[ChatStream] = None,
|
||||||
chat_id: str = None,
|
chat_id: Optional[str] = None,
|
||||||
action_data: Dict[str, Any] = None,
|
action_data: Optional[Dict[str, Any]] = None,
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
extra_info: str = "",
|
extra_info: str = "",
|
||||||
available_actions: List[str] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
enable_tool: bool = False,
|
enable_tool: bool = False,
|
||||||
enable_splitter: bool = True,
|
enable_splitter: bool = True,
|
||||||
enable_chinese_typo: bool = True,
|
enable_chinese_typo: bool = True,
|
||||||
@@ -77,25 +78,25 @@ async def generate_reply(
|
|||||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||||
request_type: str = "",
|
request_type: str = "",
|
||||||
enable_timeout: bool = False,
|
enable_timeout: bool = False,
|
||||||
) -> Tuple[bool, List[Tuple[str, Any]]]:
|
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||||
"""生成回复
|
"""生成回复
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_stream: 聊天流对象(优先)
|
chat_stream: 聊天流对象(优先)
|
||||||
action_data: 动作数据
|
|
||||||
chat_id: 聊天ID(备用)
|
chat_id: 聊天ID(备用)
|
||||||
|
action_data: 动作数据
|
||||||
enable_splitter: 是否启用消息分割器
|
enable_splitter: 是否启用消息分割器
|
||||||
enable_chinese_typo: 是否启用错字生成器
|
enable_chinese_typo: 是否启用错字生成器
|
||||||
return_prompt: 是否返回提示词
|
return_prompt: 是否返回提示词
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取回复器
|
# 获取回复器
|
||||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
|
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||||
return False, []
|
return False, [], None
|
||||||
|
|
||||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||||
|
|
||||||
@@ -108,7 +109,8 @@ async def generate_reply(
|
|||||||
enable_timeout=enable_timeout,
|
enable_timeout=enable_timeout,
|
||||||
enable_tool=enable_tool,
|
enable_tool=enable_tool,
|
||||||
)
|
)
|
||||||
|
reply_set = []
|
||||||
|
if content:
|
||||||
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -117,19 +119,19 @@ async def generate_reply(
|
|||||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||||
|
|
||||||
if return_prompt:
|
if return_prompt:
|
||||||
return success, reply_set or [], prompt
|
return success, reply_set, prompt
|
||||||
else:
|
else:
|
||||||
return success, reply_set or []
|
return success, reply_set, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||||
return False, []
|
return False, [], None
|
||||||
|
|
||||||
|
|
||||||
async def rewrite_reply(
|
async def rewrite_reply(
|
||||||
chat_stream=None,
|
chat_stream: Optional[ChatStream] = None,
|
||||||
reply_data: Dict[str, Any] = None,
|
reply_data: Optional[Dict[str, Any]] = None,
|
||||||
chat_id: str = None,
|
chat_id: Optional[str] = None,
|
||||||
enable_splitter: bool = True,
|
enable_splitter: bool = True,
|
||||||
enable_chinese_typo: bool = True,
|
enable_chinese_typo: bool = True,
|
||||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||||
@@ -157,7 +159,8 @@ async def rewrite_reply(
|
|||||||
|
|
||||||
# 调用回复器重写回复
|
# 调用回复器重写回复
|
||||||
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
|
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
|
||||||
|
reply_set = []
|
||||||
|
if content:
|
||||||
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -165,7 +168,7 @@ async def rewrite_reply(
|
|||||||
else:
|
else:
|
||||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
logger.warning("[GeneratorAPI] 重写回复失败")
|
||||||
|
|
||||||
return success, reply_set or []
|
return success, reply_set
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
||||||
|
|||||||
@@ -56,7 +56,12 @@ def get_messages_by_time(
|
|||||||
|
|
||||||
|
|
||||||
def get_messages_by_time_in_chat(
|
def get_messages_by_time_in_chat(
|
||||||
chat_id: str, start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
chat_id: str,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
limit: int = 0,
|
||||||
|
limit_mode: str = "latest",
|
||||||
|
filter_mai: bool = False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中指定时间范围内的消息
|
获取指定聊天中指定时间范围内的消息
|
||||||
@@ -78,7 +83,12 @@ def get_messages_by_time_in_chat(
|
|||||||
|
|
||||||
|
|
||||||
def get_messages_by_time_in_chat_inclusive(
|
def get_messages_by_time_in_chat_inclusive(
|
||||||
chat_id: str, start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
chat_id: str,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
limit: int = 0,
|
||||||
|
limit_mode: str = "latest",
|
||||||
|
filter_mai: bool = False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
获取指定聊天中指定时间范围内的消息(包含边界)
|
||||||
@@ -95,7 +105,9 @@ def get_messages_by_time_in_chat_inclusive(
|
|||||||
消息列表
|
消息列表
|
||||||
"""
|
"""
|
||||||
if filter_mai:
|
if filter_mai:
|
||||||
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode))
|
return filter_mai_messages(
|
||||||
|
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode)
|
||||||
|
)
|
||||||
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode)
|
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode)
|
||||||
|
|
||||||
|
|
||||||
@@ -181,7 +193,9 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
|
|||||||
return get_raw_msg_before_timestamp(timestamp, limit)
|
return get_raw_msg_before_timestamp(timestamp, limit)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_before_time_in_chat(chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
|
def get_messages_before_time_in_chat(
|
||||||
|
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定聊天中指定时间戳之前的消息
|
获取指定聊天中指定时间戳之前的消息
|
||||||
|
|
||||||
@@ -342,10 +356,12 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
|
|||||||
"""
|
"""
|
||||||
return await get_person_id_list(messages)
|
return await get_person_id_list(messages)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 消息过滤函数
|
# 消息过滤函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从消息列表中移除麦麦的消息
|
从消息列表中移除麦麦的消息
|
||||||
|
|||||||
29
src/plugin_system/apis/plugin_register_api.py
Normal file
29
src/plugin_system/apis/plugin_register_api.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("plugin_register")
|
||||||
|
|
||||||
|
|
||||||
|
def register_plugin(cls):
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
|
|
||||||
|
"""插件注册装饰器
|
||||||
|
|
||||||
|
用法:
|
||||||
|
@register_plugin
|
||||||
|
class MyPlugin(BasePlugin):
|
||||||
|
plugin_name = "my_plugin"
|
||||||
|
plugin_description = "我的插件"
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
if not issubclass(cls, BasePlugin):
|
||||||
|
logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类")
|
||||||
|
return cls
|
||||||
|
|
||||||
|
# 只是注册插件类,不立即实例化
|
||||||
|
# 插件管理器会负责实例化和注册
|
||||||
|
plugin_name = cls.plugin_name or cls.__name__
|
||||||
|
plugin_manager.plugin_classes[plugin_name] = cls
|
||||||
|
logger.debug(f"插件类已注册: {plugin_name}")
|
||||||
|
|
||||||
|
return cls
|
||||||
@@ -4,10 +4,10 @@
|
|||||||
提供插件开发的基础类和类型定义
|
提供插件开发的基础类和类型定义
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
|
from .base_plugin import BasePlugin
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from .base_action import BaseAction
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
from .base_command import BaseCommand
|
||||||
from src.plugin_system.base.component_types import (
|
from .component_types import (
|
||||||
ComponentType,
|
ComponentType,
|
||||||
ActionActivationType,
|
ActionActivationType,
|
||||||
ChatMode,
|
ChatMode,
|
||||||
@@ -15,13 +15,14 @@ from src.plugin_system.base.component_types import (
|
|||||||
ActionInfo,
|
ActionInfo,
|
||||||
CommandInfo,
|
CommandInfo,
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
|
PythonDependency,
|
||||||
)
|
)
|
||||||
|
from .config_types import ConfigField
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BasePlugin",
|
"BasePlugin",
|
||||||
"BaseAction",
|
"BaseAction",
|
||||||
"BaseCommand",
|
"BaseCommand",
|
||||||
"register_plugin",
|
|
||||||
"ComponentType",
|
"ComponentType",
|
||||||
"ActionActivationType",
|
"ActionActivationType",
|
||||||
"ChatMode",
|
"ChatMode",
|
||||||
@@ -29,4 +30,6 @@ __all__ = [
|
|||||||
"ActionInfo",
|
"ActionInfo",
|
||||||
"CommandInfo",
|
"CommandInfo",
|
||||||
"PluginInfo",
|
"PluginInfo",
|
||||||
|
"PythonDependency",
|
||||||
|
"ConfigField",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
|
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
|
||||||
from src.plugin_system.apis import send_api, database_api, message_api
|
from src.plugin_system.apis import send_api, database_api, message_api
|
||||||
import time
|
import time
|
||||||
@@ -31,9 +32,9 @@ class BaseAction(ABC):
|
|||||||
reasoning: str,
|
reasoning: str,
|
||||||
cycle_timers: dict,
|
cycle_timers: dict,
|
||||||
thinking_id: str,
|
thinking_id: str,
|
||||||
chat_stream=None,
|
chat_stream: ChatStream,
|
||||||
log_prefix: str = "",
|
log_prefix: str = "",
|
||||||
plugin_config: dict = None,
|
plugin_config: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""初始化Action组件
|
"""初始化Action组件
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class BaseCommand(ABC):
|
|||||||
command_examples: List[str] = []
|
command_examples: List[str] = []
|
||||||
intercept_message: bool = True # 默认拦截消息,不继续处理
|
intercept_message: bool = True # 默认拦截消息,不继续处理
|
||||||
|
|
||||||
def __init__(self, message: MessageRecv, plugin_config: dict = None):
|
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||||
"""初始化Command组件
|
"""初始化Command组件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Type, Optional, Any, Union
|
from typing import Dict, List, Type, Any, Union
|
||||||
import os
|
import os
|
||||||
import inspect
|
import inspect
|
||||||
import toml
|
import toml
|
||||||
import json
|
import json
|
||||||
|
import shutil
|
||||||
|
import datetime
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import (
|
from src.plugin_system.base.component_types import (
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
@@ -11,13 +14,10 @@ from src.plugin_system.base.component_types import (
|
|||||||
PythonDependency,
|
PythonDependency,
|
||||||
)
|
)
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.utils.manifest_utils import ManifestValidator
|
||||||
|
|
||||||
logger = get_logger("base_plugin")
|
logger = get_logger("base_plugin")
|
||||||
|
|
||||||
# 全局插件类注册表
|
|
||||||
_plugin_classes: Dict[str, Type["BasePlugin"]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin(ABC):
|
class BasePlugin(ABC):
|
||||||
"""插件基类
|
"""插件基类
|
||||||
@@ -29,21 +29,44 @@ class BasePlugin(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 插件基本信息(子类必须定义)
|
# 插件基本信息(子类必须定义)
|
||||||
plugin_name: str = "" # 插件内部标识符(如 "doubao_pic_plugin")
|
@property
|
||||||
enable_plugin: bool = False # 是否启用插件
|
@abstractmethod
|
||||||
dependencies: List[str] = [] # 依赖的其他插件
|
def plugin_name(self) -> str:
|
||||||
python_dependencies: List[PythonDependency] = [] # Python包依赖
|
return "" # 插件内部标识符(如 "hello_world_plugin")
|
||||||
config_file_name: Optional[str] = None # 配置文件名
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def enable_plugin(self) -> bool:
|
||||||
|
return True # 是否启用插件
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dependencies(self) -> List[str]:
|
||||||
|
return [] # 依赖的其他插件
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def python_dependencies(self) -> List[PythonDependency]:
|
||||||
|
return [] # Python包依赖
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def config_file_name(self) -> str:
|
||||||
|
return "" # 配置文件名
|
||||||
|
|
||||||
# manifest文件相关
|
# manifest文件相关
|
||||||
manifest_file_name: str = "_manifest.json" # manifest文件名
|
manifest_file_name: str = "_manifest.json" # manifest文件名
|
||||||
manifest_data: Dict[str, Any] = {} # manifest数据
|
manifest_data: Dict[str, Any] = {} # manifest数据
|
||||||
|
|
||||||
# 配置定义
|
# 配置定义
|
||||||
config_schema: Dict[str, Union[Dict[str, ConfigField], str]] = {}
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
|
||||||
|
return {}
|
||||||
|
|
||||||
config_section_descriptions: Dict[str, str] = {}
|
config_section_descriptions: Dict[str, str] = {}
|
||||||
|
|
||||||
def __init__(self, plugin_dir: str = None):
|
def __init__(self, plugin_dir: str):
|
||||||
"""初始化插件
|
"""初始化插件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -70,7 +93,8 @@ class BasePlugin(ABC):
|
|||||||
|
|
||||||
# 创建插件信息对象
|
# 创建插件信息对象
|
||||||
self.plugin_info = PluginInfo(
|
self.plugin_info = PluginInfo(
|
||||||
name=self.display_name, # 使用显示名称
|
name=self.plugin_name,
|
||||||
|
display_name=self.display_name,
|
||||||
description=self.plugin_description,
|
description=self.plugin_description,
|
||||||
version=self.plugin_version,
|
version=self.plugin_version,
|
||||||
author=self.plugin_author,
|
author=self.plugin_author,
|
||||||
@@ -103,7 +127,7 @@ class BasePlugin(ABC):
|
|||||||
if not self.get_manifest_info("description"):
|
if not self.get_manifest_info("description"):
|
||||||
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段")
|
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段")
|
||||||
|
|
||||||
def _load_manifest(self):
|
def _load_manifest(self): # sourcery skip: raise-from-previous-error
|
||||||
"""加载manifest文件(强制要求)"""
|
"""加载manifest文件(强制要求)"""
|
||||||
if not self.plugin_dir:
|
if not self.plugin_dir:
|
||||||
raise ValueError(f"{self.log_prefix} 没有插件目录路径,无法加载manifest")
|
raise ValueError(f"{self.log_prefix} 没有插件目录路径,无法加载manifest")
|
||||||
@@ -124,9 +148,6 @@ class BasePlugin(ABC):
|
|||||||
# 验证manifest格式
|
# 验证manifest格式
|
||||||
self._validate_manifest()
|
self._validate_manifest()
|
||||||
|
|
||||||
# 从manifest覆盖插件基本信息(如果插件类中未定义)
|
|
||||||
self._apply_manifest_overrides()
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
|
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@@ -136,15 +157,6 @@ class BasePlugin(ABC):
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise IOError(error_msg) # noqa
|
raise IOError(error_msg) # noqa
|
||||||
|
|
||||||
def _apply_manifest_overrides(self):
|
|
||||||
"""从manifest文件覆盖插件信息(现在只处理内部标识符的fallback)"""
|
|
||||||
if not self.manifest_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 只有当插件类中没有定义plugin_name时,才从manifest中获取作为fallback
|
|
||||||
if not self.plugin_name:
|
|
||||||
self.plugin_name = self.manifest_data.get("name", "").replace(" ", "_").lower()
|
|
||||||
|
|
||||||
def _get_author_name(self) -> str:
|
def _get_author_name(self) -> str:
|
||||||
"""从manifest获取作者名称"""
|
"""从manifest获取作者名称"""
|
||||||
author_info = self.get_manifest_info("author", {})
|
author_info = self.get_manifest_info("author", {})
|
||||||
@@ -156,10 +168,7 @@ class BasePlugin(ABC):
|
|||||||
def _validate_manifest(self):
|
def _validate_manifest(self):
|
||||||
"""验证manifest文件格式(使用强化的验证器)"""
|
"""验证manifest文件格式(使用强化的验证器)"""
|
||||||
if not self.manifest_data:
|
if not self.manifest_data:
|
||||||
return
|
raise ValueError(f"{self.log_prefix} manifest数据为空,验证失败")
|
||||||
|
|
||||||
# 导入验证器
|
|
||||||
from src.plugin_system.utils.manifest_utils import ManifestValidator
|
|
||||||
|
|
||||||
validator = ManifestValidator()
|
validator = ManifestValidator()
|
||||||
is_valid = validator.validate_manifest(self.manifest_data)
|
is_valid = validator.validate_manifest(self.manifest_data)
|
||||||
@@ -176,36 +185,6 @@ class BasePlugin(ABC):
|
|||||||
error_msg += f": {'; '.join(validator.validation_errors)}"
|
error_msg += f": {'; '.join(validator.validation_errors)}"
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
def _generate_default_manifest(self, manifest_path: str):
|
|
||||||
"""生成默认的manifest文件"""
|
|
||||||
if not self.plugin_name:
|
|
||||||
logger.debug(f"{self.log_prefix} 插件名称未定义,无法生成默认manifest")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 从plugin_name生成友好的显示名称
|
|
||||||
display_name = self.plugin_name.replace("_", " ").title()
|
|
||||||
|
|
||||||
default_manifest = {
|
|
||||||
"manifest_version": 1,
|
|
||||||
"name": display_name,
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "插件描述",
|
|
||||||
"author": {"name": "Unknown", "url": ""},
|
|
||||||
"license": "MIT",
|
|
||||||
"host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
|
|
||||||
"keywords": [],
|
|
||||||
"categories": [],
|
|
||||||
"default_locale": "zh-CN",
|
|
||||||
"locales_path": "_locales",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(default_manifest, f, ensure_ascii=False, indent=2)
|
|
||||||
logger.info(f"{self.log_prefix} 已生成默认manifest文件: {manifest_path}")
|
|
||||||
except IOError as e:
|
|
||||||
logger.error(f"{self.log_prefix} 保存默认manifest文件失败: {e}")
|
|
||||||
|
|
||||||
def get_manifest_info(self, key: str, default: Any = None) -> Any:
|
def get_manifest_info(self, key: str, default: Any = None) -> Any:
|
||||||
"""获取manifest信息
|
"""获取manifest信息
|
||||||
|
|
||||||
@@ -304,9 +283,6 @@ class BasePlugin(ABC):
|
|||||||
|
|
||||||
def _backup_config_file(self, config_file_path: str) -> str:
|
def _backup_config_file(self, config_file_path: str) -> str:
|
||||||
"""备份配置文件"""
|
"""备份配置文件"""
|
||||||
import shutil
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
backup_path = f"{config_file_path}.backup_{timestamp}"
|
backup_path = f"{config_file_path}.backup_{timestamp}"
|
||||||
|
|
||||||
@@ -377,13 +353,14 @@ class BasePlugin(ABC):
|
|||||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
|
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
|
||||||
|
|
||||||
# 检查旧配置中是否有新配置没有的节
|
# 检查旧配置中是否有新配置没有的节
|
||||||
for section_name in old_config.keys():
|
for section_name in old_config:
|
||||||
if section_name not in migrated_config:
|
if section_name not in migrated_config:
|
||||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
|
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
|
||||||
|
|
||||||
return migrated_config
|
return migrated_config
|
||||||
|
|
||||||
def _generate_config_from_schema(self) -> Dict[str, Any]:
|
def _generate_config_from_schema(self) -> Dict[str, Any]:
|
||||||
|
# sourcery skip: dict-comprehension
|
||||||
"""根据schema生成配置数据结构(不写入文件)"""
|
"""根据schema生成配置数据结构(不写入文件)"""
|
||||||
if not self.config_schema:
|
if not self.config_schema:
|
||||||
return {}
|
return {}
|
||||||
@@ -473,7 +450,7 @@ class BasePlugin(ABC):
|
|||||||
except IOError as e:
|
except IOError as e:
|
||||||
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
|
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
|
||||||
|
|
||||||
def _load_plugin_config(self):
|
def _load_plugin_config(self): # sourcery skip: extract-method
|
||||||
"""加载插件配置文件,支持版本检查和自动迁移"""
|
"""加载插件配置文件,支持版本检查和自动迁移"""
|
||||||
if not self.config_file_name:
|
if not self.config_file_name:
|
||||||
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
|
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
|
||||||
@@ -549,7 +526,7 @@ class BasePlugin(ABC):
|
|||||||
|
|
||||||
# 从配置中更新 enable_plugin
|
# 从配置中更新 enable_plugin
|
||||||
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
||||||
self.enable_plugin = self.config["plugin"]["enabled"]
|
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
|
||||||
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
|
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
||||||
@@ -568,9 +545,7 @@ class BasePlugin(ABC):
|
|||||||
|
|
||||||
def register_plugin(self) -> bool:
|
def register_plugin(self) -> bool:
|
||||||
"""注册插件及其所有组件"""
|
"""注册插件及其所有组件"""
|
||||||
if not self.enable_plugin:
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
logger.info(f"{self.log_prefix} 插件已禁用,跳过注册")
|
|
||||||
return False
|
|
||||||
|
|
||||||
components = self.get_plugin_components()
|
components = self.get_plugin_components()
|
||||||
|
|
||||||
@@ -601,6 +576,8 @@ class BasePlugin(ABC):
|
|||||||
|
|
||||||
def _check_dependencies(self) -> bool:
|
def _check_dependencies(self) -> bool:
|
||||||
"""检查插件依赖"""
|
"""检查插件依赖"""
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
if not self.dependencies:
|
if not self.dependencies:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -632,52 +609,3 @@ class BasePlugin(ABC):
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
|
||||||
def register_plugin(cls):
|
|
||||||
"""插件注册装饰器
|
|
||||||
|
|
||||||
用法:
|
|
||||||
@register_plugin
|
|
||||||
class MyPlugin(BasePlugin):
|
|
||||||
plugin_name = "my_plugin"
|
|
||||||
plugin_description = "我的插件"
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
if not issubclass(cls, BasePlugin):
|
|
||||||
logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类")
|
|
||||||
return cls
|
|
||||||
|
|
||||||
# 只是注册插件类,不立即实例化
|
|
||||||
# 插件管理器会负责实例化和注册
|
|
||||||
plugin_name = cls.plugin_name or cls.__name__
|
|
||||||
_plugin_classes[plugin_name] = cls
|
|
||||||
logger.debug(f"插件类已注册: {plugin_name}")
|
|
||||||
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_plugin_classes() -> Dict[str, Type["BasePlugin"]]:
|
|
||||||
"""获取所有已注册的插件类"""
|
|
||||||
return _plugin_classes.copy()
|
|
||||||
|
|
||||||
|
|
||||||
def instantiate_and_register_plugin(plugin_class: Type["BasePlugin"], plugin_dir: str = None) -> bool:
|
|
||||||
"""实例化并注册插件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_class: 插件类
|
|
||||||
plugin_dir: 插件目录路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
plugin_instance = plugin_class(plugin_dir=plugin_dir)
|
|
||||||
return plugin_instance.register_plugin()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"注册插件 {plugin_class.__name__} 时出错: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ class ActionActivationType(Enum):
|
|||||||
RANDOM = "random" # 随机启用action到planner
|
RANDOM = "random" # 随机启用action到planner
|
||||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
# 聊天模式枚举
|
# 聊天模式枚举
|
||||||
class ChatMode(Enum):
|
class ChatMode(Enum):
|
||||||
@@ -32,6 +35,9 @@ class ChatMode(Enum):
|
|||||||
NORMAL = "normal" # Normal聊天模式
|
NORMAL = "normal" # Normal聊天模式
|
||||||
ALL = "all" # 所有聊天模式
|
ALL = "all" # 所有聊天模式
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PythonDependency:
|
class PythonDependency:
|
||||||
@@ -60,7 +66,7 @@ class ComponentInfo:
|
|||||||
|
|
||||||
name: str # 组件名称
|
name: str # 组件名称
|
||||||
component_type: ComponentType # 组件类型
|
component_type: ComponentType # 组件类型
|
||||||
description: str # 组件描述
|
description: str = "" # 组件描述
|
||||||
enabled: bool = True # 是否启用
|
enabled: bool = True # 是否启用
|
||||||
plugin_name: str = "" # 所属插件名称
|
plugin_name: str = "" # 所属插件名称
|
||||||
is_built_in: bool = False # 是否为内置组件
|
is_built_in: bool = False # 是否为内置组件
|
||||||
@@ -75,17 +81,21 @@ class ComponentInfo:
|
|||||||
class ActionInfo(ComponentInfo):
|
class ActionInfo(ComponentInfo):
|
||||||
"""动作组件信息"""
|
"""动作组件信息"""
|
||||||
|
|
||||||
|
action_parameters: Dict[str, str] = field(
|
||||||
|
default_factory=dict
|
||||||
|
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
||||||
|
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||||
|
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||||
|
# 激活类型相关
|
||||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||||
random_activation_probability: float = 0.0
|
random_activation_probability: float = 0.0
|
||||||
llm_judge_prompt: str = ""
|
llm_judge_prompt: str = ""
|
||||||
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
||||||
keyword_case_sensitive: bool = False
|
keyword_case_sensitive: bool = False
|
||||||
|
# 模式和并行设置
|
||||||
mode_enable: ChatMode = ChatMode.ALL
|
mode_enable: ChatMode = ChatMode.ALL
|
||||||
parallel_action: bool = False
|
parallel_action: bool = False
|
||||||
action_parameters: Dict[str, Any] = field(default_factory=dict) # 动作参数
|
|
||||||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
|
||||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
@@ -120,6 +130,7 @@ class CommandInfo(ComponentInfo):
|
|||||||
class PluginInfo:
|
class PluginInfo:
|
||||||
"""插件信息"""
|
"""插件信息"""
|
||||||
|
|
||||||
|
display_name: str # 插件显示名称
|
||||||
name: str # 插件名称
|
name: str # 插件名称
|
||||||
description: str # 插件描述
|
description: str # 插件描述
|
||||||
version: str = "1.0.0" # 插件版本
|
version: str = "1.0.0" # 插件版本
|
||||||
|
|||||||
@@ -6,8 +6,10 @@
|
|||||||
|
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plugin_manager",
|
"plugin_manager",
|
||||||
"component_registry",
|
"component_registry",
|
||||||
|
"dependency_manager",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional, Any, Pattern, Union
|
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||||
import re
|
import re
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import (
|
from src.plugin_system.base.component_types import (
|
||||||
@@ -9,8 +9,8 @@ from src.plugin_system.base.component_types import (
|
|||||||
ComponentType,
|
ComponentType,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..base.base_command import BaseCommand
|
from src.plugin_system.base.base_command import BaseCommand
|
||||||
from ..base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
|
|
||||||
logger = get_logger("component_registry")
|
logger = get_logger("component_registry")
|
||||||
|
|
||||||
@@ -28,25 +28,25 @@ class ComponentRegistry:
|
|||||||
ComponentType.ACTION: {},
|
ComponentType.ACTION: {},
|
||||||
ComponentType.COMMAND: {},
|
ComponentType.COMMAND: {},
|
||||||
}
|
}
|
||||||
self._component_classes: Dict[str, Union[BaseCommand, BaseAction]] = {} # 组件名 -> 组件类
|
self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类
|
||||||
|
|
||||||
# 插件注册表
|
# 插件注册表
|
||||||
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
||||||
|
|
||||||
# Action特定注册表
|
# Action特定注册表
|
||||||
self._action_registry: Dict[str, BaseAction] = {} # action名 -> action类
|
self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类
|
||||||
self._default_actions: Dict[str, str] = {} # 启用的action名 -> 描述
|
self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集,即启用的Action集,用于重置ActionManager状态
|
||||||
|
|
||||||
# Command特定注册表
|
# Command特定注册表
|
||||||
self._command_registry: Dict[str, BaseCommand] = {} # command名 -> command类
|
self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类
|
||||||
self._command_patterns: Dict[Pattern, BaseCommand] = {} # 编译后的正则 -> command类
|
self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command类
|
||||||
|
|
||||||
logger.info("组件注册中心初始化完成")
|
logger.info("组件注册中心初始化完成")
|
||||||
|
|
||||||
# === 通用组件注册方法 ===
|
# === 通用组件注册方法 ===
|
||||||
|
|
||||||
def register_component(
|
def register_component(
|
||||||
self, component_info: ComponentInfo, component_class: Union[BaseCommand, BaseAction]
|
self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""注册组件
|
"""注册组件
|
||||||
|
|
||||||
@@ -88,9 +88,9 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# 根据组件类型进行特定注册(使用原始名称)
|
# 根据组件类型进行特定注册(使用原始名称)
|
||||||
if component_type == ComponentType.ACTION:
|
if component_type == ComponentType.ACTION:
|
||||||
self._register_action_component(component_info, component_class)
|
self._register_action_component(component_info, component_class) # type: ignore
|
||||||
elif component_type == ComponentType.COMMAND:
|
elif component_type == ComponentType.COMMAND:
|
||||||
self._register_command_component(component_info, component_class)
|
self._register_command_component(component_info, component_class) # type: ignore
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
|
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
|
||||||
@@ -98,16 +98,18 @@ class ComponentRegistry:
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _register_action_component(self, action_info: ActionInfo, action_class: BaseAction):
|
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]):
|
||||||
|
# -------------------------------- NEED REFACTORING --------------------------------
|
||||||
|
# -------------------------------- LOGIC ERROR -------------------------------------
|
||||||
"""注册Action组件到Action特定注册表"""
|
"""注册Action组件到Action特定注册表"""
|
||||||
action_name = action_info.name
|
action_name = action_info.name
|
||||||
self._action_registry[action_name] = action_class
|
self._action_registry[action_name] = action_class
|
||||||
|
|
||||||
# 如果启用,添加到默认动作集
|
# 如果启用,添加到默认动作集
|
||||||
if action_info.enabled:
|
if action_info.enabled:
|
||||||
self._default_actions[action_name] = action_info.description
|
self._default_actions[action_name] = action_info
|
||||||
|
|
||||||
def _register_command_component(self, command_info: CommandInfo, command_class: BaseCommand):
|
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]):
|
||||||
"""注册Command组件到Command特定注册表"""
|
"""注册Command组件到Command特定注册表"""
|
||||||
command_name = command_info.name
|
command_name = command_info.name
|
||||||
self._command_registry[command_name] = command_class
|
self._command_registry[command_name] = command_class
|
||||||
@@ -119,7 +121,7 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# === 组件查询方法 ===
|
# === 组件查询方法 ===
|
||||||
|
|
||||||
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]:
|
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore
|
||||||
# sourcery skip: class-extract-method
|
# sourcery skip: class-extract-method
|
||||||
"""获取组件信息,支持自动命名空间解析
|
"""获取组件信息,支持自动命名空间解析
|
||||||
|
|
||||||
@@ -167,8 +169,10 @@ class ComponentRegistry:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_component_class(
|
def get_component_class(
|
||||||
self, component_name: str, component_type: ComponentType = None
|
self,
|
||||||
) -> Optional[Union[BaseCommand, BaseAction]]:
|
component_name: str,
|
||||||
|
component_type: ComponentType = None, # type: ignore
|
||||||
|
) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]:
|
||||||
"""获取组件类,支持自动命名空间解析
|
"""获取组件类,支持自动命名空间解析
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -227,26 +231,26 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# === Action特定查询方法 ===
|
# === Action特定查询方法 ===
|
||||||
|
|
||||||
def get_action_registry(self) -> Dict[str, BaseAction]:
|
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
||||||
"""获取Action注册表(用于兼容现有系统)"""
|
"""获取Action注册表(用于兼容现有系统)"""
|
||||||
return self._action_registry.copy()
|
return self._action_registry.copy()
|
||||||
|
|
||||||
def get_default_actions(self) -> Dict[str, str]:
|
|
||||||
"""获取默认启用的Action列表(用于兼容现有系统)"""
|
|
||||||
return self._default_actions.copy()
|
|
||||||
|
|
||||||
def get_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
def get_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||||
"""获取Action信息"""
|
"""获取Action信息"""
|
||||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||||
return info if isinstance(info, ActionInfo) else None
|
return info if isinstance(info, ActionInfo) else None
|
||||||
|
|
||||||
|
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||||
|
"""获取默认动作集"""
|
||||||
|
return self._default_actions.copy()
|
||||||
|
|
||||||
# === Command特定查询方法 ===
|
# === Command特定查询方法 ===
|
||||||
|
|
||||||
def get_command_registry(self) -> Dict[str, BaseCommand]:
|
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
|
||||||
"""获取Command注册表(用于兼容现有系统)"""
|
"""获取Command注册表(用于兼容现有系统)"""
|
||||||
return self._command_registry.copy()
|
return self._command_registry.copy()
|
||||||
|
|
||||||
def get_command_patterns(self) -> Dict[Pattern, BaseCommand]:
|
def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]:
|
||||||
"""获取Command模式注册表(用于兼容现有系统)"""
|
"""获取Command模式注册表(用于兼容现有系统)"""
|
||||||
return self._command_patterns.copy()
|
return self._command_patterns.copy()
|
||||||
|
|
||||||
@@ -255,7 +259,7 @@ class ComponentRegistry:
|
|||||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||||
return info if isinstance(info, CommandInfo) else None
|
return info if isinstance(info, CommandInfo) else None
|
||||||
|
|
||||||
def find_command_by_text(self, text: str) -> Optional[tuple[BaseCommand, dict, bool, str]]:
|
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
|
||||||
# sourcery skip: use-named-expression, use-next
|
# sourcery skip: use-named-expression, use-next
|
||||||
"""根据文本查找匹配的命令
|
"""根据文本查找匹配的命令
|
||||||
|
|
||||||
@@ -263,7 +267,7 @@ class ComponentRegistry:
|
|||||||
text: 输入文本
|
text: 输入文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[tuple[BaseCommand, dict, bool, str]]: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for pattern, command_class in self._command_patterns.items():
|
for pattern, command_class in self._command_patterns.items():
|
||||||
@@ -343,6 +347,8 @@ class ComponentRegistry:
|
|||||||
# === 状态管理方法 ===
|
# === 状态管理方法 ===
|
||||||
|
|
||||||
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||||
|
# -------------------------------- NEED REFACTORING --------------------------------
|
||||||
|
# -------------------------------- LOGIC ERROR -------------------------------------
|
||||||
"""启用组件,支持命名空间解析"""
|
"""启用组件,支持命名空间解析"""
|
||||||
# 首先尝试找到正确的命名空间化名称
|
# 首先尝试找到正确的命名空间化名称
|
||||||
component_info = self.get_component_info(component_name, component_type)
|
component_info = self.get_component_info(component_name, component_type)
|
||||||
@@ -364,13 +370,16 @@ class ComponentRegistry:
|
|||||||
if namespaced_name in self._components:
|
if namespaced_name in self._components:
|
||||||
self._components[namespaced_name].enabled = True
|
self._components[namespaced_name].enabled = True
|
||||||
# 如果是Action,更新默认动作集
|
# 如果是Action,更新默认动作集
|
||||||
if isinstance(component_info, ActionInfo):
|
# ---- HERE ----
|
||||||
self._default_actions[component_name] = component_info.description
|
# if isinstance(component_info, ActionInfo):
|
||||||
|
# self._action_descriptions[component_name] = component_info.description
|
||||||
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
|
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||||
|
# -------------------------------- NEED REFACTORING --------------------------------
|
||||||
|
# -------------------------------- LOGIC ERROR -------------------------------------
|
||||||
"""禁用组件,支持命名空间解析"""
|
"""禁用组件,支持命名空间解析"""
|
||||||
# 首先尝试找到正确的命名空间化名称
|
# 首先尝试找到正确的命名空间化名称
|
||||||
component_info = self.get_component_info(component_name, component_type)
|
component_info = self.get_component_info(component_name, component_type)
|
||||||
@@ -392,8 +401,9 @@ class ComponentRegistry:
|
|||||||
if namespaced_name in self._components:
|
if namespaced_name in self._components:
|
||||||
self._components[namespaced_name].enabled = False
|
self._components[namespaced_name].enabled = False
|
||||||
# 如果是Action,从默认动作集中移除
|
# 如果是Action,从默认动作集中移除
|
||||||
if component_name in self._default_actions:
|
# ---- HERE ----
|
||||||
del self._default_actions[component_name]
|
# if component_name in self._action_descriptions:
|
||||||
|
# del self._action_descriptions[component_name]
|
||||||
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
|
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -37,16 +37,14 @@ class DependencyManager:
|
|||||||
missing_optional = []
|
missing_optional = []
|
||||||
|
|
||||||
for dep in dependencies:
|
for dep in dependencies:
|
||||||
if not self._is_package_available(dep.package_name):
|
if self._is_package_available(dep.package_name):
|
||||||
if dep.optional:
|
logger.debug(f"依赖包已存在: {dep.package_name}")
|
||||||
|
elif dep.optional:
|
||||||
missing_optional.append(dep)
|
missing_optional.append(dep)
|
||||||
logger.warning(f"可选依赖包缺失: {dep.package_name} - {dep.description}")
|
logger.warning(f"可选依赖包缺失: {dep.package_name} - {dep.description}")
|
||||||
else:
|
else:
|
||||||
missing_required.append(dep)
|
missing_required.append(dep)
|
||||||
logger.error(f"必需依赖包缺失: {dep.package_name} - {dep.description}")
|
logger.error(f"必需依赖包缺失: {dep.package_name} - {dep.description}")
|
||||||
else:
|
|
||||||
logger.debug(f"依赖包已存在: {dep.package_name}")
|
|
||||||
|
|
||||||
return missing_required, missing_optional
|
return missing_required, missing_optional
|
||||||
|
|
||||||
def _is_package_available(self, package_name: str) -> bool:
|
def _is_package_available(self, package_name: str) -> bool:
|
||||||
|
|||||||
@@ -1,64 +1,71 @@
|
|||||||
from typing import Dict, List, Optional, Any, TYPE_CHECKING, Tuple
|
from typing import Dict, List, Optional, Callable, Tuple, Type, Any
|
||||||
import os
|
import os
|
||||||
import importlib
|
from importlib.util import spec_from_file_location, module_from_spec
|
||||||
import importlib.util
|
from inspect import getmodule
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_system.events.events import EventType
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||||
from src.plugin_system.base.component_types import ComponentType, PluginInfo
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
|
from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency
|
||||||
|
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||||
|
|
||||||
logger = get_logger("plugin_manager")
|
logger = get_logger("plugin_manager")
|
||||||
|
|
||||||
|
|
||||||
class PluginManager:
|
class PluginManager:
|
||||||
"""插件管理器
|
"""
|
||||||
|
插件管理器类
|
||||||
|
|
||||||
负责加载、初始化和管理所有插件及其组件
|
负责加载,重载和卸载插件,同时管理插件的所有组件
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.plugin_directories: List[str] = []
|
self.plugin_directories: List[str] = [] # 插件根目录列表
|
||||||
self.loaded_plugins: Dict[str, "BasePlugin"] = {}
|
self.plugin_classes: Dict[str, Type[BasePlugin]] = {} # 全局插件类注册表,插件名 -> 插件类
|
||||||
self.failed_plugins: Dict[str, str] = {}
|
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
||||||
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射
|
|
||||||
|
self.loaded_plugins: Dict[str, BasePlugin] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
||||||
|
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件类及其错误信息,插件名 -> 错误信息
|
||||||
|
|
||||||
|
self.events_subscriptions: Dict[EventType, List[Callable]] = {}
|
||||||
|
|
||||||
# 确保插件目录存在
|
# 确保插件目录存在
|
||||||
self._ensure_plugin_directories()
|
self._ensure_plugin_directories()
|
||||||
logger.info("插件管理器初始化完成")
|
logger.info("插件管理器初始化完成")
|
||||||
|
|
||||||
def _ensure_plugin_directories(self):
|
def _ensure_plugin_directories(self) -> None:
|
||||||
"""确保所有插件目录存在,如果不存在则创建"""
|
"""确保所有插件根目录存在,如果不存在则创建"""
|
||||||
default_directories = ["src/plugins/built_in", "plugins"]
|
default_directories = ["src/plugins/built_in", "plugins"]
|
||||||
|
|
||||||
for directory in default_directories:
|
for directory in default_directories:
|
||||||
if not os.path.exists(directory):
|
if not os.path.exists(directory):
|
||||||
os.makedirs(directory, exist_ok=True)
|
os.makedirs(directory, exist_ok=True)
|
||||||
logger.info(f"创建插件目录: {directory}")
|
logger.info(f"创建插件根目录: {directory}")
|
||||||
if directory not in self.plugin_directories:
|
if directory not in self.plugin_directories:
|
||||||
self.plugin_directories.append(directory)
|
self.plugin_directories.append(directory)
|
||||||
logger.debug(f"已添加插件目录: {directory}")
|
logger.debug(f"已添加插件根目录: {directory}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"插件不可重复加载: {directory}")
|
logger.warning(f"根目录不可重复加载: {directory}")
|
||||||
|
|
||||||
def add_plugin_directory(self, directory: str):
|
def add_plugin_directory(self, directory: str) -> bool:
|
||||||
"""添加插件目录"""
|
"""添加插件目录"""
|
||||||
if os.path.exists(directory):
|
if os.path.exists(directory):
|
||||||
if directory not in self.plugin_directories:
|
if directory not in self.plugin_directories:
|
||||||
self.plugin_directories.append(directory)
|
self.plugin_directories.append(directory)
|
||||||
logger.debug(f"已添加插件目录: {directory}")
|
logger.debug(f"已添加插件目录: {directory}")
|
||||||
|
return True
|
||||||
else:
|
else:
|
||||||
logger.warning(f"插件不可重复加载: {directory}")
|
logger.warning(f"插件不可重复加载: {directory}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"插件目录不存在: {directory}")
|
logger.warning(f"插件目录不存在: {directory}")
|
||||||
|
return False
|
||||||
|
|
||||||
def load_all_plugins(self) -> tuple[int, int]:
|
def load_all_plugins(self) -> Tuple[int, int]:
|
||||||
"""加载所有插件目录中的插件
|
"""加载所有插件
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[int, int]: (插件数量, 组件数量)
|
tuple[int, int]: (插件数量, 组件数量)
|
||||||
@@ -76,14 +83,29 @@ class PluginManager:
|
|||||||
|
|
||||||
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
|
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
|
||||||
|
|
||||||
# 第二阶段:实例化所有已注册的插件类
|
|
||||||
from src.plugin_system.base.base_plugin import get_registered_plugin_classes
|
|
||||||
|
|
||||||
plugin_classes = get_registered_plugin_classes()
|
|
||||||
total_registered = 0
|
total_registered = 0
|
||||||
total_failed_registration = 0
|
total_failed_registration = 0
|
||||||
|
|
||||||
for plugin_name, plugin_class in plugin_classes.items():
|
for plugin_name in self.plugin_classes.keys():
|
||||||
|
load_status, count = self.load_registered_plugin_classes(plugin_name)
|
||||||
|
if load_status:
|
||||||
|
total_registered += 1
|
||||||
|
else:
|
||||||
|
total_failed_registration += count
|
||||||
|
|
||||||
|
self._show_stats(total_registered, total_failed_registration)
|
||||||
|
|
||||||
|
return total_registered, total_failed_registration
|
||||||
|
|
||||||
|
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
||||||
|
# sourcery skip: extract-duplicate-method, extract-method
|
||||||
|
"""
|
||||||
|
加载已经注册的插件类
|
||||||
|
"""
|
||||||
|
plugin_class: Type[BasePlugin] = self.plugin_classes.get(plugin_name)
|
||||||
|
if not plugin_class:
|
||||||
|
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
||||||
|
return False, 1
|
||||||
try:
|
try:
|
||||||
# 使用记录的插件目录路径
|
# 使用记录的插件目录路径
|
||||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||||
@@ -92,260 +114,77 @@ class PluginManager:
|
|||||||
if not plugin_dir:
|
if not plugin_dir:
|
||||||
plugin_dir = self._find_plugin_directory(plugin_class)
|
plugin_dir = self._find_plugin_directory(plugin_class)
|
||||||
if plugin_dir:
|
if plugin_dir:
|
||||||
self.plugin_paths[plugin_name] = plugin_dir # 实例化插件(可能因为缺少manifest而失败)
|
self.plugin_paths[plugin_name] = plugin_dir # 更新路径
|
||||||
plugin_instance = plugin_class(plugin_dir=plugin_dir)
|
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||||
|
|
||||||
# 检查插件是否启用
|
# 检查插件是否启用
|
||||||
if not plugin_instance.enable_plugin:
|
if not plugin_instance.enable_plugin:
|
||||||
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
|
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
|
||||||
continue
|
return False, 0
|
||||||
|
|
||||||
# 检查版本兼容性
|
# 检查版本兼容性
|
||||||
is_compatible, compatibility_error = self.check_plugin_version_compatibility(
|
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
|
||||||
plugin_name, plugin_instance.manifest_data
|
plugin_name, plugin_instance.manifest_data
|
||||||
)
|
)
|
||||||
if not is_compatible:
|
if not is_compatible:
|
||||||
total_failed_registration += 1
|
|
||||||
self.failed_plugins[plugin_name] = compatibility_error
|
self.failed_plugins[plugin_name] = compatibility_error
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
|
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
|
||||||
continue
|
return False, 1
|
||||||
|
|
||||||
if plugin_instance.register_plugin():
|
if plugin_instance.register_plugin():
|
||||||
total_registered += 1
|
|
||||||
self.loaded_plugins[plugin_name] = plugin_instance
|
self.loaded_plugins[plugin_name] = plugin_instance
|
||||||
|
self._show_plugin_components(plugin_name)
|
||||||
# 📊 显示插件详细信息
|
return True, 1
|
||||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
|
||||||
if plugin_info:
|
|
||||||
component_types = {}
|
|
||||||
for comp in plugin_info.components:
|
|
||||||
comp_type = comp.component_type.name
|
|
||||||
component_types[comp_type] = component_types.get(comp_type, 0) + 1
|
|
||||||
|
|
||||||
components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()])
|
|
||||||
|
|
||||||
# 显示manifest信息
|
|
||||||
manifest_info = ""
|
|
||||||
if plugin_info.license:
|
|
||||||
manifest_info += f" [{plugin_info.license}]"
|
|
||||||
if plugin_info.keywords:
|
|
||||||
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
|
|
||||||
if len(plugin_info.keywords) > 3:
|
|
||||||
manifest_info += "..."
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"✅ 插件加载成功: {plugin_name}")
|
|
||||||
else:
|
|
||||||
total_failed_registration += 1
|
|
||||||
self.failed_plugins[plugin_name] = "插件注册失败"
|
self.failed_plugins[plugin_name] = "插件注册失败"
|
||||||
logger.error(f"❌ 插件注册失败: {plugin_name}")
|
logger.error(f"❌ 插件注册失败: {plugin_name}")
|
||||||
|
return False, 1
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
# manifest文件缺失
|
# manifest文件缺失
|
||||||
total_failed_registration += 1
|
|
||||||
error_msg = f"缺少manifest文件: {str(e)}"
|
error_msg = f"缺少manifest文件: {str(e)}"
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
self.failed_plugins[plugin_name] = error_msg
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||||
|
return False, 1
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# manifest文件格式错误或验证失败
|
# manifest文件格式错误或验证失败
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
total_failed_registration += 1
|
|
||||||
error_msg = f"manifest验证失败: {str(e)}"
|
error_msg = f"manifest验证失败: {str(e)}"
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
self.failed_plugins[plugin_name] = error_msg
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||||
|
return False, 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 其他错误
|
# 其他错误
|
||||||
total_failed_registration += 1
|
|
||||||
error_msg = f"未知错误: {str(e)}"
|
error_msg = f"未知错误: {str(e)}"
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
self.failed_plugins[plugin_name] = error_msg
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||||
logger.debug("详细错误信息: ", exc_info=True)
|
logger.debug("详细错误信息: ", exc_info=True)
|
||||||
|
return False, 1
|
||||||
|
|
||||||
# 获取组件统计信息
|
def unload_registered_plugin_module(self, plugin_name: str) -> None:
|
||||||
stats = component_registry.get_registry_stats()
|
"""
|
||||||
action_count = stats.get("action_components", 0)
|
卸载插件模块
|
||||||
command_count = stats.get("command_components", 0)
|
"""
|
||||||
total_components = stats.get("total_components", 0)
|
pass
|
||||||
|
|
||||||
# 📋 显示插件加载总览
|
def reload_registered_plugin_module(self, plugin_name: str) -> None:
|
||||||
if total_registered > 0:
|
"""
|
||||||
logger.info("🎉 插件系统加载完成!")
|
重载插件模块
|
||||||
logger.info(
|
"""
|
||||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})"
|
self.unload_registered_plugin_module(plugin_name)
|
||||||
)
|
self.load_registered_plugin_classes(plugin_name)
|
||||||
|
|
||||||
# 显示详细的插件列表 logger.info("📋 已加载插件详情:")
|
def rescan_plugin_directory(self) -> None:
|
||||||
for plugin_name, _plugin_class in self.loaded_plugins.items():
|
"""
|
||||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
重新扫描插件根目录
|
||||||
if plugin_info:
|
"""
|
||||||
# 插件基本信息
|
# --------------------------------------- NEED REFACTORING ---------------------------------------
|
||||||
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
|
|
||||||
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
|
|
||||||
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
|
|
||||||
info_parts = [part for part in [version_info, author_info, license_info] if part]
|
|
||||||
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
|
|
||||||
|
|
||||||
logger.info(f" 📦 {plugin_name}{extra_info}")
|
|
||||||
|
|
||||||
# Manifest信息
|
|
||||||
if plugin_info.manifest_data:
|
|
||||||
if plugin_info.keywords:
|
|
||||||
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
|
|
||||||
if plugin_info.categories:
|
|
||||||
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
|
|
||||||
if plugin_info.homepage_url:
|
|
||||||
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
|
|
||||||
|
|
||||||
# 组件列表
|
|
||||||
if plugin_info.components:
|
|
||||||
action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"]
|
|
||||||
command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"]
|
|
||||||
|
|
||||||
if action_components:
|
|
||||||
action_names = [c.name for c in action_components]
|
|
||||||
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
|
|
||||||
|
|
||||||
if command_components:
|
|
||||||
command_names = [c.name for c in command_components]
|
|
||||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
|
||||||
|
|
||||||
# 版本兼容性信息
|
|
||||||
if plugin_info.min_host_version or plugin_info.max_host_version:
|
|
||||||
version_range = ""
|
|
||||||
if plugin_info.min_host_version:
|
|
||||||
version_range += f">={plugin_info.min_host_version}"
|
|
||||||
if plugin_info.max_host_version:
|
|
||||||
if version_range:
|
|
||||||
version_range += f", <={plugin_info.max_host_version}"
|
|
||||||
else:
|
|
||||||
version_range += f"<={plugin_info.max_host_version}"
|
|
||||||
logger.info(f" 📋 兼容版本: {version_range}")
|
|
||||||
|
|
||||||
# 依赖信息
|
|
||||||
if plugin_info.dependencies:
|
|
||||||
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
|
|
||||||
|
|
||||||
# 配置文件信息
|
|
||||||
if plugin_info.config_file:
|
|
||||||
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
|
||||||
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
|
||||||
|
|
||||||
# 显示目录统计
|
|
||||||
logger.info("📂 加载目录统计:")
|
|
||||||
for directory in self.plugin_directories:
|
for directory in self.plugin_directories:
|
||||||
if os.path.exists(directory):
|
if os.path.exists(directory):
|
||||||
plugins_in_dir = []
|
logger.debug(f"重新扫描插件根目录: {directory}")
|
||||||
for plugin_name in self.loaded_plugins.keys():
|
self._load_plugin_modules_from_directory(directory)
|
||||||
plugin_path = self.plugin_paths.get(plugin_name, "")
|
|
||||||
if plugin_path.startswith(directory):
|
|
||||||
plugins_in_dir.append(plugin_name)
|
|
||||||
|
|
||||||
if plugins_in_dir:
|
|
||||||
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
|
|
||||||
else:
|
else:
|
||||||
logger.info(f" 📁 {directory}: 0个插件")
|
logger.warning(f"插件根目录不存在: {directory}")
|
||||||
|
|
||||||
# 失败信息
|
|
||||||
if total_failed_registration > 0:
|
|
||||||
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
|
|
||||||
for failed_plugin, error in self.failed_plugins.items():
|
|
||||||
logger.info(f" ❌ {failed_plugin}: {error}")
|
|
||||||
else:
|
|
||||||
logger.warning("😕 没有成功加载任何插件")
|
|
||||||
|
|
||||||
# 返回插件数量和组件数量
|
|
||||||
return total_registered, total_components
|
|
||||||
|
|
||||||
def _find_plugin_directory(self, plugin_class) -> Optional[str]:
|
|
||||||
"""查找插件类对应的目录路径"""
|
|
||||||
try:
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
module = inspect.getmodule(plugin_class)
|
|
||||||
if module and hasattr(module, "__file__") and module.__file__:
|
|
||||||
return os.path.dirname(module.__file__)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"通过inspect获取插件目录失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
|
||||||
"""从指定目录加载插件模块"""
|
|
||||||
loaded_count = 0
|
|
||||||
failed_count = 0
|
|
||||||
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
logger.warning(f"插件目录不存在: {directory}")
|
|
||||||
return loaded_count, failed_count
|
|
||||||
|
|
||||||
logger.debug(f"正在扫描插件目录: {directory}")
|
|
||||||
|
|
||||||
# 遍历目录中的所有Python文件和包
|
|
||||||
for item in os.listdir(directory):
|
|
||||||
item_path = os.path.join(directory, item)
|
|
||||||
|
|
||||||
if os.path.isfile(item_path) and item.endswith(".py") and item != "__init__.py":
|
|
||||||
# 单文件插件
|
|
||||||
plugin_name = Path(item_path).stem
|
|
||||||
if self._load_plugin_module_file(item_path, plugin_name, directory):
|
|
||||||
loaded_count += 1
|
|
||||||
else:
|
|
||||||
failed_count += 1
|
|
||||||
|
|
||||||
elif os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
|
|
||||||
# 插件包
|
|
||||||
plugin_file = os.path.join(item_path, "plugin.py")
|
|
||||||
if os.path.exists(plugin_file):
|
|
||||||
plugin_name = item # 使用目录名作为插件名
|
|
||||||
if self._load_plugin_module_file(plugin_file, plugin_name, item_path):
|
|
||||||
loaded_count += 1
|
|
||||||
else:
|
|
||||||
failed_count += 1
|
|
||||||
|
|
||||||
return loaded_count, failed_count
|
|
||||||
|
|
||||||
def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool:
|
|
||||||
"""加载单个插件模块文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_file: 插件文件路径
|
|
||||||
plugin_name: 插件名称
|
|
||||||
plugin_dir: 插件目录路径
|
|
||||||
"""
|
|
||||||
# 生成模块名
|
|
||||||
plugin_path = Path(plugin_file)
|
|
||||||
if plugin_path.parent.name != "plugins":
|
|
||||||
# 插件包格式:parent_dir.plugin
|
|
||||||
module_name = f"plugins.{plugin_path.parent.name}.plugin"
|
|
||||||
else:
|
|
||||||
# 单文件格式:plugins.filename
|
|
||||||
module_name = f"plugins.{plugin_path.stem}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 动态导入插件模块
|
|
||||||
spec = importlib.util.spec_from_file_location(module_name, plugin_file)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
logger.error(f"无法创建模块规范: {plugin_file}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
|
|
||||||
# 记录插件名和目录路径的映射
|
|
||||||
self.plugin_paths[plugin_name] = plugin_dir
|
|
||||||
|
|
||||||
logger.debug(f"插件模块加载成功: {plugin_file}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
self.failed_plugins[plugin_name] = error_msg
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_loaded_plugins(self) -> List[PluginInfo]:
|
def get_loaded_plugins(self) -> List[PluginInfo]:
|
||||||
"""获取所有已加载的插件信息"""
|
"""获取所有已加载的插件信息"""
|
||||||
@@ -356,9 +195,9 @@ class PluginManager:
|
|||||||
return list(component_registry.get_enabled_plugins().values())
|
return list(component_registry.get_enabled_plugins().values())
|
||||||
|
|
||||||
def enable_plugin(self, plugin_name: str) -> bool:
|
def enable_plugin(self, plugin_name: str) -> bool:
|
||||||
|
# -------------------------------- NEED REFACTORING --------------------------------
|
||||||
"""启用插件"""
|
"""启用插件"""
|
||||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||||
if plugin_info:
|
|
||||||
plugin_info.enabled = True
|
plugin_info.enabled = True
|
||||||
# 启用插件的所有组件
|
# 启用插件的所有组件
|
||||||
for component in plugin_info.components:
|
for component in plugin_info.components:
|
||||||
@@ -368,9 +207,9 @@ class PluginManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def disable_plugin(self, plugin_name: str) -> bool:
|
def disable_plugin(self, plugin_name: str) -> bool:
|
||||||
|
# -------------------------------- NEED REFACTORING --------------------------------
|
||||||
"""禁用插件"""
|
"""禁用插件"""
|
||||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||||
if plugin_info:
|
|
||||||
plugin_info.enabled = False
|
plugin_info.enabled = False
|
||||||
# 禁用插件的所有组件
|
# 禁用插件的所有组件
|
||||||
for component in plugin_info.components:
|
for component in plugin_info.components:
|
||||||
@@ -409,12 +248,6 @@ class PluginManager:
|
|||||||
"failed_plugin_details": self.failed_plugins.copy(),
|
"failed_plugin_details": self.failed_plugins.copy(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def reload_plugin(self, plugin_name: str) -> bool:
|
|
||||||
"""重新加载插件(高级功能,需要谨慎使用)"""
|
|
||||||
# TODO: 实现插件热重载功能
|
|
||||||
logger.warning("插件热重载功能尚未实现")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]:
|
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, any]:
|
||||||
"""检查所有插件的Python依赖包
|
"""检查所有插件的Python依赖包
|
||||||
|
|
||||||
@@ -426,11 +259,11 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
logger.info("开始检查所有插件的Python依赖包...")
|
logger.info("开始检查所有插件的Python依赖包...")
|
||||||
|
|
||||||
all_required_missing = []
|
all_required_missing: List[PythonDependency] = []
|
||||||
all_optional_missing = []
|
all_optional_missing: List[PythonDependency] = []
|
||||||
plugin_status = {}
|
plugin_status = {}
|
||||||
|
|
||||||
for plugin_name, _plugin_instance in self.loaded_plugins.items():
|
for plugin_name in self.loaded_plugins:
|
||||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
plugin_info = component_registry.get_plugin_info(plugin_name)
|
||||||
if not plugin_info or not plugin_info.python_dependencies:
|
if not plugin_info or not plugin_info.python_dependencies:
|
||||||
plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []}
|
plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []}
|
||||||
@@ -461,19 +294,15 @@ class PluginManager:
|
|||||||
logger.info(f"插件 {plugin_name} 依赖检查通过")
|
logger.info(f"插件 {plugin_name} 依赖检查通过")
|
||||||
|
|
||||||
# 汇总结果
|
# 汇总结果
|
||||||
total_missing = len(set(dep.package_name for dep in all_required_missing))
|
total_missing = len({dep.package_name for dep in all_required_missing})
|
||||||
total_optional_missing = len(set(dep.package_name for dep in all_optional_missing))
|
total_optional_missing = len({dep.package_name for dep in all_optional_missing})
|
||||||
|
|
||||||
logger.info(f"依赖检查完成 - 缺少必需包: {total_missing}个, 缺少可选包: {total_optional_missing}个")
|
logger.info(f"依赖检查完成 - 缺少必需包: {total_missing}个, 缺少可选包: {total_optional_missing}个")
|
||||||
|
|
||||||
# 如果需要自动安装
|
# 如果需要自动安装
|
||||||
install_success = True
|
install_success = True
|
||||||
if auto_install and all_required_missing:
|
if auto_install and all_required_missing:
|
||||||
# 去重
|
unique_required = {dep.package_name: dep for dep in all_required_missing}
|
||||||
unique_required = {}
|
|
||||||
for dep in all_required_missing:
|
|
||||||
unique_required[dep.package_name] = dep
|
|
||||||
|
|
||||||
logger.info(f"开始自动安装 {len(unique_required)} 个必需依赖包...")
|
logger.info(f"开始自动安装 {len(unique_required)} 个必需依赖包...")
|
||||||
install_success = dependency_manager.install_dependencies(list(unique_required.values()), auto_install=True)
|
install_success = dependency_manager.install_dependencies(list(unique_required.values()), auto_install=True)
|
||||||
|
|
||||||
@@ -506,7 +335,7 @@ class PluginManager:
|
|||||||
|
|
||||||
all_dependencies = []
|
all_dependencies = []
|
||||||
|
|
||||||
for plugin_name, _plugin_instance in self.loaded_plugins.items():
|
for plugin_name in self.loaded_plugins:
|
||||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
plugin_info = component_registry.get_plugin_info(plugin_name)
|
||||||
if plugin_info and plugin_info.python_dependencies:
|
if plugin_info and plugin_info.python_dependencies:
|
||||||
all_dependencies.append(plugin_info.python_dependencies)
|
all_dependencies.append(plugin_info.python_dependencies)
|
||||||
@@ -517,7 +346,92 @@ class PluginManager:
|
|||||||
|
|
||||||
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
|
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
|
||||||
|
|
||||||
def check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
||||||
|
"""从指定目录加载插件模块"""
|
||||||
|
loaded_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
logger.warning(f"插件根目录不存在: {directory}")
|
||||||
|
return 0, 1
|
||||||
|
|
||||||
|
logger.debug(f"正在扫描插件根目录: {directory}")
|
||||||
|
|
||||||
|
# 遍历目录中的所有Python文件和包
|
||||||
|
for item in os.listdir(directory):
|
||||||
|
item_path = os.path.join(directory, item)
|
||||||
|
|
||||||
|
if os.path.isfile(item_path) and item.endswith(".py") and item != "__init__.py":
|
||||||
|
# 单文件插件
|
||||||
|
plugin_name = Path(item_path).stem
|
||||||
|
if self._load_plugin_module_file(item_path, plugin_name, directory):
|
||||||
|
loaded_count += 1
|
||||||
|
else:
|
||||||
|
failed_count += 1
|
||||||
|
|
||||||
|
elif os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
|
||||||
|
# 插件包
|
||||||
|
plugin_file = os.path.join(item_path, "plugin.py")
|
||||||
|
if os.path.exists(plugin_file):
|
||||||
|
plugin_name = item # 使用目录名作为插件名
|
||||||
|
if self._load_plugin_module_file(plugin_file, plugin_name, item_path):
|
||||||
|
loaded_count += 1
|
||||||
|
else:
|
||||||
|
failed_count += 1
|
||||||
|
|
||||||
|
return loaded_count, failed_count
|
||||||
|
|
||||||
|
def _find_plugin_directory(self, plugin_class: str) -> Optional[str]:
|
||||||
|
"""查找插件类对应的目录路径"""
|
||||||
|
try:
|
||||||
|
module = getmodule(plugin_class)
|
||||||
|
if module and hasattr(module, "__file__") and module.__file__:
|
||||||
|
return os.path.dirname(module.__file__)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"通过inspect获取插件目录失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_plugin_module_file(self, plugin_file: str, plugin_name: str, plugin_dir: str) -> bool:
|
||||||
|
# sourcery skip: extract-method
|
||||||
|
"""加载单个插件模块文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_file: 插件文件路径
|
||||||
|
plugin_name: 插件名称
|
||||||
|
plugin_dir: 插件目录路径
|
||||||
|
"""
|
||||||
|
# 生成模块名
|
||||||
|
plugin_path = Path(plugin_file)
|
||||||
|
if plugin_path.parent.name != "plugins":
|
||||||
|
# 插件包格式:parent_dir.plugin
|
||||||
|
module_name = f"plugins.{plugin_path.parent.name}.plugin"
|
||||||
|
else:
|
||||||
|
# 单文件格式:plugins.filename
|
||||||
|
module_name = f"plugins.{plugin_path.stem}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 动态导入插件模块
|
||||||
|
spec = spec_from_file_location(module_name, plugin_file)
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
logger.error(f"无法创建模块规范: {plugin_file}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
module = module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# 记录插件名和目录路径的映射
|
||||||
|
self.plugin_paths[plugin_name] = plugin_dir
|
||||||
|
|
||||||
|
logger.debug(f"插件模块加载成功: {plugin_file}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
self.failed_plugins[plugin_name] = error_msg
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||||
"""检查插件版本兼容性
|
"""检查插件版本兼容性
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -528,8 +442,7 @@ class PluginManager:
|
|||||||
Tuple[bool, str]: (是否兼容, 错误信息)
|
Tuple[bool, str]: (是否兼容, 错误信息)
|
||||||
"""
|
"""
|
||||||
if "host_application" not in manifest_data:
|
if "host_application" not in manifest_data:
|
||||||
# 没有版本要求,默认兼容
|
return True, "" # 没有版本要求,默认兼容
|
||||||
return True, ""
|
|
||||||
|
|
||||||
host_app = manifest_data["host_application"]
|
host_app = manifest_data["host_application"]
|
||||||
if not isinstance(host_app, dict):
|
if not isinstance(host_app, dict):
|
||||||
@@ -539,31 +452,128 @@ class PluginManager:
|
|||||||
max_version = host_app.get("max_version", "")
|
max_version = host_app.get("max_version", "")
|
||||||
|
|
||||||
if not min_version and not max_version:
|
if not min_version and not max_version:
|
||||||
return True, ""
|
return True, "" # 没有版本要求,默认兼容
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
|
||||||
|
|
||||||
current_version = VersionComparator.get_current_host_version()
|
current_version = VersionComparator.get_current_host_version()
|
||||||
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
|
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
|
||||||
|
|
||||||
if not is_compatible:
|
if not is_compatible:
|
||||||
return False, f"版本不兼容: {error_msg}"
|
return False, f"版本不兼容: {error_msg}"
|
||||||
else:
|
|
||||||
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
|
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
|
||||||
return True, ""
|
return True, ""
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
||||||
return True, "" # 检查失败时默认允许加载
|
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
||||||
|
|
||||||
|
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
||||||
|
# sourcery skip: low-code-quality
|
||||||
|
# 获取组件统计信息
|
||||||
|
stats = component_registry.get_registry_stats()
|
||||||
|
action_count = stats.get("action_components", 0)
|
||||||
|
command_count = stats.get("command_components", 0)
|
||||||
|
total_components = stats.get("total_components", 0)
|
||||||
|
|
||||||
|
# 📋 显示插件加载总览
|
||||||
|
if total_registered > 0:
|
||||||
|
logger.info("🎉 插件系统加载完成!")
|
||||||
|
logger.info(
|
||||||
|
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 显示详细的插件列表
|
||||||
|
logger.info("📋 已加载插件详情:")
|
||||||
|
for plugin_name in self.loaded_plugins.keys():
|
||||||
|
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||||
|
# 插件基本信息
|
||||||
|
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
|
||||||
|
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
|
||||||
|
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
|
||||||
|
info_parts = [part for part in [version_info, author_info, license_info] if part]
|
||||||
|
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
|
||||||
|
|
||||||
|
logger.info(f" 📦 {plugin_info.display_name}{extra_info}")
|
||||||
|
|
||||||
|
# Manifest信息
|
||||||
|
if plugin_info.manifest_data:
|
||||||
|
"""
|
||||||
|
if plugin_info.keywords:
|
||||||
|
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
|
||||||
|
if plugin_info.categories:
|
||||||
|
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
|
||||||
|
"""
|
||||||
|
if plugin_info.homepage_url:
|
||||||
|
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
|
||||||
|
|
||||||
|
# 组件列表
|
||||||
|
if plugin_info.components:
|
||||||
|
action_components = [c for c in plugin_info.components if c.component_type.name == "ACTION"]
|
||||||
|
command_components = [c for c in plugin_info.components if c.component_type.name == "COMMAND"]
|
||||||
|
|
||||||
|
if action_components:
|
||||||
|
action_names = [c.name for c in action_components]
|
||||||
|
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
|
||||||
|
|
||||||
|
if command_components:
|
||||||
|
command_names = [c.name for c in command_components]
|
||||||
|
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||||
|
|
||||||
|
# 依赖信息
|
||||||
|
if plugin_info.dependencies:
|
||||||
|
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
|
||||||
|
|
||||||
|
# 配置文件信息
|
||||||
|
if plugin_info.config_file:
|
||||||
|
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
||||||
|
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
||||||
|
|
||||||
|
# 显示目录统计
|
||||||
|
logger.info("📂 加载目录统计:")
|
||||||
|
for directory in self.plugin_directories:
|
||||||
|
if os.path.exists(directory):
|
||||||
|
plugins_in_dir = []
|
||||||
|
for plugin_name in self.loaded_plugins.keys():
|
||||||
|
plugin_path = self.plugin_paths.get(plugin_name, "")
|
||||||
|
if plugin_path.startswith(directory):
|
||||||
|
plugins_in_dir.append(plugin_name)
|
||||||
|
|
||||||
|
if plugins_in_dir:
|
||||||
|
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
|
||||||
|
else:
|
||||||
|
logger.info(f" 📁 {directory}: 0个插件")
|
||||||
|
|
||||||
|
# 失败信息
|
||||||
|
if total_failed_registration > 0:
|
||||||
|
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
|
||||||
|
for failed_plugin, error in self.failed_plugins.items():
|
||||||
|
logger.info(f" ❌ {failed_plugin}: {error}")
|
||||||
|
else:
|
||||||
|
logger.warning("😕 没有成功加载任何插件")
|
||||||
|
|
||||||
|
def _show_plugin_components(self, plugin_name: str) -> None:
|
||||||
|
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||||
|
component_types = {}
|
||||||
|
for comp in plugin_info.components:
|
||||||
|
comp_type = comp.component_type.name
|
||||||
|
component_types[comp_type] = component_types.get(comp_type, 0) + 1
|
||||||
|
|
||||||
|
components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()])
|
||||||
|
|
||||||
|
# 显示manifest信息
|
||||||
|
manifest_info = ""
|
||||||
|
if plugin_info.license:
|
||||||
|
manifest_info += f" [{plugin_info.license}]"
|
||||||
|
if plugin_info.keywords:
|
||||||
|
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
|
||||||
|
if len(plugin_info.keywords) > 3:
|
||||||
|
manifest_info += "..."
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"✅ 插件加载成功: {plugin_name}")
|
||||||
|
|
||||||
|
|
||||||
# 全局插件管理器实例
|
# 全局插件管理器实例
|
||||||
plugin_manager = PluginManager()
|
plugin_manager = PluginManager()
|
||||||
|
|
||||||
# 注释掉以解决插件目录重复加载的情况
|
|
||||||
# 默认插件目录
|
|
||||||
# plugin_manager.add_plugin_directory("src/plugins/built_in")
|
|
||||||
# plugin_manager.add_plugin_directory("src/plugins/examples")
|
|
||||||
# 用户插件目录
|
|
||||||
# plugin_manager.add_plugin_directory("plugins")
|
|
||||||
|
|||||||
9
src/plugin_system/events/__init__.py
Normal file
9
src/plugin_system/events/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
插件的事件系统模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .events import EventType
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EventType",
|
||||||
|
]
|
||||||
14
src/plugin_system/events/events.py
Normal file
14
src/plugin_system/events/events.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class EventType(Enum):
|
||||||
|
"""
|
||||||
|
事件类型枚举类
|
||||||
|
"""
|
||||||
|
|
||||||
|
ON_MESSAGE = "on_message"
|
||||||
|
ON_PLAN = "on_plan"
|
||||||
|
POST_LLM = "post_llm"
|
||||||
|
AFTER_LLM = "after_llm"
|
||||||
|
POST_SEND = "post_send"
|
||||||
|
AFTER_SEND = "after_send"
|
||||||
@@ -4,11 +4,16 @@
|
|||||||
提供插件开发和管理的实用工具
|
提供插件开发和管理的实用工具
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.plugin_system.utils.manifest_utils import (
|
from .manifest_utils import (
|
||||||
ManifestValidator,
|
ManifestValidator,
|
||||||
ManifestGenerator,
|
ManifestGenerator,
|
||||||
validate_plugin_manifest,
|
validate_plugin_manifest,
|
||||||
generate_plugin_manifest,
|
generate_plugin_manifest,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ["ManifestValidator", "ManifestGenerator", "validate_plugin_manifest", "generate_plugin_manifest"]
|
__all__ = [
|
||||||
|
"ManifestValidator",
|
||||||
|
"ManifestGenerator",
|
||||||
|
"validate_plugin_manifest",
|
||||||
|
"generate_plugin_manifest",
|
||||||
|
]
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ class ManifestValidator:
|
|||||||
# 检查URL格式(可选字段)
|
# 检查URL格式(可选字段)
|
||||||
for url_field in ["homepage_url", "repository_url"]:
|
for url_field in ["homepage_url", "repository_url"]:
|
||||||
if url_field in manifest_data and manifest_data[url_field]:
|
if url_field in manifest_data and manifest_data[url_field]:
|
||||||
url = manifest_data[url_field]
|
url: str = manifest_data[url_field]
|
||||||
if not (url.startswith("http://") or url.startswith("https://")):
|
if not (url.startswith("http://") or url.startswith("https://")):
|
||||||
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")
|
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")
|
||||||
|
|
||||||
|
|||||||
@@ -156,6 +156,8 @@ class CoreActionsPlugin(BasePlugin):
|
|||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name = "core_actions" # 内部标识符
|
plugin_name = "core_actions" # 内部标识符
|
||||||
enable_plugin = True
|
enable_plugin = True
|
||||||
|
dependencies = [] # 插件依赖列表
|
||||||
|
python_dependencies = [] # Python包依赖列表
|
||||||
config_file_name = "config.toml"
|
config_file_name = "config.toml"
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
|
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||||
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
from src.plugin_system.base.component_types import ComponentInfo
|
from src.plugin_system.base.component_types import ComponentInfo
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||||
@@ -108,6 +109,8 @@ class TTSPlugin(BasePlugin):
|
|||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name = "tts_plugin" # 内部标识符
|
plugin_name = "tts_plugin" # 内部标识符
|
||||||
enable_plugin = True
|
enable_plugin = True
|
||||||
|
dependencies = [] # 插件依赖列表
|
||||||
|
python_dependencies = [] # Python包依赖列表
|
||||||
config_file_name = "config.toml"
|
config_file_name = "config.toml"
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from src.plugin_system.base.base_plugin import BasePlugin, register_plugin
|
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||||
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
from src.plugin_system.base.component_types import ComponentInfo
|
from src.plugin_system.base.component_types import ComponentInfo
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||||
@@ -109,6 +110,8 @@ class VTBPlugin(BasePlugin):
|
|||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name = "vtb_plugin" # 内部标识符
|
plugin_name = "vtb_plugin" # 内部标识符
|
||||||
enable_plugin = True
|
enable_plugin = True
|
||||||
|
dependencies = [] # 插件依赖列表
|
||||||
|
python_dependencies = [] # Python包依赖列表
|
||||||
config_file_name = "config.toml"
|
config_file_name = "config.toml"
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class ToolExecutor:
|
|||||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chat_id: str = None, enable_cache: bool = True, cache_ttl: int = 3):
|
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
||||||
"""初始化工具执行器
|
"""初始化工具执行器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -62,8 +62,8 @@ class ToolExecutor:
|
|||||||
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
||||||
|
|
||||||
async def execute_from_chat_message(
|
async def execute_from_chat_message(
|
||||||
self, target_message: str, chat_history: list[str], sender: str, return_details: bool = False
|
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||||
) -> List[Dict] | Tuple[List[Dict], List[str], str]:
|
) -> Tuple[List[Dict], List[str], str]:
|
||||||
"""从聊天消息执行工具
|
"""从聊天消息执行工具
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -79,16 +79,14 @@ class ToolExecutor:
|
|||||||
|
|
||||||
# 首先检查缓存
|
# 首先检查缓存
|
||||||
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
||||||
cached_result = self._get_from_cache(cache_key)
|
if cached_result := self._get_from_cache(cache_key):
|
||||||
|
|
||||||
if cached_result:
|
|
||||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||||
if return_details:
|
if not return_details:
|
||||||
|
return cached_result, [], "使用缓存结果"
|
||||||
|
|
||||||
# 从缓存结果中提取工具名称
|
# 从缓存结果中提取工具名称
|
||||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||||
return cached_result, used_tools, "使用缓存结果"
|
return cached_result, used_tools, "使用缓存结果"
|
||||||
else:
|
|
||||||
return cached_result
|
|
||||||
|
|
||||||
# 缓存未命中,执行工具调用
|
# 缓存未命中,执行工具调用
|
||||||
# 获取可用工具
|
# 获取可用工具
|
||||||
@@ -134,7 +132,7 @@ class ToolExecutor:
|
|||||||
if return_details:
|
if return_details:
|
||||||
return tool_results, used_tools, prompt
|
return tool_results, used_tools, prompt
|
||||||
else:
|
else:
|
||||||
return tool_results
|
return tool_results, [], ""
|
||||||
|
|
||||||
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
|
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
|
||||||
"""执行工具调用
|
"""执行工具调用
|
||||||
@@ -207,7 +205,7 @@ class ToolExecutor:
|
|||||||
|
|
||||||
return tool_results, used_tools
|
return tool_results, used_tools
|
||||||
|
|
||||||
def _generate_cache_key(self, target_message: str, chat_history: list[str], sender: str) -> str:
|
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||||
"""生成缓存键
|
"""生成缓存键
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -267,10 +265,7 @@ class ToolExecutor:
|
|||||||
return
|
return
|
||||||
|
|
||||||
expired_keys = []
|
expired_keys = []
|
||||||
for cache_key, cache_item in self.tool_cache.items():
|
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
|
||||||
if cache_item["ttl"] <= 0:
|
|
||||||
expired_keys.append(cache_key)
|
|
||||||
|
|
||||||
for key in expired_keys:
|
for key in expired_keys:
|
||||||
del self.tool_cache[key]
|
del self.tool_cache[key]
|
||||||
|
|
||||||
@@ -355,7 +350,7 @@ class ToolExecutor:
|
|||||||
"ttl_distribution": ttl_distribution,
|
"ttl_distribution": ttl_distribution,
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_cache_config(self, enable_cache: bool = None, cache_ttl: int = None):
|
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
||||||
"""动态修改缓存配置
|
"""动态修改缓存配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -366,7 +361,7 @@ class ToolExecutor:
|
|||||||
self.enable_cache = enable_cache
|
self.enable_cache = enable_cache
|
||||||
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
|
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
|
||||||
|
|
||||||
if cache_ttl is not None and cache_ttl > 0:
|
if cache_ttl > 0:
|
||||||
self.cache_ttl = cache_ttl
|
self.cache_ttl = cache_ttl
|
||||||
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
||||||
|
|
||||||
@@ -380,7 +375,7 @@ init_tool_executor_prompt()
|
|||||||
|
|
||||||
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
|
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
|
||||||
executor = ToolExecutor(executor_id="my_executor")
|
executor = ToolExecutor(executor_id="my_executor")
|
||||||
results = await executor.execute_from_chat_message(
|
results, _, _ = await executor.execute_from_chat_message(
|
||||||
talking_message_str="今天天气怎么样?现在几点了?",
|
talking_message_str="今天天气怎么样?现在几点了?",
|
||||||
is_group_chat=False
|
is_group_chat=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,88 +0,0 @@
|
|||||||
@echo off
|
|
||||||
CHCP 65001 > nul
|
|
||||||
setlocal enabledelayedexpansion
|
|
||||||
|
|
||||||
echo 你需要选择启动方式,输入字母来选择:
|
|
||||||
echo V = 不知道什么意思就输入 V
|
|
||||||
echo C = 输入 C 使用 Conda 环境
|
|
||||||
echo.
|
|
||||||
choice /C CV /N /M "不知道什么意思就输入 V (C/V)?" /T 10 /D V
|
|
||||||
|
|
||||||
set "ENV_TYPE="
|
|
||||||
if %ERRORLEVEL% == 1 set "ENV_TYPE=CONDA"
|
|
||||||
if %ERRORLEVEL% == 2 set "ENV_TYPE=VENV"
|
|
||||||
|
|
||||||
if "%ENV_TYPE%" == "CONDA" goto activate_conda
|
|
||||||
if "%ENV_TYPE%" == "VENV" goto activate_venv
|
|
||||||
|
|
||||||
REM 如果 choice 超时或返回意外值,默认使用 venv
|
|
||||||
echo WARN: Invalid selection or timeout from choice. Defaulting to VENV.
|
|
||||||
set "ENV_TYPE=VENV"
|
|
||||||
goto activate_venv
|
|
||||||
|
|
||||||
:activate_conda
|
|
||||||
set /p CONDA_ENV_NAME="请输入要使用的 Conda 环境名称: "
|
|
||||||
if not defined CONDA_ENV_NAME (
|
|
||||||
echo 错误: 未输入 Conda 环境名称.
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
echo 选择: Conda '!CONDA_ENV_NAME!'
|
|
||||||
REM 激活Conda环境
|
|
||||||
call conda activate !CONDA_ENV_NAME!
|
|
||||||
if !ERRORLEVEL! neq 0 (
|
|
||||||
echo 错误: Conda环境 '!CONDA_ENV_NAME!' 激活失败. 请确保Conda已安装并正确配置, 且 '!CONDA_ENV_NAME!' 环境存在.
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
goto env_activated
|
|
||||||
|
|
||||||
:activate_venv
|
|
||||||
echo Selected: venv (default or selected)
|
|
||||||
REM 查找venv虚拟环境
|
|
||||||
set "venv_path=%~dp0venv\Scripts\activate.bat"
|
|
||||||
if not exist "%venv_path%" (
|
|
||||||
echo Error: venv not found. Ensure the venv directory exists alongside the script.
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
REM 激活虚拟环境
|
|
||||||
call "%venv_path%"
|
|
||||||
if %ERRORLEVEL% neq 0 (
|
|
||||||
echo Error: Failed to activate venv virtual environment.
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
goto env_activated
|
|
||||||
|
|
||||||
:env_activated
|
|
||||||
echo Environment activated successfully!
|
|
||||||
|
|
||||||
REM --- 后续脚本执行 ---
|
|
||||||
|
|
||||||
REM 运行预处理脚本
|
|
||||||
python "%~dp0scripts\raw_data_preprocessor.py"
|
|
||||||
if %ERRORLEVEL% neq 0 (
|
|
||||||
echo Error: raw_data_preprocessor.py execution failed.
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
|
|
||||||
REM 运行信息提取脚本
|
|
||||||
python "%~dp0scripts\info_extraction.py"
|
|
||||||
if %ERRORLEVEL% neq 0 (
|
|
||||||
echo Error: info_extraction.py execution failed.
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
|
|
||||||
REM 运行OpenIE导入脚本
|
|
||||||
python "%~dp0scripts\import_openie.py"
|
|
||||||
if %ERRORLEVEL% neq 0 (
|
|
||||||
echo Error: import_openie.py execution failed.
|
|
||||||
pause
|
|
||||||
exit /b 1
|
|
||||||
)
|
|
||||||
|
|
||||||
echo All processing steps completed!
|
|
||||||
pause
|
|
||||||
Reference in New Issue
Block a user