部分类型注解修复,优化import顺序,删除无用API文件

This commit is contained in:
UnCLAS-Prommer
2025-07-12 00:34:49 +08:00
parent 3165a0f8df
commit b303a95f61
44 changed files with 405 additions and 1166 deletions

View File

View File

@@ -1,26 +0,0 @@
from src.chat.heart_flow.heartflow import heartflow
from src.chat.heart_flow.sub_heartflow import ChatState
from src.common.logger import get_logger
logger = get_logger("api")
async def get_all_subheartflow_ids() -> list:
"""获取所有子心流的ID列表"""
all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows()
return [subheartflow.subheartflow_id for subheartflow in all_subheartflows]
async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool:
"""强制改变子心流的状态"""
subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id)
if subheartflow:
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
return False
async def get_all_states():
"""获取所有状态"""
all_states = await heartflow.api_get_all_states()
logger.debug(f"所有状态: {all_states}")
return all_states

View File

@@ -1,169 +0,0 @@
import platform
import psutil
import sys
import os
def get_system_info():
"""获取操作系统信息"""
return {
"system": platform.system(),
"release": platform.release(),
"version": platform.version(),
"machine": platform.machine(),
"processor": platform.processor(),
}
def get_python_version():
"""获取 Python 版本信息"""
return sys.version
def get_cpu_usage():
"""获取系统总CPU使用率"""
return psutil.cpu_percent(interval=1)
def get_process_cpu_usage():
"""获取当前进程CPU使用率"""
process = psutil.Process(os.getpid())
return process.cpu_percent(interval=1)
def get_memory_usage():
"""获取系统内存使用情况 (单位 MB)"""
mem = psutil.virtual_memory()
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
return {
"total_mb": bytes_to_mb(mem.total),
"available_mb": bytes_to_mb(mem.available),
"percent": mem.percent,
"used_mb": bytes_to_mb(mem.used),
"free_mb": bytes_to_mb(mem.free),
}
def get_process_memory_usage():
"""获取当前进程内存使用情况 (单位 MB)"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
return {
"rss_mb": bytes_to_mb(mem_info.rss), # Resident Set Size: 实际使用物理内存
"vms_mb": bytes_to_mb(mem_info.vms), # Virtual Memory Size: 虚拟内存大小
"percent": process.memory_percent(), # 进程内存使用百分比
}
def get_disk_usage(path="/"):
"""获取指定路径磁盘使用情况 (单位 GB)"""
disk = psutil.disk_usage(path)
bytes_to_gb = lambda x: round(x / (1024 * 1024 * 1024), 2) # noqa
return {
"total_gb": bytes_to_gb(disk.total),
"used_gb": bytes_to_gb(disk.used),
"free_gb": bytes_to_gb(disk.free),
"percent": disk.percent,
}
def get_all_basic_info():
"""获取所有基本信息并封装返回"""
# 对于进程CPU使用率需要先初始化
process = psutil.Process(os.getpid())
process.cpu_percent(interval=None) # 初始化调用
process_cpu = process.cpu_percent(interval=0.1) # 短暂间隔获取
return {
"system_info": get_system_info(),
"python_version": get_python_version(),
"cpu_usage_percent": get_cpu_usage(),
"process_cpu_usage_percent": process_cpu,
"memory_usage": get_memory_usage(),
"process_memory_usage": get_process_memory_usage(),
"disk_usage_root": get_disk_usage("/"),
}
def get_all_basic_info_string() -> str:
"""获取所有基本信息并以带解释的字符串形式返回"""
info = get_all_basic_info()
sys_info = info["system_info"]
mem_usage = info["memory_usage"]
proc_mem_usage = info["process_memory_usage"]
disk_usage = info["disk_usage_root"]
# 对进程内存使用百分比进行格式化,保留两位小数
proc_mem_percent = round(proc_mem_usage["percent"], 2)
output_string = f"""[系统信息]
- 操作系统: {sys_info["system"]} (例如: Windows, Linux)
- 发行版本: {sys_info["release"]} (例如: 11, Ubuntu 20.04)
- 详细版本: {sys_info["version"]}
- 硬件架构: {sys_info["machine"]} (例如: AMD64)
- 处理器信息: {sys_info["processor"]}
[Python 环境]
- Python 版本: {info["python_version"]}
[CPU 状态]
- 系统总 CPU 使用率: {info["cpu_usage_percent"]}%
- 当前进程 CPU 使用率: {info["process_cpu_usage_percent"]}%
[系统内存使用情况]
- 总物理内存: {mem_usage["total_mb"]} MB
- 可用物理内存: {mem_usage["available_mb"]} MB
- 物理内存使用率: {mem_usage["percent"]}%
- 已用物理内存: {mem_usage["used_mb"]} MB
- 空闲物理内存: {mem_usage["free_mb"]} MB
[当前进程内存使用情况]
- 实际使用物理内存 (RSS): {proc_mem_usage["rss_mb"]} MB
- 占用虚拟内存 (VMS): {proc_mem_usage["vms_mb"]} MB
- 进程内存使用率: {proc_mem_percent}%
[磁盘使用情况 (根目录)]
- 总空间: {disk_usage["total_gb"]} GB
- 已用空间: {disk_usage["used_gb"]} GB
- 可用空间: {disk_usage["free_gb"]} GB
- 磁盘使用率: {disk_usage["percent"]}%
"""
return output_string
if __name__ == "__main__":
print(f"System Info: {get_system_info()}")
print(f"Python Version: {get_python_version()}")
print(f"CPU Usage: {get_cpu_usage()}%")
# 第一次调用 process.cpu_percent() 会返回0.0或一个无意义的值,需要间隔一段时间再调用
# 或者在初始化Process对象后先调用一次cpu_percent(interval=None)然后再调用cpu_percent(interval=1)
current_process = psutil.Process(os.getpid())
current_process.cpu_percent(interval=None) # 初始化
print(f"Process CPU Usage: {current_process.cpu_percent(interval=1)}%") # 实际获取
memory_usage_info = get_memory_usage()
print(
f"Memory Usage: Total={memory_usage_info['total_mb']}MB, Used={memory_usage_info['used_mb']}MB, Percent={memory_usage_info['percent']}%"
)
process_memory_info = get_process_memory_usage()
print(
f"Process Memory Usage: RSS={process_memory_info['rss_mb']}MB, VMS={process_memory_info['vms_mb']}MB, Percent={process_memory_info['percent']}%"
)
disk_usage_info = get_disk_usage("/")
print(
f"Disk Usage (Root): Total={disk_usage_info['total_gb']}GB, Used={disk_usage_info['used_gb']}GB, Percent={disk_usage_info['percent']}%"
)
print("\n--- All Basic Info (JSON) ---")
all_info = get_all_basic_info()
import json
print(json.dumps(all_info, indent=4, ensure_ascii=False))
print("\n--- All Basic Info (String with Explanations) ---")
info_string = get_all_basic_info_string()
print(info_string)

View File

@@ -1,317 +0,0 @@
from typing import List, Optional, Dict, Any
import strawberry
# from packaging.version import Version
import os
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
@strawberry.type
class APIBotConfig:
"""机器人配置类"""
INNER_VERSION: str # 配置文件内部版本号toml为字符串
MAI_VERSION: str # 硬编码的版本信息
# bot
BOT_QQ: Optional[int] # 机器人QQ号
BOT_NICKNAME: Optional[str] # 机器人昵称
BOT_ALIAS_NAMES: List[str] # 机器人别名列表
# group
talk_allowed_groups: List[int] # 允许回复消息的群号列表
talk_frequency_down_groups: List[int] # 降低回复频率的群号列表
ban_user_id: List[int] # 禁止回复和读取消息的QQ号列表
# personality
personality_core: str # 人格核心特点描述
personality_sides: List[str] # 人格细节描述列表
# identity
identity_detail: List[str] # 身份特点列表
age: int # 年龄(岁)
gender: str # 性别
appearance: str # 外貌特征描述
# platforms
platforms: Dict[str, str] # 平台信息
# chat
allow_focus_mode: bool # 是否允许专注聊天状态
base_normal_chat_num: int # 最多允许多少个群进行普通聊天
base_focused_chat_num: int # 最多允许多少个群进行专注聊天
observation_context_size: int # 观察到的最长上下文大小
message_buffer: bool # 是否启用消息缓冲
ban_words: List[str] # 禁止词列表
ban_msgs_regex: List[str] # 禁止消息的正则表达式列表
# normal_chat
model_reasoning_probability: float # 推理模型概率
model_normal_probability: float # 普通模型概率
emoji_chance: float # 表情符号出现概率
thinking_timeout: int # 思考超时时间
willing_mode: str # 意愿模式
response_interested_rate_amplifier: float # 回复兴趣率放大器
emoji_response_penalty: float # 表情回复惩罚
mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复
at_bot_inevitable_reply: bool # @bot 必然回复
# focus_chat
reply_trigger_threshold: float # 回复触发阈值
default_decay_rate_per_second: float # 默认每秒衰减率
# compressed
compressed_length: int # 压缩长度
compress_length_limit: int # 压缩长度限制
# emoji
max_emoji_num: int # 最大表情符号数量
max_reach_deletion: bool # 达到最大数量时是否删除
check_interval: int # 检查表情包的时间间隔(分钟)
save_emoji: bool # 是否保存表情包
steal_emoji: bool # 是否偷取表情包
enable_check: bool # 是否启用表情包过滤
check_prompt: str # 表情包过滤要求
# memory
build_memory_interval: int # 记忆构建间隔
build_memory_distribution: List[float] # 记忆构建分布
build_memory_sample_num: int # 采样数量
build_memory_sample_length: int # 采样长度
memory_compress_rate: float # 记忆压缩率
forget_memory_interval: int # 记忆遗忘间隔
memory_forget_time: int # 记忆遗忘时间(小时)
memory_forget_percentage: float # 记忆遗忘比例
consolidate_memory_interval: int # 记忆整合间隔
consolidation_similarity_threshold: float # 相似度阈值
consolidation_check_percentage: float # 检查节点比例
memory_ban_words: List[str] # 记忆禁止词列表
# mood
mood_update_interval: float # 情绪更新间隔
mood_decay_rate: float # 情绪衰减率
mood_intensity_factor: float # 情绪强度因子
# keywords_reaction
keywords_reaction_enable: bool # 是否启用关键词反应
keywords_reaction_rules: List[Dict[str, Any]] # 关键词反应规则
# chinese_typo
chinese_typo_enable: bool # 是否启用中文错别字
chinese_typo_error_rate: float # 中文错别字错误率
chinese_typo_min_freq: int # 中文错别字最小频率
chinese_typo_tone_error_rate: float # 中文错别字声调错误率
chinese_typo_word_replace_rate: float # 中文错别字单词替换率
# response_splitter
enable_response_splitter: bool # 是否启用回复分割器
response_max_length: int # 回复最大长度
response_max_sentence_num: int # 回复最大句子数
enable_kaomoji_protection: bool # 是否启用颜文字保护
model_max_output_length: int # 模型最大输出长度
# remote
remote_enable: bool # 是否启用远程功能
# experimental
enable_friend_chat: bool # 是否启用好友聊天
talk_allowed_private: List[int] # 允许私聊的QQ号列表
pfc_chatting: bool # 是否启用PFC聊天
# 模型配置
llm_reasoning: Dict[str, Any] # 推理模型配置
llm_normal: Dict[str, Any] # 普通模型配置
llm_topic_judge: Dict[str, Any] # 主题判断模型配置
summary: Dict[str, Any] # 总结模型配置
vlm: Dict[str, Any] # VLM模型配置
llm_heartflow: Dict[str, Any] # 心流模型配置
llm_observation: Dict[str, Any] # 观察模型配置
llm_sub_heartflow: Dict[str, Any] # 子心流模型配置
llm_plan: Optional[Dict[str, Any]] # 计划模型配置
embedding: Dict[str, Any] # 嵌入模型配置
llm_PFC_action_planner: Optional[Dict[str, Any]] # PFC行动计划模型配置
llm_PFC_chat: Optional[Dict[str, Any]] # PFC聊天模型配置
llm_PFC_reply_checker: Optional[Dict[str, Any]] # PFC回复检查模型配置
llm_tool_use: Optional[Dict[str, Any]] # 工具使用模型配置
api_urls: Optional[Dict[str, str]] # API地址配置
@staticmethod
def validate_config(config: dict):
"""
校验传入的 toml 配置字典是否合法。
:param config: toml库load后的配置字典
:raises: ValueError, KeyError, TypeError
"""
# 检查主层级
required_sections = [
"inner",
"bot",
"groups",
"personality",
"identity",
"platforms",
"chat",
"normal_chat",
"focus_chat",
"emoji",
"memory",
"mood",
"keywords_reaction",
"chinese_typo",
"response_splitter",
"remote",
"experimental",
"model",
]
for section in required_sections:
if section not in config:
raise KeyError(f"缺少配置段: [{section}]")
# 检查部分关键字段
if "version" not in config["inner"]:
raise KeyError("缺少 inner.version 字段")
if not isinstance(config["inner"]["version"], str):
raise TypeError("inner.version 必须为字符串")
if "qq" not in config["bot"]:
raise KeyError("缺少 bot.qq 字段")
if not isinstance(config["bot"]["qq"], int):
raise TypeError("bot.qq 必须为整数")
if "personality_core" not in config["personality"]:
raise KeyError("缺少 personality.personality_core 字段")
if not isinstance(config["personality"]["personality_core"], str):
raise TypeError("personality.personality_core 必须为字符串")
if "identity_detail" not in config["identity"]:
raise KeyError("缺少 identity.identity_detail 字段")
if not isinstance(config["identity"]["identity_detail"], list):
raise TypeError("identity.identity_detail 必须为列表")
# 可继续添加更多字段的类型和值检查
# ...
# 检查模型配置
model_keys = [
"llm_reasoning",
"llm_normal",
"llm_topic_judge",
"summary",
"vlm",
"llm_heartflow",
"llm_observation",
"llm_sub_heartflow",
"embedding",
]
if "model" not in config:
raise KeyError("缺少 [model] 配置段")
for key in model_keys:
if key not in config["model"]:
raise KeyError(f"缺少 model.{key} 配置")
# 检查通过
return True
@strawberry.type
class APIEnvConfig:
"""环境变量配置"""
HOST: str # 服务主机地址
PORT: int # 服务端口
PLUGINS: List[str] # 插件列表
MONGODB_HOST: str # MongoDB 主机地址
MONGODB_PORT: int # MongoDB 端口
DATABASE_NAME: str # 数据库名称
CHAT_ANY_WHERE_BASE_URL: str # ChatAnywhere 基础URL
SILICONFLOW_BASE_URL: str # SiliconFlow 基础URL
DEEP_SEEK_BASE_URL: str # DeepSeek 基础URL
DEEP_SEEK_KEY: Optional[str] # DeepSeek API Key
CHAT_ANY_WHERE_KEY: Optional[str] # ChatAnywhere API Key
SILICONFLOW_KEY: Optional[str] # SiliconFlow API Key
SIMPLE_OUTPUT: Optional[bool] # 是否简化输出
CONSOLE_LOG_LEVEL: Optional[str] # 控制台日志等级
FILE_LOG_LEVEL: Optional[str] # 文件日志等级
DEFAULT_CONSOLE_LOG_LEVEL: Optional[str] # 默认控制台日志等级
DEFAULT_FILE_LOG_LEVEL: Optional[str] # 默认文件日志等级
@strawberry.field
def get_env(self) -> str:
return "env"
@staticmethod
def validate_config(config: dict):
"""
校验环境变量配置字典是否合法。
:param config: 环境变量配置字典
:raises: KeyError, TypeError
"""
required_fields = [
"HOST",
"PORT",
"PLUGINS",
"MONGODB_HOST",
"MONGODB_PORT",
"DATABASE_NAME",
"CHAT_ANY_WHERE_BASE_URL",
"SILICONFLOW_BASE_URL",
"DEEP_SEEK_BASE_URL",
]
for field in required_fields:
if field not in config:
raise KeyError(f"缺少环境变量配置字段: {field}")
if not isinstance(config["HOST"], str):
raise TypeError("HOST 必须为字符串")
if not isinstance(config["PORT"], int):
raise TypeError("PORT 必须为整数")
if not isinstance(config["PLUGINS"], list):
raise TypeError("PLUGINS 必须为列表")
if not isinstance(config["MONGODB_HOST"], str):
raise TypeError("MONGODB_HOST 必须为字符串")
if not isinstance(config["MONGODB_PORT"], int):
raise TypeError("MONGODB_PORT 必须为整数")
if not isinstance(config["DATABASE_NAME"], str):
raise TypeError("DATABASE_NAME 必须为字符串")
if not isinstance(config["CHAT_ANY_WHERE_BASE_URL"], str):
raise TypeError("CHAT_ANY_WHERE_BASE_URL 必须为字符串")
if not isinstance(config["SILICONFLOW_BASE_URL"], str):
raise TypeError("SILICONFLOW_BASE_URL 必须为字符串")
if not isinstance(config["DEEP_SEEK_BASE_URL"], str):
raise TypeError("DEEP_SEEK_BASE_URL 必须为字符串")
# 可选字段类型检查
optional_str_fields = [
"DEEP_SEEK_KEY",
"CHAT_ANY_WHERE_KEY",
"SILICONFLOW_KEY",
"CONSOLE_LOG_LEVEL",
"FILE_LOG_LEVEL",
"DEFAULT_CONSOLE_LOG_LEVEL",
"DEFAULT_FILE_LOG_LEVEL",
]
for field in optional_str_fields:
if field in config and config[field] is not None and not isinstance(config[field], str):
raise TypeError(f"{field} 必须为字符串或None")
if (
"SIMPLE_OUTPUT" in config
and config["SIMPLE_OUTPUT"] is not None
and not isinstance(config["SIMPLE_OUTPUT"], bool)
):
raise TypeError("SIMPLE_OUTPUT 必须为布尔值或None")
# 检查通过
return True
print("当前路径:")
print(ROOT_PATH)

View File

@@ -1,22 +0,0 @@
import strawberry
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from src.common.server import get_global_server
@strawberry.type
class Query:
@strawberry.field
def hello(self) -> str:
return "Hello World"
schema = strawberry.Schema(Query)
graphql_app = GraphQLRouter(schema)
fast_api_app: FastAPI = get_global_server().get_app()
fast_api_app.include_router(graphql_app, prefix="/graphql")

View File

@@ -1 +0,0 @@
pass

View File

@@ -1,112 +0,0 @@
from fastapi import APIRouter
from strawberry.fastapi import GraphQLRouter
import os
import sys
# from src.chat.heart_flow.heartflow import heartflow
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
# from src.config.config import BotConfig
from src.common.logger import get_logger
from src.api.reload_config import reload_config as reload_config_func
from src.common.server import get_global_server
from src.api.apiforgui import (
get_all_subheartflow_ids,
forced_change_subheartflow_status,
get_subheartflow_cycle_info,
get_all_states,
)
from src.chat.heart_flow.sub_heartflow import ChatState
from src.api.basic_info_api import get_all_basic_info # 新增导入
router = APIRouter()
logger = get_logger("api")
logger.info("麦麦API服务器已启动")
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
@router.post("/config/reload")
async def reload_config():
return await reload_config_func()
@router.get("/gui/subheartflow/get/all")
async def get_subheartflow_ids():
"""获取所有子心流的ID列表"""
return await get_all_subheartflow_ids()
@router.post("/gui/subheartflow/forced_change_status")
async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): # noqa
"""强制改变子心流的状态"""
# 参数检查
if not isinstance(status, ChatState):
logger.warning(f"无效的状态参数: {status}")
return {"status": "failed", "reason": "invalid status"}
logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}")
success = await forced_change_subheartflow_status(subheartflow_id, status)
if success:
logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功")
return {"status": "success"}
else:
logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败")
return {"status": "failed"}
@router.get("/stop")
async def force_stop_maibot():
"""强制停止MAI Bot"""
from bot import request_shutdown
success = await request_shutdown()
if success:
logger.info("MAI Bot已强制停止")
return {"status": "success"}
else:
logger.error("MAI Bot强制停止失败")
return {"status": "failed"}
@router.get("/gui/subheartflow/cycleinfo")
async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int):
"""获取子心流的循环信息"""
cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len)
if cycle_info:
return {"status": "success", "data": cycle_info}
else:
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
return {"status": "failed", "reason": "subheartflow not found"}
@router.get("/gui/get_all_states")
async def get_all_states_api():
"""获取所有状态"""
all_states = await get_all_states()
if all_states:
return {"status": "success", "data": all_states}
else:
logger.warning("获取所有状态失败")
return {"status": "failed", "reason": "failed to get all states"}
@router.get("/info")
async def get_system_basic_info():
"""获取系统基本信息"""
logger.info("请求系统基本信息")
try:
info = get_all_basic_info()
return {"status": "success", "data": info}
except Exception as e:
logger.error(f"获取系统基本信息失败: {e}")
return {"status": "failed", "reason": str(e)}
def start_api_server():
"""启动API服务器"""
get_global_server().register_router(router, prefix="/api/v1")
# pass

View File

@@ -1,24 +0,0 @@
from fastapi import HTTPException
from rich.traceback import install
from src.config.config import get_config_dir, load_config
from src.common.logger import get_logger
import os
install(extra_lines=3)
logger = get_logger("api")
async def reload_config():
try:
from src.config import config as config_module
logger.debug("正在重载配置文件...")
bot_config_path = os.path.join(get_config_dir(), "bot_config.toml")
config_module.global_config = load_config(config_path=bot_config_path)
logger.debug("配置文件重载成功")
return {"status": "reloaded"}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e

View File

@@ -5,20 +5,19 @@ import os
import random import random
import time import time
import traceback import traceback
from typing import Optional, Tuple, List, Any
from PIL import Image
import io import io
import re import re
import binascii
# from gradio_client import file from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from src.common.database.database_model import Emoji from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db from src.common.database.database import db as peewee_db
from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -26,7 +25,7 @@ logger = get_logger("emoji")
BASE_DIR = os.path.join("data") BASE_DIR = os.path.join("data")
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录 EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录 EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中 MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
""" """
@@ -85,7 +84,7 @@ class MaiEmoji:
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}") logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
try: try:
with Image.open(io.BytesIO(image_bytes)) as img: with Image.open(io.BytesIO(image_bytes)) as img:
self.format = img.format.lower() self.format = img.format.lower() # type: ignore
logger.debug(f"[初始化] 格式获取成功: {self.format}") logger.debug(f"[初始化] 格式获取成功: {self.format}")
except Exception as pil_error: except Exception as pil_error:
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}") logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
@@ -100,7 +99,7 @@ class MaiEmoji:
logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}") logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}")
self.is_deleted = True self.is_deleted = True
return None return None
except base64.binascii.Error as b64_error: except (binascii.Error, ValueError) as b64_error:
logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}") logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}")
self.is_deleted = True self.is_deleted = True
return None return None
@@ -113,7 +112,7 @@ class MaiEmoji:
async def register_to_db(self) -> bool: async def register_to_db(self) -> bool:
""" """
注册表情包 注册表情包
将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下 将表情包对应的文件从当前路径移动到EMOJI_REGISTERED_DIR目录下
并修改对应的实例属性,然后将表情包信息保存到数据库中 并修改对应的实例属性,然后将表情包信息保存到数据库中
""" """
try: try:
@@ -122,7 +121,7 @@ class MaiEmoji:
# 源路径是当前实例的完整路径 self.full_path # 源路径是当前实例的完整路径 self.full_path
source_full_path = self.full_path source_full_path = self.full_path
# 目标完整路径 # 目标完整路径
destination_full_path = os.path.join(EMOJI_REGISTED_DIR, self.filename) destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename)
# 检查源文件是否存在 # 检查源文件是否存在
if not os.path.exists(source_full_path): if not os.path.exists(source_full_path):
@@ -139,7 +138,7 @@ class MaiEmoji:
logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}") logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}")
# 更新实例的路径属性为新路径 # 更新实例的路径属性为新路径
self.full_path = destination_full_path self.full_path = destination_full_path
self.path = EMOJI_REGISTED_DIR self.path = EMOJI_REGISTERED_DIR
# self.filename 保持不变 # self.filename 保持不变
except Exception as move_error: except Exception as move_error:
logger.error(f"[错误] 移动文件失败: {str(move_error)}") logger.error(f"[错误] 移动文件失败: {str(move_error)}")
@@ -202,7 +201,7 @@ class MaiEmoji:
try: try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash) will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist: except Emoji.DoesNotExist: # type: ignore
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted result = 0 # Indicate no DB record was deleted
@@ -298,7 +297,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
def _ensure_emoji_dir() -> None: def _ensure_emoji_dir() -> None:
"""确保表情存储目录存在""" """确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True) os.makedirs(EMOJI_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True) os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
async def clear_temp_emoji() -> None: async def clear_temp_emoji() -> None:
@@ -331,10 +330,10 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
return removed_count return removed_count
cleaned_count = 0
try: try:
# 获取内存中所有有效表情包的完整路径集合 # 获取内存中所有有效表情包的完整路径集合
tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted} tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted}
cleaned_count = 0
# 遍历指定目录中的所有文件 # 遍历指定目录中的所有文件
for file_name in os.listdir(emoji_dir): for file_name in os.listdir(emoji_dir):
@@ -358,11 +357,11 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
else: else:
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。") logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
return removed_count + cleaned_count
except Exception as e: except Exception as e:
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}") logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
return removed_count + cleaned_count
class EmojiManager: class EmojiManager:
_instance = None _instance = None
@@ -414,7 +413,7 @@ class EmojiManager:
emoji_update.usage_count += 1 emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB emoji_update.save() # Persist changes to DB
except Emoji.DoesNotExist: except Emoji.DoesNotExist: # type: ignore
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.error(f"记录表情使用失败: {str(e)}")
@@ -570,8 +569,8 @@ class EmojiManager:
if objects_to_remove: if objects_to_remove:
self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove] self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove]
# 清理 EMOJI_REGISTED_DIR 目录中未被追踪的文件 # 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件
removed_count = await clean_unused_emojis(EMOJI_REGISTED_DIR, self.emoji_objects, removed_count) removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count)
# 输出清理结果 # 输出清理结果
if removed_count > 0: if removed_count > 0:
@@ -850,11 +849,13 @@ class EmojiManager:
if isinstance(image_base64, str): if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 调用AI获取描述 # 调用AI获取描述
if image_format == "gif" or image_format == "GIF": if image_format == "gif" or image_format == "GIF":
image_base64 = get_image_manager().transform_gif(image_base64) image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
if not image_base64:
raise RuntimeError("GIF表情包转换失败")
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
else: else:

View File

@@ -1,14 +1,16 @@
from .exprssion_learner import get_expression_learner
import random
from typing import List, Dict, Tuple
from json_repair import repair_json
import json import json
import os import os
import time import time
import random
from typing import List, Dict, Tuple, Optional
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .exprssion_learner import get_expression_learner
logger = get_logger("expression_selector") logger = get_logger("expression_selector")
@@ -165,7 +167,12 @@ class ExpressionSelector:
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}") logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
async def select_suitable_expressions_llm( async def select_suitable_expressions_llm(
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5, target_message: str = None self,
chat_id: str,
chat_info: str,
max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None,
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
"""使用LLM选择适合的表达方式""" """使用LLM选择适合的表达方式"""

View File

@@ -1,14 +1,16 @@
import time import time
import random import random
import json
import os
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import os
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
import json
MAX_EXPRESSION_COUNT = 300 MAX_EXPRESSION_COUNT = 300
@@ -74,7 +76,8 @@ class ExpressionLearner:
) )
self.llm_model = None self.llm_model = None
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
# sourcery skip: extract-duplicate-method, remove-unnecessary-cast
""" """
获取指定chat_id的style和grammar表达方式 获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
@@ -119,10 +122,10 @@ class ExpressionLearner:
min_len = min(len(s1), len(s2)) min_len = min(len(s1), len(s2))
if min_len < 5: if min_len < 5:
return False return False
same = sum(1 for a, b in zip(s1, s2) if a == b) same = sum(a == b for a, b in zip(s1, s2))
return same / min_len > 0.8 return same / min_len > 0.8
async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]: async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]:
""" """
学习并存储表达方式,分别学习语言风格和句法特点 学习并存储表达方式,分别学习语言风格和句法特点
同时对所有已存储的表达方式进行全局衰减 同时对所有已存储的表达方式进行全局衰减
@@ -158,12 +161,12 @@ class ExpressionLearner:
for _ in range(3): for _ in range(3):
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25) learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
if not learnt_style: if not learnt_style:
return [] return [], []
for _ in range(1): for _ in range(1):
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10) learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar: if not learnt_grammar:
return [] return [], []
return learnt_style, learnt_grammar return learnt_style, learnt_grammar
@@ -214,6 +217,7 @@ class ExpressionLearner:
return result return result
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
# sourcery skip: use-join
""" """
选择从当前到最近1小时内的随机num条消息然后学习这些消息的表达方式 选择从当前到最近1小时内的随机num条消息然后学习这些消息的表达方式
type: "style" or "grammar" type: "style" or "grammar"
@@ -249,7 +253,7 @@ class ExpressionLearner:
return [] return []
# 按chat_id分组 # 按chat_id分组
chat_dict: Dict[str, List[Dict[str, str]]] = {} chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions: for chat_id, situation, style in learnt_expressions:
if chat_id not in chat_dict: if chat_id not in chat_dict:
chat_dict[chat_id] = [] chat_dict[chat_id] = []

View File

@@ -1,10 +1,10 @@
# 定义了来自外部世界的信息 # 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体 # 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime from datetime import datetime
from typing import List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.focus_chat.hfc_utils import CycleDetail from src.chat.focus_chat.hfc_utils import CycleDetail
from typing import List
# Import the new utility function
logger = get_logger("loop_info") logger = get_logger("loop_info")

View File

@@ -8,7 +8,7 @@ from rich.traceback import install
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.planner_actions.planner import ActionPlanner from src.chat.planner_actions.planner import ActionPlanner
@@ -49,7 +49,9 @@ class HeartFChatting:
""" """
# 基础属性 # 基础属性
self.stream_id: str = chat_id # 聊天流ID self.stream_id: str = chat_id # 聊天流ID
self.chat_stream = get_chat_manager().get_stream(self.stream_id) self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
if not self.chat_stream:
raise ValueError(f"无法找到聊天流: {self.stream_id}")
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id) self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
@@ -171,7 +173,7 @@ class HeartFChatting:
# 执行规划和处理阶段 # 执行规划和处理阶段
try: try:
async with self._get_cycle_context(): async with self._get_cycle_context():
thinking_id = "tid" + str(round(time.time(), 2)) thinking_id = f"tid{str(round(time.time(), 2))}"
self._current_cycle_detail.set_thinking_id(thinking_id) self._current_cycle_detail.set_thinking_id(thinking_id)
# 使用异步上下文管理器处理消息 # 使用异步上下文管理器处理消息
@@ -245,7 +247,7 @@ class HeartFChatting:
logger.info( logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考," f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}" f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "") + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
) )
@@ -256,7 +258,7 @@ class HeartFChatting:
cycle_performance_data = { cycle_performance_data = {
"cycle_id": self._current_cycle_detail.cycle_id, "cycle_id": self._current_cycle_detail.cycle_id,
"action_type": action_result.get("action_type", "unknown"), "action_type": action_result.get("action_type", "unknown"),
"total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time, "total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time, # type: ignore
"step_times": cycle_timers.copy(), "step_times": cycle_timers.copy(),
"reasoning": action_result.get("reasoning", ""), "reasoning": action_result.get("reasoning", ""),
"success": self._current_cycle_detail.loop_action_info.get("action_taken", False), "success": self._current_cycle_detail.loop_action_info.get("action_taken", False),
@@ -447,9 +449,6 @@ class HeartFChatting:
# 处理动作并获取结果 # 处理动作并获取结果
result = await action_handler.handle_action() result = await action_handler.handle_action()
if len(result) == 3:
success, reply_text, command = result
else:
success, reply_text = result success, reply_text = result
command = "" command = ""
@@ -478,8 +477,7 @@ class HeartFChatting:
) )
# 设置系统命令,在下次循环检查时触发退出 # 设置系统命令,在下次循环检查时触发退出
command = "stop_focus_chat" command = "stop_focus_chat"
else: elif reply_text == "timeout":
if reply_text == "timeout":
self.reply_timeout_count += 1 self.reply_timeout_count += 1
if self.reply_timeout_count > 5: if self.reply_timeout_count > 5:
logger.warning( logger.warning(

View File

@@ -2,6 +2,7 @@ import json
from datetime import datetime from datetime import datetime
from typing import Dict, Any from typing import Dict, Any
from pathlib import Path from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("hfc_performance") logger = get_logger("hfc_performance")

View File

@@ -1,11 +1,12 @@
import time import time
from typing import Optional import json
from typing import Optional, Dict, Any
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import UserInfo from src.chat.message_receive.message import UserInfo
from src.common.logger import get_logger from src.common.logger import get_logger
import json
from typing import Dict, Any
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -117,7 +118,7 @@ async def create_empty_anchor_message(
placeholder_msg_info = BaseMessageInfo( placeholder_msg_info = BaseMessageInfo(
message_id=placeholder_id, message_id=placeholder_id,
platform=platform, platform=platform,
group_info=group_info, group_info=group_info, # type: ignore
user_info=placeholder_user, user_info=placeholder_user,
time=time.time(), time=time.time(),
) )

View File

@@ -1,7 +1,7 @@
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState from typing import Any, Optional, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from typing import Any, Optional from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
from typing import Dict
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("heartflow") logger = get_logger("heartflow")
@@ -34,7 +34,7 @@ class Heartflow:
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True) logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
return None return None
async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> None: async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> bool:
"""强制改变子心流的状态""" """强制改变子心流的状态"""
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
return await self.force_change_state(subheartflow_id, status) return await self.force_change_state(subheartflow_id, status)

View File

@@ -1,21 +1,21 @@
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.config.config import global_config
import asyncio import asyncio
import re
import math
import traceback
from typing import Tuple
from src.config.config import global_config
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow from src.chat.heart_flow.heartflow import heartflow
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.common.logger import get_logger from src.common.logger import get_logger
import re
import math
import traceback
from typing import Tuple
from src.person_info.relationship_manager import get_relationship_manager from src.person_info.relationship_manager import get_relationship_manager
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
logger = get_logger("chat") logger = get_logger("chat")
@@ -26,16 +26,16 @@ async def _process_relationship(message: MessageRecv) -> None:
message: 消息对象,包含用户信息 message: 消息对象,包含用户信息
""" """
platform = message.message_info.platform platform = message.message_info.platform
user_id = message.message_info.user_info.user_id user_id = message.message_info.user_info.user_id # type: ignore
nickname = message.message_info.user_info.user_nickname nickname = message.message_info.user_info.user_nickname # type: ignore
cardname = message.message_info.user_info.user_cardname or nickname cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
relationship_manager = get_relationship_manager() relationship_manager = get_relationship_manager()
is_known = await relationship_manager.is_known_some_one(platform, user_id) is_known = await relationship_manager.is_known_some_one(platform, user_id)
if not is_known: if not is_known:
logger.info(f"首次认识用户: {nickname}") logger.info(f"首次认识用户: {nickname}")
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
@@ -105,9 +105,9 @@ class HeartFCMessageReceiver:
# 2. 兴趣度计算与更新 # 2. 兴趣度计算与更新
interested_rate, is_mentioned = await _calculate_interest(message) interested_rate, is_mentioned = await _calculate_interest(message)
subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) # type: ignore
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate)) asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
# 3. 日志记录 # 3. 日志记录
@@ -119,7 +119,7 @@ class HeartFCMessageReceiver:
picid_pattern = r"\[picid:([^\]]+)\]" picid_pattern = r"\[picid:([^\]]+)\]"
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text) processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]") logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")

View File

@@ -1,16 +1,18 @@
import asyncio import asyncio
import time import time
from typing import Optional, List, Dict, Tuple
import traceback import traceback
from typing import Optional, List, Dict, Tuple
from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.focus_chat.heartFC_chat import HeartFChatting from src.chat.focus_chat.heartFC_chat import HeartFChatting
from src.chat.normal_chat.normal_chat import NormalChat from src.chat.normal_chat.normal_chat import NormalChat
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.utils.utils import get_chat_type_and_target_info
from src.config.config import global_config
from rich.traceback import install
logger = get_logger("sub_heartflow") logger = get_logger("sub_heartflow")
@@ -40,7 +42,7 @@ class SubHeartflow:
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id) self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
# 兴趣消息集合 # 兴趣消息集合
self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {} self.interest_dict: Dict[str, Tuple[MessageRecv, float, bool]] = {}
# focus模式退出冷却时间管理 # focus模式退出冷却时间管理
self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间 self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间
@@ -297,7 +299,7 @@ class SubHeartflow:
) )
def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool): def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) # type: ignore
# 如果字典长度超过10删除最旧的消息 # 如果字典长度超过10删除最旧的消息
if len(self.interest_dict) > 30: if len(self.interest_dict) > 30:
oldest_key = next(iter(self.interest_dict)) oldest_key = next(iter(self.interest_dict))

View File

@@ -42,7 +42,7 @@ def calculate_information_content(text):
return entropy return entropy
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else
"""计算余弦相似度""" """计算余弦相似度"""
dot_product = np.dot(v1, v2) dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1) norm1 = np.linalg.norm(v1)
@@ -89,13 +89,12 @@ class MemoryGraph:
if not isinstance(self.G.nodes[concept]["memory_items"], list): if not isinstance(self.G.nodes[concept]["memory_items"], list):
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]] self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory) self.G.nodes[concept]["memory_items"].append(memory)
# 更新最后修改时间
self.G.nodes[concept]["last_modified"] = current_time
else: else:
self.G.nodes[concept]["memory_items"] = [memory] self.G.nodes[concept]["memory_items"] = [memory]
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time # 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
if "created_time" not in self.G.nodes[concept]: if "created_time" not in self.G.nodes[concept]:
self.G.nodes[concept]["created_time"] = current_time self.G.nodes[concept]["created_time"] = current_time
# 更新最后修改时间
self.G.nodes[concept]["last_modified"] = current_time self.G.nodes[concept]["last_modified"] = current_time
else: else:
# 如果是新节点,创建新的记忆列表 # 如果是新节点,创建新的记忆列表
@@ -108,11 +107,7 @@ class MemoryGraph:
def get_dot(self, concept): def get_dot(self, concept):
# 检查节点是否存在于图中 # 检查节点是否存在于图中
if concept in self.G: return (concept, self.G.nodes[concept]) if concept in self.G else None
# 从图中获取节点数据
node_data = self.G.nodes[concept]
return concept, node_data
return None
def get_related_item(self, topic, depth=1): def get_related_item(self, topic, depth=1):
if topic not in self.G: if topic not in self.G:
@@ -139,8 +134,7 @@ class MemoryGraph:
if depth >= 2: if depth >= 2:
# 获取相邻节点的记忆项 # 获取相邻节点的记忆项
for neighbor in neighbors: for neighbor in neighbors:
node_data = self.get_dot(neighbor) if node_data := self.get_dot(neighbor):
if node_data:
concept, data = node_data concept, data = node_data
if "memory_items" in data: if "memory_items" in data:
memory_items = data["memory_items"] memory_items = data["memory_items"]
@@ -194,9 +188,9 @@ class MemoryGraph:
class Hippocampus: class Hippocampus:
def __init__(self): def __init__(self):
self.memory_graph = MemoryGraph() self.memory_graph = MemoryGraph()
self.model_summary = None self.model_summary: LLMRequest = None # type: ignore
self.entorhinal_cortex = None self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
self.parahippocampal_gyrus = None self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
def initialize(self): def initialize(self):
# 初始化子组件 # 初始化子组件
@@ -218,7 +212,7 @@ class Hippocampus:
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
# 使用集合来去重,避免排序 # 使用集合来去重,避免排序
unique_items = set(str(item) for item in memory_items) unique_items = {str(item) for item in memory_items}
# 使用frozenset来保证顺序一致性 # 使用frozenset来保证顺序一致性
content = f"{concept}:{frozenset(unique_items)}" content = f"{concept}:{frozenset(unique_items)}"
return hash(content) return hash(content)
@@ -231,6 +225,7 @@ class Hippocampus:
@staticmethod @staticmethod
def find_topic_llm(text, topic_num): def find_topic_llm(text, topic_num):
# sourcery skip: inline-immediately-returned-variable
prompt = ( prompt = (
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
@@ -240,6 +235,7 @@ class Hippocampus:
@staticmethod @staticmethod
def topic_what(text, topic): def topic_what(text, topic):
# sourcery skip: inline-immediately-returned-variable
# 不再需要 time_info 参数 # 不再需要 time_info 参数
prompt = ( prompt = (
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,' f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
@@ -480,9 +476,7 @@ class Hippocampus:
top_memories = memory_similarities[:max_memory_length] top_memories = memory_similarities[:max_memory_length]
# 添加到结果中 # 添加到结果中
for memory, similarity in top_memories: all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
all_memories.append((node, [memory], similarity))
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
else: else:
logger.info("节点没有记忆") logger.info("节点没有记忆")
@@ -646,9 +640,7 @@ class Hippocampus:
top_memories = memory_similarities[:max_memory_length] top_memories = memory_similarities[:max_memory_length]
# 添加到结果中 # 添加到结果中
for memory, similarity in top_memories: all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
all_memories.append((node, [memory], similarity))
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
else: else:
logger.info("节点没有记忆") logger.info("节点没有记忆")
@@ -823,11 +815,11 @@ class EntorhinalCortex:
logger.debug(f"回忆往事: {readable_timestamp}") logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = [] chat_samples = []
for timestamp in timestamps: for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet if messages := self.random_get_msg_snippet(
messages = self.random_get_msg_snippet( timestamp,
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg global_config.memory.memory_build_sample_length,
) max_memorized_time_per_msg,
if messages: ):
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}") logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}")
chat_samples.append(messages) chat_samples.append(messages)
@@ -838,6 +830,7 @@ class EntorhinalCortex:
@staticmethod @staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None: def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
try_count = 0 try_count = 0
time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟 time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟
@@ -847,22 +840,21 @@ class EntorhinalCortex:
timestamp_start = target_timestamp timestamp_start = target_timestamp
timestamp_end = target_timestamp + time_window_seconds timestamp_end = target_timestamp + time_window_seconds
chosen_message = get_raw_msg_by_timestamp( if chosen_message := get_raw_msg_by_timestamp(
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest" timestamp_start=timestamp_start,
) timestamp_end=timestamp_end,
limit=1,
limit_mode="earliest",
):
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
if chosen_message: if messages := get_raw_msg_by_timestamp_with_chat(
chat_id = chosen_message[0].get("chat_id")
messages = get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start, timestamp_start=timestamp_start,
timestamp_end=timestamp_end, timestamp_end=timestamp_end,
limit=chat_size, limit=chat_size,
limit_mode="earliest", limit_mode="earliest",
chat_id=chat_id, chat_id=chat_id,
) ):
if messages:
# 检查获取到的所有消息是否都未达到最大记忆次数 # 检查获取到的所有消息是否都未达到最大记忆次数
all_valid = True all_valid = True
for message in messages: for message in messages:
@@ -975,7 +967,7 @@ class EntorhinalCortex:
).execute() ).execute()
if nodes_to_delete: if nodes_to_delete:
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
# 处理边的信息 # 处理边的信息
db_edges = list(GraphEdges.select()) db_edges = list(GraphEdges.select())
@@ -1114,7 +1106,7 @@ class EntorhinalCortex:
node_start = time.time() node_start = time.time()
if nodes_data: if nodes_data:
batch_size = 500 # 增加批量大小 batch_size = 500 # 增加批量大小
with GraphNodes._meta.database.atomic(): with GraphNodes._meta.database.atomic(): # type: ignore
for i in range(0, len(nodes_data), batch_size): for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size] batch = nodes_data[i : i + batch_size]
GraphNodes.insert_many(batch).execute() GraphNodes.insert_many(batch).execute()
@@ -1125,7 +1117,7 @@ class EntorhinalCortex:
edge_start = time.time() edge_start = time.time()
if edges_data: if edges_data:
batch_size = 500 # 增加批量大小 batch_size = 500 # 增加批量大小
with GraphEdges._meta.database.atomic(): with GraphEdges._meta.database.atomic(): # type: ignore
for i in range(0, len(edges_data), batch_size): for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size] batch = edges_data[i : i + batch_size]
GraphEdges.insert_many(batch).execute() GraphEdges.insert_many(batch).execute()
@@ -1489,9 +1481,7 @@ class ParahippocampalGyrus:
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 --- # --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
last_modified = node_data.get("last_modified", current_time) last_modified = node_data.get("last_modified", current_time)
# 条件1检查是否长时间未修改 (超过24小时) # 条件1检查是否长时间未修改 (超过24小时)
if current_time - last_modified > 3600 * 24: if current_time - last_modified > 3600 * 24 and memory_items:
# 条件2再次确认节点包含记忆项理论上已确认但作为保险
if memory_items:
current_count = len(memory_items) current_count = len(memory_items)
# 如果列表非空,才进行随机选择 # 如果列表非空,才进行随机选择
if current_count > 0: if current_count > 0:
@@ -1669,7 +1659,7 @@ class ParahippocampalGyrus:
class HippocampusManager: class HippocampusManager:
def __init__(self): def __init__(self):
self._hippocampus = None self._hippocampus: Hippocampus = None # type: ignore
self._initialized = False self._initialized = False
def initialize(self): def initialize(self):

View File

@@ -13,7 +13,7 @@ from json_repair import repair_json
logger = get_logger("memory_activator") logger = get_logger("memory_activator")
def get_keywords_from_json(json_str): def get_keywords_from_json(json_str) -> List:
""" """
从JSON字符串中提取关键词列表 从JSON字符串中提取关键词列表
@@ -28,15 +28,8 @@ def get_keywords_from_json(json_str):
fixed_json = repair_json(json_str) fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象 # 如果repair_json返回的是字符串需要解析为Python对象
if isinstance(fixed_json, str): result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
result = json.loads(fixed_json) return result.get("keywords", [])
else:
# 如果repair_json直接返回了字典对象直接使用
result = fixed_json
# 提取关键词
keywords = result.get("keywords", [])
return keywords
except Exception as e: except Exception as e:
logger.error(f"解析关键词JSON失败: {e}") logger.error(f"解析关键词JSON失败: {e}")
return [] return []

View File

@@ -1,52 +1,10 @@
import numpy as np import numpy as np
from scipy import stats
from datetime import datetime, timedelta from datetime import datetime, timedelta
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
class DistributionVisualizer:
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
"""
初始化分布可视化器
参数:
mean (float): 期望均值
std (float): 标准差
skewness (float): 偏度
sample_size (int): 样本大小
"""
self.mean = mean
self.std = std
self.skewness = skewness
self.sample_size = sample_size
self.samples = None
def generate_samples(self):
"""生成具有指定参数的样本"""
if self.skewness == 0:
# 对于无偏度的情况,直接使用正态分布
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
else:
# 使用 scipy.stats 生成具有偏度的分布
self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size)
def get_weighted_samples(self):
"""获取加权后的样本数列"""
if self.samples is None:
self.generate_samples()
# 将样本值乘以样本大小
return self.samples * self.sample_size
def get_statistics(self):
"""获取分布的统计信息"""
if self.samples is None:
self.generate_samples()
return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)}
class MemoryBuildScheduler: class MemoryBuildScheduler:
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50): def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
""" """

View File

@@ -1,23 +1,25 @@
import traceback import traceback
import os import os
import re
from typing import Dict, Any from typing import Dict, Any
from maim_message import UserInfo
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器 from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.experimental.only_message_process import MessageProcessor
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.experimental.PFC.pfc_manager import PFCManager
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.config.config import global_config from src.experimental.only_message_process import MessageProcessor
from src.experimental.PFC.pfc_manager import PFCManager
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统 from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_command import BaseCommand
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import ChatStream
import re
# 定义日志配置 # 定义日志配置
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录) # 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
@@ -184,8 +186,8 @@ class ChatBot:
get_chat_manager().register_message(message) get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream( chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, platform=message.message_info.platform, # type: ignore
user_info=user_info, user_info=user_info, # type: ignore
group_info=group_info, group_info=group_info,
) )
@@ -195,8 +197,10 @@ class ChatBot:
await message.process() await message.process()
# 过滤检查 # 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, chat, user_info message.raw_message, # type: ignore
chat,
user_info, # type: ignore
): ):
return return

View File

@@ -3,18 +3,17 @@ import hashlib
import time import time
import copy import copy
from typing import Dict, Optional, TYPE_CHECKING from typing import Dict, Optional, TYPE_CHECKING
from rich.traceback import install
from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示 # 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING: if TYPE_CHECKING:
from .message import MessageRecv from .message import MessageRecv
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -28,7 +27,7 @@ class ChatMessageContext:
def __init__(self, message: "MessageRecv"): def __init__(self, message: "MessageRecv"):
self.message = message self.message = message
def get_template_name(self) -> str: def get_template_name(self) -> Optional[str]:
"""获取模板名称""" """获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name return self.message.message_info.template_info.template_name
@@ -41,10 +40,10 @@ class ChatMessageContext:
def check_types(self, types: list) -> bool: def check_types(self, types: list) -> bool:
# sourcery skip: invert-any-all, use-any, use-next # sourcery skip: invert-any-all, use-any, use-next
"""检查消息类型""" """检查消息类型"""
if not self.message.message_info.format_info.accept_format: if not self.message.message_info.format_info.accept_format: # type: ignore
return False return False
for t in types: for t in types:
if t not in self.message.message_info.format_info.accept_format: if t not in self.message.message_info.format_info.accept_format: # type: ignore
return False return False
return True return True
@@ -68,7 +67,7 @@ class ChatStream:
platform: str, platform: str,
user_info: UserInfo, user_info: UserInfo,
group_info: Optional[GroupInfo] = None, group_info: Optional[GroupInfo] = None,
data: dict = None, data: Optional[dict] = None,
): ):
self.stream_id = stream_id self.stream_id = stream_id
self.platform = platform self.platform = platform
@@ -77,7 +76,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time() self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False self.saved = False
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息 self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""转换为字典格式""" """转换为字典格式"""
@@ -99,7 +98,7 @@ class ChatStream:
return cls( return cls(
stream_id=data["stream_id"], stream_id=data["stream_id"],
platform=data["platform"], platform=data["platform"],
user_info=user_info, user_info=user_info, # type: ignore
group_info=group_info, group_info=group_info,
data=data, data=data,
) )
@@ -163,8 +162,8 @@ class ChatManager:
def register_message(self, message: "MessageRecv"): def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流""" """注册消息到聊天流"""
stream_id = self._generate_stream_id( stream_id = self._generate_stream_id(
message.message_info.platform, message.message_info.platform, # type: ignore
message.message_info.user_info, message.message_info.user_info, # type: ignore
message.message_info.group_info, message.message_info.group_info,
) )
self.last_messages[stream_id] = message self.last_messages[stream_id] = message
@@ -185,10 +184,7 @@ class ChatManager:
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str: def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
"""获取聊天流ID""" """获取聊天流ID"""
if is_group: components = [platform, id] if is_group else [platform, id, "private"]
components = [platform, str(id)]
else:
components = [platform, str(id), "private"]
key = "_".join(components) key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()

View File

@@ -1,17 +1,15 @@
import time import time
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Any, TYPE_CHECKING
import urllib3 import urllib3
from src.common.logger import get_logger from abc import abstractmethod
from dataclasses import dataclass
if TYPE_CHECKING:
from .chat_stream import ChatStream
from ..utils.utils_image import get_image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from rich.traceback import install from rich.traceback import install
from typing import Optional, Any
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.chat.utils.utils_image import get_image_manager
from .chat_stream import ChatStream
install(extra_lines=3) install(extra_lines=3)
@@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass @dataclass
class Message(MessageBase): class Message(MessageBase):
chat_stream: "ChatStream" = None chat_stream: "ChatStream" = None # type: ignore
reply: Optional["Message"] = None reply: Optional["Message"] = None
processed_plain_text: str = "" processed_plain_text: str = ""
memorized_times: int = 0 memorized_times: int = 0
@@ -55,7 +53,7 @@ class Message(MessageBase):
) )
# 调用父类初始化 # 调用父类初始化
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
self.chat_stream = chat_stream self.chat_stream = chat_stream
# 文本处理相关属性 # 文本处理相关属性
@@ -66,6 +64,7 @@ class Message(MessageBase):
self.reply = reply self.reply = reply
async def _process_message_segments(self, segment: Seg) -> str: async def _process_message_segments(self, segment: Seg) -> str:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
"""递归处理消息段,转换为文字描述 """递归处理消息段,转换为文字描述
Args: Args:
@@ -78,13 +77,13 @@ class Message(MessageBase):
# 处理消息段列表 # 处理消息段列表
segments_text = [] segments_text = []
for seg in segment.data: for seg in segment.data:
processed = await self._process_message_segments(seg) processed = await self._process_message_segments(seg) # type: ignore
if processed: if processed:
segments_text.append(processed) segments_text.append(processed)
return " ".join(segments_text) return " ".join(segments_text)
else: else:
# 处理单个消息段 # 处理单个消息段
return await self._process_single_segment(segment) return await self._process_single_segment(segment) # type: ignore
@abstractmethod @abstractmethod
async def _process_single_segment(self, segment): async def _process_single_segment(self, segment):
@@ -138,7 +137,7 @@ class MessageRecv(Message):
if segment.type == "text": if segment.type == "text":
self.is_picid = False self.is_picid = False
self.is_emoji = False self.is_emoji = False
return segment.data return segment.data # type: ignore
elif segment.type == "image": elif segment.type == "image":
# 如果是base64图片数据 # 如果是base64图片数据
if isinstance(segment.data, str): if isinstance(segment.data, str):
@@ -160,7 +159,7 @@ class MessageRecv(Message):
elif segment.type == "mention_bot": elif segment.type == "mention_bot":
self.is_picid = False self.is_picid = False
self.is_emoji = False self.is_emoji = False
self.is_mentioned = float(segment.data) self.is_mentioned = float(segment.data) # type: ignore
return "" return ""
elif segment.type == "priority_info": elif segment.type == "priority_info":
self.is_picid = False self.is_picid = False
@@ -186,7 +185,7 @@ class MessageRecv(Message):
"""生成详细文本,包含时间和用户信息""" """生成详细文本,包含时间和用户信息"""
timestamp = self.message_info.time timestamp = self.message_info.time
user_info = self.message_info.user_info user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}] {name}: {self.processed_plain_text}\n" return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@@ -234,7 +233,7 @@ class MessageProcessBase(Message):
""" """
try: try:
if seg.type == "text": if seg.type == "text":
return seg.data return seg.data # type: ignore
elif seg.type == "image": elif seg.type == "image":
# 如果是base64图片数据 # 如果是base64图片数据
if isinstance(seg.data, str): if isinstance(seg.data, str):
@@ -250,7 +249,7 @@ class MessageProcessBase(Message):
if self.reply and hasattr(self.reply, "processed_plain_text"): if self.reply and hasattr(self.reply, "processed_plain_text"):
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}") # print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
# print(f"reply: {self.reply}") # print(f"reply: {self.reply}")
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
return None return None
else: else:
return f"[{seg.type}:{str(seg.data)}]" return f"[{seg.type}:{str(seg.data)}]"
@@ -264,7 +263,7 @@ class MessageProcessBase(Message):
timestamp = self.message_info.time timestamp = self.message_info.time
user_info = self.message_info.user_info user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n" return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@@ -313,7 +312,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False, is_emoji: bool = False,
thinking_start_time: float = 0, thinking_start_time: float = 0,
apply_set_reply_logic: bool = False, apply_set_reply_logic: bool = False,
reply_to: str = None, reply_to: str = None, # type: ignore
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(
@@ -344,7 +343,7 @@ class MessageSending(MessageProcessBase):
self.message_segment = Seg( self.message_segment = Seg(
type="seglist", type="seglist",
data=[ data=[
Seg(type="reply", data=self.reply.message_info.message_id), Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment, self.message_segment,
], ],
) )
@@ -364,10 +363,10 @@ class MessageSending(MessageProcessBase):
) -> "MessageSending": ) -> "MessageSending":
"""从思考状态消息创建发送状态消息""" """从思考状态消息创建发送状态消息"""
return cls( return cls(
message_id=thinking.message_info.message_id, message_id=thinking.message_info.message_id, # type: ignore
chat_stream=thinking.chat_stream, chat_stream=thinking.chat_stream,
message_segment=message_segment, message_segment=message_segment,
bot_user_info=thinking.message_info.user_info, bot_user_info=thinking.message_info.user_info, # type: ignore
reply=thinking.reply, reply=thinking.reply,
is_head=is_head, is_head=is_head,
is_emoji=is_emoji, is_emoji=is_emoji,
@@ -399,13 +398,11 @@ class MessageSet:
if not isinstance(message, MessageSending): if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息") raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message) self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time) self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> Optional[MessageSending]: def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息""" """通过索引获取消息"""
if 0 <= index < len(self.messages): return self.messages[index] if 0 <= index < len(self.messages) else None
return self.messages[index]
return None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]: def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息""" """获取最接近指定时间的消息"""
@@ -415,7 +412,7 @@ class MessageSet:
left, right = 0, len(self.messages) - 1 left, right = 0, len(self.messages) - 1
while left < right: while left < right:
mid = (left + right) // 2 mid = (left + right) // 2
if self.messages[mid].message_info.time < target_time: if self.messages[mid].message_info.time < target_time: # type: ignore
left = mid + 1 left = mid + 1
else: else:
right = mid right = mid

View File

@@ -1,21 +1,16 @@
# src/plugins/chat/message_sender.py
import asyncio import asyncio
import time import time
from asyncio import Task from asyncio import Task
from typing import Union from typing import Union
from src.common.message.api import get_global_api
# from ...common.database import db # 数据库依赖似乎不需要了,注释掉
from .message import MessageSending, MessageThinking, MessageSet
from src.chat.message_receive.storage import MessageStorage
from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between
from src.common.logger import get_logger
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) from src.common.logger import get_logger
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message, calculate_typing_time, count_messages_between
from .message import MessageSending, MessageThinking, MessageSet
install(extra_lines=3)
logger = get_logger("sender") logger = get_logger("sender")
@@ -79,9 +74,10 @@ class MessageContainer:
def count_thinking_messages(self) -> int: def count_thinking_messages(self) -> int:
"""计算当前容器中思考消息的数量""" """计算当前容器中思考消息的数量"""
return sum(1 for msg in self.messages if isinstance(msg, MessageThinking)) return sum(isinstance(msg, MessageThinking) for msg in self.messages)
def get_timeout_sending_messages(self) -> list[MessageSending]: def get_timeout_sending_messages(self) -> list[MessageSending]:
# sourcery skip: merge-nested-ifs
"""获取所有超时的MessageSending对象思考时间超过20秒按thinking_start_time排序 - 从旧 sender 合并""" """获取所有超时的MessageSending对象思考时间超过20秒按thinking_start_time排序 - 从旧 sender 合并"""
current_time = time.time() current_time = time.time()
timeout_messages = [] timeout_messages = []
@@ -230,9 +226,7 @@ class MessageManager:
f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}" f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}"
) )
logger.exception("详细错误信息:") logger.exception("详细错误信息:")
# 考虑是否移除出错的消息,防止无限循环 if container.remove_message(message):
removed = container.remove_message(message)
if removed:
logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。") logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。")
async def _process_chat_messages(self, chat_id: str): async def _process_chat_messages(self, chat_id: str):
@@ -261,10 +255,7 @@ class MessageManager:
# --- 处理发送消息 --- # --- 处理发送消息 ---
await self._handle_sending_message(container, message_earliest) await self._handle_sending_message(container, message_earliest)
# --- 处理超时发送消息 (来自旧 sender) --- if timeout_sending_messages := container.get_timeout_sending_messages():
# 在处理完最早的消息后,检查是否有超时的发送消息
timeout_sending_messages = container.get_timeout_sending_messages()
if timeout_sending_messages:
logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息") logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息")
for msg in timeout_sending_messages: for msg in timeout_sending_messages:
# 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一) # 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一)
@@ -274,6 +265,7 @@ class MessageManager:
await self._handle_sending_message(container, msg) # 复用处理逻辑 await self._handle_sending_message(container, msg) # 复用处理逻辑
async def _start_processor_loop(self): async def _start_processor_loop(self):
# sourcery skip: list-comprehension, move-assign-in-block, use-named-expression
"""消息处理器主循环""" """消息处理器主循环"""
while self._running: while self._running:
tasks = [] tasks = []
@@ -282,10 +274,7 @@ class MessageManager:
# 创建 keys 的快照以安全迭代 # 创建 keys 的快照以安全迭代
chat_ids = list(self.containers.keys()) chat_ids = list(self.containers.keys())
for chat_id in chat_ids: tasks.extend(asyncio.create_task(self._process_chat_messages(chat_id)) for chat_id in chat_ids)
# 为每个 chat_id 创建一个处理任务
tasks.append(asyncio.create_task(self._process_chat_messages(chat_id)))
if tasks: if tasks:
try: try:
# 等待当前批次的所有任务完成 # 等待当前批次的所有任务完成

View File

@@ -1,11 +1,10 @@
import re import re
from typing import Union from typing import Union
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance from src.common.database.database_model import Messages, RecalledMessages, Images
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models
from src.common.logger import get_logger from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
logger = get_logger("message_storage") logger = get_logger("message_storage")
@@ -44,7 +43,7 @@ class MessageStorage:
reply_to = "" reply_to = ""
chat_info_dict = chat_stream.to_dict() chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# message_id 现在是 TextField直接使用字符串值 # message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id msg_id = message.message_info.message_id
@@ -56,7 +55,7 @@ class MessageStorage:
Messages.create( Messages.create(
message_id=msg_id, message_id=msg_id,
time=float(message.message_info.time), time=float(message.message_info.time), # type: ignore
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
# Flattened chat_info # Flattened chat_info
reply_to=reply_to, reply_to=reply_to,
@@ -103,7 +102,7 @@ class MessageStorage:
try: try:
# Assuming input 'time' is a string timestamp that can be converted to float # Assuming input 'time' is a string timestamp that can be converted to float
current_time_float = float(time) current_time_float = float(time)
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore
except Exception: except Exception:
logger.exception("删除撤回消息失败") logger.exception("删除撤回消息失败")
@@ -115,22 +114,19 @@ class MessageStorage:
"""更新最新一条匹配消息的message_id""" """更新最新一条匹配消息的message_id"""
try: try:
if message.message_segment.type == "notify": if message.message_segment.type == "notify":
mmc_message_id = message.message_segment.data.get("echo") mmc_message_id = message.message_segment.data.get("echo") # type: ignore
qq_message_id = message.message_segment.data.get("actual_id") qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
else: else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}") logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return return
if not qq_message_id: if not qq_message_id:
logger.info("消息不存在message_id无法更新") logger.info("消息不存在message_id无法更新")
return return
# 查询最新一条匹配消息 if matched_message := (
matched_message = (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first() Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
) ):
if matched_message:
# 更新找到的消息记录 # 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else: else:
logger.debug("未找到匹配的消息") logger.debug("未找到匹配的消息")
@@ -155,10 +151,7 @@ class MessageStorage:
image_record = ( image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first() Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
) )
if image_record: return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
return f"[picid:{image_record.image_id}]"
else:
return match.group(0) # 保持原样
except Exception: except Exception:
return match.group(0) return match.group(0)

View File

@@ -1,16 +1,17 @@
import asyncio import asyncio
from typing import Dict, Optional # 重新导入类型
from src.chat.message_receive.message import MessageSending, MessageThinking
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.common.logger import get_logger
from src.chat.utils.utils import calculate_typing_time
from rich.traceback import install
import traceback import traceback
install(extra_lines=3) from typing import Dict, Optional
from rich.traceback import install
from src.common.message.api import get_global_api
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageSending, MessageThinking
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
install(extra_lines=3)
logger = get_logger("sender") logger = get_logger("sender")
@@ -86,10 +87,10 @@ class HeartFCSender:
""" """
if not message.chat_stream: if not message.chat_stream:
logger.error("消息缺少 chat_stream无法发送") logger.error("消息缺少 chat_stream无法发送")
raise Exception("消息缺少 chat_stream无法发送") raise ValueError("消息缺少 chat_stream无法发送")
if not message.message_info or not message.message_info.message_id: if not message.message_info or not message.message_info.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送") logger.error("消息缺少 message_info 或 message_id无法发送")
raise Exception("消息缺少 message_info 或 message_id无法发送") raise ValueError("消息缺少 message_info 或 message_id无法发送")
chat_id = message.chat_stream.stream_id chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id message_id = message.message_info.message_id

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import time import time
import traceback import traceback
from random import random from random import random
from typing import List, Optional, Dict from typing import List, Optional, Dict
from maim_message import UserInfo, Seg from maim_message import UserInfo, Seg
@@ -40,7 +41,7 @@ class NormalChat:
def __init__( def __init__(
self, self,
chat_stream: ChatStream, chat_stream: ChatStream,
interest_dict: dict = None, interest_dict: Optional[Dict] = None,
on_switch_to_focus_callback=None, on_switch_to_focus_callback=None,
get_cooldown_progress_callback=None, get_cooldown_progress_callback=None,
): ):
@@ -147,10 +148,7 @@ class NormalChat:
while not self._disabled: while not self._disabled:
try: try:
if not self.priority_manager.is_empty(): if not self.priority_manager.is_empty():
# 获取最高优先级的消息 if message := self.priority_manager.get_highest_priority_message():
message = self.priority_manager.get_highest_priority_message()
if message:
logger.info( logger.info(
f"[{self.stream_name}] 从队列中取出消息进行处理: User {message.message_info.user_info.user_id}, Time: {time.strftime('%H:%M:%S', time.localtime(message.message_info.time))}" f"[{self.stream_name}] 从队列中取出消息进行处理: User {message.message_info.user_info.user_id}, Time: {time.strftime('%H:%M:%S', time.localtime(message.message_info.time))}"
) )

View File

@@ -53,7 +53,7 @@ class PriorityManager:
""" """
添加新消息到合适的队列中。 添加新消息到合适的队列中。
""" """
user_id = message.message_info.user_info.user_id user_id = message.message_info.user_info.user_id # type: ignore
is_vip = message.priority_info.get("message_type") == "vip" if message.priority_info else False is_vip = message.priority_info.get("message_type") == "vip" if message.priority_info else False
message_priority = message.priority_info.get("message_priority", 0.0) if message.priority_info else 0.0 message_priority = message.priority_info.get("message_priority", 0.0) if message.priority_info else 0.0

View File

@@ -35,9 +35,7 @@ class ClassicalWillingManager(BaseWillingManager):
self.chat_reply_willing[chat_id] = min(current_willing, 3.0) self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1) return min(max((current_willing - 0.5), 0.01) * 2, 1)
return reply_probability
async def before_generate_reply_handle(self, message_id): async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id chat_id = self.ongoing_messages[message_id].chat_id

View File

@@ -1,14 +1,16 @@
from src.common.logger import get_logger import importlib
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Optional
from rich.traceback import install
from dataclasses import dataclass from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from abc import ABC, abstractmethod
import importlib
from typing import Dict, Optional
import asyncio
from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -92,8 +94,8 @@ class BaseWillingManager(ABC):
self.logger = logger self.logger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float): def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
self.ongoing_messages[message.message_info.message_id] = WillingInfo( self.ongoing_messages[message.message_info.message_id] = WillingInfo( # type: ignore
message=message, message=message,
chat=chat, chat=chat,
person_info_manager=get_person_info_manager(), person_info_manager=get_person_info_manager(),

View File

@@ -27,14 +27,11 @@ class ActionManager:
# 当前正在使用的动作集合,默认加载默认动作 # 当前正在使用的动作集合,默认加载默认动作
self._using_actions: Dict[str, ActionInfo] = {} self._using_actions: Dict[str, ActionInfo] = {}
# 默认动作集,仅作为快照,用于恢复默认
self._default_actions: Dict[str, ActionInfo] = {}
# 加载插件动作 # 加载插件动作
self._load_plugin_actions() self._load_plugin_actions()
# 初始化时将默认动作加载到使用中的动作 # 初始化时将默认动作加载到使用中的动作
self._using_actions = self._default_actions.copy() self._using_actions = component_registry.get_default_actions()
def _load_plugin_actions(self) -> None: def _load_plugin_actions(self) -> None:
""" """
@@ -52,7 +49,7 @@ class ActionManager:
"""从插件系统的component_registry加载Action组件""" """从插件系统的component_registry加载Action组件"""
try: try:
# 获取所有Action组件 # 获取所有Action组件
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
for action_name, action_info in action_components.items(): for action_name, action_info in action_components.items():
if action_name in self._registered_actions: if action_name in self._registered_actions:
@@ -61,10 +58,6 @@ class ActionManager:
self._registered_actions[action_name] = action_info self._registered_actions[action_name] = action_info
# 如果启用,也添加到默认动作集
if action_info.enabled:
self._default_actions[action_name] = action_info
logger.debug( logger.debug(
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})" f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
) )
@@ -106,7 +99,9 @@ class ActionManager:
""" """
try: try:
# 获取组件类 - 明确指定查询Action类型 # 获取组件类 - 明确指定查询Action类型
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION) component_class: Type[BaseAction] = component_registry.get_component_class(
action_name, ComponentType.ACTION
) # type: ignore
if not component_class: if not component_class:
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}") logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
return None return None
@@ -146,10 +141,6 @@ class ActionManager:
"""获取所有已注册的动作集""" """获取所有已注册的动作集"""
return self._registered_actions.copy() return self._registered_actions.copy()
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
def get_using_actions(self) -> Dict[str, ActionInfo]: def get_using_actions(self) -> Dict[str, ActionInfo]:
"""获取当前正在使用的动作集合""" """获取当前正在使用的动作集合"""
return self._using_actions.copy() return self._using_actions.copy()
@@ -217,31 +208,31 @@ class ActionManager:
logger.debug(f"已从使用集中移除动作 {action_name}") logger.debug(f"已从使用集中移除动作 {action_name}")
return True return True
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: # def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
""" # """
添加新的动作到注册集 # 添加新的动作到注册集
Args: # Args:
action_name: 动作名称 # action_name: 动作名称
description: 动作描述 # description: 动作描述
parameters: 动作参数定义,默认为空字典 # parameters: 动作参数定义,默认为空字典
require: 动作依赖项,默认为空列表 # require: 动作依赖项,默认为空列表
Returns: # Returns:
bool: 添加是否成功 # bool: 添加是否成功
""" # """
if action_name in self._registered_actions: # if action_name in self._registered_actions:
return False # return False
if parameters is None: # if parameters is None:
parameters = {} # parameters = {}
if require is None: # if require is None:
require = [] # require = []
action_info = {"description": description, "parameters": parameters, "require": require} # action_info = {"description": description, "parameters": parameters, "require": require}
self._registered_actions[action_name] = action_info # self._registered_actions[action_name] = action_info
return True # return True
def remove_action(self, action_name: str) -> bool: def remove_action(self, action_name: str) -> bool:
"""从注册集移除指定动作""" """从注册集移除指定动作"""
@@ -260,10 +251,9 @@ class ActionManager:
def restore_actions(self) -> None: def restore_actions(self) -> None:
"""恢复到默认动作集""" """恢复到默认动作集"""
logger.debug( actions_to_restore = list(self._using_actions.keys())
f"恢复动作集: 从 {list(self._using_actions.keys())} 恢复到默认动作集 {list(self._default_actions.keys())}" self._using_actions = component_registry.get_default_actions()
) logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
self._using_actions = self._default_actions.copy()
def add_system_action_if_needed(self, action_name: str) -> bool: def add_system_action_if_needed(self, action_name: str) -> bool:
""" """
@@ -293,4 +283,4 @@ class ActionManager:
""" """
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_class(action_name) return component_registry.get_component_class(action_name) # type: ignore

View File

@@ -2,7 +2,7 @@ import random
import asyncio import asyncio
import hashlib import hashlib
import time import time
from typing import List, Any, Dict from typing import List, Any, Dict, TYPE_CHECKING
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -13,6 +13,9 @@ from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("action_manager") logger = get_logger("action_manager")
@@ -27,7 +30,7 @@ class ActionModifier:
def __init__(self, action_manager: ActionManager, chat_id: str): def __init__(self, action_manager: ActionManager, chat_id: str):
"""初始化动作处理器""" """初始化动作处理器"""
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id) self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.action_manager = action_manager self.action_manager = action_manager
@@ -142,7 +145,7 @@ class ActionModifier:
async def _get_deactivated_actions_by_type( async def _get_deactivated_actions_by_type(
self, self,
actions_with_info: Dict[str, ActionInfo], actions_with_info: Dict[str, ActionInfo],
mode: str = "focus", mode: ChatMode = ChatMode.FOCUS,
chat_content: str = "", chat_content: str = "",
) -> List[tuple[str, str]]: ) -> List[tuple[str, str]]:
""" """
@@ -270,7 +273,7 @@ class ActionModifier:
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果并更新缓存 # 处理结果并更新缓存
for _, (action_name, result) in enumerate(zip(task_names, task_results)): for action_name, result in zip(task_names, task_results):
if isinstance(result, Exception): if isinstance(result, Exception):
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}") logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
results[action_name] = False results[action_name] = False
@@ -286,7 +289,7 @@ class ActionModifier:
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}") logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
# 如果并行执行失败为所有任务返回False # 如果并行执行失败为所有任务返回False
for action_name in tasks_to_run.keys(): for action_name in tasks_to_run:
results[action_name] = False results[action_name] = False
# 清理过期缓存 # 清理过期缓存
@@ -297,10 +300,11 @@ class ActionModifier:
def _cleanup_expired_cache(self, current_time: float): def _cleanup_expired_cache(self, current_time: float):
"""清理过期的缓存条目""" """清理过期的缓存条目"""
expired_keys = [] expired_keys = []
for cache_key, cache_data in self._llm_judge_cache.items(): expired_keys.extend(
if current_time - cache_data["timestamp"] > self._cache_expiry_time: cache_key
expired_keys.append(cache_key) for cache_key, cache_data in self._llm_judge_cache.items()
if current_time - cache_data["timestamp"] > self._cache_expiry_time
)
for key in expired_keys: for key in expired_keys:
del self._llm_judge_cache[key] del self._llm_judge_cache[key]
@@ -379,7 +383,7 @@ class ActionModifier:
def _check_keyword_activation( def _check_keyword_activation(
self, self,
action_name: str, action_name: str,
action_info: Dict[str, Any], action_info: ActionInfo,
chat_content: str = "", chat_content: str = "",
) -> bool: ) -> bool:
""" """
@@ -396,8 +400,8 @@ class ActionModifier:
bool: 是否应该激活此action bool: 是否应该激活此action
""" """
activation_keywords = action_info.get("activation_keywords", []) activation_keywords = action_info.activation_keywords
case_sensitive = action_info.get("keyword_case_sensitive", False) case_sensitive = action_info.keyword_case_sensitive
if not activation_keywords: if not activation_keywords:
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词") logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")

View File

@@ -70,7 +70,7 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
async def plan(self) -> Dict[str, Any]: async def plan(self) -> Dict[str, Any]: # sourcery skip: dict-comprehension
""" """
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
""" """
@@ -162,7 +162,6 @@ class ActionPlanner:
reasoning = parsed_json.get("reasoning", "未提供原因") reasoning = parsed_json.get("reasoning", "未提供原因")
# 将所有其他属性添加到action_data # 将所有其他属性添加到action_data
action_data = {}
for key, value in parsed_json.items(): for key, value in parsed_json.items():
if key not in ["action", "reasoning"]: if key not in ["action", "reasoning"]:
action_data[key] = value action_data[key] = value
@@ -285,7 +284,7 @@ class ActionPlanner:
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}"
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format( return planner_prompt_template.format(
time_block=time_block, time_block=time_block,
by_what=by_what, by_what=by_what,
chat_context_description=chat_context_description, chat_context_description=chat_context_description,
@@ -295,8 +294,6 @@ class ActionPlanner:
moderation_prompt=moderation_prompt_block, moderation_prompt=moderation_prompt_block,
identity_block=identity_block, identity_block=identity_block,
) )
return prompt
except Exception as e: except Exception as e:
logger.error(f"构建 Planner 提示词时出错: {e}") logger.error(f"构建 Planner 提示词时出错: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())

View File

@@ -130,9 +130,7 @@ class DefaultReplyer:
# 提取权重,如果模型配置中没有'weight'键则默认为1.0 # 提取权重,如果模型配置中没有'weight'键则默认为1.0
weights = [config.get("weight", 1.0) for config in configs] weights = [config.get("weight", 1.0) for config in configs]
# random.choices 返回一个列表,我们取第一个元素 return random.choices(population=configs, weights=weights, k=1)[0]
selected_config = random.choices(population=configs, weights=weights, k=1)[0]
return selected_config
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str): async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
"""创建思考消息 (尝试锚定到 anchor_message)""" """创建思考消息 (尝试锚定到 anchor_message)"""
@@ -314,8 +312,7 @@ class DefaultReplyer:
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID跳过信息提取") logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。" return f"你完全不认识{sender}不理解ta的相关信息。"
relation_info = await relationship_fetcher.build_relation_info(person_id, text, chat_history) return await relationship_fetcher.build_relation_info(person_id, text, chat_history)
return relation_info
async def build_expression_habits(self, chat_history, target): async def build_expression_habits(self, chat_history, target):
if not global_config.expression.enable_expression: if not global_config.expression.enable_expression:
@@ -363,15 +360,13 @@ class DefaultReplyer:
target_message=target, chat_history_prompt=chat_history target_message=target, chat_history_prompt=chat_history
) )
if running_memories: if not running_memories:
return ""
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories: for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n" memory_str += f"- {running_memory['content']}\n"
memory_block = memory_str return memory_str
else:
memory_block = ""
return memory_block
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True): async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
"""构建工具信息块 """构建工具信息块
@@ -453,7 +448,7 @@ class DefaultReplyer:
for name, content in result.groupdict().items(): for name, content in result.groupdict().items():
reaction = reaction.replace(f"[{name}]", content) reaction = reaction.replace(f"[{name}]", content)
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}") logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
keywords_reaction_prompt += reaction + "" keywords_reaction_prompt += f"{reaction}"
break break
except re.error as e: except re.error as e:
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}") logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
@@ -477,7 +472,7 @@ class DefaultReplyer:
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_timeout: bool = False, enable_timeout: bool = False,
enable_tool: bool = True, enable_tool: bool = True,
) -> str: ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
""" """
构建回复器上下文 构建回复器上下文
@@ -612,7 +607,7 @@ class DefaultReplyer:
short_impression = ["友好活泼", "人类"] short_impression = ["友好活泼", "人类"]
personality = short_impression[0] personality = short_impression[0]
identity = short_impression[1] identity = short_impression[1]
prompt_personality = personality + "" + identity prompt_personality = f"{personality}{identity}"
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
moderation_prompt_block = ( moderation_prompt_block = (
@@ -660,7 +655,7 @@ class DefaultReplyer:
"chat_target_private2", sender_name=chat_target_name "chat_target_private2", sender_name=chat_target_name
) )
prompt = await global_prompt_manager.format_prompt( return await global_prompt_manager.format_prompt(
template_name, template_name,
expression_habits_block=expression_habits_block, expression_habits_block=expression_habits_block,
chat_target=chat_target_1, chat_target=chat_target_1,
@@ -683,8 +678,6 @@ class DefaultReplyer:
mood_state=mood_prompt, mood_state=mood_prompt,
) )
return prompt
async def build_prompt_rewrite_context( async def build_prompt_rewrite_context(
self, self,
reply_data: Dict[str, Any], reply_data: Dict[str, Any],
@@ -745,7 +738,7 @@ class DefaultReplyer:
short_impression = ["友好活泼", "人类"] short_impression = ["友好活泼", "人类"]
personality = short_impression[0] personality = short_impression[0]
identity = short_impression[1] identity = short_impression[1]
prompt_personality = personality + "" + identity prompt_personality = f"{personality}{identity}"
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
moderation_prompt_block = ( moderation_prompt_block = (
@@ -790,7 +783,7 @@ class DefaultReplyer:
template_name = "default_expressor_prompt" template_name = "default_expressor_prompt"
prompt = await global_prompt_manager.format_prompt( return await global_prompt_manager.format_prompt(
template_name, template_name,
expression_habits_block=expression_habits_block, expression_habits_block=expression_habits_block,
relation_info_block=relation_info, relation_info_block=relation_info,
@@ -807,8 +800,6 @@ class DefaultReplyer:
moderation_prompt=moderation_prompt_block, moderation_prompt=moderation_prompt_block,
) )
return prompt
async def send_response_messages( async def send_response_messages(
self, self,
anchor_message: Optional[MessageRecv], anchor_message: Optional[MessageRecv],
@@ -816,6 +807,7 @@ class DefaultReplyer:
thinking_id: str = "", thinking_id: str = "",
display_message: str = "", display_message: str = "",
) -> Optional[MessageSending]: ) -> Optional[MessageSending]:
# sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream chat = self.chat_stream
chat_id = self.chat_stream.stream_id chat_id = self.chat_stream.stream_id
@@ -849,16 +841,16 @@ class DefaultReplyer:
for i, msg_text in enumerate(response_set): for i, msg_text in enumerate(response_set):
# 为每个消息片段生成唯一ID # 为每个消息片段生成唯一ID
type = msg_text[0] msg_type = msg_text[0]
data = msg_text[1] data = msg_text[1]
if global_config.debug.debug_show_chat_mode and type == "text": if global_config.debug.debug_show_chat_mode and msg_type == "text":
data += "" data += ""
part_message_id = f"{thinking_id}_{i}" part_message_id = f"{thinking_id}_{i}"
message_segment = Seg(type=type, data=data) message_segment = Seg(type=msg_type, data=data)
if type == "emoji": if msg_type == "emoji":
is_emoji = True is_emoji = True
else: else:
is_emoji = False is_emoji = False
@@ -871,7 +863,6 @@ class DefaultReplyer:
display_message=display_message, display_message=display_message,
reply_to=reply_to, reply_to=reply_to,
is_emoji=is_emoji, is_emoji=is_emoji,
thinking_id=thinking_id,
thinking_start_time=thinking_start_time, thinking_start_time=thinking_start_time,
) )
@@ -895,7 +886,7 @@ class DefaultReplyer:
reply_message_ids.append(part_message_id) # 记录我们生成的ID reply_message_ids.append(part_message_id) # 记录我们生成的ID
sent_msg_list.append((type, sent_msg)) sent_msg_list.append((msg_type, sent_msg))
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}") logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
@@ -930,12 +921,9 @@ class DefaultReplyer:
) )
# await anchor_message.process() # await anchor_message.process()
if anchor_message: sender_info = anchor_message.message_info.user_info if anchor_message else None
sender_info = anchor_message.message_info.user_info
else:
sender_info = None
bot_message = MessageSending( return MessageSending(
message_id=message_id, # 使用片段的唯一ID message_id=message_id, # 使用片段的唯一ID
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
bot_user_info=bot_user_info, bot_user_info=bot_user_info,
@@ -948,8 +936,6 @@ class DefaultReplyer:
display_message=display_message, display_message=display_message,
) )
return bot_message
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:
""" """

View File

@@ -1,4 +1,5 @@
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.default_generator import DefaultReplyer
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -8,7 +9,7 @@ logger = get_logger("ReplyerManager")
class ReplyerManager: class ReplyerManager:
def __init__(self): def __init__(self):
self._replyers: Dict[str, DefaultReplyer] = {} self._repliers: Dict[str, DefaultReplyer] = {}
def get_replyer( def get_replyer(
self, self,
@@ -29,17 +30,16 @@ class ReplyerManager:
return None return None
# 如果已有缓存实例,直接返回 # 如果已有缓存实例,直接返回
if stream_id in self._replyers: if stream_id in self._repliers:
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。") logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
return self._replyers[stream_id] return self._repliers[stream_id]
# 如果没有缓存,则创建新实例(首次初始化) # 如果没有缓存,则创建新实例(首次初始化)
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。") logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
target_stream = chat_stream target_stream = chat_stream
if not target_stream: if not target_stream:
chat_manager = get_chat_manager() if chat_manager := get_chat_manager():
if chat_manager:
target_stream = chat_manager.get_stream(stream_id) target_stream = chat_manager.get_stream(stream_id)
if not target_stream: if not target_stream:
@@ -52,7 +52,7 @@ class ReplyerManager:
model_configs=model_configs, # 可以是None此时使用默认模型 model_configs=model_configs, # 可以是None此时使用默认模型
request_type=request_type, request_type=request_type,
) )
self._replyers[stream_id] = replyer self._repliers[stream_id] = replyer
return replyer return replyer

View File

@@ -1,14 +1,15 @@
from src.config.config import global_config
from typing import List, Dict, Any, Tuple # 确保类型提示被导入
import time # 导入 time 模块以获取当前时间 import time # 导入 time 模块以获取当前时间
import random import random
import re import re
from src.common.message_repository import find_messages, count_messages from typing import List, Dict, Any, Tuple, Optional
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable
from rich.traceback import install from rich.traceback import install
from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages
from src.common.database.database_model import ActionRecords from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images from src.common.database.database_model import Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable
install(extra_lines=3) install(extra_lines=3)
@@ -135,7 +136,7 @@ def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list,
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int: def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
""" """
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。 如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -172,7 +173,7 @@ def _build_readable_messages_internal(
merge_messages: bool = False, merge_messages: bool = False,
timestamp_mode: str = "relative", timestamp_mode: str = "relative",
truncate: bool = False, truncate: bool = False,
pic_id_mapping: Dict[str, str] = None, pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1, pic_counter: int = 1,
show_pic: bool = True, show_pic: bool = True,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
@@ -194,7 +195,7 @@ def _build_readable_messages_internal(
if not messages: if not messages:
return "", [], pic_id_mapping or {}, pic_counter return "", [], pic_id_mapping or {}, pic_counter
message_details_raw: List[Tuple[float, str, str]] = [] message_details_raw: List[Tuple[float, str, str, bool]] = []
# 使用传入的映射字典,如果没有则创建新的 # 使用传入的映射字典,如果没有则创建新的
if pic_id_mapping is None: if pic_id_mapping is None:
@@ -225,7 +226,7 @@ def _build_readable_messages_internal(
# 检查是否是动作记录 # 检查是否是动作记录
if msg.get("is_action_record", False): if msg.get("is_action_record", False):
is_action = True is_action = True
timestamp = msg.get("time") timestamp: float = msg.get("time") # type: ignore
content = msg.get("display_message", "") content = msg.get("display_message", "")
# 对于动作记录也处理图片ID # 对于动作记录也处理图片ID
content = process_pic_ids(content) content = process_pic_ids(content)
@@ -249,9 +250,10 @@ def _build_readable_messages_internal(
user_nickname = user_info.get("user_nickname") user_nickname = user_info.get("user_nickname")
user_cardname = user_info.get("user_cardname") user_cardname = user_info.get("user_cardname")
timestamp = msg.get("time") timestamp: float = msg.get("time") # type: ignore
content: str
if msg.get("display_message"): if msg.get("display_message"):
content = msg.get("display_message") content = msg.get("display_message", "")
else: else:
content = msg.get("processed_plain_text", "") # 默认空字符串 content = msg.get("processed_plain_text", "") # 默认空字符串
@@ -271,6 +273,7 @@ def _build_readable_messages_internal(
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
# 根据 replace_bot_name 参数决定是否替换机器人名称 # 根据 replace_bot_name 参数决定是否替换机器人名称
person_name: str
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)" person_name = f"{global_config.bot.nickname}(你)"
else: else:
@@ -289,12 +292,10 @@ def _build_readable_messages_internal(
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content) match = re.search(reply_pattern, content)
if match: if match:
aaa = match.group(1) aaa: str = match[1]
bbb = match.group(2) bbb: str = match[2]
reply_person_id = PersonInfoManager.get_person_id(platform, bbb) reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa
if not reply_person_name:
reply_person_name = aaa
# 在内容前加上回复信息 # 在内容前加上回复信息
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1) content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
@@ -309,17 +310,14 @@ def _build_readable_messages_internal(
aaa = m.group(1) aaa = m.group(1)
bbb = m.group(2) bbb = m.group(2)
at_person_id = PersonInfoManager.get_person_id(platform, bbb) at_person_id = PersonInfoManager.get_person_id(platform, bbb)
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa
if not at_person_name:
at_person_name = aaa
new_content += f"@{at_person_name}" new_content += f"@{at_person_name}"
last_end = m.end() last_end = m.end()
new_content += content[last_end:] new_content += content[last_end:]
content = new_content content = new_content
target_str = "这是QQ的一个功能用于提及某人但没那么明显" target_str = "这是QQ的一个功能用于提及某人但没那么明显"
if target_str in content: if target_str in content and random.random() < 0.6:
if random.random() < 0.6:
content = content.replace(target_str, "") content = content.replace(target_str, "")
if content != "": if content != "":
@@ -470,6 +468,7 @@ def _build_readable_messages_internal(
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
""" """
构建图片映射信息字符串,显示图片的具体描述内容 构建图片映射信息字符串,显示图片的具体描述内容
@@ -518,9 +517,7 @@ async def build_readable_messages_with_list(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate messages, replace_bot_name, merge_messages, timestamp_mode, truncate
) )
# 生成图片映射信息并添加到最前面 if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info:
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, details_list return formatted_string, details_list
@@ -535,7 +532,7 @@ def build_readable_messages(
truncate: bool = False, truncate: bool = False,
show_actions: bool = False, show_actions: bool = False,
show_pic: bool = True, show_pic: bool = True,
) -> str: ) -> str: # sourcery skip: extract-method
""" """
将消息列表转换为可读的文本格式。 将消息列表转换为可读的文本格式。
如果提供了 read_mark则在相应位置插入已读标记。 如果提供了 read_mark则在相应位置插入已读标记。
@@ -658,9 +655,7 @@ def build_readable_messages(
# 组合结果 # 组合结果
result_parts = [] result_parts = []
if pic_mapping_info: if pic_mapping_info:
result_parts.append(pic_mapping_info) result_parts.extend((pic_mapping_info, "\n"))
result_parts.append("\n")
if formatted_before and formatted_after: if formatted_before and formatted_after:
result_parts.extend([formatted_before, read_mark_line, formatted_after]) result_parts.extend([formatted_before, read_mark_line, formatted_after])
elif formatted_before: elif formatted_before:
@@ -733,8 +728,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
platform = msg.get("chat_info_platform") platform = msg.get("chat_info_platform")
user_id = msg.get("user_id") user_id = msg.get("user_id")
_timestamp = msg.get("time") _timestamp = msg.get("time")
content: str = ""
if msg.get("display_message"): if msg.get("display_message"):
content = msg.get("display_message") content = msg.get("display_message", "")
else: else:
content = msg.get("processed_plain_text", "") content = msg.get("processed_plain_text", "")
@@ -829,10 +825,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account: if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue continue
person_id = PersonInfoManager.get_person_id(platform, user_id) if person_id := PersonInfoManager.get_person_id(platform, user_id):
# 只有当获取到有效 person_id 时才添加
if person_id:
person_ids_set.add(person_id) person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回 return list(person_ids_set) # 将集合转换为列表返回

View File

@@ -103,7 +103,7 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji") cached_description = self._get_description_from_db(image_hash, "emoji")
@@ -154,7 +154,7 @@ class ImageManager:
img_obj.description = description img_obj.description = description
img_obj.timestamp = current_timestamp img_obj.timestamp = current_timestamp
img_obj.save() img_obj.save()
except Images.DoesNotExist: except Images.DoesNotExist: # type: ignore
Images.create( Images.create(
emoji_hash=image_hash, emoji_hash=image_hash,
path=file_path, path=file_path,
@@ -204,7 +204,7 @@ class ImageManager:
return f"[图片:{cached_description}]" return f"[图片:{cached_description}]"
# 调用AI获取描述 # 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来请留意其主题直观感受输出为一段平文本最多50字" prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来请留意其主题直观感受输出为一段平文本最多50字"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
@@ -491,7 +491,7 @@ class ImageManager:
return return
# 获取图片格式 # 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt # 构建prompt
prompt = """请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本""" prompt = """请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本"""

View File

@@ -3,7 +3,7 @@ from src.common.database.database import db
from src.common.database.database_model import PersonInfo # 新增导入 from src.common.database.database_model import PersonInfo # 新增导入
import copy import copy
import hashlib import hashlib
from typing import Any, Callable, Dict from typing import Any, Callable, Dict, Union
import datetime import datetime
import asyncio import asyncio
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@@ -84,7 +84,7 @@ class PersonInfoManager:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod @staticmethod
def get_person_id(platform: str, user_id: int): def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id""" """获取唯一id"""
if "-" in platform: if "-" in platform:
platform = platform.split("-")[1] platform = platform.split("-")[1]

View File

@@ -32,10 +32,10 @@ class BaseAction(ABC):
reasoning: str, reasoning: str,
cycle_timers: dict, cycle_timers: dict,
thinking_id: str, thinking_id: str,
chat_stream: ChatStream = None, chat_stream: ChatStream,
log_prefix: str = "", log_prefix: str = "",
shutting_down: bool = False, shutting_down: bool = False,
plugin_config: dict = None, plugin_config: Optional[dict] = None,
**kwargs, **kwargs,
): ):
"""初始化Action组件 """初始化Action组件

View File

@@ -29,7 +29,7 @@ class BaseCommand(ABC):
command_examples: List[str] = [] command_examples: List[str] = []
intercept_message: bool = True # 默认拦截消息,不继续处理 intercept_message: bool = True # 默认拦截消息,不继续处理
def __init__(self, message: MessageRecv, plugin_config: dict = None): def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
"""初始化Command组件 """初始化Command组件
Args: Args:

View File

@@ -66,7 +66,7 @@ class BasePlugin(ABC):
config_section_descriptions: Dict[str, str] = {} config_section_descriptions: Dict[str, str] = {}
def __init__(self, plugin_dir: str = None): def __init__(self, plugin_dir: str):
"""初始化插件 """初始化插件
Args: Args:
@@ -526,7 +526,7 @@ class BasePlugin(ABC):
# 从配置中更新 enable_plugin # 从配置中更新 enable_plugin
if "plugin" in self.config and "enabled" in self.config["plugin"]: if "plugin" in self.config and "enabled" in self.config["plugin"]:
self.enable_plugin = self.config["plugin"]["enabled"] self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}") logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
else: else:
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml") logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")

View File

@@ -81,7 +81,9 @@ class ComponentInfo:
class ActionInfo(ComponentInfo): class ActionInfo(ComponentInfo):
"""动作组件信息""" """动作组件信息"""
action_parameters: Dict[str, str] = field(default_factory=dict) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"} action_parameters: Dict[str, str] = field(
default_factory=dict
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
action_require: List[str] = field(default_factory=list) # 动作需求说明 action_require: List[str] = field(default_factory=list) # 动作需求说明
associated_types: List[str] = field(default_factory=list) # 关联的消息类型 associated_types: List[str] = field(default_factory=list) # 关联的消息类型
# 激活类型相关 # 激活类型相关

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Any, Pattern, Union from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
import re import re
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import ( from src.plugin_system.base.component_types import (
@@ -28,25 +28,25 @@ class ComponentRegistry:
ComponentType.ACTION: {}, ComponentType.ACTION: {},
ComponentType.COMMAND: {}, ComponentType.COMMAND: {},
} }
self._component_classes: Dict[str, Union[BaseCommand, BaseAction]] = {} # 组件名 -> 组件类 self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类
# 插件注册表 # 插件注册表
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息 self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
# Action特定注册表 # Action特定注册表
self._action_registry: Dict[str, BaseAction] = {} # action名 -> action类 self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类
# self._action_descriptions: Dict[str, str] = {} # 启用的action名 -> 描述 self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集即启用的Action集用于重置ActionManager状态
# Command特定注册表 # Command特定注册表
self._command_registry: Dict[str, BaseCommand] = {} # command名 -> command类 self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类
self._command_patterns: Dict[Pattern, BaseCommand] = {} # 编译后的正则 -> command类 self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command类
logger.info("组件注册中心初始化完成") logger.info("组件注册中心初始化完成")
# === 通用组件注册方法 === # === 通用组件注册方法 ===
def register_component( def register_component(
self, component_info: ComponentInfo, component_class: Union[BaseCommand, BaseAction] self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]]
) -> bool: ) -> bool:
"""注册组件 """注册组件
@@ -88,9 +88,9 @@ class ComponentRegistry:
# 根据组件类型进行特定注册(使用原始名称) # 根据组件类型进行特定注册(使用原始名称)
if component_type == ComponentType.ACTION: if component_type == ComponentType.ACTION:
self._register_action_component(component_info, component_class) self._register_action_component(component_info, component_class) # type: ignore
elif component_type == ComponentType.COMMAND: elif component_type == ComponentType.COMMAND:
self._register_command_component(component_info, component_class) self._register_command_component(component_info, component_class) # type: ignore
logger.debug( logger.debug(
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' " f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
@@ -98,7 +98,7 @@ class ComponentRegistry:
) )
return True return True
def _register_action_component(self, action_info: ActionInfo, action_class: BaseAction): def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]):
# -------------------------------- NEED REFACTORING -------------------------------- # -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR ------------------------------------- # -------------------------------- LOGIC ERROR -------------------------------------
"""注册Action组件到Action特定注册表""" """注册Action组件到Action特定注册表"""
@@ -106,11 +106,10 @@ class ComponentRegistry:
self._action_registry[action_name] = action_class self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集 # 如果启用,添加到默认动作集
# ---- HERE ---- if action_info.enabled:
# if action_info.enabled: self._default_actions[action_name] = action_info
# self._action_descriptions[action_name] = action_info.description
def _register_command_component(self, command_info: CommandInfo, command_class: BaseCommand): def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]):
"""注册Command组件到Command特定注册表""" """注册Command组件到Command特定注册表"""
command_name = command_info.name command_name = command_info.name
self._command_registry[command_name] = command_class self._command_registry[command_name] = command_class
@@ -122,7 +121,7 @@ class ComponentRegistry:
# === 组件查询方法 === # === 组件查询方法 ===
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore
# sourcery skip: class-extract-method # sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析 """获取组件信息,支持自动命名空间解析
@@ -170,8 +169,10 @@ class ComponentRegistry:
return None return None
def get_component_class( def get_component_class(
self, component_name: str, component_type: ComponentType = None self,
) -> Optional[Union[BaseCommand, BaseAction]]: component_name: str,
component_type: ComponentType = None, # type: ignore
) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]:
"""获取组件类,支持自动命名空间解析 """获取组件类,支持自动命名空间解析
Args: Args:
@@ -230,7 +231,7 @@ class ComponentRegistry:
# === Action特定查询方法 === # === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, BaseAction]: def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
"""获取Action注册表用于兼容现有系统""" """获取Action注册表用于兼容现有系统"""
return self._action_registry.copy() return self._action_registry.copy()
@@ -239,13 +240,17 @@ class ComponentRegistry:
info = self.get_component_info(action_name, ComponentType.ACTION) info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None return info if isinstance(info, ActionInfo) else None
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
# === Command特定查询方法 === # === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, BaseCommand]: def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
"""获取Command注册表用于兼容现有系统""" """获取Command注册表用于兼容现有系统"""
return self._command_registry.copy() return self._command_registry.copy()
def get_command_patterns(self) -> Dict[Pattern, BaseCommand]: def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]:
"""获取Command模式注册表用于兼容现有系统""" """获取Command模式注册表用于兼容现有系统"""
return self._command_patterns.copy() return self._command_patterns.copy()
@@ -254,7 +259,7 @@ class ComponentRegistry:
info = self.get_component_info(command_name, ComponentType.COMMAND) info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None return info if isinstance(info, CommandInfo) else None
def find_command_by_text(self, text: str) -> Optional[tuple[BaseCommand, dict, bool, str]]: def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
# sourcery skip: use-named-expression, use-next # sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令 """根据文本查找匹配的命令
@@ -262,7 +267,7 @@ class ComponentRegistry:
text: 输入文本 text: 输入文本
Returns: Returns:
Optional[tuple[BaseCommand, dict, bool, str]]: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
""" """
for pattern, command_class in self._command_patterns.items(): for pattern, command_class in self._command_patterns.items():