部分类型注解修复,优化import顺序,删除无用API文件
This commit is contained in:
@@ -1,26 +0,0 @@
|
|||||||
from src.chat.heart_flow.heartflow import heartflow
|
|
||||||
from src.chat.heart_flow.sub_heartflow import ChatState
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("api")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all_subheartflow_ids() -> list:
|
|
||||||
"""获取所有子心流的ID列表"""
|
|
||||||
all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows()
|
|
||||||
return [subheartflow.subheartflow_id for subheartflow in all_subheartflows]
|
|
||||||
|
|
||||||
|
|
||||||
async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool:
|
|
||||||
"""强制改变子心流的状态"""
|
|
||||||
subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id)
|
|
||||||
if subheartflow:
|
|
||||||
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all_states():
|
|
||||||
"""获取所有状态"""
|
|
||||||
all_states = await heartflow.api_get_all_states()
|
|
||||||
logger.debug(f"所有状态: {all_states}")
|
|
||||||
return all_states
|
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
import platform
|
|
||||||
import psutil
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_system_info():
|
|
||||||
"""获取操作系统信息"""
|
|
||||||
return {
|
|
||||||
"system": platform.system(),
|
|
||||||
"release": platform.release(),
|
|
||||||
"version": platform.version(),
|
|
||||||
"machine": platform.machine(),
|
|
||||||
"processor": platform.processor(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_python_version():
|
|
||||||
"""获取 Python 版本信息"""
|
|
||||||
return sys.version
|
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_usage():
|
|
||||||
"""获取系统总CPU使用率"""
|
|
||||||
return psutil.cpu_percent(interval=1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_process_cpu_usage():
|
|
||||||
"""获取当前进程CPU使用率"""
|
|
||||||
process = psutil.Process(os.getpid())
|
|
||||||
return process.cpu_percent(interval=1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
|
||||||
"""获取系统内存使用情况 (单位 MB)"""
|
|
||||||
mem = psutil.virtual_memory()
|
|
||||||
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
|
|
||||||
return {
|
|
||||||
"total_mb": bytes_to_mb(mem.total),
|
|
||||||
"available_mb": bytes_to_mb(mem.available),
|
|
||||||
"percent": mem.percent,
|
|
||||||
"used_mb": bytes_to_mb(mem.used),
|
|
||||||
"free_mb": bytes_to_mb(mem.free),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_process_memory_usage():
|
|
||||||
"""获取当前进程内存使用情况 (单位 MB)"""
|
|
||||||
process = psutil.Process(os.getpid())
|
|
||||||
mem_info = process.memory_info()
|
|
||||||
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
|
|
||||||
return {
|
|
||||||
"rss_mb": bytes_to_mb(mem_info.rss), # Resident Set Size: 实际使用物理内存
|
|
||||||
"vms_mb": bytes_to_mb(mem_info.vms), # Virtual Memory Size: 虚拟内存大小
|
|
||||||
"percent": process.memory_percent(), # 进程内存使用百分比
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_disk_usage(path="/"):
|
|
||||||
"""获取指定路径磁盘使用情况 (单位 GB)"""
|
|
||||||
disk = psutil.disk_usage(path)
|
|
||||||
bytes_to_gb = lambda x: round(x / (1024 * 1024 * 1024), 2) # noqa
|
|
||||||
return {
|
|
||||||
"total_gb": bytes_to_gb(disk.total),
|
|
||||||
"used_gb": bytes_to_gb(disk.used),
|
|
||||||
"free_gb": bytes_to_gb(disk.free),
|
|
||||||
"percent": disk.percent,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_basic_info():
|
|
||||||
"""获取所有基本信息并封装返回"""
|
|
||||||
# 对于进程CPU使用率,需要先初始化
|
|
||||||
process = psutil.Process(os.getpid())
|
|
||||||
process.cpu_percent(interval=None) # 初始化调用
|
|
||||||
process_cpu = process.cpu_percent(interval=0.1) # 短暂间隔获取
|
|
||||||
|
|
||||||
return {
|
|
||||||
"system_info": get_system_info(),
|
|
||||||
"python_version": get_python_version(),
|
|
||||||
"cpu_usage_percent": get_cpu_usage(),
|
|
||||||
"process_cpu_usage_percent": process_cpu,
|
|
||||||
"memory_usage": get_memory_usage(),
|
|
||||||
"process_memory_usage": get_process_memory_usage(),
|
|
||||||
"disk_usage_root": get_disk_usage("/"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_basic_info_string() -> str:
|
|
||||||
"""获取所有基本信息并以带解释的字符串形式返回"""
|
|
||||||
info = get_all_basic_info()
|
|
||||||
|
|
||||||
sys_info = info["system_info"]
|
|
||||||
mem_usage = info["memory_usage"]
|
|
||||||
proc_mem_usage = info["process_memory_usage"]
|
|
||||||
disk_usage = info["disk_usage_root"]
|
|
||||||
|
|
||||||
# 对进程内存使用百分比进行格式化,保留两位小数
|
|
||||||
proc_mem_percent = round(proc_mem_usage["percent"], 2)
|
|
||||||
|
|
||||||
output_string = f"""[系统信息]
|
|
||||||
- 操作系统: {sys_info["system"]} (例如: Windows, Linux)
|
|
||||||
- 发行版本: {sys_info["release"]} (例如: 11, Ubuntu 20.04)
|
|
||||||
- 详细版本: {sys_info["version"]}
|
|
||||||
- 硬件架构: {sys_info["machine"]} (例如: AMD64)
|
|
||||||
- 处理器信息: {sys_info["processor"]}
|
|
||||||
|
|
||||||
[Python 环境]
|
|
||||||
- Python 版本: {info["python_version"]}
|
|
||||||
|
|
||||||
[CPU 状态]
|
|
||||||
- 系统总 CPU 使用率: {info["cpu_usage_percent"]}%
|
|
||||||
- 当前进程 CPU 使用率: {info["process_cpu_usage_percent"]}%
|
|
||||||
|
|
||||||
[系统内存使用情况]
|
|
||||||
- 总物理内存: {mem_usage["total_mb"]} MB
|
|
||||||
- 可用物理内存: {mem_usage["available_mb"]} MB
|
|
||||||
- 物理内存使用率: {mem_usage["percent"]}%
|
|
||||||
- 已用物理内存: {mem_usage["used_mb"]} MB
|
|
||||||
- 空闲物理内存: {mem_usage["free_mb"]} MB
|
|
||||||
|
|
||||||
[当前进程内存使用情况]
|
|
||||||
- 实际使用物理内存 (RSS): {proc_mem_usage["rss_mb"]} MB
|
|
||||||
- 占用虚拟内存 (VMS): {proc_mem_usage["vms_mb"]} MB
|
|
||||||
- 进程内存使用率: {proc_mem_percent}%
|
|
||||||
|
|
||||||
[磁盘使用情况 (根目录)]
|
|
||||||
- 总空间: {disk_usage["total_gb"]} GB
|
|
||||||
- 已用空间: {disk_usage["used_gb"]} GB
|
|
||||||
- 可用空间: {disk_usage["free_gb"]} GB
|
|
||||||
- 磁盘使用率: {disk_usage["percent"]}%
|
|
||||||
"""
|
|
||||||
return output_string
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(f"System Info: {get_system_info()}")
|
|
||||||
print(f"Python Version: {get_python_version()}")
|
|
||||||
print(f"CPU Usage: {get_cpu_usage()}%")
|
|
||||||
# 第一次调用 process.cpu_percent() 会返回0.0或一个无意义的值,需要间隔一段时间再调用
|
|
||||||
# 或者在初始化Process对象后,先调用一次cpu_percent(interval=None),然后再调用cpu_percent(interval=1)
|
|
||||||
current_process = psutil.Process(os.getpid())
|
|
||||||
current_process.cpu_percent(interval=None) # 初始化
|
|
||||||
print(f"Process CPU Usage: {current_process.cpu_percent(interval=1)}%") # 实际获取
|
|
||||||
|
|
||||||
memory_usage_info = get_memory_usage()
|
|
||||||
print(
|
|
||||||
f"Memory Usage: Total={memory_usage_info['total_mb']}MB, Used={memory_usage_info['used_mb']}MB, Percent={memory_usage_info['percent']}%"
|
|
||||||
)
|
|
||||||
|
|
||||||
process_memory_info = get_process_memory_usage()
|
|
||||||
print(
|
|
||||||
f"Process Memory Usage: RSS={process_memory_info['rss_mb']}MB, VMS={process_memory_info['vms_mb']}MB, Percent={process_memory_info['percent']}%"
|
|
||||||
)
|
|
||||||
|
|
||||||
disk_usage_info = get_disk_usage("/")
|
|
||||||
print(
|
|
||||||
f"Disk Usage (Root): Total={disk_usage_info['total_gb']}GB, Used={disk_usage_info['used_gb']}GB, Percent={disk_usage_info['percent']}%"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\n--- All Basic Info (JSON) ---")
|
|
||||||
all_info = get_all_basic_info()
|
|
||||||
import json
|
|
||||||
|
|
||||||
print(json.dumps(all_info, indent=4, ensure_ascii=False))
|
|
||||||
|
|
||||||
print("\n--- All Basic Info (String with Explanations) ---")
|
|
||||||
info_string = get_all_basic_info_string()
|
|
||||||
print(info_string)
|
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
from typing import List, Optional, Dict, Any
|
|
||||||
import strawberry
|
|
||||||
|
|
||||||
# from packaging.version import Version
|
|
||||||
import os
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class APIBotConfig:
|
|
||||||
"""机器人配置类"""
|
|
||||||
|
|
||||||
INNER_VERSION: str # 配置文件内部版本号(toml为字符串)
|
|
||||||
MAI_VERSION: str # 硬编码的版本信息
|
|
||||||
|
|
||||||
# bot
|
|
||||||
BOT_QQ: Optional[int] # 机器人QQ号
|
|
||||||
BOT_NICKNAME: Optional[str] # 机器人昵称
|
|
||||||
BOT_ALIAS_NAMES: List[str] # 机器人别名列表
|
|
||||||
|
|
||||||
# group
|
|
||||||
talk_allowed_groups: List[int] # 允许回复消息的群号列表
|
|
||||||
talk_frequency_down_groups: List[int] # 降低回复频率的群号列表
|
|
||||||
ban_user_id: List[int] # 禁止回复和读取消息的QQ号列表
|
|
||||||
|
|
||||||
# personality
|
|
||||||
personality_core: str # 人格核心特点描述
|
|
||||||
personality_sides: List[str] # 人格细节描述列表
|
|
||||||
|
|
||||||
# identity
|
|
||||||
identity_detail: List[str] # 身份特点列表
|
|
||||||
age: int # 年龄(岁)
|
|
||||||
gender: str # 性别
|
|
||||||
appearance: str # 外貌特征描述
|
|
||||||
|
|
||||||
# platforms
|
|
||||||
platforms: Dict[str, str] # 平台信息
|
|
||||||
|
|
||||||
# chat
|
|
||||||
allow_focus_mode: bool # 是否允许专注聊天状态
|
|
||||||
base_normal_chat_num: int # 最多允许多少个群进行普通聊天
|
|
||||||
base_focused_chat_num: int # 最多允许多少个群进行专注聊天
|
|
||||||
observation_context_size: int # 观察到的最长上下文大小
|
|
||||||
message_buffer: bool # 是否启用消息缓冲
|
|
||||||
ban_words: List[str] # 禁止词列表
|
|
||||||
ban_msgs_regex: List[str] # 禁止消息的正则表达式列表
|
|
||||||
|
|
||||||
# normal_chat
|
|
||||||
model_reasoning_probability: float # 推理模型概率
|
|
||||||
model_normal_probability: float # 普通模型概率
|
|
||||||
emoji_chance: float # 表情符号出现概率
|
|
||||||
thinking_timeout: int # 思考超时时间
|
|
||||||
willing_mode: str # 意愿模式
|
|
||||||
response_interested_rate_amplifier: float # 回复兴趣率放大器
|
|
||||||
emoji_response_penalty: float # 表情回复惩罚
|
|
||||||
mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复
|
|
||||||
at_bot_inevitable_reply: bool # @bot 必然回复
|
|
||||||
|
|
||||||
# focus_chat
|
|
||||||
reply_trigger_threshold: float # 回复触发阈值
|
|
||||||
default_decay_rate_per_second: float # 默认每秒衰减率
|
|
||||||
|
|
||||||
# compressed
|
|
||||||
compressed_length: int # 压缩长度
|
|
||||||
compress_length_limit: int # 压缩长度限制
|
|
||||||
|
|
||||||
# emoji
|
|
||||||
max_emoji_num: int # 最大表情符号数量
|
|
||||||
max_reach_deletion: bool # 达到最大数量时是否删除
|
|
||||||
check_interval: int # 检查表情包的时间间隔(分钟)
|
|
||||||
save_emoji: bool # 是否保存表情包
|
|
||||||
steal_emoji: bool # 是否偷取表情包
|
|
||||||
enable_check: bool # 是否启用表情包过滤
|
|
||||||
check_prompt: str # 表情包过滤要求
|
|
||||||
|
|
||||||
# memory
|
|
||||||
build_memory_interval: int # 记忆构建间隔
|
|
||||||
build_memory_distribution: List[float] # 记忆构建分布
|
|
||||||
build_memory_sample_num: int # 采样数量
|
|
||||||
build_memory_sample_length: int # 采样长度
|
|
||||||
memory_compress_rate: float # 记忆压缩率
|
|
||||||
forget_memory_interval: int # 记忆遗忘间隔
|
|
||||||
memory_forget_time: int # 记忆遗忘时间(小时)
|
|
||||||
memory_forget_percentage: float # 记忆遗忘比例
|
|
||||||
consolidate_memory_interval: int # 记忆整合间隔
|
|
||||||
consolidation_similarity_threshold: float # 相似度阈值
|
|
||||||
consolidation_check_percentage: float # 检查节点比例
|
|
||||||
memory_ban_words: List[str] # 记忆禁止词列表
|
|
||||||
|
|
||||||
# mood
|
|
||||||
mood_update_interval: float # 情绪更新间隔
|
|
||||||
mood_decay_rate: float # 情绪衰减率
|
|
||||||
mood_intensity_factor: float # 情绪强度因子
|
|
||||||
|
|
||||||
# keywords_reaction
|
|
||||||
keywords_reaction_enable: bool # 是否启用关键词反应
|
|
||||||
keywords_reaction_rules: List[Dict[str, Any]] # 关键词反应规则
|
|
||||||
|
|
||||||
# chinese_typo
|
|
||||||
chinese_typo_enable: bool # 是否启用中文错别字
|
|
||||||
chinese_typo_error_rate: float # 中文错别字错误率
|
|
||||||
chinese_typo_min_freq: int # 中文错别字最小频率
|
|
||||||
chinese_typo_tone_error_rate: float # 中文错别字声调错误率
|
|
||||||
chinese_typo_word_replace_rate: float # 中文错别字单词替换率
|
|
||||||
|
|
||||||
# response_splitter
|
|
||||||
enable_response_splitter: bool # 是否启用回复分割器
|
|
||||||
response_max_length: int # 回复最大长度
|
|
||||||
response_max_sentence_num: int # 回复最大句子数
|
|
||||||
enable_kaomoji_protection: bool # 是否启用颜文字保护
|
|
||||||
|
|
||||||
model_max_output_length: int # 模型最大输出长度
|
|
||||||
|
|
||||||
# remote
|
|
||||||
remote_enable: bool # 是否启用远程功能
|
|
||||||
|
|
||||||
# experimental
|
|
||||||
enable_friend_chat: bool # 是否启用好友聊天
|
|
||||||
talk_allowed_private: List[int] # 允许私聊的QQ号列表
|
|
||||||
pfc_chatting: bool # 是否启用PFC聊天
|
|
||||||
|
|
||||||
# 模型配置
|
|
||||||
llm_reasoning: Dict[str, Any] # 推理模型配置
|
|
||||||
llm_normal: Dict[str, Any] # 普通模型配置
|
|
||||||
llm_topic_judge: Dict[str, Any] # 主题判断模型配置
|
|
||||||
summary: Dict[str, Any] # 总结模型配置
|
|
||||||
vlm: Dict[str, Any] # VLM模型配置
|
|
||||||
llm_heartflow: Dict[str, Any] # 心流模型配置
|
|
||||||
llm_observation: Dict[str, Any] # 观察模型配置
|
|
||||||
llm_sub_heartflow: Dict[str, Any] # 子心流模型配置
|
|
||||||
llm_plan: Optional[Dict[str, Any]] # 计划模型配置
|
|
||||||
embedding: Dict[str, Any] # 嵌入模型配置
|
|
||||||
llm_PFC_action_planner: Optional[Dict[str, Any]] # PFC行动计划模型配置
|
|
||||||
llm_PFC_chat: Optional[Dict[str, Any]] # PFC聊天模型配置
|
|
||||||
llm_PFC_reply_checker: Optional[Dict[str, Any]] # PFC回复检查模型配置
|
|
||||||
llm_tool_use: Optional[Dict[str, Any]] # 工具使用模型配置
|
|
||||||
|
|
||||||
api_urls: Optional[Dict[str, str]] # API地址配置
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_config(config: dict):
|
|
||||||
"""
|
|
||||||
校验传入的 toml 配置字典是否合法。
|
|
||||||
:param config: toml库load后的配置字典
|
|
||||||
:raises: ValueError, KeyError, TypeError
|
|
||||||
"""
|
|
||||||
# 检查主层级
|
|
||||||
required_sections = [
|
|
||||||
"inner",
|
|
||||||
"bot",
|
|
||||||
"groups",
|
|
||||||
"personality",
|
|
||||||
"identity",
|
|
||||||
"platforms",
|
|
||||||
"chat",
|
|
||||||
"normal_chat",
|
|
||||||
"focus_chat",
|
|
||||||
"emoji",
|
|
||||||
"memory",
|
|
||||||
"mood",
|
|
||||||
"keywords_reaction",
|
|
||||||
"chinese_typo",
|
|
||||||
"response_splitter",
|
|
||||||
"remote",
|
|
||||||
"experimental",
|
|
||||||
"model",
|
|
||||||
]
|
|
||||||
for section in required_sections:
|
|
||||||
if section not in config:
|
|
||||||
raise KeyError(f"缺少配置段: [{section}]")
|
|
||||||
|
|
||||||
# 检查部分关键字段
|
|
||||||
if "version" not in config["inner"]:
|
|
||||||
raise KeyError("缺少 inner.version 字段")
|
|
||||||
if not isinstance(config["inner"]["version"], str):
|
|
||||||
raise TypeError("inner.version 必须为字符串")
|
|
||||||
|
|
||||||
if "qq" not in config["bot"]:
|
|
||||||
raise KeyError("缺少 bot.qq 字段")
|
|
||||||
if not isinstance(config["bot"]["qq"], int):
|
|
||||||
raise TypeError("bot.qq 必须为整数")
|
|
||||||
|
|
||||||
if "personality_core" not in config["personality"]:
|
|
||||||
raise KeyError("缺少 personality.personality_core 字段")
|
|
||||||
if not isinstance(config["personality"]["personality_core"], str):
|
|
||||||
raise TypeError("personality.personality_core 必须为字符串")
|
|
||||||
|
|
||||||
if "identity_detail" not in config["identity"]:
|
|
||||||
raise KeyError("缺少 identity.identity_detail 字段")
|
|
||||||
if not isinstance(config["identity"]["identity_detail"], list):
|
|
||||||
raise TypeError("identity.identity_detail 必须为列表")
|
|
||||||
|
|
||||||
# 可继续添加更多字段的类型和值检查
|
|
||||||
# ...
|
|
||||||
|
|
||||||
# 检查模型配置
|
|
||||||
model_keys = [
|
|
||||||
"llm_reasoning",
|
|
||||||
"llm_normal",
|
|
||||||
"llm_topic_judge",
|
|
||||||
"summary",
|
|
||||||
"vlm",
|
|
||||||
"llm_heartflow",
|
|
||||||
"llm_observation",
|
|
||||||
"llm_sub_heartflow",
|
|
||||||
"embedding",
|
|
||||||
]
|
|
||||||
if "model" not in config:
|
|
||||||
raise KeyError("缺少 [model] 配置段")
|
|
||||||
for key in model_keys:
|
|
||||||
if key not in config["model"]:
|
|
||||||
raise KeyError(f"缺少 model.{key} 配置")
|
|
||||||
|
|
||||||
# 检查通过
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class APIEnvConfig:
|
|
||||||
"""环境变量配置"""
|
|
||||||
|
|
||||||
HOST: str # 服务主机地址
|
|
||||||
PORT: int # 服务端口
|
|
||||||
|
|
||||||
PLUGINS: List[str] # 插件列表
|
|
||||||
|
|
||||||
MONGODB_HOST: str # MongoDB 主机地址
|
|
||||||
MONGODB_PORT: int # MongoDB 端口
|
|
||||||
DATABASE_NAME: str # 数据库名称
|
|
||||||
|
|
||||||
CHAT_ANY_WHERE_BASE_URL: str # ChatAnywhere 基础URL
|
|
||||||
SILICONFLOW_BASE_URL: str # SiliconFlow 基础URL
|
|
||||||
DEEP_SEEK_BASE_URL: str # DeepSeek 基础URL
|
|
||||||
|
|
||||||
DEEP_SEEK_KEY: Optional[str] # DeepSeek API Key
|
|
||||||
CHAT_ANY_WHERE_KEY: Optional[str] # ChatAnywhere API Key
|
|
||||||
SILICONFLOW_KEY: Optional[str] # SiliconFlow API Key
|
|
||||||
|
|
||||||
SIMPLE_OUTPUT: Optional[bool] # 是否简化输出
|
|
||||||
CONSOLE_LOG_LEVEL: Optional[str] # 控制台日志等级
|
|
||||||
FILE_LOG_LEVEL: Optional[str] # 文件日志等级
|
|
||||||
DEFAULT_CONSOLE_LOG_LEVEL: Optional[str] # 默认控制台日志等级
|
|
||||||
DEFAULT_FILE_LOG_LEVEL: Optional[str] # 默认文件日志等级
|
|
||||||
|
|
||||||
@strawberry.field
|
|
||||||
def get_env(self) -> str:
|
|
||||||
return "env"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_config(config: dict):
|
|
||||||
"""
|
|
||||||
校验环境变量配置字典是否合法。
|
|
||||||
:param config: 环境变量配置字典
|
|
||||||
:raises: KeyError, TypeError
|
|
||||||
"""
|
|
||||||
required_fields = [
|
|
||||||
"HOST",
|
|
||||||
"PORT",
|
|
||||||
"PLUGINS",
|
|
||||||
"MONGODB_HOST",
|
|
||||||
"MONGODB_PORT",
|
|
||||||
"DATABASE_NAME",
|
|
||||||
"CHAT_ANY_WHERE_BASE_URL",
|
|
||||||
"SILICONFLOW_BASE_URL",
|
|
||||||
"DEEP_SEEK_BASE_URL",
|
|
||||||
]
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in config:
|
|
||||||
raise KeyError(f"缺少环境变量配置字段: {field}")
|
|
||||||
|
|
||||||
if not isinstance(config["HOST"], str):
|
|
||||||
raise TypeError("HOST 必须为字符串")
|
|
||||||
if not isinstance(config["PORT"], int):
|
|
||||||
raise TypeError("PORT 必须为整数")
|
|
||||||
if not isinstance(config["PLUGINS"], list):
|
|
||||||
raise TypeError("PLUGINS 必须为列表")
|
|
||||||
if not isinstance(config["MONGODB_HOST"], str):
|
|
||||||
raise TypeError("MONGODB_HOST 必须为字符串")
|
|
||||||
if not isinstance(config["MONGODB_PORT"], int):
|
|
||||||
raise TypeError("MONGODB_PORT 必须为整数")
|
|
||||||
if not isinstance(config["DATABASE_NAME"], str):
|
|
||||||
raise TypeError("DATABASE_NAME 必须为字符串")
|
|
||||||
if not isinstance(config["CHAT_ANY_WHERE_BASE_URL"], str):
|
|
||||||
raise TypeError("CHAT_ANY_WHERE_BASE_URL 必须为字符串")
|
|
||||||
if not isinstance(config["SILICONFLOW_BASE_URL"], str):
|
|
||||||
raise TypeError("SILICONFLOW_BASE_URL 必须为字符串")
|
|
||||||
if not isinstance(config["DEEP_SEEK_BASE_URL"], str):
|
|
||||||
raise TypeError("DEEP_SEEK_BASE_URL 必须为字符串")
|
|
||||||
|
|
||||||
# 可选字段类型检查
|
|
||||||
optional_str_fields = [
|
|
||||||
"DEEP_SEEK_KEY",
|
|
||||||
"CHAT_ANY_WHERE_KEY",
|
|
||||||
"SILICONFLOW_KEY",
|
|
||||||
"CONSOLE_LOG_LEVEL",
|
|
||||||
"FILE_LOG_LEVEL",
|
|
||||||
"DEFAULT_CONSOLE_LOG_LEVEL",
|
|
||||||
"DEFAULT_FILE_LOG_LEVEL",
|
|
||||||
]
|
|
||||||
for field in optional_str_fields:
|
|
||||||
if field in config and config[field] is not None and not isinstance(config[field], str):
|
|
||||||
raise TypeError(f"{field} 必须为字符串或None")
|
|
||||||
|
|
||||||
if (
|
|
||||||
"SIMPLE_OUTPUT" in config
|
|
||||||
and config["SIMPLE_OUTPUT"] is not None
|
|
||||||
and not isinstance(config["SIMPLE_OUTPUT"], bool)
|
|
||||||
):
|
|
||||||
raise TypeError("SIMPLE_OUTPUT 必须为布尔值或None")
|
|
||||||
|
|
||||||
# 检查通过
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
print("当前路径:")
|
|
||||||
print(ROOT_PATH)
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
import strawberry
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from strawberry.fastapi import GraphQLRouter
|
|
||||||
|
|
||||||
from src.common.server import get_global_server
|
|
||||||
|
|
||||||
|
|
||||||
@strawberry.type
|
|
||||||
class Query:
|
|
||||||
@strawberry.field
|
|
||||||
def hello(self) -> str:
|
|
||||||
return "Hello World"
|
|
||||||
|
|
||||||
|
|
||||||
schema = strawberry.Schema(Query)
|
|
||||||
|
|
||||||
graphql_app = GraphQLRouter(schema)
|
|
||||||
|
|
||||||
fast_api_app: FastAPI = get_global_server().get_app()
|
|
||||||
|
|
||||||
fast_api_app.include_router(graphql_app, prefix="/graphql")
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
pass
|
|
||||||
112
src/api/main.py
112
src/api/main.py
@@ -1,112 +0,0 @@
|
|||||||
from fastapi import APIRouter
|
|
||||||
from strawberry.fastapi import GraphQLRouter
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# from src.chat.heart_flow.heartflow import heartflow
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
|
||||||
# from src.config.config import BotConfig
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.api.reload_config import reload_config as reload_config_func
|
|
||||||
from src.common.server import get_global_server
|
|
||||||
from src.api.apiforgui import (
|
|
||||||
get_all_subheartflow_ids,
|
|
||||||
forced_change_subheartflow_status,
|
|
||||||
get_subheartflow_cycle_info,
|
|
||||||
get_all_states,
|
|
||||||
)
|
|
||||||
from src.chat.heart_flow.sub_heartflow import ChatState
|
|
||||||
from src.api.basic_info_api import get_all_basic_info # 新增导入
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("api")
|
|
||||||
|
|
||||||
logger.info("麦麦API服务器已启动")
|
|
||||||
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
|
|
||||||
|
|
||||||
router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/config/reload")
|
|
||||||
async def reload_config():
|
|
||||||
return await reload_config_func()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/gui/subheartflow/get/all")
|
|
||||||
async def get_subheartflow_ids():
|
|
||||||
"""获取所有子心流的ID列表"""
|
|
||||||
return await get_all_subheartflow_ids()
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/gui/subheartflow/forced_change_status")
|
|
||||||
async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): # noqa
|
|
||||||
"""强制改变子心流的状态"""
|
|
||||||
# 参数检查
|
|
||||||
if not isinstance(status, ChatState):
|
|
||||||
logger.warning(f"无效的状态参数: {status}")
|
|
||||||
return {"status": "failed", "reason": "invalid status"}
|
|
||||||
logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}")
|
|
||||||
success = await forced_change_subheartflow_status(subheartflow_id, status)
|
|
||||||
if success:
|
|
||||||
logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功")
|
|
||||||
return {"status": "success"}
|
|
||||||
else:
|
|
||||||
logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败")
|
|
||||||
return {"status": "failed"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stop")
|
|
||||||
async def force_stop_maibot():
|
|
||||||
"""强制停止MAI Bot"""
|
|
||||||
from bot import request_shutdown
|
|
||||||
|
|
||||||
success = await request_shutdown()
|
|
||||||
if success:
|
|
||||||
logger.info("MAI Bot已强制停止")
|
|
||||||
return {"status": "success"}
|
|
||||||
else:
|
|
||||||
logger.error("MAI Bot强制停止失败")
|
|
||||||
return {"status": "failed"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/gui/subheartflow/cycleinfo")
|
|
||||||
async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int):
|
|
||||||
"""获取子心流的循环信息"""
|
|
||||||
cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len)
|
|
||||||
if cycle_info:
|
|
||||||
return {"status": "success", "data": cycle_info}
|
|
||||||
else:
|
|
||||||
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
|
|
||||||
return {"status": "failed", "reason": "subheartflow not found"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/gui/get_all_states")
|
|
||||||
async def get_all_states_api():
|
|
||||||
"""获取所有状态"""
|
|
||||||
all_states = await get_all_states()
|
|
||||||
if all_states:
|
|
||||||
return {"status": "success", "data": all_states}
|
|
||||||
else:
|
|
||||||
logger.warning("获取所有状态失败")
|
|
||||||
return {"status": "failed", "reason": "failed to get all states"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/info")
|
|
||||||
async def get_system_basic_info():
|
|
||||||
"""获取系统基本信息"""
|
|
||||||
logger.info("请求系统基本信息")
|
|
||||||
try:
|
|
||||||
info = get_all_basic_info()
|
|
||||||
return {"status": "success", "data": info}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取系统基本信息失败: {e}")
|
|
||||||
return {"status": "failed", "reason": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
def start_api_server():
|
|
||||||
"""启动API服务器"""
|
|
||||||
get_global_server().register_router(router, prefix="/api/v1")
|
|
||||||
# pass
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from fastapi import HTTPException
|
|
||||||
from rich.traceback import install
|
|
||||||
from src.config.config import get_config_dir, load_config
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
import os
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
|
||||||
|
|
||||||
logger = get_logger("api")
|
|
||||||
|
|
||||||
|
|
||||||
async def reload_config():
|
|
||||||
try:
|
|
||||||
from src.config import config as config_module
|
|
||||||
|
|
||||||
logger.debug("正在重载配置文件...")
|
|
||||||
bot_config_path = os.path.join(get_config_dir(), "bot_config.toml")
|
|
||||||
config_module.global_config = load_config(config_path=bot_config_path)
|
|
||||||
logger.debug("配置文件重载成功")
|
|
||||||
return {"status": "reloaded"}
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e
|
|
||||||
@@ -5,20 +5,19 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, Tuple, List, Any
|
|
||||||
from PIL import Image
|
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
|
import binascii
|
||||||
# from gradio_client import file
|
from typing import Optional, Tuple, List, Any
|
||||||
|
from PIL import Image
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.database.database_model import Emoji
|
from src.common.database.database_model import Emoji
|
||||||
from src.common.database.database import db as peewee_db
|
from src.common.database.database import db as peewee_db
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger import get_logger
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -26,7 +25,7 @@ logger = get_logger("emoji")
|
|||||||
|
|
||||||
BASE_DIR = os.path.join("data")
|
BASE_DIR = os.path.join("data")
|
||||||
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
||||||
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -85,7 +84,7 @@ class MaiEmoji:
|
|||||||
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
|
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
|
||||||
try:
|
try:
|
||||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||||
self.format = img.format.lower()
|
self.format = img.format.lower() # type: ignore
|
||||||
logger.debug(f"[初始化] 格式获取成功: {self.format}")
|
logger.debug(f"[初始化] 格式获取成功: {self.format}")
|
||||||
except Exception as pil_error:
|
except Exception as pil_error:
|
||||||
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
|
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
|
||||||
@@ -100,7 +99,7 @@ class MaiEmoji:
|
|||||||
logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}")
|
logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}")
|
||||||
self.is_deleted = True
|
self.is_deleted = True
|
||||||
return None
|
return None
|
||||||
except base64.binascii.Error as b64_error:
|
except (binascii.Error, ValueError) as b64_error:
|
||||||
logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}")
|
logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}")
|
||||||
self.is_deleted = True
|
self.is_deleted = True
|
||||||
return None
|
return None
|
||||||
@@ -113,7 +112,7 @@ class MaiEmoji:
|
|||||||
async def register_to_db(self) -> bool:
|
async def register_to_db(self) -> bool:
|
||||||
"""
|
"""
|
||||||
注册表情包
|
注册表情包
|
||||||
将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下
|
将表情包对应的文件,从当前路径移动到EMOJI_REGISTERED_DIR目录下
|
||||||
并修改对应的实例属性,然后将表情包信息保存到数据库中
|
并修改对应的实例属性,然后将表情包信息保存到数据库中
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -122,7 +121,7 @@ class MaiEmoji:
|
|||||||
# 源路径是当前实例的完整路径 self.full_path
|
# 源路径是当前实例的完整路径 self.full_path
|
||||||
source_full_path = self.full_path
|
source_full_path = self.full_path
|
||||||
# 目标完整路径
|
# 目标完整路径
|
||||||
destination_full_path = os.path.join(EMOJI_REGISTED_DIR, self.filename)
|
destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename)
|
||||||
|
|
||||||
# 检查源文件是否存在
|
# 检查源文件是否存在
|
||||||
if not os.path.exists(source_full_path):
|
if not os.path.exists(source_full_path):
|
||||||
@@ -139,7 +138,7 @@ class MaiEmoji:
|
|||||||
logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}")
|
logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}")
|
||||||
# 更新实例的路径属性为新路径
|
# 更新实例的路径属性为新路径
|
||||||
self.full_path = destination_full_path
|
self.full_path = destination_full_path
|
||||||
self.path = EMOJI_REGISTED_DIR
|
self.path = EMOJI_REGISTERED_DIR
|
||||||
# self.filename 保持不变
|
# self.filename 保持不变
|
||||||
except Exception as move_error:
|
except Exception as move_error:
|
||||||
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
|
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
|
||||||
@@ -202,7 +201,7 @@ class MaiEmoji:
|
|||||||
try:
|
try:
|
||||||
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
||||||
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
||||||
except Emoji.DoesNotExist:
|
except Emoji.DoesNotExist: # type: ignore
|
||||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||||
result = 0 # Indicate no DB record was deleted
|
result = 0 # Indicate no DB record was deleted
|
||||||
|
|
||||||
@@ -298,7 +297,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
|||||||
def _ensure_emoji_dir() -> None:
|
def _ensure_emoji_dir() -> None:
|
||||||
"""确保表情存储目录存在"""
|
"""确保表情存储目录存在"""
|
||||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||||
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
|
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
async def clear_temp_emoji() -> None:
|
async def clear_temp_emoji() -> None:
|
||||||
@@ -331,10 +330,10 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
|||||||
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
||||||
return removed_count
|
return removed_count
|
||||||
|
|
||||||
|
cleaned_count = 0
|
||||||
try:
|
try:
|
||||||
# 获取内存中所有有效表情包的完整路径集合
|
# 获取内存中所有有效表情包的完整路径集合
|
||||||
tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted}
|
tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted}
|
||||||
cleaned_count = 0
|
|
||||||
|
|
||||||
# 遍历指定目录中的所有文件
|
# 遍历指定目录中的所有文件
|
||||||
for file_name in os.listdir(emoji_dir):
|
for file_name in os.listdir(emoji_dir):
|
||||||
@@ -358,11 +357,11 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
|||||||
else:
|
else:
|
||||||
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
||||||
|
|
||||||
return removed_count + cleaned_count
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
|
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
|
||||||
|
|
||||||
|
return removed_count + cleaned_count
|
||||||
|
|
||||||
|
|
||||||
class EmojiManager:
|
class EmojiManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -414,7 +413,7 @@ class EmojiManager:
|
|||||||
emoji_update.usage_count += 1
|
emoji_update.usage_count += 1
|
||||||
emoji_update.last_used_time = time.time() # Update last used time
|
emoji_update.last_used_time = time.time() # Update last used time
|
||||||
emoji_update.save() # Persist changes to DB
|
emoji_update.save() # Persist changes to DB
|
||||||
except Emoji.DoesNotExist:
|
except Emoji.DoesNotExist: # type: ignore
|
||||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录表情使用失败: {str(e)}")
|
logger.error(f"记录表情使用失败: {str(e)}")
|
||||||
@@ -570,8 +569,8 @@ class EmojiManager:
|
|||||||
if objects_to_remove:
|
if objects_to_remove:
|
||||||
self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove]
|
self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove]
|
||||||
|
|
||||||
# 清理 EMOJI_REGISTED_DIR 目录中未被追踪的文件
|
# 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件
|
||||||
removed_count = await clean_unused_emojis(EMOJI_REGISTED_DIR, self.emoji_objects, removed_count)
|
removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count)
|
||||||
|
|
||||||
# 输出清理结果
|
# 输出清理结果
|
||||||
if removed_count > 0:
|
if removed_count > 0:
|
||||||
@@ -850,11 +849,13 @@ class EmojiManager:
|
|||||||
if isinstance(image_base64, str):
|
if isinstance(image_base64, str):
|
||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
if image_format == "gif" or image_format == "GIF":
|
if image_format == "gif" or image_format == "GIF":
|
||||||
image_base64 = get_image_manager().transform_gif(image_base64)
|
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||||
|
if not image_base64:
|
||||||
|
raise RuntimeError("GIF表情包转换失败")
|
||||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from .exprssion_learner import get_expression_learner
|
|
||||||
import random
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
from json_repair import repair_json
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
|
|
||||||
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
|
from .exprssion_learner import get_expression_learner
|
||||||
|
|
||||||
logger = get_logger("expression_selector")
|
logger = get_logger("expression_selector")
|
||||||
|
|
||||||
@@ -165,7 +167,12 @@ class ExpressionSelector:
|
|||||||
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
||||||
|
|
||||||
async def select_suitable_expressions_llm(
|
async def select_suitable_expressions_llm(
|
||||||
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5, target_message: str = None
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
chat_info: str,
|
||||||
|
max_num: int = 10,
|
||||||
|
min_num: int = 5,
|
||||||
|
target_message: Optional[str] = None,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""使用LLM选择适合的表达方式"""
|
"""使用LLM选择适合的表达方式"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import List, Dict, Optional, Any, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
import os
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
MAX_EXPRESSION_COUNT = 300
|
MAX_EXPRESSION_COUNT = 300
|
||||||
@@ -74,7 +76,8 @@ class ExpressionLearner:
|
|||||||
)
|
)
|
||||||
self.llm_model = None
|
self.llm_model = None
|
||||||
|
|
||||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||||
|
# sourcery skip: extract-duplicate-method, remove-unnecessary-cast
|
||||||
"""
|
"""
|
||||||
获取指定chat_id的style和grammar表达方式
|
获取指定chat_id的style和grammar表达方式
|
||||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||||
@@ -119,10 +122,10 @@ class ExpressionLearner:
|
|||||||
min_len = min(len(s1), len(s2))
|
min_len = min(len(s1), len(s2))
|
||||||
if min_len < 5:
|
if min_len < 5:
|
||||||
return False
|
return False
|
||||||
same = sum(1 for a, b in zip(s1, s2) if a == b)
|
same = sum(a == b for a, b in zip(s1, s2))
|
||||||
return same / min_len > 0.8
|
return same / min_len > 0.8
|
||||||
|
|
||||||
async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]:
|
async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]:
|
||||||
"""
|
"""
|
||||||
学习并存储表达方式,分别学习语言风格和句法特点
|
学习并存储表达方式,分别学习语言风格和句法特点
|
||||||
同时对所有已存储的表达方式进行全局衰减
|
同时对所有已存储的表达方式进行全局衰减
|
||||||
@@ -158,12 +161,12 @@ class ExpressionLearner:
|
|||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
||||||
if not learnt_style:
|
if not learnt_style:
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
for _ in range(1):
|
for _ in range(1):
|
||||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
||||||
if not learnt_grammar:
|
if not learnt_grammar:
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
return learnt_style, learnt_grammar
|
return learnt_style, learnt_grammar
|
||||||
|
|
||||||
@@ -214,6 +217,7 @@ class ExpressionLearner:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||||
|
# sourcery skip: use-join
|
||||||
"""
|
"""
|
||||||
选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式
|
选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式
|
||||||
type: "style" or "grammar"
|
type: "style" or "grammar"
|
||||||
@@ -249,7 +253,7 @@ class ExpressionLearner:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# 按chat_id分组
|
# 按chat_id分组
|
||||||
chat_dict: Dict[str, List[Dict[str, str]]] = {}
|
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
for chat_id, situation, style in learnt_expressions:
|
for chat_id, situation, style in learnt_expressions:
|
||||||
if chat_id not in chat_dict:
|
if chat_id not in chat_dict:
|
||||||
chat_dict[chat_id] = []
|
chat_dict[chat_id] = []
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# 定义了来自外部世界的信息
|
# 定义了来自外部世界的信息
|
||||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.focus_chat.hfc_utils import CycleDetail
|
from src.chat.focus_chat.hfc_utils import CycleDetail
|
||||||
from typing import List
|
|
||||||
# Import the new utility function
|
|
||||||
|
|
||||||
logger = get_logger("loop_info")
|
logger = get_logger("loop_info")
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from rich.traceback import install
|
|||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.chat.planner_actions.planner import ActionPlanner
|
from src.chat.planner_actions.planner import ActionPlanner
|
||||||
@@ -49,7 +49,9 @@ class HeartFChatting:
|
|||||||
"""
|
"""
|
||||||
# 基础属性
|
# 基础属性
|
||||||
self.stream_id: str = chat_id # 聊天流ID
|
self.stream_id: str = chat_id # 聊天流ID
|
||||||
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
|
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
|
||||||
|
if not self.chat_stream:
|
||||||
|
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||||
|
|
||||||
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
||||||
@@ -171,7 +173,7 @@ class HeartFChatting:
|
|||||||
# 执行规划和处理阶段
|
# 执行规划和处理阶段
|
||||||
try:
|
try:
|
||||||
async with self._get_cycle_context():
|
async with self._get_cycle_context():
|
||||||
thinking_id = "tid" + str(round(time.time(), 2))
|
thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||||
self._current_cycle_detail.set_thinking_id(thinking_id)
|
self._current_cycle_detail.set_thinking_id(thinking_id)
|
||||||
|
|
||||||
# 使用异步上下文管理器处理消息
|
# 使用异步上下文管理器处理消息
|
||||||
@@ -245,7 +247,7 @@ class HeartFChatting:
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, "
|
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
|
||||||
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
|
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
|
||||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||||
)
|
)
|
||||||
@@ -256,7 +258,7 @@ class HeartFChatting:
|
|||||||
cycle_performance_data = {
|
cycle_performance_data = {
|
||||||
"cycle_id": self._current_cycle_detail.cycle_id,
|
"cycle_id": self._current_cycle_detail.cycle_id,
|
||||||
"action_type": action_result.get("action_type", "unknown"),
|
"action_type": action_result.get("action_type", "unknown"),
|
||||||
"total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time,
|
"total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time, # type: ignore
|
||||||
"step_times": cycle_timers.copy(),
|
"step_times": cycle_timers.copy(),
|
||||||
"reasoning": action_result.get("reasoning", ""),
|
"reasoning": action_result.get("reasoning", ""),
|
||||||
"success": self._current_cycle_detail.loop_action_info.get("action_taken", False),
|
"success": self._current_cycle_detail.loop_action_info.get("action_taken", False),
|
||||||
@@ -447,9 +449,6 @@ class HeartFChatting:
|
|||||||
|
|
||||||
# 处理动作并获取结果
|
# 处理动作并获取结果
|
||||||
result = await action_handler.handle_action()
|
result = await action_handler.handle_action()
|
||||||
if len(result) == 3:
|
|
||||||
success, reply_text, command = result
|
|
||||||
else:
|
|
||||||
success, reply_text = result
|
success, reply_text = result
|
||||||
command = ""
|
command = ""
|
||||||
|
|
||||||
@@ -478,8 +477,7 @@ class HeartFChatting:
|
|||||||
)
|
)
|
||||||
# 设置系统命令,在下次循环检查时触发退出
|
# 设置系统命令,在下次循环检查时触发退出
|
||||||
command = "stop_focus_chat"
|
command = "stop_focus_chat"
|
||||||
else:
|
elif reply_text == "timeout":
|
||||||
if reply_text == "timeout":
|
|
||||||
self.reply_timeout_count += 1
|
self.reply_timeout_count += 1
|
||||||
if self.reply_timeout_count > 5:
|
if self.reply_timeout_count > 5:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("hfc_performance")
|
logger = get_logger("hfc_performance")
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Optional
|
import json
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.chat.message_receive.message import UserInfo
|
from src.chat.message_receive.message import UserInfo
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
import json
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -117,7 +118,7 @@ async def create_empty_anchor_message(
|
|||||||
placeholder_msg_info = BaseMessageInfo(
|
placeholder_msg_info = BaseMessageInfo(
|
||||||
message_id=placeholder_id,
|
message_id=placeholder_id,
|
||||||
platform=platform,
|
platform=platform,
|
||||||
group_info=group_info,
|
group_info=group_info, # type: ignore
|
||||||
user_info=placeholder_user,
|
user_info=placeholder_user,
|
||||||
time=time.time(),
|
time=time.time(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
from typing import Any, Optional, Dict
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from typing import Any, Optional
|
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||||
from typing import Dict
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
logger = get_logger("heartflow")
|
logger = get_logger("heartflow")
|
||||||
@@ -34,7 +34,7 @@ class Heartflow:
|
|||||||
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
|
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> None:
|
async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> bool:
|
||||||
"""强制改变子心流的状态"""
|
"""强制改变子心流的状态"""
|
||||||
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
|
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
|
||||||
return await self.force_change_state(subheartflow_id, status)
|
return await self.force_change_state(subheartflow_id, status)
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
|
||||||
from src.config.config import global_config
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
|
import math
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.heart_flow.heartflow import heartflow
|
from src.chat.heart_flow.heartflow import heartflow
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
import re
|
|
||||||
import math
|
|
||||||
import traceback
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_manager import get_relationship_manager
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
|
|
||||||
|
|
||||||
@@ -26,16 +26,16 @@ async def _process_relationship(message: MessageRecv) -> None:
|
|||||||
message: 消息对象,包含用户信息
|
message: 消息对象,包含用户信息
|
||||||
"""
|
"""
|
||||||
platform = message.message_info.platform
|
platform = message.message_info.platform
|
||||||
user_id = message.message_info.user_info.user_id
|
user_id = message.message_info.user_info.user_id # type: ignore
|
||||||
nickname = message.message_info.user_info.user_nickname
|
nickname = message.message_info.user_info.user_nickname # type: ignore
|
||||||
cardname = message.message_info.user_info.user_cardname or nickname
|
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
|
||||||
|
|
||||||
relationship_manager = get_relationship_manager()
|
relationship_manager = get_relationship_manager()
|
||||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||||
|
|
||||||
if not is_known:
|
if not is_known:
|
||||||
logger.info(f"首次认识用户: {nickname}")
|
logger.info(f"首次认识用户: {nickname}")
|
||||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
|
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||||
@@ -105,9 +105,9 @@ class HeartFCMessageReceiver:
|
|||||||
|
|
||||||
# 2. 兴趣度计算与更新
|
# 2. 兴趣度计算与更新
|
||||||
interested_rate, is_mentioned = await _calculate_interest(message)
|
interested_rate, is_mentioned = await _calculate_interest(message)
|
||||||
subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) # type: ignore
|
||||||
|
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore
|
||||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||||
|
|
||||||
# 3. 日志记录
|
# 3. 日志记录
|
||||||
@@ -119,7 +119,7 @@ class HeartFCMessageReceiver:
|
|||||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||||
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
||||||
|
|
||||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
|
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
||||||
|
|
||||||
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Optional, List, Dict, Tuple
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Tuple
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.focus_chat.heartFC_chat import HeartFChatting
|
from src.chat.focus_chat.heartFC_chat import HeartFChatting
|
||||||
from src.chat.normal_chat.normal_chat import NormalChat
|
from src.chat.normal_chat.normal_chat import NormalChat
|
||||||
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
|
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
from src.config.config import global_config
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
logger = get_logger("sub_heartflow")
|
logger = get_logger("sub_heartflow")
|
||||||
|
|
||||||
@@ -40,7 +42,7 @@ class SubHeartflow:
|
|||||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||||
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||||
# 兴趣消息集合
|
# 兴趣消息集合
|
||||||
self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {}
|
self.interest_dict: Dict[str, Tuple[MessageRecv, float, bool]] = {}
|
||||||
|
|
||||||
# focus模式退出冷却时间管理
|
# focus模式退出冷却时间管理
|
||||||
self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间
|
self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间
|
||||||
@@ -297,7 +299,7 @@ class SubHeartflow:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
|
def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
|
||||||
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned)
|
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) # type: ignore
|
||||||
# 如果字典长度超过10,删除最旧的消息
|
# 如果字典长度超过10,删除最旧的消息
|
||||||
if len(self.interest_dict) > 30:
|
if len(self.interest_dict) > 30:
|
||||||
oldest_key = next(iter(self.interest_dict))
|
oldest_key = next(iter(self.interest_dict))
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ def calculate_information_content(text):
|
|||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else
|
||||||
"""计算余弦相似度"""
|
"""计算余弦相似度"""
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
norm1 = np.linalg.norm(v1)
|
norm1 = np.linalg.norm(v1)
|
||||||
@@ -89,13 +89,12 @@ class MemoryGraph:
|
|||||||
if not isinstance(self.G.nodes[concept]["memory_items"], list):
|
if not isinstance(self.G.nodes[concept]["memory_items"], list):
|
||||||
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
|
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
|
||||||
self.G.nodes[concept]["memory_items"].append(memory)
|
self.G.nodes[concept]["memory_items"].append(memory)
|
||||||
# 更新最后修改时间
|
|
||||||
self.G.nodes[concept]["last_modified"] = current_time
|
|
||||||
else:
|
else:
|
||||||
self.G.nodes[concept]["memory_items"] = [memory]
|
self.G.nodes[concept]["memory_items"] = [memory]
|
||||||
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
|
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
|
||||||
if "created_time" not in self.G.nodes[concept]:
|
if "created_time" not in self.G.nodes[concept]:
|
||||||
self.G.nodes[concept]["created_time"] = current_time
|
self.G.nodes[concept]["created_time"] = current_time
|
||||||
|
# 更新最后修改时间
|
||||||
self.G.nodes[concept]["last_modified"] = current_time
|
self.G.nodes[concept]["last_modified"] = current_time
|
||||||
else:
|
else:
|
||||||
# 如果是新节点,创建新的记忆列表
|
# 如果是新节点,创建新的记忆列表
|
||||||
@@ -108,11 +107,7 @@ class MemoryGraph:
|
|||||||
|
|
||||||
def get_dot(self, concept):
|
def get_dot(self, concept):
|
||||||
# 检查节点是否存在于图中
|
# 检查节点是否存在于图中
|
||||||
if concept in self.G:
|
return (concept, self.G.nodes[concept]) if concept in self.G else None
|
||||||
# 从图中获取节点数据
|
|
||||||
node_data = self.G.nodes[concept]
|
|
||||||
return concept, node_data
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_related_item(self, topic, depth=1):
|
def get_related_item(self, topic, depth=1):
|
||||||
if topic not in self.G:
|
if topic not in self.G:
|
||||||
@@ -139,8 +134,7 @@ class MemoryGraph:
|
|||||||
if depth >= 2:
|
if depth >= 2:
|
||||||
# 获取相邻节点的记忆项
|
# 获取相邻节点的记忆项
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
node_data = self.get_dot(neighbor)
|
if node_data := self.get_dot(neighbor):
|
||||||
if node_data:
|
|
||||||
concept, data = node_data
|
concept, data = node_data
|
||||||
if "memory_items" in data:
|
if "memory_items" in data:
|
||||||
memory_items = data["memory_items"]
|
memory_items = data["memory_items"]
|
||||||
@@ -194,9 +188,9 @@ class MemoryGraph:
|
|||||||
class Hippocampus:
|
class Hippocampus:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.memory_graph = MemoryGraph()
|
self.memory_graph = MemoryGraph()
|
||||||
self.model_summary = None
|
self.model_summary: LLMRequest = None # type: ignore
|
||||||
self.entorhinal_cortex = None
|
self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
|
||||||
self.parahippocampal_gyrus = None
|
self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
# 初始化子组件
|
# 初始化子组件
|
||||||
@@ -218,7 +212,7 @@ class Hippocampus:
|
|||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
# 使用集合来去重,避免排序
|
# 使用集合来去重,避免排序
|
||||||
unique_items = set(str(item) for item in memory_items)
|
unique_items = {str(item) for item in memory_items}
|
||||||
# 使用frozenset来保证顺序一致性
|
# 使用frozenset来保证顺序一致性
|
||||||
content = f"{concept}:{frozenset(unique_items)}"
|
content = f"{concept}:{frozenset(unique_items)}"
|
||||||
return hash(content)
|
return hash(content)
|
||||||
@@ -231,6 +225,7 @@ class Hippocampus:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_topic_llm(text, topic_num):
|
def find_topic_llm(text, topic_num):
|
||||||
|
# sourcery skip: inline-immediately-returned-variable
|
||||||
prompt = (
|
prompt = (
|
||||||
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
||||||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||||
@@ -240,6 +235,7 @@ class Hippocampus:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def topic_what(text, topic):
|
def topic_what(text, topic):
|
||||||
|
# sourcery skip: inline-immediately-returned-variable
|
||||||
# 不再需要 time_info 参数
|
# 不再需要 time_info 参数
|
||||||
prompt = (
|
prompt = (
|
||||||
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
|
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
|
||||||
@@ -480,9 +476,7 @@ class Hippocampus:
|
|||||||
top_memories = memory_similarities[:max_memory_length]
|
top_memories = memory_similarities[:max_memory_length]
|
||||||
|
|
||||||
# 添加到结果中
|
# 添加到结果中
|
||||||
for memory, similarity in top_memories:
|
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
|
||||||
all_memories.append((node, [memory], similarity))
|
|
||||||
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
|
|
||||||
else:
|
else:
|
||||||
logger.info("节点没有记忆")
|
logger.info("节点没有记忆")
|
||||||
|
|
||||||
@@ -646,9 +640,7 @@ class Hippocampus:
|
|||||||
top_memories = memory_similarities[:max_memory_length]
|
top_memories = memory_similarities[:max_memory_length]
|
||||||
|
|
||||||
# 添加到结果中
|
# 添加到结果中
|
||||||
for memory, similarity in top_memories:
|
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
|
||||||
all_memories.append((node, [memory], similarity))
|
|
||||||
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
|
|
||||||
else:
|
else:
|
||||||
logger.info("节点没有记忆")
|
logger.info("节点没有记忆")
|
||||||
|
|
||||||
@@ -823,11 +815,11 @@ class EntorhinalCortex:
|
|||||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
# 调用修改后的 random_get_msg_snippet
|
if messages := self.random_get_msg_snippet(
|
||||||
messages = self.random_get_msg_snippet(
|
timestamp,
|
||||||
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
|
global_config.memory.memory_build_sample_length,
|
||||||
)
|
max_memorized_time_per_msg,
|
||||||
if messages:
|
):
|
||||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||||
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||||
chat_samples.append(messages)
|
chat_samples.append(messages)
|
||||||
@@ -838,6 +830,7 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
||||||
|
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||||
try_count = 0
|
try_count = 0
|
||||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||||
@@ -847,22 +840,21 @@ class EntorhinalCortex:
|
|||||||
timestamp_start = target_timestamp
|
timestamp_start = target_timestamp
|
||||||
timestamp_end = target_timestamp + time_window_seconds
|
timestamp_end = target_timestamp + time_window_seconds
|
||||||
|
|
||||||
chosen_message = get_raw_msg_by_timestamp(
|
if chosen_message := get_raw_msg_by_timestamp(
|
||||||
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
|
timestamp_start=timestamp_start,
|
||||||
)
|
timestamp_end=timestamp_end,
|
||||||
|
limit=1,
|
||||||
|
limit_mode="earliest",
|
||||||
|
):
|
||||||
|
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
|
||||||
|
|
||||||
if chosen_message:
|
if messages := get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id = chosen_message[0].get("chat_id")
|
|
||||||
|
|
||||||
messages = get_raw_msg_by_timestamp_with_chat(
|
|
||||||
timestamp_start=timestamp_start,
|
timestamp_start=timestamp_start,
|
||||||
timestamp_end=timestamp_end,
|
timestamp_end=timestamp_end,
|
||||||
limit=chat_size,
|
limit=chat_size,
|
||||||
limit_mode="earliest",
|
limit_mode="earliest",
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
)
|
):
|
||||||
|
|
||||||
if messages:
|
|
||||||
# 检查获取到的所有消息是否都未达到最大记忆次数
|
# 检查获取到的所有消息是否都未达到最大记忆次数
|
||||||
all_valid = True
|
all_valid = True
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -975,7 +967,7 @@ class EntorhinalCortex:
|
|||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
if nodes_to_delete:
|
if nodes_to_delete:
|
||||||
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute()
|
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
|
||||||
|
|
||||||
# 处理边的信息
|
# 处理边的信息
|
||||||
db_edges = list(GraphEdges.select())
|
db_edges = list(GraphEdges.select())
|
||||||
@@ -1114,7 +1106,7 @@ class EntorhinalCortex:
|
|||||||
node_start = time.time()
|
node_start = time.time()
|
||||||
if nodes_data:
|
if nodes_data:
|
||||||
batch_size = 500 # 增加批量大小
|
batch_size = 500 # 增加批量大小
|
||||||
with GraphNodes._meta.database.atomic():
|
with GraphNodes._meta.database.atomic(): # type: ignore
|
||||||
for i in range(0, len(nodes_data), batch_size):
|
for i in range(0, len(nodes_data), batch_size):
|
||||||
batch = nodes_data[i : i + batch_size]
|
batch = nodes_data[i : i + batch_size]
|
||||||
GraphNodes.insert_many(batch).execute()
|
GraphNodes.insert_many(batch).execute()
|
||||||
@@ -1125,7 +1117,7 @@ class EntorhinalCortex:
|
|||||||
edge_start = time.time()
|
edge_start = time.time()
|
||||||
if edges_data:
|
if edges_data:
|
||||||
batch_size = 500 # 增加批量大小
|
batch_size = 500 # 增加批量大小
|
||||||
with GraphEdges._meta.database.atomic():
|
with GraphEdges._meta.database.atomic(): # type: ignore
|
||||||
for i in range(0, len(edges_data), batch_size):
|
for i in range(0, len(edges_data), batch_size):
|
||||||
batch = edges_data[i : i + batch_size]
|
batch = edges_data[i : i + batch_size]
|
||||||
GraphEdges.insert_many(batch).execute()
|
GraphEdges.insert_many(batch).execute()
|
||||||
@@ -1489,9 +1481,7 @@ class ParahippocampalGyrus:
|
|||||||
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
||||||
last_modified = node_data.get("last_modified", current_time)
|
last_modified = node_data.get("last_modified", current_time)
|
||||||
# 条件1:检查是否长时间未修改 (超过24小时)
|
# 条件1:检查是否长时间未修改 (超过24小时)
|
||||||
if current_time - last_modified > 3600 * 24:
|
if current_time - last_modified > 3600 * 24 and memory_items:
|
||||||
# 条件2:再次确认节点包含记忆项(理论上已确认,但作为保险)
|
|
||||||
if memory_items:
|
|
||||||
current_count = len(memory_items)
|
current_count = len(memory_items)
|
||||||
# 如果列表非空,才进行随机选择
|
# 如果列表非空,才进行随机选择
|
||||||
if current_count > 0:
|
if current_count > 0:
|
||||||
@@ -1669,7 +1659,7 @@ class ParahippocampalGyrus:
|
|||||||
|
|
||||||
class HippocampusManager:
|
class HippocampusManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._hippocampus = None
|
self._hippocampus: Hippocampus = None # type: ignore
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from json_repair import repair_json
|
|||||||
logger = get_logger("memory_activator")
|
logger = get_logger("memory_activator")
|
||||||
|
|
||||||
|
|
||||||
def get_keywords_from_json(json_str):
|
def get_keywords_from_json(json_str) -> List:
|
||||||
"""
|
"""
|
||||||
从JSON字符串中提取关键词列表
|
从JSON字符串中提取关键词列表
|
||||||
|
|
||||||
@@ -28,15 +28,8 @@ def get_keywords_from_json(json_str):
|
|||||||
fixed_json = repair_json(json_str)
|
fixed_json = repair_json(json_str)
|
||||||
|
|
||||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||||
if isinstance(fixed_json, str):
|
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
|
||||||
result = json.loads(fixed_json)
|
return result.get("keywords", [])
|
||||||
else:
|
|
||||||
# 如果repair_json直接返回了字典对象,直接使用
|
|
||||||
result = fixed_json
|
|
||||||
|
|
||||||
# 提取关键词
|
|
||||||
keywords = result.get("keywords", [])
|
|
||||||
return keywords
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解析关键词JSON失败: {e}")
|
logger.error(f"解析关键词JSON失败: {e}")
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -1,52 +1,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import stats
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
class DistributionVisualizer:
|
|
||||||
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
|
|
||||||
"""
|
|
||||||
初始化分布可视化器
|
|
||||||
|
|
||||||
参数:
|
|
||||||
mean (float): 期望均值
|
|
||||||
std (float): 标准差
|
|
||||||
skewness (float): 偏度
|
|
||||||
sample_size (int): 样本大小
|
|
||||||
"""
|
|
||||||
self.mean = mean
|
|
||||||
self.std = std
|
|
||||||
self.skewness = skewness
|
|
||||||
self.sample_size = sample_size
|
|
||||||
self.samples = None
|
|
||||||
|
|
||||||
def generate_samples(self):
|
|
||||||
"""生成具有指定参数的样本"""
|
|
||||||
if self.skewness == 0:
|
|
||||||
# 对于无偏度的情况,直接使用正态分布
|
|
||||||
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
|
|
||||||
else:
|
|
||||||
# 使用 scipy.stats 生成具有偏度的分布
|
|
||||||
self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size)
|
|
||||||
|
|
||||||
def get_weighted_samples(self):
|
|
||||||
"""获取加权后的样本数列"""
|
|
||||||
if self.samples is None:
|
|
||||||
self.generate_samples()
|
|
||||||
# 将样本值乘以样本大小
|
|
||||||
return self.samples * self.sample_size
|
|
||||||
|
|
||||||
def get_statistics(self):
|
|
||||||
"""获取分布的统计信息"""
|
|
||||||
if self.samples is None:
|
|
||||||
self.generate_samples()
|
|
||||||
|
|
||||||
return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)}
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBuildScheduler:
|
class MemoryBuildScheduler:
|
||||||
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,23 +1,25 @@
|
|||||||
import traceback
|
import traceback
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
from maim_message import UserInfo
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.experimental.only_message_process import MessageProcessor
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.experimental.PFC.pfc_manager import PFCManager
|
|
||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.config.config import global_config
|
from src.experimental.only_message_process import MessageProcessor
|
||||||
|
from src.experimental.PFC.pfc_manager import PFCManager
|
||||||
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
|
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
from src.plugin_system.base.base_command import BaseCommand
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
from maim_message import UserInfo
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
import re
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||||
@@ -184,8 +186,8 @@ class ChatBot:
|
|||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform=message.message_info.platform,
|
platform=message.message_info.platform, # type: ignore
|
||||||
user_info=user_info,
|
user_info=user_info, # type: ignore
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -195,8 +197,10 @@ class ChatBot:
|
|||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
# 过滤检查
|
# 过滤检查
|
||||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex(
|
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||||
message.raw_message, chat, user_info
|
message.raw_message, # type: ignore
|
||||||
|
chat,
|
||||||
|
user_info, # type: ignore
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,17 @@ import hashlib
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, Optional, TYPE_CHECKING
|
from typing import Dict, Optional, TYPE_CHECKING
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
from ...common.database.database import db
|
|
||||||
from ...common.database.database_model import ChatStreams # 新增导入
|
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database import db
|
||||||
|
from src.common.database.database_model import ChatStreams # 新增导入
|
||||||
|
|
||||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .message import MessageRecv
|
from .message import MessageRecv
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -28,7 +27,7 @@ class ChatMessageContext:
|
|||||||
def __init__(self, message: "MessageRecv"):
|
def __init__(self, message: "MessageRecv"):
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
def get_template_name(self) -> str:
|
def get_template_name(self) -> Optional[str]:
|
||||||
"""获取模板名称"""
|
"""获取模板名称"""
|
||||||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||||
return self.message.message_info.template_info.template_name
|
return self.message.message_info.template_info.template_name
|
||||||
@@ -41,10 +40,10 @@ class ChatMessageContext:
|
|||||||
def check_types(self, types: list) -> bool:
|
def check_types(self, types: list) -> bool:
|
||||||
# sourcery skip: invert-any-all, use-any, use-next
|
# sourcery skip: invert-any-all, use-any, use-next
|
||||||
"""检查消息类型"""
|
"""检查消息类型"""
|
||||||
if not self.message.message_info.format_info.accept_format:
|
if not self.message.message_info.format_info.accept_format: # type: ignore
|
||||||
return False
|
return False
|
||||||
for t in types:
|
for t in types:
|
||||||
if t not in self.message.message_info.format_info.accept_format:
|
if t not in self.message.message_info.format_info.accept_format: # type: ignore
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -68,7 +67,7 @@ class ChatStream:
|
|||||||
platform: str,
|
platform: str,
|
||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
group_info: Optional[GroupInfo] = None,
|
group_info: Optional[GroupInfo] = None,
|
||||||
data: dict = None,
|
data: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
@@ -77,7 +76,7 @@ class ChatStream:
|
|||||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||||
self.saved = False
|
self.saved = False
|
||||||
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
|
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
@@ -99,7 +98,7 @@ class ChatStream:
|
|||||||
return cls(
|
return cls(
|
||||||
stream_id=data["stream_id"],
|
stream_id=data["stream_id"],
|
||||||
platform=data["platform"],
|
platform=data["platform"],
|
||||||
user_info=user_info,
|
user_info=user_info, # type: ignore
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
@@ -163,8 +162,8 @@ class ChatManager:
|
|||||||
def register_message(self, message: "MessageRecv"):
|
def register_message(self, message: "MessageRecv"):
|
||||||
"""注册消息到聊天流"""
|
"""注册消息到聊天流"""
|
||||||
stream_id = self._generate_stream_id(
|
stream_id = self._generate_stream_id(
|
||||||
message.message_info.platform,
|
message.message_info.platform, # type: ignore
|
||||||
message.message_info.user_info,
|
message.message_info.user_info, # type: ignore
|
||||||
message.message_info.group_info,
|
message.message_info.group_info,
|
||||||
)
|
)
|
||||||
self.last_messages[stream_id] = message
|
self.last_messages[stream_id] = message
|
||||||
@@ -185,10 +184,7 @@ class ChatManager:
|
|||||||
|
|
||||||
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
|
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
|
||||||
"""获取聊天流ID"""
|
"""获取聊天流ID"""
|
||||||
if is_group:
|
components = [platform, id] if is_group else [platform, id, "private"]
|
||||||
components = [platform, str(id)]
|
|
||||||
else:
|
|
||||||
components = [platform, str(id), "private"]
|
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Any, TYPE_CHECKING
|
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .chat_stream import ChatStream
|
|
||||||
from ..utils.utils_image import get_image_manager
|
|
||||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from typing import Optional, Any
|
||||||
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.utils.utils_image import get_image_manager
|
||||||
|
from .chat_stream import ChatStream
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Message(MessageBase):
|
class Message(MessageBase):
|
||||||
chat_stream: "ChatStream" = None
|
chat_stream: "ChatStream" = None # type: ignore
|
||||||
reply: Optional["Message"] = None
|
reply: Optional["Message"] = None
|
||||||
processed_plain_text: str = ""
|
processed_plain_text: str = ""
|
||||||
memorized_times: int = 0
|
memorized_times: int = 0
|
||||||
@@ -55,7 +53,7 @@ class Message(MessageBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None)
|
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
|
||||||
|
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
# 文本处理相关属性
|
# 文本处理相关属性
|
||||||
@@ -66,6 +64,7 @@ class Message(MessageBase):
|
|||||||
self.reply = reply
|
self.reply = reply
|
||||||
|
|
||||||
async def _process_message_segments(self, segment: Seg) -> str:
|
async def _process_message_segments(self, segment: Seg) -> str:
|
||||||
|
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
|
||||||
"""递归处理消息段,转换为文字描述
|
"""递归处理消息段,转换为文字描述
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -78,13 +77,13 @@ class Message(MessageBase):
|
|||||||
# 处理消息段列表
|
# 处理消息段列表
|
||||||
segments_text = []
|
segments_text = []
|
||||||
for seg in segment.data:
|
for seg in segment.data:
|
||||||
processed = await self._process_message_segments(seg)
|
processed = await self._process_message_segments(seg) # type: ignore
|
||||||
if processed:
|
if processed:
|
||||||
segments_text.append(processed)
|
segments_text.append(processed)
|
||||||
return " ".join(segments_text)
|
return " ".join(segments_text)
|
||||||
else:
|
else:
|
||||||
# 处理单个消息段
|
# 处理单个消息段
|
||||||
return await self._process_single_segment(segment)
|
return await self._process_single_segment(segment) # type: ignore
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _process_single_segment(self, segment):
|
async def _process_single_segment(self, segment):
|
||||||
@@ -138,7 +137,7 @@ class MessageRecv(Message):
|
|||||||
if segment.type == "text":
|
if segment.type == "text":
|
||||||
self.is_picid = False
|
self.is_picid = False
|
||||||
self.is_emoji = False
|
self.is_emoji = False
|
||||||
return segment.data
|
return segment.data # type: ignore
|
||||||
elif segment.type == "image":
|
elif segment.type == "image":
|
||||||
# 如果是base64图片数据
|
# 如果是base64图片数据
|
||||||
if isinstance(segment.data, str):
|
if isinstance(segment.data, str):
|
||||||
@@ -160,7 +159,7 @@ class MessageRecv(Message):
|
|||||||
elif segment.type == "mention_bot":
|
elif segment.type == "mention_bot":
|
||||||
self.is_picid = False
|
self.is_picid = False
|
||||||
self.is_emoji = False
|
self.is_emoji = False
|
||||||
self.is_mentioned = float(segment.data)
|
self.is_mentioned = float(segment.data) # type: ignore
|
||||||
return ""
|
return ""
|
||||||
elif segment.type == "priority_info":
|
elif segment.type == "priority_info":
|
||||||
self.is_picid = False
|
self.is_picid = False
|
||||||
@@ -186,7 +185,7 @@ class MessageRecv(Message):
|
|||||||
"""生成详细文本,包含时间和用户信息"""
|
"""生成详细文本,包含时间和用户信息"""
|
||||||
timestamp = self.message_info.time
|
timestamp = self.message_info.time
|
||||||
user_info = self.message_info.user_info
|
user_info = self.message_info.user_info
|
||||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
|
||||||
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
||||||
|
|
||||||
|
|
||||||
@@ -234,7 +233,7 @@ class MessageProcessBase(Message):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if seg.type == "text":
|
if seg.type == "text":
|
||||||
return seg.data
|
return seg.data # type: ignore
|
||||||
elif seg.type == "image":
|
elif seg.type == "image":
|
||||||
# 如果是base64图片数据
|
# 如果是base64图片数据
|
||||||
if isinstance(seg.data, str):
|
if isinstance(seg.data, str):
|
||||||
@@ -250,7 +249,7 @@ class MessageProcessBase(Message):
|
|||||||
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
||||||
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
||||||
# print(f"reply: {self.reply}")
|
# print(f"reply: {self.reply}")
|
||||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]"
|
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return f"[{seg.type}:{str(seg.data)}]"
|
return f"[{seg.type}:{str(seg.data)}]"
|
||||||
@@ -264,7 +263,7 @@ class MessageProcessBase(Message):
|
|||||||
timestamp = self.message_info.time
|
timestamp = self.message_info.time
|
||||||
user_info = self.message_info.user_info
|
user_info = self.message_info.user_info
|
||||||
|
|
||||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
|
||||||
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n"
|
||||||
|
|
||||||
|
|
||||||
@@ -313,7 +312,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
is_emoji: bool = False,
|
is_emoji: bool = False,
|
||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
apply_set_reply_logic: bool = False,
|
apply_set_reply_logic: bool = False,
|
||||||
reply_to: str = None,
|
reply_to: str = None, # type: ignore
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -344,7 +343,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
self.message_segment = Seg(
|
self.message_segment = Seg(
|
||||||
type="seglist",
|
type="seglist",
|
||||||
data=[
|
data=[
|
||||||
Seg(type="reply", data=self.reply.message_info.message_id),
|
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
|
||||||
self.message_segment,
|
self.message_segment,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -364,10 +363,10 @@ class MessageSending(MessageProcessBase):
|
|||||||
) -> "MessageSending":
|
) -> "MessageSending":
|
||||||
"""从思考状态消息创建发送状态消息"""
|
"""从思考状态消息创建发送状态消息"""
|
||||||
return cls(
|
return cls(
|
||||||
message_id=thinking.message_info.message_id,
|
message_id=thinking.message_info.message_id, # type: ignore
|
||||||
chat_stream=thinking.chat_stream,
|
chat_stream=thinking.chat_stream,
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
bot_user_info=thinking.message_info.user_info,
|
bot_user_info=thinking.message_info.user_info, # type: ignore
|
||||||
reply=thinking.reply,
|
reply=thinking.reply,
|
||||||
is_head=is_head,
|
is_head=is_head,
|
||||||
is_emoji=is_emoji,
|
is_emoji=is_emoji,
|
||||||
@@ -399,13 +398,11 @@ class MessageSet:
|
|||||||
if not isinstance(message, MessageSending):
|
if not isinstance(message, MessageSending):
|
||||||
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
||||||
self.messages.append(message)
|
self.messages.append(message)
|
||||||
self.messages.sort(key=lambda x: x.message_info.time)
|
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||||
|
|
||||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||||
"""通过索引获取消息"""
|
"""通过索引获取消息"""
|
||||||
if 0 <= index < len(self.messages):
|
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||||
return self.messages[index]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||||
"""获取最接近指定时间的消息"""
|
"""获取最接近指定时间的消息"""
|
||||||
@@ -415,7 +412,7 @@ class MessageSet:
|
|||||||
left, right = 0, len(self.messages) - 1
|
left, right = 0, len(self.messages) - 1
|
||||||
while left < right:
|
while left < right:
|
||||||
mid = (left + right) // 2
|
mid = (left + right) // 2
|
||||||
if self.messages[mid].message_info.time < target_time:
|
if self.messages[mid].message_info.time < target_time: # type: ignore
|
||||||
left = mid + 1
|
left = mid + 1
|
||||||
else:
|
else:
|
||||||
right = mid
|
right = mid
|
||||||
|
|||||||
@@ -1,21 +1,16 @@
|
|||||||
# src/plugins/chat/message_sender.py
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from asyncio import Task
|
from asyncio import Task
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from src.common.message.api import get_global_api
|
|
||||||
|
|
||||||
# from ...common.database import db # 数据库依赖似乎不需要了,注释掉
|
|
||||||
from .message import MessageSending, MessageThinking, MessageSet
|
|
||||||
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
|
||||||
from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
from src.common.logger import get_logger
|
||||||
|
from src.common.message.api import get_global_api
|
||||||
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
|
from src.chat.utils.utils import truncate_message, calculate_typing_time, count_messages_between
|
||||||
|
from .message import MessageSending, MessageThinking, MessageSet
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("sender")
|
logger = get_logger("sender")
|
||||||
|
|
||||||
@@ -79,9 +74,10 @@ class MessageContainer:
|
|||||||
|
|
||||||
def count_thinking_messages(self) -> int:
|
def count_thinking_messages(self) -> int:
|
||||||
"""计算当前容器中思考消息的数量"""
|
"""计算当前容器中思考消息的数量"""
|
||||||
return sum(1 for msg in self.messages if isinstance(msg, MessageThinking))
|
return sum(isinstance(msg, MessageThinking) for msg in self.messages)
|
||||||
|
|
||||||
def get_timeout_sending_messages(self) -> list[MessageSending]:
|
def get_timeout_sending_messages(self) -> list[MessageSending]:
|
||||||
|
# sourcery skip: merge-nested-ifs
|
||||||
"""获取所有超时的MessageSending对象(思考时间超过20秒),按thinking_start_time排序 - 从旧 sender 合并"""
|
"""获取所有超时的MessageSending对象(思考时间超过20秒),按thinking_start_time排序 - 从旧 sender 合并"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
timeout_messages = []
|
timeout_messages = []
|
||||||
@@ -230,9 +226,7 @@ class MessageManager:
|
|||||||
f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}"
|
f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}"
|
||||||
)
|
)
|
||||||
logger.exception("详细错误信息:")
|
logger.exception("详细错误信息:")
|
||||||
# 考虑是否移除出错的消息,防止无限循环
|
if container.remove_message(message):
|
||||||
removed = container.remove_message(message)
|
|
||||||
if removed:
|
|
||||||
logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。")
|
logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。")
|
||||||
|
|
||||||
async def _process_chat_messages(self, chat_id: str):
|
async def _process_chat_messages(self, chat_id: str):
|
||||||
@@ -261,10 +255,7 @@ class MessageManager:
|
|||||||
# --- 处理发送消息 ---
|
# --- 处理发送消息 ---
|
||||||
await self._handle_sending_message(container, message_earliest)
|
await self._handle_sending_message(container, message_earliest)
|
||||||
|
|
||||||
# --- 处理超时发送消息 (来自旧 sender) ---
|
if timeout_sending_messages := container.get_timeout_sending_messages():
|
||||||
# 在处理完最早的消息后,检查是否有超时的发送消息
|
|
||||||
timeout_sending_messages = container.get_timeout_sending_messages()
|
|
||||||
if timeout_sending_messages:
|
|
||||||
logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息")
|
logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息")
|
||||||
for msg in timeout_sending_messages:
|
for msg in timeout_sending_messages:
|
||||||
# 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一)
|
# 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一)
|
||||||
@@ -274,6 +265,7 @@ class MessageManager:
|
|||||||
await self._handle_sending_message(container, msg) # 复用处理逻辑
|
await self._handle_sending_message(container, msg) # 复用处理逻辑
|
||||||
|
|
||||||
async def _start_processor_loop(self):
|
async def _start_processor_loop(self):
|
||||||
|
# sourcery skip: list-comprehension, move-assign-in-block, use-named-expression
|
||||||
"""消息处理器主循环"""
|
"""消息处理器主循环"""
|
||||||
while self._running:
|
while self._running:
|
||||||
tasks = []
|
tasks = []
|
||||||
@@ -282,10 +274,7 @@ class MessageManager:
|
|||||||
# 创建 keys 的快照以安全迭代
|
# 创建 keys 的快照以安全迭代
|
||||||
chat_ids = list(self.containers.keys())
|
chat_ids = list(self.containers.keys())
|
||||||
|
|
||||||
for chat_id in chat_ids:
|
tasks.extend(asyncio.create_task(self._process_chat_messages(chat_id)) for chat_id in chat_ids)
|
||||||
# 为每个 chat_id 创建一个处理任务
|
|
||||||
tasks.append(asyncio.create_task(self._process_chat_messages(chat_id)))
|
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
try:
|
try:
|
||||||
# 等待当前批次的所有任务完成
|
# 等待当前批次的所有任务完成
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
|
from src.common.database.database_model import Messages, RecalledMessages, Images
|
||||||
from .message import MessageSending, MessageRecv
|
|
||||||
from .chat_stream import ChatStream
|
|
||||||
from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from .chat_stream import ChatStream
|
||||||
|
from .message import MessageSending, MessageRecv
|
||||||
|
|
||||||
logger = get_logger("message_storage")
|
logger = get_logger("message_storage")
|
||||||
|
|
||||||
@@ -44,7 +43,7 @@ class MessageStorage:
|
|||||||
reply_to = ""
|
reply_to = ""
|
||||||
|
|
||||||
chat_info_dict = chat_stream.to_dict()
|
chat_info_dict = chat_stream.to_dict()
|
||||||
user_info_dict = message.message_info.user_info.to_dict()
|
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||||
|
|
||||||
# message_id 现在是 TextField,直接使用字符串值
|
# message_id 现在是 TextField,直接使用字符串值
|
||||||
msg_id = message.message_info.message_id
|
msg_id = message.message_info.message_id
|
||||||
@@ -56,7 +55,7 @@ class MessageStorage:
|
|||||||
|
|
||||||
Messages.create(
|
Messages.create(
|
||||||
message_id=msg_id,
|
message_id=msg_id,
|
||||||
time=float(message.message_info.time),
|
time=float(message.message_info.time), # type: ignore
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
# Flattened chat_info
|
# Flattened chat_info
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
@@ -103,7 +102,7 @@ class MessageStorage:
|
|||||||
try:
|
try:
|
||||||
# Assuming input 'time' is a string timestamp that can be converted to float
|
# Assuming input 'time' is a string timestamp that can be converted to float
|
||||||
current_time_float = float(time)
|
current_time_float = float(time)
|
||||||
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
|
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("删除撤回消息失败")
|
logger.exception("删除撤回消息失败")
|
||||||
|
|
||||||
@@ -115,22 +114,19 @@ class MessageStorage:
|
|||||||
"""更新最新一条匹配消息的message_id"""
|
"""更新最新一条匹配消息的message_id"""
|
||||||
try:
|
try:
|
||||||
if message.message_segment.type == "notify":
|
if message.message_segment.type == "notify":
|
||||||
mmc_message_id = message.message_segment.data.get("echo")
|
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
|
||||||
qq_message_id = message.message_segment.data.get("actual_id")
|
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
|
||||||
else:
|
else:
|
||||||
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
||||||
return
|
return
|
||||||
if not qq_message_id:
|
if not qq_message_id:
|
||||||
logger.info("消息不存在message_id,无法更新")
|
logger.info("消息不存在message_id,无法更新")
|
||||||
return
|
return
|
||||||
# 查询最新一条匹配消息
|
if matched_message := (
|
||||||
matched_message = (
|
|
||||||
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
||||||
)
|
):
|
||||||
|
|
||||||
if matched_message:
|
|
||||||
# 更新找到的消息记录
|
# 更新找到的消息记录
|
||||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute()
|
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
|
||||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||||
else:
|
else:
|
||||||
logger.debug("未找到匹配的消息")
|
logger.debug("未找到匹配的消息")
|
||||||
@@ -155,10 +151,7 @@ class MessageStorage:
|
|||||||
image_record = (
|
image_record = (
|
||||||
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
||||||
)
|
)
|
||||||
if image_record:
|
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||||
return f"[picid:{image_record.image_id}]"
|
|
||||||
else:
|
|
||||||
return match.group(0) # 保持原样
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return match.group(0)
|
return match.group(0)
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, Optional # 重新导入类型
|
|
||||||
from src.chat.message_receive.message import MessageSending, MessageThinking
|
|
||||||
from src.common.message.api import get_global_api
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
|
||||||
from src.chat.utils.utils import truncate_message
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.utils.utils import calculate_typing_time
|
|
||||||
from rich.traceback import install
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
install(extra_lines=3)
|
from typing import Dict, Optional
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from src.common.message.api import get_global_api
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.message_receive.message import MessageSending, MessageThinking
|
||||||
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
|
from src.chat.utils.utils import truncate_message
|
||||||
|
from src.chat.utils.utils import calculate_typing_time
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("sender")
|
logger = get_logger("sender")
|
||||||
|
|
||||||
@@ -86,10 +87,10 @@ class HeartFCSender:
|
|||||||
"""
|
"""
|
||||||
if not message.chat_stream:
|
if not message.chat_stream:
|
||||||
logger.error("消息缺少 chat_stream,无法发送")
|
logger.error("消息缺少 chat_stream,无法发送")
|
||||||
raise Exception("消息缺少 chat_stream,无法发送")
|
raise ValueError("消息缺少 chat_stream,无法发送")
|
||||||
if not message.message_info or not message.message_info.message_id:
|
if not message.message_info or not message.message_info.message_id:
|
||||||
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
||||||
raise Exception("消息缺少 message_info 或 message_id,无法发送")
|
raise ValueError("消息缺少 message_info 或 message_id,无法发送")
|
||||||
|
|
||||||
chat_id = message.chat_stream.stream_id
|
chat_id = message.chat_stream.stream_id
|
||||||
message_id = message.message_info.message_id
|
message_id = message.message_info.message_id
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from random import random
|
from random import random
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
from maim_message import UserInfo, Seg
|
from maim_message import UserInfo, Seg
|
||||||
@@ -40,7 +41,7 @@ class NormalChat:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
interest_dict: dict = None,
|
interest_dict: Optional[Dict] = None,
|
||||||
on_switch_to_focus_callback=None,
|
on_switch_to_focus_callback=None,
|
||||||
get_cooldown_progress_callback=None,
|
get_cooldown_progress_callback=None,
|
||||||
):
|
):
|
||||||
@@ -147,10 +148,7 @@ class NormalChat:
|
|||||||
while not self._disabled:
|
while not self._disabled:
|
||||||
try:
|
try:
|
||||||
if not self.priority_manager.is_empty():
|
if not self.priority_manager.is_empty():
|
||||||
# 获取最高优先级的消息
|
if message := self.priority_manager.get_highest_priority_message():
|
||||||
message = self.priority_manager.get_highest_priority_message()
|
|
||||||
|
|
||||||
if message:
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.stream_name}] 从队列中取出消息进行处理: User {message.message_info.user_info.user_id}, Time: {time.strftime('%H:%M:%S', time.localtime(message.message_info.time))}"
|
f"[{self.stream_name}] 从队列中取出消息进行处理: User {message.message_info.user_info.user_id}, Time: {time.strftime('%H:%M:%S', time.localtime(message.message_info.time))}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class PriorityManager:
|
|||||||
"""
|
"""
|
||||||
添加新消息到合适的队列中。
|
添加新消息到合适的队列中。
|
||||||
"""
|
"""
|
||||||
user_id = message.message_info.user_info.user_id
|
user_id = message.message_info.user_info.user_id # type: ignore
|
||||||
is_vip = message.priority_info.get("message_type") == "vip" if message.priority_info else False
|
is_vip = message.priority_info.get("message_type") == "vip" if message.priority_info else False
|
||||||
message_priority = message.priority_info.get("message_priority", 0.0) if message.priority_info else 0.0
|
message_priority = message.priority_info.get("message_priority", 0.0) if message.priority_info else 0.0
|
||||||
|
|
||||||
|
|||||||
@@ -35,9 +35,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
|||||||
|
|
||||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||||
|
|
||||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
|
return min(max((current_willing - 0.5), 0.01) * 2, 1)
|
||||||
|
|
||||||
return reply_probability
|
|
||||||
|
|
||||||
async def before_generate_reply_handle(self, message_id):
|
async def before_generate_reply_handle(self, message_id):
|
||||||
chat_id = self.ongoing_messages[message_id].chat_id
|
chat_id = self.ongoing_messages[message_id].chat_id
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from src.common.logger import get_logger
|
import importlib
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from rich.traceback import install
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
import importlib
|
|
||||||
from typing import Dict, Optional
|
|
||||||
import asyncio
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -92,8 +94,8 @@ class BaseWillingManager(ABC):
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
||||||
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
|
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
|
||||||
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
|
self.ongoing_messages[message.message_info.message_id] = WillingInfo( # type: ignore
|
||||||
message=message,
|
message=message,
|
||||||
chat=chat,
|
chat=chat,
|
||||||
person_info_manager=get_person_info_manager(),
|
person_info_manager=get_person_info_manager(),
|
||||||
|
|||||||
@@ -27,14 +27,11 @@ class ActionManager:
|
|||||||
# 当前正在使用的动作集合,默认加载默认动作
|
# 当前正在使用的动作集合,默认加载默认动作
|
||||||
self._using_actions: Dict[str, ActionInfo] = {}
|
self._using_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
# 默认动作集,仅作为快照,用于恢复默认
|
|
||||||
self._default_actions: Dict[str, ActionInfo] = {}
|
|
||||||
|
|
||||||
# 加载插件动作
|
# 加载插件动作
|
||||||
self._load_plugin_actions()
|
self._load_plugin_actions()
|
||||||
|
|
||||||
# 初始化时将默认动作加载到使用中的动作
|
# 初始化时将默认动作加载到使用中的动作
|
||||||
self._using_actions = self._default_actions.copy()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
|
|
||||||
def _load_plugin_actions(self) -> None:
|
def _load_plugin_actions(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -52,7 +49,7 @@ class ActionManager:
|
|||||||
"""从插件系统的component_registry加载Action组件"""
|
"""从插件系统的component_registry加载Action组件"""
|
||||||
try:
|
try:
|
||||||
# 获取所有Action组件
|
# 获取所有Action组件
|
||||||
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION)
|
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
|
||||||
|
|
||||||
for action_name, action_info in action_components.items():
|
for action_name, action_info in action_components.items():
|
||||||
if action_name in self._registered_actions:
|
if action_name in self._registered_actions:
|
||||||
@@ -61,10 +58,6 @@ class ActionManager:
|
|||||||
|
|
||||||
self._registered_actions[action_name] = action_info
|
self._registered_actions[action_name] = action_info
|
||||||
|
|
||||||
# 如果启用,也添加到默认动作集
|
|
||||||
if action_info.enabled:
|
|
||||||
self._default_actions[action_name] = action_info
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
||||||
)
|
)
|
||||||
@@ -106,7 +99,9 @@ class ActionManager:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取组件类 - 明确指定查询Action类型
|
# 获取组件类 - 明确指定查询Action类型
|
||||||
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
component_class: Type[BaseAction] = component_registry.get_component_class(
|
||||||
|
action_name, ComponentType.ACTION
|
||||||
|
) # type: ignore
|
||||||
if not component_class:
|
if not component_class:
|
||||||
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
|
||||||
return None
|
return None
|
||||||
@@ -146,10 +141,6 @@ class ActionManager:
|
|||||||
"""获取所有已注册的动作集"""
|
"""获取所有已注册的动作集"""
|
||||||
return self._registered_actions.copy()
|
return self._registered_actions.copy()
|
||||||
|
|
||||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
|
||||||
"""获取默认动作集"""
|
|
||||||
return self._default_actions.copy()
|
|
||||||
|
|
||||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||||
"""获取当前正在使用的动作集合"""
|
"""获取当前正在使用的动作集合"""
|
||||||
return self._using_actions.copy()
|
return self._using_actions.copy()
|
||||||
@@ -217,31 +208,31 @@ class ActionManager:
|
|||||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
||||||
"""
|
# """
|
||||||
添加新的动作到注册集
|
# 添加新的动作到注册集
|
||||||
|
|
||||||
Args:
|
# Args:
|
||||||
action_name: 动作名称
|
# action_name: 动作名称
|
||||||
description: 动作描述
|
# description: 动作描述
|
||||||
parameters: 动作参数定义,默认为空字典
|
# parameters: 动作参数定义,默认为空字典
|
||||||
require: 动作依赖项,默认为空列表
|
# require: 动作依赖项,默认为空列表
|
||||||
|
|
||||||
Returns:
|
# Returns:
|
||||||
bool: 添加是否成功
|
# bool: 添加是否成功
|
||||||
"""
|
# """
|
||||||
if action_name in self._registered_actions:
|
# if action_name in self._registered_actions:
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
if parameters is None:
|
# if parameters is None:
|
||||||
parameters = {}
|
# parameters = {}
|
||||||
if require is None:
|
# if require is None:
|
||||||
require = []
|
# require = []
|
||||||
|
|
||||||
action_info = {"description": description, "parameters": parameters, "require": require}
|
# action_info = {"description": description, "parameters": parameters, "require": require}
|
||||||
|
|
||||||
self._registered_actions[action_name] = action_info
|
# self._registered_actions[action_name] = action_info
|
||||||
return True
|
# return True
|
||||||
|
|
||||||
def remove_action(self, action_name: str) -> bool:
|
def remove_action(self, action_name: str) -> bool:
|
||||||
"""从注册集移除指定动作"""
|
"""从注册集移除指定动作"""
|
||||||
@@ -260,10 +251,9 @@ class ActionManager:
|
|||||||
|
|
||||||
def restore_actions(self) -> None:
|
def restore_actions(self) -> None:
|
||||||
"""恢复到默认动作集"""
|
"""恢复到默认动作集"""
|
||||||
logger.debug(
|
actions_to_restore = list(self._using_actions.keys())
|
||||||
f"恢复动作集: 从 {list(self._using_actions.keys())} 恢复到默认动作集 {list(self._default_actions.keys())}"
|
self._using_actions = component_registry.get_default_actions()
|
||||||
)
|
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||||
self._using_actions = self._default_actions.copy()
|
|
||||||
|
|
||||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -293,4 +283,4 @@ class ActionManager:
|
|||||||
"""
|
"""
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
return component_registry.get_component_class(action_name)
|
return component_registry.get_component_class(action_name) # type: ignore
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import random
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import List, Any, Dict
|
from typing import List, Any, Dict, TYPE_CHECKING
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -13,6 +13,9 @@ from src.chat.planner_actions.action_manager import ActionManager
|
|||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||||
from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType
|
from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
|
|
||||||
@@ -27,7 +30,7 @@ class ActionModifier:
|
|||||||
def __init__(self, action_manager: ActionManager, chat_id: str):
|
def __init__(self, action_manager: ActionManager, chat_id: str):
|
||||||
"""初始化动作处理器"""
|
"""初始化动作处理器"""
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
|
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
|
||||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||||
|
|
||||||
self.action_manager = action_manager
|
self.action_manager = action_manager
|
||||||
@@ -142,7 +145,7 @@ class ActionModifier:
|
|||||||
async def _get_deactivated_actions_by_type(
|
async def _get_deactivated_actions_by_type(
|
||||||
self,
|
self,
|
||||||
actions_with_info: Dict[str, ActionInfo],
|
actions_with_info: Dict[str, ActionInfo],
|
||||||
mode: str = "focus",
|
mode: ChatMode = ChatMode.FOCUS,
|
||||||
chat_content: str = "",
|
chat_content: str = "",
|
||||||
) -> List[tuple[str, str]]:
|
) -> List[tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
@@ -270,7 +273,7 @@ class ActionModifier:
|
|||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 处理结果并更新缓存
|
# 处理结果并更新缓存
|
||||||
for _, (action_name, result) in enumerate(zip(task_names, task_results)):
|
for action_name, result in zip(task_names, task_results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||||
results[action_name] = False
|
results[action_name] = False
|
||||||
@@ -286,7 +289,7 @@ class ActionModifier:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
||||||
# 如果并行执行失败,为所有任务返回False
|
# 如果并行执行失败,为所有任务返回False
|
||||||
for action_name in tasks_to_run.keys():
|
for action_name in tasks_to_run:
|
||||||
results[action_name] = False
|
results[action_name] = False
|
||||||
|
|
||||||
# 清理过期缓存
|
# 清理过期缓存
|
||||||
@@ -297,10 +300,11 @@ class ActionModifier:
|
|||||||
def _cleanup_expired_cache(self, current_time: float):
|
def _cleanup_expired_cache(self, current_time: float):
|
||||||
"""清理过期的缓存条目"""
|
"""清理过期的缓存条目"""
|
||||||
expired_keys = []
|
expired_keys = []
|
||||||
for cache_key, cache_data in self._llm_judge_cache.items():
|
expired_keys.extend(
|
||||||
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
|
cache_key
|
||||||
expired_keys.append(cache_key)
|
for cache_key, cache_data in self._llm_judge_cache.items()
|
||||||
|
if current_time - cache_data["timestamp"] > self._cache_expiry_time
|
||||||
|
)
|
||||||
for key in expired_keys:
|
for key in expired_keys:
|
||||||
del self._llm_judge_cache[key]
|
del self._llm_judge_cache[key]
|
||||||
|
|
||||||
@@ -379,7 +383,7 @@ class ActionModifier:
|
|||||||
def _check_keyword_activation(
|
def _check_keyword_activation(
|
||||||
self,
|
self,
|
||||||
action_name: str,
|
action_name: str,
|
||||||
action_info: Dict[str, Any],
|
action_info: ActionInfo,
|
||||||
chat_content: str = "",
|
chat_content: str = "",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -396,8 +400,8 @@ class ActionModifier:
|
|||||||
bool: 是否应该激活此action
|
bool: 是否应该激活此action
|
||||||
"""
|
"""
|
||||||
|
|
||||||
activation_keywords = action_info.get("activation_keywords", [])
|
activation_keywords = action_info.activation_keywords
|
||||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
case_sensitive = action_info.keyword_case_sensitive
|
||||||
|
|
||||||
if not activation_keywords:
|
if not activation_keywords:
|
||||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class ActionPlanner:
|
|||||||
|
|
||||||
self.last_obs_time_mark = 0.0
|
self.last_obs_time_mark = 0.0
|
||||||
|
|
||||||
async def plan(self) -> Dict[str, Any]:
|
async def plan(self) -> Dict[str, Any]: # sourcery skip: dict-comprehension
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
@@ -162,7 +162,6 @@ class ActionPlanner:
|
|||||||
reasoning = parsed_json.get("reasoning", "未提供原因")
|
reasoning = parsed_json.get("reasoning", "未提供原因")
|
||||||
|
|
||||||
# 将所有其他属性添加到action_data
|
# 将所有其他属性添加到action_data
|
||||||
action_data = {}
|
|
||||||
for key, value in parsed_json.items():
|
for key, value in parsed_json.items():
|
||||||
if key not in ["action", "reasoning"]:
|
if key not in ["action", "reasoning"]:
|
||||||
action_data[key] = value
|
action_data[key] = value
|
||||||
@@ -285,7 +284,7 @@ class ActionPlanner:
|
|||||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||||
|
|
||||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||||
prompt = planner_prompt_template.format(
|
return planner_prompt_template.format(
|
||||||
time_block=time_block,
|
time_block=time_block,
|
||||||
by_what=by_what,
|
by_what=by_what,
|
||||||
chat_context_description=chat_context_description,
|
chat_context_description=chat_context_description,
|
||||||
@@ -295,8 +294,6 @@ class ActionPlanner:
|
|||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
identity_block=identity_block,
|
identity_block=identity_block,
|
||||||
)
|
)
|
||||||
return prompt
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|||||||
@@ -130,9 +130,7 @@ class DefaultReplyer:
|
|||||||
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
|
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
|
||||||
weights = [config.get("weight", 1.0) for config in configs]
|
weights = [config.get("weight", 1.0) for config in configs]
|
||||||
|
|
||||||
# random.choices 返回一个列表,我们取第一个元素
|
return random.choices(population=configs, weights=weights, k=1)[0]
|
||||||
selected_config = random.choices(population=configs, weights=weights, k=1)[0]
|
|
||||||
return selected_config
|
|
||||||
|
|
||||||
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
|
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
|
||||||
"""创建思考消息 (尝试锚定到 anchor_message)"""
|
"""创建思考消息 (尝试锚定到 anchor_message)"""
|
||||||
@@ -314,8 +312,7 @@ class DefaultReplyer:
|
|||||||
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取")
|
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取")
|
||||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||||
|
|
||||||
relation_info = await relationship_fetcher.build_relation_info(person_id, text, chat_history)
|
return await relationship_fetcher.build_relation_info(person_id, text, chat_history)
|
||||||
return relation_info
|
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_history, target):
|
async def build_expression_habits(self, chat_history, target):
|
||||||
if not global_config.expression.enable_expression:
|
if not global_config.expression.enable_expression:
|
||||||
@@ -363,15 +360,13 @@ class DefaultReplyer:
|
|||||||
target_message=target, chat_history_prompt=chat_history
|
target_message=target, chat_history_prompt=chat_history
|
||||||
)
|
)
|
||||||
|
|
||||||
if running_memories:
|
if not running_memories:
|
||||||
|
return ""
|
||||||
|
|
||||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||||
for running_memory in running_memories:
|
for running_memory in running_memories:
|
||||||
memory_str += f"- {running_memory['content']}\n"
|
memory_str += f"- {running_memory['content']}\n"
|
||||||
memory_block = memory_str
|
return memory_str
|
||||||
else:
|
|
||||||
memory_block = ""
|
|
||||||
|
|
||||||
return memory_block
|
|
||||||
|
|
||||||
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
|
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
|
||||||
"""构建工具信息块
|
"""构建工具信息块
|
||||||
@@ -453,7 +448,7 @@ class DefaultReplyer:
|
|||||||
for name, content in result.groupdict().items():
|
for name, content in result.groupdict().items():
|
||||||
reaction = reaction.replace(f"[{name}]", content)
|
reaction = reaction.replace(f"[{name}]", content)
|
||||||
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
|
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
|
||||||
keywords_reaction_prompt += reaction + ","
|
keywords_reaction_prompt += f"{reaction},"
|
||||||
break
|
break
|
||||||
except re.error as e:
|
except re.error as e:
|
||||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
|
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
|
||||||
@@ -477,7 +472,7 @@ class DefaultReplyer:
|
|||||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
enable_timeout: bool = False,
|
enable_timeout: bool = False,
|
||||||
enable_tool: bool = True,
|
enable_tool: bool = True,
|
||||||
) -> str:
|
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||||
"""
|
"""
|
||||||
构建回复器上下文
|
构建回复器上下文
|
||||||
|
|
||||||
@@ -612,7 +607,7 @@ class DefaultReplyer:
|
|||||||
short_impression = ["友好活泼", "人类"]
|
short_impression = ["友好活泼", "人类"]
|
||||||
personality = short_impression[0]
|
personality = short_impression[0]
|
||||||
identity = short_impression[1]
|
identity = short_impression[1]
|
||||||
prompt_personality = personality + "," + identity
|
prompt_personality = f"{personality},{identity}"
|
||||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
moderation_prompt_block = (
|
moderation_prompt_block = (
|
||||||
@@ -660,7 +655,7 @@ class DefaultReplyer:
|
|||||||
"chat_target_private2", sender_name=chat_target_name
|
"chat_target_private2", sender_name=chat_target_name
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
expression_habits_block=expression_habits_block,
|
expression_habits_block=expression_habits_block,
|
||||||
chat_target=chat_target_1,
|
chat_target=chat_target_1,
|
||||||
@@ -683,8 +678,6 @@ class DefaultReplyer:
|
|||||||
mood_state=mood_prompt,
|
mood_state=mood_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
async def build_prompt_rewrite_context(
|
async def build_prompt_rewrite_context(
|
||||||
self,
|
self,
|
||||||
reply_data: Dict[str, Any],
|
reply_data: Dict[str, Any],
|
||||||
@@ -745,7 +738,7 @@ class DefaultReplyer:
|
|||||||
short_impression = ["友好活泼", "人类"]
|
short_impression = ["友好活泼", "人类"]
|
||||||
personality = short_impression[0]
|
personality = short_impression[0]
|
||||||
identity = short_impression[1]
|
identity = short_impression[1]
|
||||||
prompt_personality = personality + "," + identity
|
prompt_personality = f"{personality},{identity}"
|
||||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
moderation_prompt_block = (
|
moderation_prompt_block = (
|
||||||
@@ -790,7 +783,7 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
template_name = "default_expressor_prompt"
|
template_name = "default_expressor_prompt"
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
expression_habits_block=expression_habits_block,
|
expression_habits_block=expression_habits_block,
|
||||||
relation_info_block=relation_info,
|
relation_info_block=relation_info,
|
||||||
@@ -807,8 +800,6 @@ class DefaultReplyer:
|
|||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
async def send_response_messages(
|
async def send_response_messages(
|
||||||
self,
|
self,
|
||||||
anchor_message: Optional[MessageRecv],
|
anchor_message: Optional[MessageRecv],
|
||||||
@@ -816,6 +807,7 @@ class DefaultReplyer:
|
|||||||
thinking_id: str = "",
|
thinking_id: str = "",
|
||||||
display_message: str = "",
|
display_message: str = "",
|
||||||
) -> Optional[MessageSending]:
|
) -> Optional[MessageSending]:
|
||||||
|
# sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast
|
||||||
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
|
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
|
||||||
chat = self.chat_stream
|
chat = self.chat_stream
|
||||||
chat_id = self.chat_stream.stream_id
|
chat_id = self.chat_stream.stream_id
|
||||||
@@ -849,16 +841,16 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
for i, msg_text in enumerate(response_set):
|
for i, msg_text in enumerate(response_set):
|
||||||
# 为每个消息片段生成唯一ID
|
# 为每个消息片段生成唯一ID
|
||||||
type = msg_text[0]
|
msg_type = msg_text[0]
|
||||||
data = msg_text[1]
|
data = msg_text[1]
|
||||||
|
|
||||||
if global_config.debug.debug_show_chat_mode and type == "text":
|
if global_config.debug.debug_show_chat_mode and msg_type == "text":
|
||||||
data += "ᶠ"
|
data += "ᶠ"
|
||||||
|
|
||||||
part_message_id = f"{thinking_id}_{i}"
|
part_message_id = f"{thinking_id}_{i}"
|
||||||
message_segment = Seg(type=type, data=data)
|
message_segment = Seg(type=msg_type, data=data)
|
||||||
|
|
||||||
if type == "emoji":
|
if msg_type == "emoji":
|
||||||
is_emoji = True
|
is_emoji = True
|
||||||
else:
|
else:
|
||||||
is_emoji = False
|
is_emoji = False
|
||||||
@@ -871,7 +863,6 @@ class DefaultReplyer:
|
|||||||
display_message=display_message,
|
display_message=display_message,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
is_emoji=is_emoji,
|
is_emoji=is_emoji,
|
||||||
thinking_id=thinking_id,
|
|
||||||
thinking_start_time=thinking_start_time,
|
thinking_start_time=thinking_start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -895,7 +886,7 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
reply_message_ids.append(part_message_id) # 记录我们生成的ID
|
reply_message_ids.append(part_message_id) # 记录我们生成的ID
|
||||||
|
|
||||||
sent_msg_list.append((type, sent_msg))
|
sent_msg_list.append((msg_type, sent_msg))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
|
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
|
||||||
@@ -930,12 +921,9 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# await anchor_message.process()
|
# await anchor_message.process()
|
||||||
if anchor_message:
|
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
||||||
sender_info = anchor_message.message_info.user_info
|
|
||||||
else:
|
|
||||||
sender_info = None
|
|
||||||
|
|
||||||
bot_message = MessageSending(
|
return MessageSending(
|
||||||
message_id=message_id, # 使用片段的唯一ID
|
message_id=message_id, # 使用片段的唯一ID
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
bot_user_info=bot_user_info,
|
bot_user_info=bot_user_info,
|
||||||
@@ -948,8 +936,6 @@ class DefaultReplyer:
|
|||||||
display_message=display_message,
|
display_message=display_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
return bot_message
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.replyer.default_generator import DefaultReplyer
|
from src.chat.replyer.default_generator import DefaultReplyer
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -8,7 +9,7 @@ logger = get_logger("ReplyerManager")
|
|||||||
|
|
||||||
class ReplyerManager:
|
class ReplyerManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._replyers: Dict[str, DefaultReplyer] = {}
|
self._repliers: Dict[str, DefaultReplyer] = {}
|
||||||
|
|
||||||
def get_replyer(
|
def get_replyer(
|
||||||
self,
|
self,
|
||||||
@@ -29,17 +30,16 @@ class ReplyerManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 如果已有缓存实例,直接返回
|
# 如果已有缓存实例,直接返回
|
||||||
if stream_id in self._replyers:
|
if stream_id in self._repliers:
|
||||||
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
|
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
|
||||||
return self._replyers[stream_id]
|
return self._repliers[stream_id]
|
||||||
|
|
||||||
# 如果没有缓存,则创建新实例(首次初始化)
|
# 如果没有缓存,则创建新实例(首次初始化)
|
||||||
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
|
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
|
||||||
|
|
||||||
target_stream = chat_stream
|
target_stream = chat_stream
|
||||||
if not target_stream:
|
if not target_stream:
|
||||||
chat_manager = get_chat_manager()
|
if chat_manager := get_chat_manager():
|
||||||
if chat_manager:
|
|
||||||
target_stream = chat_manager.get_stream(stream_id)
|
target_stream = chat_manager.get_stream(stream_id)
|
||||||
|
|
||||||
if not target_stream:
|
if not target_stream:
|
||||||
@@ -52,7 +52,7 @@ class ReplyerManager:
|
|||||||
model_configs=model_configs, # 可以是None,此时使用默认模型
|
model_configs=model_configs, # 可以是None,此时使用默认模型
|
||||||
request_type=request_type,
|
request_type=request_type,
|
||||||
)
|
)
|
||||||
self._replyers[stream_id] = replyer
|
self._repliers[stream_id] = replyer
|
||||||
return replyer
|
return replyer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from src.config.config import global_config
|
|
||||||
from typing import List, Dict, Any, Tuple # 确保类型提示被导入
|
|
||||||
import time # 导入 time 模块以获取当前时间
|
import time # 导入 time 模块以获取当前时间
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.common.database.database_model import ActionRecords
|
from src.common.database.database_model import ActionRecords
|
||||||
from src.common.database.database_model import Images
|
from src.common.database.database_model import Images
|
||||||
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -135,7 +136,7 @@ def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list,
|
|||||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int:
|
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||||
"""
|
"""
|
||||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||||
@@ -172,7 +173,7 @@ def _build_readable_messages_internal(
|
|||||||
merge_messages: bool = False,
|
merge_messages: bool = False,
|
||||||
timestamp_mode: str = "relative",
|
timestamp_mode: str = "relative",
|
||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
pic_id_mapping: Dict[str, str] = None,
|
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||||
pic_counter: int = 1,
|
pic_counter: int = 1,
|
||||||
show_pic: bool = True,
|
show_pic: bool = True,
|
||||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||||
@@ -194,7 +195,7 @@ def _build_readable_messages_internal(
|
|||||||
if not messages:
|
if not messages:
|
||||||
return "", [], pic_id_mapping or {}, pic_counter
|
return "", [], pic_id_mapping or {}, pic_counter
|
||||||
|
|
||||||
message_details_raw: List[Tuple[float, str, str]] = []
|
message_details_raw: List[Tuple[float, str, str, bool]] = []
|
||||||
|
|
||||||
# 使用传入的映射字典,如果没有则创建新的
|
# 使用传入的映射字典,如果没有则创建新的
|
||||||
if pic_id_mapping is None:
|
if pic_id_mapping is None:
|
||||||
@@ -225,7 +226,7 @@ def _build_readable_messages_internal(
|
|||||||
# 检查是否是动作记录
|
# 检查是否是动作记录
|
||||||
if msg.get("is_action_record", False):
|
if msg.get("is_action_record", False):
|
||||||
is_action = True
|
is_action = True
|
||||||
timestamp = msg.get("time")
|
timestamp: float = msg.get("time") # type: ignore
|
||||||
content = msg.get("display_message", "")
|
content = msg.get("display_message", "")
|
||||||
# 对于动作记录,也处理图片ID
|
# 对于动作记录,也处理图片ID
|
||||||
content = process_pic_ids(content)
|
content = process_pic_ids(content)
|
||||||
@@ -249,9 +250,10 @@ def _build_readable_messages_internal(
|
|||||||
user_nickname = user_info.get("user_nickname")
|
user_nickname = user_info.get("user_nickname")
|
||||||
user_cardname = user_info.get("user_cardname")
|
user_cardname = user_info.get("user_cardname")
|
||||||
|
|
||||||
timestamp = msg.get("time")
|
timestamp: float = msg.get("time") # type: ignore
|
||||||
|
content: str
|
||||||
if msg.get("display_message"):
|
if msg.get("display_message"):
|
||||||
content = msg.get("display_message")
|
content = msg.get("display_message", "")
|
||||||
else:
|
else:
|
||||||
content = msg.get("processed_plain_text", "") # 默认空字符串
|
content = msg.get("processed_plain_text", "") # 默认空字符串
|
||||||
|
|
||||||
@@ -271,6 +273,7 @@ def _build_readable_messages_internal(
|
|||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||||
|
person_name: str
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
person_name = f"{global_config.bot.nickname}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
@@ -289,12 +292,10 @@ def _build_readable_messages_internal(
|
|||||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||||
match = re.search(reply_pattern, content)
|
match = re.search(reply_pattern, content)
|
||||||
if match:
|
if match:
|
||||||
aaa = match.group(1)
|
aaa: str = match[1]
|
||||||
bbb = match.group(2)
|
bbb: str = match[2]
|
||||||
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||||
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
|
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa
|
||||||
if not reply_person_name:
|
|
||||||
reply_person_name = aaa
|
|
||||||
# 在内容前加上回复信息
|
# 在内容前加上回复信息
|
||||||
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
|
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
|
||||||
|
|
||||||
@@ -309,17 +310,14 @@ def _build_readable_messages_internal(
|
|||||||
aaa = m.group(1)
|
aaa = m.group(1)
|
||||||
bbb = m.group(2)
|
bbb = m.group(2)
|
||||||
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||||
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
|
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa
|
||||||
if not at_person_name:
|
|
||||||
at_person_name = aaa
|
|
||||||
new_content += f"@{at_person_name}"
|
new_content += f"@{at_person_name}"
|
||||||
last_end = m.end()
|
last_end = m.end()
|
||||||
new_content += content[last_end:]
|
new_content += content[last_end:]
|
||||||
content = new_content
|
content = new_content
|
||||||
|
|
||||||
target_str = "这是QQ的一个功能,用于提及某人,但没那么明显"
|
target_str = "这是QQ的一个功能,用于提及某人,但没那么明显"
|
||||||
if target_str in content:
|
if target_str in content and random.random() < 0.6:
|
||||||
if random.random() < 0.6:
|
|
||||||
content = content.replace(target_str, "")
|
content = content.replace(target_str, "")
|
||||||
|
|
||||||
if content != "":
|
if content != "":
|
||||||
@@ -470,6 +468,7 @@ def _build_readable_messages_internal(
|
|||||||
|
|
||||||
|
|
||||||
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||||
|
# sourcery skip: use-contextlib-suppress
|
||||||
"""
|
"""
|
||||||
构建图片映射信息字符串,显示图片的具体描述内容
|
构建图片映射信息字符串,显示图片的具体描述内容
|
||||||
|
|
||||||
@@ -518,9 +517,7 @@ async def build_readable_messages_with_list(
|
|||||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成图片映射信息并添加到最前面
|
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
|
||||||
if pic_mapping_info:
|
|
||||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||||
|
|
||||||
return formatted_string, details_list
|
return formatted_string, details_list
|
||||||
@@ -535,7 +532,7 @@ def build_readable_messages(
|
|||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
show_actions: bool = False,
|
show_actions: bool = False,
|
||||||
show_pic: bool = True,
|
show_pic: bool = True,
|
||||||
) -> str:
|
) -> str: # sourcery skip: extract-method
|
||||||
"""
|
"""
|
||||||
将消息列表转换为可读的文本格式。
|
将消息列表转换为可读的文本格式。
|
||||||
如果提供了 read_mark,则在相应位置插入已读标记。
|
如果提供了 read_mark,则在相应位置插入已读标记。
|
||||||
@@ -658,9 +655,7 @@ def build_readable_messages(
|
|||||||
# 组合结果
|
# 组合结果
|
||||||
result_parts = []
|
result_parts = []
|
||||||
if pic_mapping_info:
|
if pic_mapping_info:
|
||||||
result_parts.append(pic_mapping_info)
|
result_parts.extend((pic_mapping_info, "\n"))
|
||||||
result_parts.append("\n")
|
|
||||||
|
|
||||||
if formatted_before and formatted_after:
|
if formatted_before and formatted_after:
|
||||||
result_parts.extend([formatted_before, read_mark_line, formatted_after])
|
result_parts.extend([formatted_before, read_mark_line, formatted_after])
|
||||||
elif formatted_before:
|
elif formatted_before:
|
||||||
@@ -733,8 +728,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
platform = msg.get("chat_info_platform")
|
platform = msg.get("chat_info_platform")
|
||||||
user_id = msg.get("user_id")
|
user_id = msg.get("user_id")
|
||||||
_timestamp = msg.get("time")
|
_timestamp = msg.get("time")
|
||||||
|
content: str = ""
|
||||||
if msg.get("display_message"):
|
if msg.get("display_message"):
|
||||||
content = msg.get("display_message")
|
content = msg.get("display_message", "")
|
||||||
else:
|
else:
|
||||||
content = msg.get("processed_plain_text", "")
|
content = msg.get("processed_plain_text", "")
|
||||||
|
|
||||||
@@ -829,10 +825,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
if person_id := PersonInfoManager.get_person_id(platform, user_id):
|
||||||
|
|
||||||
# 只有当获取到有效 person_id 时才添加
|
|
||||||
if person_id:
|
|
||||||
person_ids_set.add(person_id)
|
person_ids_set.add(person_id)
|
||||||
|
|
||||||
return list(person_ids_set) # 将集合转换为列表返回
|
return list(person_ids_set) # 将集合转换为列表返回
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class ImageManager:
|
|||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
|
|
||||||
# 查询缓存的描述
|
# 查询缓存的描述
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||||
@@ -154,7 +154,7 @@ class ImageManager:
|
|||||||
img_obj.description = description
|
img_obj.description = description
|
||||||
img_obj.timestamp = current_timestamp
|
img_obj.timestamp = current_timestamp
|
||||||
img_obj.save()
|
img_obj.save()
|
||||||
except Images.DoesNotExist:
|
except Images.DoesNotExist: # type: ignore
|
||||||
Images.create(
|
Images.create(
|
||||||
emoji_hash=image_hash,
|
emoji_hash=image_hash,
|
||||||
path=file_path,
|
path=file_path,
|
||||||
@@ -204,7 +204,7 @@ class ImageManager:
|
|||||||
return f"[图片:{cached_description}]"
|
return f"[图片:{cached_description}]"
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,输出为一段平文本,最多50字"
|
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,输出为一段平文本,最多50字"
|
||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||||
|
|
||||||
@@ -491,7 +491,7 @@ class ImageManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 获取图片格式
|
# 获取图片格式
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
prompt = """请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本"""
|
prompt = """请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本"""
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from src.common.database.database import db
|
|||||||
from src.common.database.database_model import PersonInfo # 新增导入
|
from src.common.database.database_model import PersonInfo # 新增导入
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict, Union
|
||||||
import datetime
|
import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -84,7 +84,7 @@ class PersonInfoManager:
|
|||||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_person_id(platform: str, user_id: int):
|
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||||
"""获取唯一id"""
|
"""获取唯一id"""
|
||||||
if "-" in platform:
|
if "-" in platform:
|
||||||
platform = platform.split("-")[1]
|
platform = platform.split("-")[1]
|
||||||
|
|||||||
@@ -32,10 +32,10 @@ class BaseAction(ABC):
|
|||||||
reasoning: str,
|
reasoning: str,
|
||||||
cycle_timers: dict,
|
cycle_timers: dict,
|
||||||
thinking_id: str,
|
thinking_id: str,
|
||||||
chat_stream: ChatStream = None,
|
chat_stream: ChatStream,
|
||||||
log_prefix: str = "",
|
log_prefix: str = "",
|
||||||
shutting_down: bool = False,
|
shutting_down: bool = False,
|
||||||
plugin_config: dict = None,
|
plugin_config: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""初始化Action组件
|
"""初始化Action组件
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class BaseCommand(ABC):
|
|||||||
command_examples: List[str] = []
|
command_examples: List[str] = []
|
||||||
intercept_message: bool = True # 默认拦截消息,不继续处理
|
intercept_message: bool = True # 默认拦截消息,不继续处理
|
||||||
|
|
||||||
def __init__(self, message: MessageRecv, plugin_config: dict = None):
|
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||||
"""初始化Command组件
|
"""初始化Command组件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class BasePlugin(ABC):
|
|||||||
|
|
||||||
config_section_descriptions: Dict[str, str] = {}
|
config_section_descriptions: Dict[str, str] = {}
|
||||||
|
|
||||||
def __init__(self, plugin_dir: str = None):
|
def __init__(self, plugin_dir: str):
|
||||||
"""初始化插件
|
"""初始化插件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -526,7 +526,7 @@ class BasePlugin(ABC):
|
|||||||
|
|
||||||
# 从配置中更新 enable_plugin
|
# 从配置中更新 enable_plugin
|
||||||
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
||||||
self.enable_plugin = self.config["plugin"]["enabled"]
|
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
|
||||||
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
|
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
||||||
|
|||||||
@@ -81,7 +81,9 @@ class ComponentInfo:
|
|||||||
class ActionInfo(ComponentInfo):
|
class ActionInfo(ComponentInfo):
|
||||||
"""动作组件信息"""
|
"""动作组件信息"""
|
||||||
|
|
||||||
action_parameters: Dict[str, str] = field(default_factory=dict) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
action_parameters: Dict[str, str] = field(
|
||||||
|
default_factory=dict
|
||||||
|
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
||||||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||||
# 激活类型相关
|
# 激活类型相关
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional, Any, Pattern, Union
|
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||||
import re
|
import re
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import (
|
from src.plugin_system.base.component_types import (
|
||||||
@@ -28,25 +28,25 @@ class ComponentRegistry:
|
|||||||
ComponentType.ACTION: {},
|
ComponentType.ACTION: {},
|
||||||
ComponentType.COMMAND: {},
|
ComponentType.COMMAND: {},
|
||||||
}
|
}
|
||||||
self._component_classes: Dict[str, Union[BaseCommand, BaseAction]] = {} # 组件名 -> 组件类
|
self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类
|
||||||
|
|
||||||
# 插件注册表
|
# 插件注册表
|
||||||
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
||||||
|
|
||||||
# Action特定注册表
|
# Action特定注册表
|
||||||
self._action_registry: Dict[str, BaseAction] = {} # action名 -> action类
|
self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类
|
||||||
# self._action_descriptions: Dict[str, str] = {} # 启用的action名 -> 描述
|
self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集,即启用的Action集,用于重置ActionManager状态
|
||||||
|
|
||||||
# Command特定注册表
|
# Command特定注册表
|
||||||
self._command_registry: Dict[str, BaseCommand] = {} # command名 -> command类
|
self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类
|
||||||
self._command_patterns: Dict[Pattern, BaseCommand] = {} # 编译后的正则 -> command类
|
self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command类
|
||||||
|
|
||||||
logger.info("组件注册中心初始化完成")
|
logger.info("组件注册中心初始化完成")
|
||||||
|
|
||||||
# === 通用组件注册方法 ===
|
# === 通用组件注册方法 ===
|
||||||
|
|
||||||
def register_component(
|
def register_component(
|
||||||
self, component_info: ComponentInfo, component_class: Union[BaseCommand, BaseAction]
|
self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""注册组件
|
"""注册组件
|
||||||
|
|
||||||
@@ -88,9 +88,9 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# 根据组件类型进行特定注册(使用原始名称)
|
# 根据组件类型进行特定注册(使用原始名称)
|
||||||
if component_type == ComponentType.ACTION:
|
if component_type == ComponentType.ACTION:
|
||||||
self._register_action_component(component_info, component_class)
|
self._register_action_component(component_info, component_class) # type: ignore
|
||||||
elif component_type == ComponentType.COMMAND:
|
elif component_type == ComponentType.COMMAND:
|
||||||
self._register_command_component(component_info, component_class)
|
self._register_command_component(component_info, component_class) # type: ignore
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
|
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
|
||||||
@@ -98,7 +98,7 @@ class ComponentRegistry:
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _register_action_component(self, action_info: ActionInfo, action_class: BaseAction):
|
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]):
|
||||||
# -------------------------------- NEED REFACTORING --------------------------------
|
# -------------------------------- NEED REFACTORING --------------------------------
|
||||||
# -------------------------------- LOGIC ERROR -------------------------------------
|
# -------------------------------- LOGIC ERROR -------------------------------------
|
||||||
"""注册Action组件到Action特定注册表"""
|
"""注册Action组件到Action特定注册表"""
|
||||||
@@ -106,11 +106,10 @@ class ComponentRegistry:
|
|||||||
self._action_registry[action_name] = action_class
|
self._action_registry[action_name] = action_class
|
||||||
|
|
||||||
# 如果启用,添加到默认动作集
|
# 如果启用,添加到默认动作集
|
||||||
# ---- HERE ----
|
if action_info.enabled:
|
||||||
# if action_info.enabled:
|
self._default_actions[action_name] = action_info
|
||||||
# self._action_descriptions[action_name] = action_info.description
|
|
||||||
|
|
||||||
def _register_command_component(self, command_info: CommandInfo, command_class: BaseCommand):
|
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]):
|
||||||
"""注册Command组件到Command特定注册表"""
|
"""注册Command组件到Command特定注册表"""
|
||||||
command_name = command_info.name
|
command_name = command_info.name
|
||||||
self._command_registry[command_name] = command_class
|
self._command_registry[command_name] = command_class
|
||||||
@@ -122,7 +121,7 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# === 组件查询方法 ===
|
# === 组件查询方法 ===
|
||||||
|
|
||||||
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]:
|
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore
|
||||||
# sourcery skip: class-extract-method
|
# sourcery skip: class-extract-method
|
||||||
"""获取组件信息,支持自动命名空间解析
|
"""获取组件信息,支持自动命名空间解析
|
||||||
|
|
||||||
@@ -170,8 +169,10 @@ class ComponentRegistry:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_component_class(
|
def get_component_class(
|
||||||
self, component_name: str, component_type: ComponentType = None
|
self,
|
||||||
) -> Optional[Union[BaseCommand, BaseAction]]:
|
component_name: str,
|
||||||
|
component_type: ComponentType = None, # type: ignore
|
||||||
|
) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]:
|
||||||
"""获取组件类,支持自动命名空间解析
|
"""获取组件类,支持自动命名空间解析
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -230,7 +231,7 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# === Action特定查询方法 ===
|
# === Action特定查询方法 ===
|
||||||
|
|
||||||
def get_action_registry(self) -> Dict[str, BaseAction]:
|
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
||||||
"""获取Action注册表(用于兼容现有系统)"""
|
"""获取Action注册表(用于兼容现有系统)"""
|
||||||
return self._action_registry.copy()
|
return self._action_registry.copy()
|
||||||
|
|
||||||
@@ -239,13 +240,17 @@ class ComponentRegistry:
|
|||||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||||
return info if isinstance(info, ActionInfo) else None
|
return info if isinstance(info, ActionInfo) else None
|
||||||
|
|
||||||
|
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||||
|
"""获取默认动作集"""
|
||||||
|
return self._default_actions.copy()
|
||||||
|
|
||||||
# === Command特定查询方法 ===
|
# === Command特定查询方法 ===
|
||||||
|
|
||||||
def get_command_registry(self) -> Dict[str, BaseCommand]:
|
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
|
||||||
"""获取Command注册表(用于兼容现有系统)"""
|
"""获取Command注册表(用于兼容现有系统)"""
|
||||||
return self._command_registry.copy()
|
return self._command_registry.copy()
|
||||||
|
|
||||||
def get_command_patterns(self) -> Dict[Pattern, BaseCommand]:
|
def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]:
|
||||||
"""获取Command模式注册表(用于兼容现有系统)"""
|
"""获取Command模式注册表(用于兼容现有系统)"""
|
||||||
return self._command_patterns.copy()
|
return self._command_patterns.copy()
|
||||||
|
|
||||||
@@ -254,7 +259,7 @@ class ComponentRegistry:
|
|||||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||||
return info if isinstance(info, CommandInfo) else None
|
return info if isinstance(info, CommandInfo) else None
|
||||||
|
|
||||||
def find_command_by_text(self, text: str) -> Optional[tuple[BaseCommand, dict, bool, str]]:
|
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
|
||||||
# sourcery skip: use-named-expression, use-next
|
# sourcery skip: use-named-expression, use-next
|
||||||
"""根据文本查找匹配的命令
|
"""根据文本查找匹配的命令
|
||||||
|
|
||||||
@@ -262,7 +267,7 @@ class ComponentRegistry:
|
|||||||
text: 输入文本
|
text: 输入文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[tuple[BaseCommand, dict, bool, str]]: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for pattern, command_class in self._command_patterns.items():
|
for pattern, command_class in self._command_patterns.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user