部分类型注解修复,优化import顺序,删除无用API文件

This commit is contained in:
UnCLAS-Prommer
2025-07-12 00:34:49 +08:00
parent 3165a0f8df
commit b303a95f61
44 changed files with 405 additions and 1166 deletions

View File

View File

@@ -1,26 +0,0 @@
from src.chat.heart_flow.heartflow import heartflow
from src.chat.heart_flow.sub_heartflow import ChatState
from src.common.logger import get_logger
logger = get_logger("api")
async def get_all_subheartflow_ids() -> list:
"""获取所有子心流的ID列表"""
all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows()
return [subheartflow.subheartflow_id for subheartflow in all_subheartflows]
async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool:
"""强制改变子心流的状态"""
subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id)
if subheartflow:
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
return False
async def get_all_states():
"""获取所有状态"""
all_states = await heartflow.api_get_all_states()
logger.debug(f"所有状态: {all_states}")
return all_states

View File

@@ -1,169 +0,0 @@
import platform
import psutil
import sys
import os
def get_system_info():
"""获取操作系统信息"""
return {
"system": platform.system(),
"release": platform.release(),
"version": platform.version(),
"machine": platform.machine(),
"processor": platform.processor(),
}
def get_python_version():
"""获取 Python 版本信息"""
return sys.version
def get_cpu_usage():
"""获取系统总CPU使用率"""
return psutil.cpu_percent(interval=1)
def get_process_cpu_usage():
"""获取当前进程CPU使用率"""
process = psutil.Process(os.getpid())
return process.cpu_percent(interval=1)
def get_memory_usage():
"""获取系统内存使用情况 (单位 MB)"""
mem = psutil.virtual_memory()
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
return {
"total_mb": bytes_to_mb(mem.total),
"available_mb": bytes_to_mb(mem.available),
"percent": mem.percent,
"used_mb": bytes_to_mb(mem.used),
"free_mb": bytes_to_mb(mem.free),
}
def get_process_memory_usage():
"""获取当前进程内存使用情况 (单位 MB)"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
bytes_to_mb = lambda x: round(x / (1024 * 1024), 2) # noqa
return {
"rss_mb": bytes_to_mb(mem_info.rss), # Resident Set Size: 实际使用物理内存
"vms_mb": bytes_to_mb(mem_info.vms), # Virtual Memory Size: 虚拟内存大小
"percent": process.memory_percent(), # 进程内存使用百分比
}
def get_disk_usage(path="/"):
"""获取指定路径磁盘使用情况 (单位 GB)"""
disk = psutil.disk_usage(path)
bytes_to_gb = lambda x: round(x / (1024 * 1024 * 1024), 2) # noqa
return {
"total_gb": bytes_to_gb(disk.total),
"used_gb": bytes_to_gb(disk.used),
"free_gb": bytes_to_gb(disk.free),
"percent": disk.percent,
}
def get_all_basic_info():
"""获取所有基本信息并封装返回"""
# 对于进程CPU使用率需要先初始化
process = psutil.Process(os.getpid())
process.cpu_percent(interval=None) # 初始化调用
process_cpu = process.cpu_percent(interval=0.1) # 短暂间隔获取
return {
"system_info": get_system_info(),
"python_version": get_python_version(),
"cpu_usage_percent": get_cpu_usage(),
"process_cpu_usage_percent": process_cpu,
"memory_usage": get_memory_usage(),
"process_memory_usage": get_process_memory_usage(),
"disk_usage_root": get_disk_usage("/"),
}
def get_all_basic_info_string() -> str:
"""获取所有基本信息并以带解释的字符串形式返回"""
info = get_all_basic_info()
sys_info = info["system_info"]
mem_usage = info["memory_usage"]
proc_mem_usage = info["process_memory_usage"]
disk_usage = info["disk_usage_root"]
# 对进程内存使用百分比进行格式化,保留两位小数
proc_mem_percent = round(proc_mem_usage["percent"], 2)
output_string = f"""[系统信息]
- 操作系统: {sys_info["system"]} (例如: Windows, Linux)
- 发行版本: {sys_info["release"]} (例如: 11, Ubuntu 20.04)
- 详细版本: {sys_info["version"]}
- 硬件架构: {sys_info["machine"]} (例如: AMD64)
- 处理器信息: {sys_info["processor"]}
[Python 环境]
- Python 版本: {info["python_version"]}
[CPU 状态]
- 系统总 CPU 使用率: {info["cpu_usage_percent"]}%
- 当前进程 CPU 使用率: {info["process_cpu_usage_percent"]}%
[系统内存使用情况]
- 总物理内存: {mem_usage["total_mb"]} MB
- 可用物理内存: {mem_usage["available_mb"]} MB
- 物理内存使用率: {mem_usage["percent"]}%
- 已用物理内存: {mem_usage["used_mb"]} MB
- 空闲物理内存: {mem_usage["free_mb"]} MB
[当前进程内存使用情况]
- 实际使用物理内存 (RSS): {proc_mem_usage["rss_mb"]} MB
- 占用虚拟内存 (VMS): {proc_mem_usage["vms_mb"]} MB
- 进程内存使用率: {proc_mem_percent}%
[磁盘使用情况 (根目录)]
- 总空间: {disk_usage["total_gb"]} GB
- 已用空间: {disk_usage["used_gb"]} GB
- 可用空间: {disk_usage["free_gb"]} GB
- 磁盘使用率: {disk_usage["percent"]}%
"""
return output_string
if __name__ == "__main__":
print(f"System Info: {get_system_info()}")
print(f"Python Version: {get_python_version()}")
print(f"CPU Usage: {get_cpu_usage()}%")
# 第一次调用 process.cpu_percent() 会返回0.0或一个无意义的值,需要间隔一段时间再调用
# 或者在初始化Process对象后先调用一次cpu_percent(interval=None)然后再调用cpu_percent(interval=1)
current_process = psutil.Process(os.getpid())
current_process.cpu_percent(interval=None) # 初始化
print(f"Process CPU Usage: {current_process.cpu_percent(interval=1)}%") # 实际获取
memory_usage_info = get_memory_usage()
print(
f"Memory Usage: Total={memory_usage_info['total_mb']}MB, Used={memory_usage_info['used_mb']}MB, Percent={memory_usage_info['percent']}%"
)
process_memory_info = get_process_memory_usage()
print(
f"Process Memory Usage: RSS={process_memory_info['rss_mb']}MB, VMS={process_memory_info['vms_mb']}MB, Percent={process_memory_info['percent']}%"
)
disk_usage_info = get_disk_usage("/")
print(
f"Disk Usage (Root): Total={disk_usage_info['total_gb']}GB, Used={disk_usage_info['used_gb']}GB, Percent={disk_usage_info['percent']}%"
)
print("\n--- All Basic Info (JSON) ---")
all_info = get_all_basic_info()
import json
print(json.dumps(all_info, indent=4, ensure_ascii=False))
print("\n--- All Basic Info (String with Explanations) ---")
info_string = get_all_basic_info_string()
print(info_string)

View File

@@ -1,317 +0,0 @@
from typing import List, Optional, Dict, Any
import strawberry
# from packaging.version import Version
import os
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
@strawberry.type
class APIBotConfig:
"""机器人配置类"""
INNER_VERSION: str # 配置文件内部版本号toml为字符串
MAI_VERSION: str # 硬编码的版本信息
# bot
BOT_QQ: Optional[int] # 机器人QQ号
BOT_NICKNAME: Optional[str] # 机器人昵称
BOT_ALIAS_NAMES: List[str] # 机器人别名列表
# group
talk_allowed_groups: List[int] # 允许回复消息的群号列表
talk_frequency_down_groups: List[int] # 降低回复频率的群号列表
ban_user_id: List[int] # 禁止回复和读取消息的QQ号列表
# personality
personality_core: str # 人格核心特点描述
personality_sides: List[str] # 人格细节描述列表
# identity
identity_detail: List[str] # 身份特点列表
age: int # 年龄(岁)
gender: str # 性别
appearance: str # 外貌特征描述
# platforms
platforms: Dict[str, str] # 平台信息
# chat
allow_focus_mode: bool # 是否允许专注聊天状态
base_normal_chat_num: int # 最多允许多少个群进行普通聊天
base_focused_chat_num: int # 最多允许多少个群进行专注聊天
observation_context_size: int # 观察到的最长上下文大小
message_buffer: bool # 是否启用消息缓冲
ban_words: List[str] # 禁止词列表
ban_msgs_regex: List[str] # 禁止消息的正则表达式列表
# normal_chat
model_reasoning_probability: float # 推理模型概率
model_normal_probability: float # 普通模型概率
emoji_chance: float # 表情符号出现概率
thinking_timeout: int # 思考超时时间
willing_mode: str # 意愿模式
response_interested_rate_amplifier: float # 回复兴趣率放大器
emoji_response_penalty: float # 表情回复惩罚
mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复
at_bot_inevitable_reply: bool # @bot 必然回复
# focus_chat
reply_trigger_threshold: float # 回复触发阈值
default_decay_rate_per_second: float # 默认每秒衰减率
# compressed
compressed_length: int # 压缩长度
compress_length_limit: int # 压缩长度限制
# emoji
max_emoji_num: int # 最大表情符号数量
max_reach_deletion: bool # 达到最大数量时是否删除
check_interval: int # 检查表情包的时间间隔(分钟)
save_emoji: bool # 是否保存表情包
steal_emoji: bool # 是否偷取表情包
enable_check: bool # 是否启用表情包过滤
check_prompt: str # 表情包过滤要求
# memory
build_memory_interval: int # 记忆构建间隔
build_memory_distribution: List[float] # 记忆构建分布
build_memory_sample_num: int # 采样数量
build_memory_sample_length: int # 采样长度
memory_compress_rate: float # 记忆压缩率
forget_memory_interval: int # 记忆遗忘间隔
memory_forget_time: int # 记忆遗忘时间(小时)
memory_forget_percentage: float # 记忆遗忘比例
consolidate_memory_interval: int # 记忆整合间隔
consolidation_similarity_threshold: float # 相似度阈值
consolidation_check_percentage: float # 检查节点比例
memory_ban_words: List[str] # 记忆禁止词列表
# mood
mood_update_interval: float # 情绪更新间隔
mood_decay_rate: float # 情绪衰减率
mood_intensity_factor: float # 情绪强度因子
# keywords_reaction
keywords_reaction_enable: bool # 是否启用关键词反应
keywords_reaction_rules: List[Dict[str, Any]] # 关键词反应规则
# chinese_typo
chinese_typo_enable: bool # 是否启用中文错别字
chinese_typo_error_rate: float # 中文错别字错误率
chinese_typo_min_freq: int # 中文错别字最小频率
chinese_typo_tone_error_rate: float # 中文错别字声调错误率
chinese_typo_word_replace_rate: float # 中文错别字单词替换率
# response_splitter
enable_response_splitter: bool # 是否启用回复分割器
response_max_length: int # 回复最大长度
response_max_sentence_num: int # 回复最大句子数
enable_kaomoji_protection: bool # 是否启用颜文字保护
model_max_output_length: int # 模型最大输出长度
# remote
remote_enable: bool # 是否启用远程功能
# experimental
enable_friend_chat: bool # 是否启用好友聊天
talk_allowed_private: List[int] # 允许私聊的QQ号列表
pfc_chatting: bool # 是否启用PFC聊天
# 模型配置
llm_reasoning: Dict[str, Any] # 推理模型配置
llm_normal: Dict[str, Any] # 普通模型配置
llm_topic_judge: Dict[str, Any] # 主题判断模型配置
summary: Dict[str, Any] # 总结模型配置
vlm: Dict[str, Any] # VLM模型配置
llm_heartflow: Dict[str, Any] # 心流模型配置
llm_observation: Dict[str, Any] # 观察模型配置
llm_sub_heartflow: Dict[str, Any] # 子心流模型配置
llm_plan: Optional[Dict[str, Any]] # 计划模型配置
embedding: Dict[str, Any] # 嵌入模型配置
llm_PFC_action_planner: Optional[Dict[str, Any]] # PFC行动计划模型配置
llm_PFC_chat: Optional[Dict[str, Any]] # PFC聊天模型配置
llm_PFC_reply_checker: Optional[Dict[str, Any]] # PFC回复检查模型配置
llm_tool_use: Optional[Dict[str, Any]] # 工具使用模型配置
api_urls: Optional[Dict[str, str]] # API地址配置
@staticmethod
def validate_config(config: dict):
"""
校验传入的 toml 配置字典是否合法。
:param config: toml库load后的配置字典
:raises: ValueError, KeyError, TypeError
"""
# 检查主层级
required_sections = [
"inner",
"bot",
"groups",
"personality",
"identity",
"platforms",
"chat",
"normal_chat",
"focus_chat",
"emoji",
"memory",
"mood",
"keywords_reaction",
"chinese_typo",
"response_splitter",
"remote",
"experimental",
"model",
]
for section in required_sections:
if section not in config:
raise KeyError(f"缺少配置段: [{section}]")
# 检查部分关键字段
if "version" not in config["inner"]:
raise KeyError("缺少 inner.version 字段")
if not isinstance(config["inner"]["version"], str):
raise TypeError("inner.version 必须为字符串")
if "qq" not in config["bot"]:
raise KeyError("缺少 bot.qq 字段")
if not isinstance(config["bot"]["qq"], int):
raise TypeError("bot.qq 必须为整数")
if "personality_core" not in config["personality"]:
raise KeyError("缺少 personality.personality_core 字段")
if not isinstance(config["personality"]["personality_core"], str):
raise TypeError("personality.personality_core 必须为字符串")
if "identity_detail" not in config["identity"]:
raise KeyError("缺少 identity.identity_detail 字段")
if not isinstance(config["identity"]["identity_detail"], list):
raise TypeError("identity.identity_detail 必须为列表")
# 可继续添加更多字段的类型和值检查
# ...
# 检查模型配置
model_keys = [
"llm_reasoning",
"llm_normal",
"llm_topic_judge",
"summary",
"vlm",
"llm_heartflow",
"llm_observation",
"llm_sub_heartflow",
"embedding",
]
if "model" not in config:
raise KeyError("缺少 [model] 配置段")
for key in model_keys:
if key not in config["model"]:
raise KeyError(f"缺少 model.{key} 配置")
# 检查通过
return True
@strawberry.type
class APIEnvConfig:
"""环境变量配置"""
HOST: str # 服务主机地址
PORT: int # 服务端口
PLUGINS: List[str] # 插件列表
MONGODB_HOST: str # MongoDB 主机地址
MONGODB_PORT: int # MongoDB 端口
DATABASE_NAME: str # 数据库名称
CHAT_ANY_WHERE_BASE_URL: str # ChatAnywhere 基础URL
SILICONFLOW_BASE_URL: str # SiliconFlow 基础URL
DEEP_SEEK_BASE_URL: str # DeepSeek 基础URL
DEEP_SEEK_KEY: Optional[str] # DeepSeek API Key
CHAT_ANY_WHERE_KEY: Optional[str] # ChatAnywhere API Key
SILICONFLOW_KEY: Optional[str] # SiliconFlow API Key
SIMPLE_OUTPUT: Optional[bool] # 是否简化输出
CONSOLE_LOG_LEVEL: Optional[str] # 控制台日志等级
FILE_LOG_LEVEL: Optional[str] # 文件日志等级
DEFAULT_CONSOLE_LOG_LEVEL: Optional[str] # 默认控制台日志等级
DEFAULT_FILE_LOG_LEVEL: Optional[str] # 默认文件日志等级
@strawberry.field
def get_env(self) -> str:
return "env"
@staticmethod
def validate_config(config: dict):
"""
校验环境变量配置字典是否合法。
:param config: 环境变量配置字典
:raises: KeyError, TypeError
"""
required_fields = [
"HOST",
"PORT",
"PLUGINS",
"MONGODB_HOST",
"MONGODB_PORT",
"DATABASE_NAME",
"CHAT_ANY_WHERE_BASE_URL",
"SILICONFLOW_BASE_URL",
"DEEP_SEEK_BASE_URL",
]
for field in required_fields:
if field not in config:
raise KeyError(f"缺少环境变量配置字段: {field}")
if not isinstance(config["HOST"], str):
raise TypeError("HOST 必须为字符串")
if not isinstance(config["PORT"], int):
raise TypeError("PORT 必须为整数")
if not isinstance(config["PLUGINS"], list):
raise TypeError("PLUGINS 必须为列表")
if not isinstance(config["MONGODB_HOST"], str):
raise TypeError("MONGODB_HOST 必须为字符串")
if not isinstance(config["MONGODB_PORT"], int):
raise TypeError("MONGODB_PORT 必须为整数")
if not isinstance(config["DATABASE_NAME"], str):
raise TypeError("DATABASE_NAME 必须为字符串")
if not isinstance(config["CHAT_ANY_WHERE_BASE_URL"], str):
raise TypeError("CHAT_ANY_WHERE_BASE_URL 必须为字符串")
if not isinstance(config["SILICONFLOW_BASE_URL"], str):
raise TypeError("SILICONFLOW_BASE_URL 必须为字符串")
if not isinstance(config["DEEP_SEEK_BASE_URL"], str):
raise TypeError("DEEP_SEEK_BASE_URL 必须为字符串")
# 可选字段类型检查
optional_str_fields = [
"DEEP_SEEK_KEY",
"CHAT_ANY_WHERE_KEY",
"SILICONFLOW_KEY",
"CONSOLE_LOG_LEVEL",
"FILE_LOG_LEVEL",
"DEFAULT_CONSOLE_LOG_LEVEL",
"DEFAULT_FILE_LOG_LEVEL",
]
for field in optional_str_fields:
if field in config and config[field] is not None and not isinstance(config[field], str):
raise TypeError(f"{field} 必须为字符串或None")
if (
"SIMPLE_OUTPUT" in config
and config["SIMPLE_OUTPUT"] is not None
and not isinstance(config["SIMPLE_OUTPUT"], bool)
):
raise TypeError("SIMPLE_OUTPUT 必须为布尔值或None")
# 检查通过
return True
print("当前路径:")
print(ROOT_PATH)

View File

@@ -1,22 +0,0 @@
import strawberry
from fastapi import FastAPI
from strawberry.fastapi import GraphQLRouter
from src.common.server import get_global_server
@strawberry.type
class Query:
@strawberry.field
def hello(self) -> str:
return "Hello World"
schema = strawberry.Schema(Query)
graphql_app = GraphQLRouter(schema)
fast_api_app: FastAPI = get_global_server().get_app()
fast_api_app.include_router(graphql_app, prefix="/graphql")

View File

@@ -1 +0,0 @@
pass

View File

@@ -1,112 +0,0 @@
from fastapi import APIRouter
from strawberry.fastapi import GraphQLRouter
import os
import sys
# from src.chat.heart_flow.heartflow import heartflow
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
# from src.config.config import BotConfig
from src.common.logger import get_logger
from src.api.reload_config import reload_config as reload_config_func
from src.common.server import get_global_server
from src.api.apiforgui import (
get_all_subheartflow_ids,
forced_change_subheartflow_status,
get_subheartflow_cycle_info,
get_all_states,
)
from src.chat.heart_flow.sub_heartflow import ChatState
from src.api.basic_info_api import get_all_basic_info # 新增导入
router = APIRouter()
logger = get_logger("api")
logger.info("麦麦API服务器已启动")
graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema
router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
@router.post("/config/reload")
async def reload_config():
return await reload_config_func()
@router.get("/gui/subheartflow/get/all")
async def get_subheartflow_ids():
"""获取所有子心流的ID列表"""
return await get_all_subheartflow_ids()
@router.post("/gui/subheartflow/forced_change_status")
async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): # noqa
"""强制改变子心流的状态"""
# 参数检查
if not isinstance(status, ChatState):
logger.warning(f"无效的状态参数: {status}")
return {"status": "failed", "reason": "invalid status"}
logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}")
success = await forced_change_subheartflow_status(subheartflow_id, status)
if success:
logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功")
return {"status": "success"}
else:
logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败")
return {"status": "failed"}
@router.get("/stop")
async def force_stop_maibot():
"""强制停止MAI Bot"""
from bot import request_shutdown
success = await request_shutdown()
if success:
logger.info("MAI Bot已强制停止")
return {"status": "success"}
else:
logger.error("MAI Bot强制停止失败")
return {"status": "failed"}
@router.get("/gui/subheartflow/cycleinfo")
async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int):
"""获取子心流的循环信息"""
cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len)
if cycle_info:
return {"status": "success", "data": cycle_info}
else:
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
return {"status": "failed", "reason": "subheartflow not found"}
@router.get("/gui/get_all_states")
async def get_all_states_api():
"""获取所有状态"""
all_states = await get_all_states()
if all_states:
return {"status": "success", "data": all_states}
else:
logger.warning("获取所有状态失败")
return {"status": "failed", "reason": "failed to get all states"}
@router.get("/info")
async def get_system_basic_info():
"""获取系统基本信息"""
logger.info("请求系统基本信息")
try:
info = get_all_basic_info()
return {"status": "success", "data": info}
except Exception as e:
logger.error(f"获取系统基本信息失败: {e}")
return {"status": "failed", "reason": str(e)}
def start_api_server():
"""启动API服务器"""
get_global_server().register_router(router, prefix="/api/v1")
# pass

View File

@@ -1,24 +0,0 @@
from fastapi import HTTPException
from rich.traceback import install
from src.config.config import get_config_dir, load_config
from src.common.logger import get_logger
import os
install(extra_lines=3)
logger = get_logger("api")
async def reload_config():
try:
from src.config import config as config_module
logger.debug("正在重载配置文件...")
bot_config_path = os.path.join(get_config_dir(), "bot_config.toml")
config_module.global_config = load_config(config_path=bot_config_path)
logger.debug("配置文件重载成功")
return {"status": "reloaded"}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e

View File

@@ -5,20 +5,19 @@ import os
import random
import time
import traceback
from typing import Optional, Tuple, List, Any
from PIL import Image
import io
import re
# from gradio_client import file
import binascii
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
@@ -26,7 +25,7 @@ logger = get_logger("emoji")
BASE_DIR = os.path.join("data")
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
"""
@@ -85,7 +84,7 @@ class MaiEmoji:
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
try:
with Image.open(io.BytesIO(image_bytes)) as img:
self.format = img.format.lower()
self.format = img.format.lower() # type: ignore
logger.debug(f"[初始化] 格式获取成功: {self.format}")
except Exception as pil_error:
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
@@ -100,7 +99,7 @@ class MaiEmoji:
logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}")
self.is_deleted = True
return None
except base64.binascii.Error as b64_error:
except (binascii.Error, ValueError) as b64_error:
logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}")
self.is_deleted = True
return None
@@ -113,7 +112,7 @@ class MaiEmoji:
async def register_to_db(self) -> bool:
"""
注册表情包
将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下
将表情包对应的文件从当前路径移动到EMOJI_REGISTERED_DIR目录下
并修改对应的实例属性,然后将表情包信息保存到数据库中
"""
try:
@@ -122,7 +121,7 @@ class MaiEmoji:
# 源路径是当前实例的完整路径 self.full_path
source_full_path = self.full_path
# 目标完整路径
destination_full_path = os.path.join(EMOJI_REGISTED_DIR, self.filename)
destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename)
# 检查源文件是否存在
if not os.path.exists(source_full_path):
@@ -139,7 +138,7 @@ class MaiEmoji:
logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}")
# 更新实例的路径属性为新路径
self.full_path = destination_full_path
self.path = EMOJI_REGISTED_DIR
self.path = EMOJI_REGISTERED_DIR
# self.filename 保持不变
except Exception as move_error:
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
@@ -202,7 +201,7 @@ class MaiEmoji:
try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist:
except Emoji.DoesNotExist: # type: ignore
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
@@ -298,7 +297,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
def _ensure_emoji_dir() -> None:
"""确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
async def clear_temp_emoji() -> None:
@@ -331,10 +330,10 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
return removed_count
cleaned_count = 0
try:
# 获取内存中所有有效表情包的完整路径集合
tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted}
cleaned_count = 0
# 遍历指定目录中的所有文件
for file_name in os.listdir(emoji_dir):
@@ -358,11 +357,11 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
else:
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
return removed_count + cleaned_count
except Exception as e:
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
return removed_count + cleaned_count
class EmojiManager:
_instance = None
@@ -414,7 +413,7 @@ class EmojiManager:
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
except Emoji.DoesNotExist:
except Emoji.DoesNotExist: # type: ignore
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -570,8 +569,8 @@ class EmojiManager:
if objects_to_remove:
self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove]
# 清理 EMOJI_REGISTED_DIR 目录中未被追踪的文件
removed_count = await clean_unused_emojis(EMOJI_REGISTED_DIR, self.emoji_objects, removed_count)
# 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件
removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count)
# 输出清理结果
if removed_count > 0:
@@ -850,11 +849,13 @@ class EmojiManager:
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
image_base64 = get_image_manager().transform_gif(image_base64)
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
if not image_base64:
raise RuntimeError("GIF表情包转换失败")
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
else:

View File

@@ -1,14 +1,16 @@
from .exprssion_learner import get_expression_learner
import random
from typing import List, Dict, Tuple
from json_repair import repair_json
import json
import os
import time
import random
from typing import List, Dict, Tuple, Optional
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .exprssion_learner import get_expression_learner
logger = get_logger("expression_selector")
@@ -165,7 +167,12 @@ class ExpressionSelector:
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
async def select_suitable_expressions_llm(
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5, target_message: str = None
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None,
) -> List[Dict[str, str]]:
"""使用LLM选择适合的表达方式"""

View File

@@ -1,14 +1,16 @@
import time
import random
import json
import os
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import os
from src.chat.message_receive.chat_stream import get_chat_manager
import json
MAX_EXPRESSION_COUNT = 300
@@ -74,7 +76,8 @@ class ExpressionLearner:
)
self.llm_model = None
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
# sourcery skip: extract-duplicate-method, remove-unnecessary-cast
"""
获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
@@ -119,10 +122,10 @@ class ExpressionLearner:
min_len = min(len(s1), len(s2))
if min_len < 5:
return False
same = sum(1 for a, b in zip(s1, s2) if a == b)
same = sum(a == b for a, b in zip(s1, s2))
return same / min_len > 0.8
async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]:
async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]:
"""
学习并存储表达方式,分别学习语言风格和句法特点
同时对所有已存储的表达方式进行全局衰减
@@ -158,12 +161,12 @@ class ExpressionLearner:
for _ in range(3):
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
if not learnt_style:
return []
return [], []
for _ in range(1):
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar:
return []
return [], []
return learnt_style, learnt_grammar
@@ -214,6 +217,7 @@ class ExpressionLearner:
return result
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
# sourcery skip: use-join
"""
选择从当前到最近1小时内的随机num条消息然后学习这些消息的表达方式
type: "style" or "grammar"
@@ -249,7 +253,7 @@ class ExpressionLearner:
return []
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, str]]] = {}
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []

View File

@@ -1,10 +1,10 @@
# 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from typing import List
from src.common.logger import get_logger
from src.chat.focus_chat.hfc_utils import CycleDetail
from typing import List
# Import the new utility function
logger = get_logger("loop_info")

View File

@@ -8,7 +8,7 @@ from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.planner_actions.planner import ActionPlanner
@@ -49,7 +49,9 @@ class HeartFChatting:
"""
# 基础属性
self.stream_id: str = chat_id # 聊天流ID
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
if not self.chat_stream:
raise ValueError(f"无法找到聊天流: {self.stream_id}")
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
@@ -171,7 +173,7 @@ class HeartFChatting:
# 执行规划和处理阶段
try:
async with self._get_cycle_context():
thinking_id = "tid" + str(round(time.time(), 2))
thinking_id = f"tid{str(round(time.time(), 2))}"
self._current_cycle_detail.set_thinking_id(thinking_id)
# 使用异步上下文管理器处理消息
@@ -245,7 +247,7 @@ class HeartFChatting:
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, "
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
@@ -256,7 +258,7 @@ class HeartFChatting:
cycle_performance_data = {
"cycle_id": self._current_cycle_detail.cycle_id,
"action_type": action_result.get("action_type", "unknown"),
"total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time,
"total_time": self._current_cycle_detail.end_time - self._current_cycle_detail.start_time, # type: ignore
"step_times": cycle_timers.copy(),
"reasoning": action_result.get("reasoning", ""),
"success": self._current_cycle_detail.loop_action_info.get("action_taken", False),
@@ -447,11 +449,8 @@ class HeartFChatting:
# 处理动作并获取结果
result = await action_handler.handle_action()
if len(result) == 3:
success, reply_text, command = result
else:
success, reply_text = result
command = ""
success, reply_text = result
command = ""
# 检查action_data中是否有系统命令优先使用系统命令
if "_system_command" in action_data:
@@ -478,15 +477,14 @@ class HeartFChatting:
)
# 设置系统命令,在下次循环检查时触发退出
command = "stop_focus_chat"
else:
if reply_text == "timeout":
self.reply_timeout_count += 1
if self.reply_timeout_count > 5:
logger.warning(
f"[{self.log_prefix} ] 连续回复超时次数过多,{global_config.chat.thinking_timeout}秒 内大模型没有返回有效内容请检查你的api是否速度过慢或配置错误。建议不要使用推理模型推理模型生成速度过慢。或者尝试拉高thinking_timeout参数这可能导致回复时间过长。"
)
logger.warning(f"{self.log_prefix} 回复生成超时{global_config.chat.thinking_timeout}s已跳过")
return False, "", ""
elif reply_text == "timeout":
self.reply_timeout_count += 1
if self.reply_timeout_count > 5:
logger.warning(
f"[{self.log_prefix} ] 连续回复超时次数过多,{global_config.chat.thinking_timeout}秒 内大模型没有返回有效内容请检查你的api是否速度过慢或配置错误。建议不要使用推理模型推理模型生成速度过慢。或者尝试拉高thinking_timeout参数这可能导致回复时间过长。"
)
logger.warning(f"{self.log_prefix} 回复生成超时{global_config.chat.thinking_timeout}s已跳过")
return False, "", ""
return success, reply_text, command

View File

@@ -2,6 +2,7 @@ import json
from datetime import datetime
from typing import Dict, Any
from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("hfc_performance")

View File

@@ -1,11 +1,12 @@
import time
from typing import Optional
import json
from typing import Optional, Dict, Any
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import UserInfo
from src.common.logger import get_logger
import json
from typing import Dict, Any
logger = get_logger(__name__)
@@ -117,7 +118,7 @@ async def create_empty_anchor_message(
placeholder_msg_info = BaseMessageInfo(
message_id=placeholder_id,
platform=platform,
group_info=group_info,
group_info=group_info, # type: ignore
user_info=placeholder_user,
time=time.time(),
)

View File

@@ -1,7 +1,7 @@
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
from typing import Any, Optional, Dict
from src.common.logger import get_logger
from typing import Any, Optional
from typing import Dict
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("heartflow")
@@ -34,7 +34,7 @@ class Heartflow:
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
return None
async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> None:
async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> bool:
"""强制改变子心流的状态"""
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
return await self.force_change_state(subheartflow_id, status)

View File

@@ -1,21 +1,21 @@
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.config.config import global_config
import asyncio
import re
import math
import traceback
from typing import Tuple
from src.config.config import global_config
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer
from src.common.logger import get_logger
import re
import math
import traceback
from typing import Tuple
from src.person_info.relationship_manager import get_relationship_manager
from src.mood.mood_manager import mood_manager
logger = get_logger("chat")
@@ -26,16 +26,16 @@ async def _process_relationship(message: MessageRecv) -> None:
message: 消息对象,包含用户信息
"""
platform = message.message_info.platform
user_id = message.message_info.user_info.user_id
nickname = message.message_info.user_info.user_nickname
cardname = message.message_info.user_info.user_cardname or nickname
user_id = message.message_info.user_info.user_id # type: ignore
nickname = message.message_info.user_info.user_nickname # type: ignore
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
relationship_manager = get_relationship_manager()
is_known = await relationship_manager.is_known_some_one(platform, user_id)
if not is_known:
logger.info(f"首次认识用户: {nickname}")
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
@@ -105,9 +105,9 @@ class HeartFCMessageReceiver:
# 2. 兴趣度计算与更新
interested_rate, is_mentioned = await _calculate_interest(message)
subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned) # type: ignore
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
# 3. 日志记录
@@ -119,7 +119,7 @@ class HeartFCMessageReceiver:
picid_pattern = r"\[picid:([^\]]+)\]"
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")

View File

@@ -1,16 +1,18 @@
import asyncio
import time
from typing import Optional, List, Dict, Tuple
import traceback
from typing import Optional, List, Dict, Tuple
from rich.traceback import install
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.focus_chat.heartFC_chat import HeartFChatting
from src.chat.normal_chat.normal_chat import NormalChat
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
from src.chat.utils.utils import get_chat_type_and_target_info
from src.config.config import global_config
from rich.traceback import install
logger = get_logger("sub_heartflow")
@@ -40,7 +42,7 @@ class SubHeartflow:
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
# 兴趣消息集合
self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {}
self.interest_dict: Dict[str, Tuple[MessageRecv, float, bool]] = {}
# focus模式退出冷却时间管理
self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间
@@ -297,7 +299,7 @@ class SubHeartflow:
)
def add_message_to_normal_chat_cache(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned)
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) # type: ignore
# 如果字典长度超过10删除最旧的消息
if len(self.interest_dict) > 30:
oldest_key = next(iter(self.interest_dict))

View File

@@ -42,7 +42,7 @@ def calculate_information_content(text):
return entropy
def cosine_similarity(v1, v2):
def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else
"""计算余弦相似度"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
@@ -89,14 +89,13 @@ class MemoryGraph:
if not isinstance(self.G.nodes[concept]["memory_items"], list):
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory)
# 更新最后修改时间
self.G.nodes[concept]["last_modified"] = current_time
else:
self.G.nodes[concept]["memory_items"] = [memory]
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
if "created_time" not in self.G.nodes[concept]:
self.G.nodes[concept]["created_time"] = current_time
self.G.nodes[concept]["last_modified"] = current_time
# 更新最后修改时间
self.G.nodes[concept]["last_modified"] = current_time
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(
@@ -108,11 +107,7 @@ class MemoryGraph:
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
return concept, node_data
return None
return (concept, self.G.nodes[concept]) if concept in self.G else None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
@@ -139,8 +134,7 @@ class MemoryGraph:
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
node_data = self.get_dot(neighbor)
if node_data:
if node_data := self.get_dot(neighbor):
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
@@ -194,9 +188,9 @@ class MemoryGraph:
class Hippocampus:
def __init__(self):
self.memory_graph = MemoryGraph()
self.model_summary = None
self.entorhinal_cortex = None
self.parahippocampal_gyrus = None
self.model_summary: LLMRequest = None # type: ignore
self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
def initialize(self):
# 初始化子组件
@@ -218,7 +212,7 @@ class Hippocampus:
memory_items = [memory_items] if memory_items else []
# 使用集合来去重,避免排序
unique_items = set(str(item) for item in memory_items)
unique_items = {str(item) for item in memory_items}
# 使用frozenset来保证顺序一致性
content = f"{concept}:{frozenset(unique_items)}"
return hash(content)
@@ -231,6 +225,7 @@ class Hippocampus:
@staticmethod
def find_topic_llm(text, topic_num):
# sourcery skip: inline-immediately-returned-variable
prompt = (
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
@@ -240,6 +235,7 @@ class Hippocampus:
@staticmethod
def topic_what(text, topic):
# sourcery skip: inline-immediately-returned-variable
# 不再需要 time_info 参数
prompt = (
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
@@ -480,9 +476,7 @@ class Hippocampus:
top_memories = memory_similarities[:max_memory_length]
# 添加到结果中
for memory, similarity in top_memories:
all_memories.append((node, [memory], similarity))
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
else:
logger.info("节点没有记忆")
@@ -646,9 +640,7 @@ class Hippocampus:
top_memories = memory_similarities[:max_memory_length]
# 添加到结果中
for memory, similarity in top_memories:
all_memories.append((node, [memory], similarity))
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
else:
logger.info("节点没有记忆")
@@ -823,11 +815,11 @@ class EntorhinalCortex:
logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = []
for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet
messages = self.random_get_msg_snippet(
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
)
if messages:
if messages := self.random_get_msg_snippet(
timestamp,
global_config.memory.memory_build_sample_length,
max_memorized_time_per_msg,
):
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}")
chat_samples.append(messages)
@@ -838,6 +830,7 @@ class EntorhinalCortex:
@staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
try_count = 0
time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟
@@ -847,22 +840,21 @@ class EntorhinalCortex:
timestamp_start = target_timestamp
timestamp_end = target_timestamp + time_window_seconds
chosen_message = get_raw_msg_by_timestamp(
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
)
if chosen_message := get_raw_msg_by_timestamp(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=1,
limit_mode="earliest",
):
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
if chosen_message:
chat_id = chosen_message[0].get("chat_id")
messages = get_raw_msg_by_timestamp_with_chat(
if messages := get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=chat_size,
limit_mode="earliest",
chat_id=chat_id,
)
if messages:
):
# 检查获取到的所有消息是否都未达到最大记忆次数
all_valid = True
for message in messages:
@@ -975,7 +967,7 @@ class EntorhinalCortex:
).execute()
if nodes_to_delete:
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute()
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
# 处理边的信息
db_edges = list(GraphEdges.select())
@@ -1114,7 +1106,7 @@ class EntorhinalCortex:
node_start = time.time()
if nodes_data:
batch_size = 500 # 增加批量大小
with GraphNodes._meta.database.atomic():
with GraphNodes._meta.database.atomic(): # type: ignore
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
GraphNodes.insert_many(batch).execute()
@@ -1125,7 +1117,7 @@ class EntorhinalCortex:
edge_start = time.time()
if edges_data:
batch_size = 500 # 增加批量大小
with GraphEdges._meta.database.atomic():
with GraphEdges._meta.database.atomic(): # type: ignore
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
@@ -1489,32 +1481,30 @@ class ParahippocampalGyrus:
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
last_modified = node_data.get("last_modified", current_time)
# 条件1检查是否长时间未修改 (超过24小时)
if current_time - last_modified > 3600 * 24:
# 条件2再次确认节点包含记忆项理论上已确认但作为保险
if memory_items:
current_count = len(memory_items)
# 如果列表非空,才进行随机选择
if current_count > 0:
removed_item = random.choice(memory_items)
try:
memory_items.remove(removed_item)
if current_time - last_modified > 3600 * 24 and memory_items:
current_count = len(memory_items)
# 如果列表非空,才进行随机选择
if current_count > 0:
removed_item = random.choice(memory_items)
try:
memory_items.remove(removed_item)
# 条件3检查移除后 memory_items 是否变空
if memory_items: # 如果移除后列表不为空
# self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可
self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间
node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})")
else: # 如果移除后列表为空
# 尝试移除节点,处理可能的错误
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空
logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}")
except ValueError:
# 这个错误理论上不应发生,因为 removed_item 来自 memory_items
logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'")
# 条件3检查移除后 memory_items 是否变空
if memory_items: # 如果移除后列表不为空
# self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可
self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间
node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})")
else: # 如果移除后列表为空
# 尝试移除节点,处理可能的错误
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空
logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}")
except ValueError:
# 这个错误理论上不应发生,因为 removed_item 来自 memory_items
logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'")
node_check_end = time.time()
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}")
@@ -1669,7 +1659,7 @@ class ParahippocampalGyrus:
class HippocampusManager:
def __init__(self):
self._hippocampus = None
self._hippocampus: Hippocampus = None # type: ignore
self._initialized = False
def initialize(self):

View File

@@ -13,7 +13,7 @@ from json_repair import repair_json
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str):
def get_keywords_from_json(json_str) -> List:
"""
从JSON字符串中提取关键词列表
@@ -28,15 +28,8 @@ def get_keywords_from_json(json_str):
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
if isinstance(fixed_json, str):
result = json.loads(fixed_json)
else:
# 如果repair_json直接返回了字典对象直接使用
result = fixed_json
# 提取关键词
keywords = result.get("keywords", [])
return keywords
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []

View File

@@ -1,52 +1,10 @@
import numpy as np
from scipy import stats
from datetime import datetime, timedelta
from rich.traceback import install
install(extra_lines=3)
class DistributionVisualizer:
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
"""
初始化分布可视化器
参数:
mean (float): 期望均值
std (float): 标准差
skewness (float): 偏度
sample_size (int): 样本大小
"""
self.mean = mean
self.std = std
self.skewness = skewness
self.sample_size = sample_size
self.samples = None
def generate_samples(self):
"""生成具有指定参数的样本"""
if self.skewness == 0:
# 对于无偏度的情况,直接使用正态分布
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
else:
# 使用 scipy.stats 生成具有偏度的分布
self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size)
def get_weighted_samples(self):
"""获取加权后的样本数列"""
if self.samples is None:
self.generate_samples()
# 将样本值乘以样本大小
return self.samples * self.sample_size
def get_statistics(self):
"""获取分布的统计信息"""
if self.samples is None:
self.generate_samples()
return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)}
class MemoryBuildScheduler:
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
"""

View File

@@ -1,23 +1,25 @@
import traceback
import os
import re
from typing import Dict, Any
from maim_message import UserInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv
from src.experimental.only_message_process import MessageProcessor
from src.chat.message_receive.storage import MessageStorage
from src.experimental.PFC.pfc_manager import PFCManager
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.config.config import global_config
from src.experimental.only_message_process import MessageProcessor
from src.experimental.PFC.pfc_manager import PFCManager
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
from src.plugin_system.base.base_command import BaseCommand
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import ChatStream
import re
# 定义日志配置
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
@@ -184,8 +186,8 @@ class ChatBot:
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform,
user_info=user_info,
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
@@ -195,8 +197,10 @@ class ChatBot:
await message.process()
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex(
message.raw_message, chat, user_info
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
chat,
user_info, # type: ignore
):
return

View File

@@ -3,18 +3,17 @@ import hashlib
import time
import copy
from typing import Dict, Optional, TYPE_CHECKING
from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
@@ -28,7 +27,7 @@ class ChatMessageContext:
def __init__(self, message: "MessageRecv"):
self.message = message
def get_template_name(self) -> str:
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
@@ -41,10 +40,10 @@ class ChatMessageContext:
def check_types(self, types: list) -> bool:
# sourcery skip: invert-any-all, use-any, use-next
"""检查消息类型"""
if not self.message.message_info.format_info.accept_format:
if not self.message.message_info.format_info.accept_format: # type: ignore
return False
for t in types:
if t not in self.message.message_info.format_info.accept_format:
if t not in self.message.message_info.format_info.accept_format: # type: ignore
return False
return True
@@ -68,7 +67,7 @@ class ChatStream:
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
data: Optional[dict] = None,
):
self.stream_id = stream_id
self.platform = platform
@@ -77,7 +76,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -99,7 +98,7 @@ class ChatStream:
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
user_info=user_info, # type: ignore
group_info=group_info,
data=data,
)
@@ -163,8 +162,8 @@ class ChatManager:
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
stream_id = self._generate_stream_id(
message.message_info.platform,
message.message_info.user_info,
message.message_info.platform, # type: ignore
message.message_info.user_info, # type: ignore
message.message_info.group_info,
)
self.last_messages[stream_id] = message
@@ -185,10 +184,7 @@ class ChatManager:
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
"""获取聊天流ID"""
if is_group:
components = [platform, str(id)]
else:
components = [platform, str(id), "private"]
components = [platform, id] if is_group else [platform, id, "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()

View File

@@ -1,17 +1,15 @@
import time
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Any, TYPE_CHECKING
import urllib3
from src.common.logger import get_logger
if TYPE_CHECKING:
from .chat_stream import ChatStream
from ..utils.utils_image import get_image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from abc import abstractmethod
from dataclasses import dataclass
from rich.traceback import install
from typing import Optional, Any
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.chat.utils.utils_image import get_image_manager
from .chat_stream import ChatStream
install(extra_lines=3)
@@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass
class Message(MessageBase):
chat_stream: "ChatStream" = None
chat_stream: "ChatStream" = None # type: ignore
reply: Optional["Message"] = None
processed_plain_text: str = ""
memorized_times: int = 0
@@ -55,7 +53,7 @@ class Message(MessageBase):
)
# 调用父类初始化
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None)
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
self.chat_stream = chat_stream
# 文本处理相关属性
@@ -66,6 +64,7 @@ class Message(MessageBase):
self.reply = reply
async def _process_message_segments(self, segment: Seg) -> str:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
"""递归处理消息段,转换为文字描述
Args:
@@ -78,13 +77,13 @@ class Message(MessageBase):
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
processed = await self._process_message_segments(seg) # type: ignore
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
return await self._process_single_segment(segment) # type: ignore
@abstractmethod
async def _process_single_segment(self, segment):
@@ -138,7 +137,7 @@ class MessageRecv(Message):
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
return segment.data
return segment.data # type: ignore
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
@@ -160,7 +159,7 @@ class MessageRecv(Message):
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data)
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
@@ -186,7 +185,7 @@ class MessageRecv(Message):
"""生成详细文本,包含时间和用户信息"""
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@@ -234,7 +233,7 @@ class MessageProcessBase(Message):
"""
try:
if seg.type == "text":
return seg.data
return seg.data # type: ignore
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
@@ -250,7 +249,7 @@ class MessageProcessBase(Message):
if self.reply and hasattr(self.reply, "processed_plain_text"):
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
# print(f"reply: {self.reply}")
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]"
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
return None
else:
return f"[{seg.type}:{str(seg.data)}]"
@@ -264,7 +263,7 @@ class MessageProcessBase(Message):
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@@ -313,7 +312,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: str = None,
reply_to: str = None, # type: ignore
):
# 调用父类初始化
super().__init__(
@@ -344,7 +343,7 @@ class MessageSending(MessageProcessBase):
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id),
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment,
],
)
@@ -364,10 +363,10 @@ class MessageSending(MessageProcessBase):
) -> "MessageSending":
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
message_id=thinking.message_info.message_id, # type: ignore
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
bot_user_info=thinking.message_info.user_info, # type: ignore
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji,
@@ -399,13 +398,11 @@ class MessageSet:
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time)
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
return self.messages[index] if 0 <= index < len(self.messages) else None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
@@ -415,7 +412,7 @@ class MessageSet:
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].message_info.time < target_time:
if self.messages[mid].message_info.time < target_time: # type: ignore
left = mid + 1
else:
right = mid

View File

@@ -1,21 +1,16 @@
# src/plugins/chat/message_sender.py
import asyncio
import time
from asyncio import Task
from typing import Union
from src.common.message.api import get_global_api
# from ...common.database import db # 数据库依赖似乎不需要了,注释掉
from .message import MessageSending, MessageThinking, MessageSet
from src.chat.message_receive.storage import MessageStorage
from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
from src.common.logger import get_logger
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message, calculate_typing_time, count_messages_between
from .message import MessageSending, MessageThinking, MessageSet
install(extra_lines=3)
logger = get_logger("sender")
@@ -79,9 +74,10 @@ class MessageContainer:
def count_thinking_messages(self) -> int:
"""计算当前容器中思考消息的数量"""
return sum(1 for msg in self.messages if isinstance(msg, MessageThinking))
return sum(isinstance(msg, MessageThinking) for msg in self.messages)
def get_timeout_sending_messages(self) -> list[MessageSending]:
# sourcery skip: merge-nested-ifs
"""获取所有超时的MessageSending对象思考时间超过20秒按thinking_start_time排序 - 从旧 sender 合并"""
current_time = time.time()
timeout_messages = []
@@ -230,9 +226,7 @@ class MessageManager:
f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}"
)
logger.exception("详细错误信息:")
# 考虑是否移除出错的消息,防止无限循环
removed = container.remove_message(message)
if removed:
if container.remove_message(message):
logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。")
async def _process_chat_messages(self, chat_id: str):
@@ -261,10 +255,7 @@ class MessageManager:
# --- 处理发送消息 ---
await self._handle_sending_message(container, message_earliest)
# --- 处理超时发送消息 (来自旧 sender) ---
# 在处理完最早的消息后,检查是否有超时的发送消息
timeout_sending_messages = container.get_timeout_sending_messages()
if timeout_sending_messages:
if timeout_sending_messages := container.get_timeout_sending_messages():
logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息")
for msg in timeout_sending_messages:
# 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一)
@@ -274,6 +265,7 @@ class MessageManager:
await self._handle_sending_message(container, msg) # 复用处理逻辑
async def _start_processor_loop(self):
# sourcery skip: list-comprehension, move-assign-in-block, use-named-expression
"""消息处理器主循环"""
while self._running:
tasks = []
@@ -282,10 +274,7 @@ class MessageManager:
# 创建 keys 的快照以安全迭代
chat_ids = list(self.containers.keys())
for chat_id in chat_ids:
# 为每个 chat_id 创建一个处理任务
tasks.append(asyncio.create_task(self._process_chat_messages(chat_id)))
tasks.extend(asyncio.create_task(self._process_chat_messages(chat_id)) for chat_id in chat_ids)
if tasks:
try:
# 等待当前批次的所有任务完成

View File

@@ -1,11 +1,10 @@
import re
from typing import Union
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models
from src.common.database.database_model import Messages, RecalledMessages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
logger = get_logger("message_storage")
@@ -44,7 +43,7 @@ class MessageStorage:
reply_to = ""
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
@@ -56,7 +55,7 @@ class MessageStorage:
Messages.create(
message_id=msg_id,
time=float(message.message_info.time),
time=float(message.message_info.time), # type: ignore
chat_id=chat_stream.stream_id,
# Flattened chat_info
reply_to=reply_to,
@@ -103,7 +102,7 @@ class MessageStorage:
try:
# Assuming input 'time' is a string timestamp that can be converted to float
current_time_float = float(time)
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore
except Exception:
logger.exception("删除撤回消息失败")
@@ -115,22 +114,19 @@ class MessageStorage:
"""更新最新一条匹配消息的message_id"""
try:
if message.message_segment.type == "notify":
mmc_message_id = message.message_segment.data.get("echo")
qq_message_id = message.message_segment.data.get("actual_id")
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return
# 查询最新一条匹配消息
matched_message = (
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
)
if matched_message:
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute()
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
@@ -155,10 +151,7 @@ class MessageStorage:
image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
)
if image_record:
return f"[picid:{image_record.image_id}]"
else:
return match.group(0) # 保持原样
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)

View File

@@ -1,16 +1,17 @@
import asyncio
from typing import Dict, Optional # 重新导入类型
from src.chat.message_receive.message import MessageSending, MessageThinking
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.common.logger import get_logger
from src.chat.utils.utils import calculate_typing_time
from rich.traceback import install
import traceback
install(extra_lines=3)
from typing import Dict, Optional
from rich.traceback import install
from src.common.message.api import get_global_api
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageSending, MessageThinking
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
install(extra_lines=3)
logger = get_logger("sender")
@@ -86,10 +87,10 @@ class HeartFCSender:
"""
if not message.chat_stream:
logger.error("消息缺少 chat_stream无法发送")
raise Exception("消息缺少 chat_stream无法发送")
raise ValueError("消息缺少 chat_stream无法发送")
if not message.message_info or not message.message_info.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送")
raise Exception("消息缺少 message_info 或 message_id无法发送")
raise ValueError("消息缺少 message_info 或 message_id无法发送")
chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id

View File

@@ -1,6 +1,7 @@
import asyncio
import time
import traceback
from random import random
from typing import List, Optional, Dict
from maim_message import UserInfo, Seg
@@ -40,7 +41,7 @@ class NormalChat:
def __init__(
self,
chat_stream: ChatStream,
interest_dict: dict = None,
interest_dict: Optional[Dict] = None,
on_switch_to_focus_callback=None,
get_cooldown_progress_callback=None,
):
@@ -147,10 +148,7 @@ class NormalChat:
while not self._disabled:
try:
if not self.priority_manager.is_empty():
# 获取最高优先级的消息
message = self.priority_manager.get_highest_priority_message()
if message:
if message := self.priority_manager.get_highest_priority_message():
logger.info(
f"[{self.stream_name}] 从队列中取出消息进行处理: User {message.message_info.user_info.user_id}, Time: {time.strftime('%H:%M:%S', time.localtime(message.message_info.time))}"
)

View File

@@ -53,7 +53,7 @@ class PriorityManager:
"""
添加新消息到合适的队列中。
"""
user_id = message.message_info.user_info.user_id
user_id = message.message_info.user_info.user_id # type: ignore
is_vip = message.priority_info.get("message_type") == "vip" if message.priority_info else False
message_priority = message.priority_info.get("message_priority", 0.0) if message.priority_info else 0.0

View File

@@ -35,9 +35,7 @@ class ClassicalWillingManager(BaseWillingManager):
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
return reply_probability
return min(max((current_willing - 0.5), 0.01) * 2, 1)
async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id

View File

@@ -1,14 +1,16 @@
from src.common.logger import get_logger
import importlib
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Optional
from rich.traceback import install
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from abc import ABC, abstractmethod
import importlib
from typing import Dict, Optional
import asyncio
from rich.traceback import install
install(extra_lines=3)
@@ -92,8 +94,8 @@ class BaseWillingManager(ABC):
self.logger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
self.ongoing_messages[message.message_info.message_id] = WillingInfo( # type: ignore
message=message,
chat=chat,
person_info_manager=get_person_info_manager(),

View File

@@ -27,14 +27,11 @@ class ActionManager:
# 当前正在使用的动作集合,默认加载默认动作
self._using_actions: Dict[str, ActionInfo] = {}
# 默认动作集,仅作为快照,用于恢复默认
self._default_actions: Dict[str, ActionInfo] = {}
# 加载插件动作
self._load_plugin_actions()
# 初始化时将默认动作加载到使用中的动作
self._using_actions = self._default_actions.copy()
self._using_actions = component_registry.get_default_actions()
def _load_plugin_actions(self) -> None:
"""
@@ -52,7 +49,7 @@ class ActionManager:
"""从插件系统的component_registry加载Action组件"""
try:
# 获取所有Action组件
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION)
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
for action_name, action_info in action_components.items():
if action_name in self._registered_actions:
@@ -61,10 +58,6 @@ class ActionManager:
self._registered_actions[action_name] = action_info
# 如果启用,也添加到默认动作集
if action_info.enabled:
self._default_actions[action_name] = action_info
logger.debug(
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
)
@@ -106,7 +99,9 @@ class ActionManager:
"""
try:
# 获取组件类 - 明确指定查询Action类型
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
component_class: Type[BaseAction] = component_registry.get_component_class(
action_name, ComponentType.ACTION
) # type: ignore
if not component_class:
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
return None
@@ -146,10 +141,6 @@ class ActionManager:
"""获取所有已注册的动作集"""
return self._registered_actions.copy()
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
def get_using_actions(self) -> Dict[str, ActionInfo]:
"""获取当前正在使用的动作集合"""
return self._using_actions.copy()
@@ -217,31 +208,31 @@ class ActionManager:
logger.debug(f"已从使用集中移除动作 {action_name}")
return True
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
"""
添加新的动作到注册集
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
# """
# 添加新的动作到注册集
Args:
action_name: 动作名称
description: 动作描述
parameters: 动作参数定义,默认为空字典
require: 动作依赖项,默认为空列表
# Args:
# action_name: 动作名称
# description: 动作描述
# parameters: 动作参数定义,默认为空字典
# require: 动作依赖项,默认为空列表
Returns:
bool: 添加是否成功
"""
if action_name in self._registered_actions:
return False
# Returns:
# bool: 添加是否成功
# """
# if action_name in self._registered_actions:
# return False
if parameters is None:
parameters = {}
if require is None:
require = []
# if parameters is None:
# parameters = {}
# if require is None:
# require = []
action_info = {"description": description, "parameters": parameters, "require": require}
# action_info = {"description": description, "parameters": parameters, "require": require}
self._registered_actions[action_name] = action_info
return True
# self._registered_actions[action_name] = action_info
# return True
def remove_action(self, action_name: str) -> bool:
"""从注册集移除指定动作"""
@@ -260,10 +251,9 @@ class ActionManager:
def restore_actions(self) -> None:
"""恢复到默认动作集"""
logger.debug(
f"恢复动作集: 从 {list(self._using_actions.keys())} 恢复到默认动作集 {list(self._default_actions.keys())}"
)
self._using_actions = self._default_actions.copy()
actions_to_restore = list(self._using_actions.keys())
self._using_actions = component_registry.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
def add_system_action_if_needed(self, action_name: str) -> bool:
"""
@@ -293,4 +283,4 @@ class ActionManager:
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_class(action_name)
return component_registry.get_component_class(action_name) # type: ignore

View File

@@ -2,7 +2,7 @@ import random
import asyncio
import hashlib
import time
from typing import List, Any, Dict
from typing import List, Any, Dict, TYPE_CHECKING
from src.common.logger import get_logger
from src.config.config import global_config
@@ -13,6 +13,9 @@ from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("action_manager")
@@ -27,7 +30,7 @@ class ActionModifier:
def __init__(self, action_manager: ActionManager, chat_id: str):
"""初始化动作处理器"""
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.action_manager = action_manager
@@ -142,7 +145,7 @@ class ActionModifier:
async def _get_deactivated_actions_by_type(
self,
actions_with_info: Dict[str, ActionInfo],
mode: str = "focus",
mode: ChatMode = ChatMode.FOCUS,
chat_content: str = "",
) -> List[tuple[str, str]]:
"""
@@ -270,7 +273,7 @@ class ActionModifier:
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果并更新缓存
for _, (action_name, result) in enumerate(zip(task_names, task_results)):
for action_name, result in zip(task_names, task_results):
if isinstance(result, Exception):
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
results[action_name] = False
@@ -286,7 +289,7 @@ class ActionModifier:
except Exception as e:
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
# 如果并行执行失败为所有任务返回False
for action_name in tasks_to_run.keys():
for action_name in tasks_to_run:
results[action_name] = False
# 清理过期缓存
@@ -297,10 +300,11 @@ class ActionModifier:
def _cleanup_expired_cache(self, current_time: float):
"""清理过期的缓存条目"""
expired_keys = []
for cache_key, cache_data in self._llm_judge_cache.items():
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
expired_keys.append(cache_key)
expired_keys.extend(
cache_key
for cache_key, cache_data in self._llm_judge_cache.items()
if current_time - cache_data["timestamp"] > self._cache_expiry_time
)
for key in expired_keys:
del self._llm_judge_cache[key]
@@ -379,7 +383,7 @@ class ActionModifier:
def _check_keyword_activation(
self,
action_name: str,
action_info: Dict[str, Any],
action_info: ActionInfo,
chat_content: str = "",
) -> bool:
"""
@@ -396,8 +400,8 @@ class ActionModifier:
bool: 是否应该激活此action
"""
activation_keywords = action_info.get("activation_keywords", [])
case_sensitive = action_info.get("keyword_case_sensitive", False)
activation_keywords = action_info.activation_keywords
case_sensitive = action_info.keyword_case_sensitive
if not activation_keywords:
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")

View File

@@ -70,7 +70,7 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0
async def plan(self) -> Dict[str, Any]:
async def plan(self) -> Dict[str, Any]: # sourcery skip: dict-comprehension
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
"""
@@ -162,7 +162,6 @@ class ActionPlanner:
reasoning = parsed_json.get("reasoning", "未提供原因")
# 将所有其他属性添加到action_data
action_data = {}
for key, value in parsed_json.items():
if key not in ["action", "reasoning"]:
action_data[key] = value
@@ -285,7 +284,7 @@ class ActionPlanner:
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}"
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
return planner_prompt_template.format(
time_block=time_block,
by_what=by_what,
chat_context_description=chat_context_description,
@@ -295,8 +294,6 @@ class ActionPlanner:
moderation_prompt=moderation_prompt_block,
identity_block=identity_block,
)
return prompt
except Exception as e:
logger.error(f"构建 Planner 提示词时出错: {e}")
logger.error(traceback.format_exc())

View File

@@ -130,9 +130,7 @@ class DefaultReplyer:
# 提取权重,如果模型配置中没有'weight'键则默认为1.0
weights = [config.get("weight", 1.0) for config in configs]
# random.choices 返回一个列表,我们取第一个元素
selected_config = random.choices(population=configs, weights=weights, k=1)[0]
return selected_config
return random.choices(population=configs, weights=weights, k=1)[0]
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
"""创建思考消息 (尝试锚定到 anchor_message)"""
@@ -314,8 +312,7 @@ class DefaultReplyer:
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
relation_info = await relationship_fetcher.build_relation_info(person_id, text, chat_history)
return relation_info
return await relationship_fetcher.build_relation_info(person_id, text, chat_history)
async def build_expression_habits(self, chat_history, target):
if not global_config.expression.enable_expression:
@@ -363,15 +360,13 @@ class DefaultReplyer:
target_message=target, chat_history_prompt=chat_history
)
if running_memories:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n"
memory_block = memory_str
else:
memory_block = ""
if not running_memories:
return ""
return memory_block
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n"
return memory_str
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
"""构建工具信息块
@@ -453,7 +448,7 @@ class DefaultReplyer:
for name, content in result.groupdict().items():
reaction = reaction.replace(f"[{name}]", content)
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
keywords_reaction_prompt += reaction + ""
keywords_reaction_prompt += f"{reaction}"
break
except re.error as e:
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
@@ -477,7 +472,7 @@ class DefaultReplyer:
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_timeout: bool = False,
enable_tool: bool = True,
) -> str:
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
"""
构建回复器上下文
@@ -612,7 +607,7 @@ class DefaultReplyer:
short_impression = ["友好活泼", "人类"]
personality = short_impression[0]
identity = short_impression[1]
prompt_personality = personality + "" + identity
prompt_personality = f"{personality}{identity}"
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
moderation_prompt_block = (
@@ -660,7 +655,7 @@ class DefaultReplyer:
"chat_target_private2", sender_name=chat_target_name
)
prompt = await global_prompt_manager.format_prompt(
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
chat_target=chat_target_1,
@@ -683,8 +678,6 @@ class DefaultReplyer:
mood_state=mood_prompt,
)
return prompt
async def build_prompt_rewrite_context(
self,
reply_data: Dict[str, Any],
@@ -745,7 +738,7 @@ class DefaultReplyer:
short_impression = ["友好活泼", "人类"]
personality = short_impression[0]
identity = short_impression[1]
prompt_personality = personality + "" + identity
prompt_personality = f"{personality}{identity}"
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
moderation_prompt_block = (
@@ -790,7 +783,7 @@ class DefaultReplyer:
template_name = "default_expressor_prompt"
prompt = await global_prompt_manager.format_prompt(
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info,
@@ -807,8 +800,6 @@ class DefaultReplyer:
moderation_prompt=moderation_prompt_block,
)
return prompt
async def send_response_messages(
self,
anchor_message: Optional[MessageRecv],
@@ -816,6 +807,7 @@ class DefaultReplyer:
thinking_id: str = "",
display_message: str = "",
) -> Optional[MessageSending]:
# sourcery skip: assign-if-exp, boolean-if-exp-identity, remove-unnecessary-cast
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream
chat_id = self.chat_stream.stream_id
@@ -849,16 +841,16 @@ class DefaultReplyer:
for i, msg_text in enumerate(response_set):
# 为每个消息片段生成唯一ID
type = msg_text[0]
msg_type = msg_text[0]
data = msg_text[1]
if global_config.debug.debug_show_chat_mode and type == "text":
if global_config.debug.debug_show_chat_mode and msg_type == "text":
data += ""
part_message_id = f"{thinking_id}_{i}"
message_segment = Seg(type=type, data=data)
message_segment = Seg(type=msg_type, data=data)
if type == "emoji":
if msg_type == "emoji":
is_emoji = True
else:
is_emoji = False
@@ -871,7 +863,6 @@ class DefaultReplyer:
display_message=display_message,
reply_to=reply_to,
is_emoji=is_emoji,
thinking_id=thinking_id,
thinking_start_time=thinking_start_time,
)
@@ -895,7 +886,7 @@ class DefaultReplyer:
reply_message_ids.append(part_message_id) # 记录我们生成的ID
sent_msg_list.append((type, sent_msg))
sent_msg_list.append((msg_type, sent_msg))
except Exception as e:
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
@@ -930,12 +921,9 @@ class DefaultReplyer:
)
# await anchor_message.process()
if anchor_message:
sender_info = anchor_message.message_info.user_info
else:
sender_info = None
sender_info = anchor_message.message_info.user_info if anchor_message else None
bot_message = MessageSending(
return MessageSending(
message_id=message_id, # 使用片段的唯一ID
chat_stream=self.chat_stream,
bot_user_info=bot_user_info,
@@ -948,8 +936,6 @@ class DefaultReplyer:
display_message=display_message,
)
return bot_message
def weighted_sample_no_replacement(items, weights, k) -> list:
"""

View File

@@ -1,4 +1,5 @@
from typing import Dict, Any, Optional, List
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer
from src.common.logger import get_logger
@@ -8,7 +9,7 @@ logger = get_logger("ReplyerManager")
class ReplyerManager:
def __init__(self):
self._replyers: Dict[str, DefaultReplyer] = {}
self._repliers: Dict[str, DefaultReplyer] = {}
def get_replyer(
self,
@@ -29,17 +30,16 @@ class ReplyerManager:
return None
# 如果已有缓存实例,直接返回
if stream_id in self._replyers:
if stream_id in self._repliers:
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
return self._replyers[stream_id]
return self._repliers[stream_id]
# 如果没有缓存,则创建新实例(首次初始化)
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
target_stream = chat_stream
if not target_stream:
chat_manager = get_chat_manager()
if chat_manager:
if chat_manager := get_chat_manager():
target_stream = chat_manager.get_stream(stream_id)
if not target_stream:
@@ -52,7 +52,7 @@ class ReplyerManager:
model_configs=model_configs, # 可以是None此时使用默认模型
request_type=request_type,
)
self._replyers[stream_id] = replyer
self._repliers[stream_id] = replyer
return replyer

View File

@@ -1,14 +1,15 @@
from src.config.config import global_config
from typing import List, Dict, Any, Tuple # 确保类型提示被导入
import time # 导入 time 模块以获取当前时间
import random
import re
from src.common.message_repository import find_messages, count_messages
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable
from typing import List, Dict, Any, Tuple, Optional
from rich.traceback import install
from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable
install(extra_lines=3)
@@ -135,7 +136,7 @@ def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list,
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int:
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -172,7 +173,7 @@ def _build_readable_messages_internal(
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
pic_id_mapping: Dict[str, str] = None,
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1,
show_pic: bool = True,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
@@ -194,7 +195,7 @@ def _build_readable_messages_internal(
if not messages:
return "", [], pic_id_mapping or {}, pic_counter
message_details_raw: List[Tuple[float, str, str]] = []
message_details_raw: List[Tuple[float, str, str, bool]] = []
# 使用传入的映射字典,如果没有则创建新的
if pic_id_mapping is None:
@@ -225,7 +226,7 @@ def _build_readable_messages_internal(
# 检查是否是动作记录
if msg.get("is_action_record", False):
is_action = True
timestamp = msg.get("time")
timestamp: float = msg.get("time") # type: ignore
content = msg.get("display_message", "")
# 对于动作记录也处理图片ID
content = process_pic_ids(content)
@@ -249,9 +250,10 @@ def _build_readable_messages_internal(
user_nickname = user_info.get("user_nickname")
user_cardname = user_info.get("user_cardname")
timestamp = msg.get("time")
timestamp: float = msg.get("time") # type: ignore
content: str
if msg.get("display_message"):
content = msg.get("display_message")
content = msg.get("display_message", "")
else:
content = msg.get("processed_plain_text", "") # 默认空字符串
@@ -271,6 +273,7 @@ def _build_readable_messages_internal(
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager()
# 根据 replace_bot_name 参数决定是否替换机器人名称
person_name: str
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
@@ -289,12 +292,10 @@ def _build_readable_messages_internal(
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content)
if match:
aaa = match.group(1)
bbb = match.group(2)
aaa: str = match[1]
bbb: str = match[2]
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
if not reply_person_name:
reply_person_name = aaa
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa
# 在内容前加上回复信息
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
@@ -309,18 +310,15 @@ def _build_readable_messages_internal(
aaa = m.group(1)
bbb = m.group(2)
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
if not at_person_name:
at_person_name = aaa
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa
new_content += f"@{at_person_name}"
last_end = m.end()
new_content += content[last_end:]
content = new_content
target_str = "这是QQ的一个功能用于提及某人但没那么明显"
if target_str in content:
if random.random() < 0.6:
content = content.replace(target_str, "")
if target_str in content and random.random() < 0.6:
content = content.replace(target_str, "")
if content != "":
message_details_raw.append((timestamp, person_name, content, False))
@@ -470,6 +468,7 @@ def _build_readable_messages_internal(
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
"""
构建图片映射信息字符串,显示图片的具体描述内容
@@ -518,9 +517,7 @@ async def build_readable_messages_with_list(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
# 生成图片映射信息并添加到最前面
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info:
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, details_list
@@ -535,7 +532,7 @@ def build_readable_messages(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
) -> str:
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。
如果提供了 read_mark则在相应位置插入已读标记。
@@ -658,9 +655,7 @@ def build_readable_messages(
# 组合结果
result_parts = []
if pic_mapping_info:
result_parts.append(pic_mapping_info)
result_parts.append("\n")
result_parts.extend((pic_mapping_info, "\n"))
if formatted_before and formatted_after:
result_parts.extend([formatted_before, read_mark_line, formatted_after])
elif formatted_before:
@@ -733,8 +728,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
platform = msg.get("chat_info_platform")
user_id = msg.get("user_id")
_timestamp = msg.get("time")
content: str = ""
if msg.get("display_message"):
content = msg.get("display_message")
content = msg.get("display_message", "")
else:
content = msg.get("processed_plain_text", "")
@@ -829,10 +825,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
person_id = PersonInfoManager.get_person_id(platform, user_id)
# 只有当获取到有效 person_id 时才添加
if person_id:
if person_id := PersonInfoManager.get_person_id(platform, user_id):
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回

View File

@@ -103,7 +103,7 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji")
@@ -154,7 +154,7 @@ class ImageManager:
img_obj.description = description
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist:
except Images.DoesNotExist: # type: ignore
Images.create(
emoji_hash=image_hash,
path=file_path,
@@ -204,7 +204,7 @@ class ImageManager:
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来请留意其主题直观感受输出为一段平文本最多50字"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
@@ -491,7 +491,7 @@ class ImageManager:
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = """请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本"""

View File

@@ -3,7 +3,7 @@ from src.common.database.database import db
from src.common.database.database_model import PersonInfo # 新增导入
import copy
import hashlib
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Union
import datetime
import asyncio
from src.llm_models.utils_model import LLMRequest
@@ -84,7 +84,7 @@ class PersonInfoManager:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod
def get_person_id(platform: str, user_id: int):
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
platform = platform.split("-")[1]

View File

@@ -32,10 +32,10 @@ class BaseAction(ABC):
reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: ChatStream = None,
chat_stream: ChatStream,
log_prefix: str = "",
shutting_down: bool = False,
plugin_config: dict = None,
plugin_config: Optional[dict] = None,
**kwargs,
):
"""初始化Action组件

View File

@@ -29,7 +29,7 @@ class BaseCommand(ABC):
command_examples: List[str] = []
intercept_message: bool = True # 默认拦截消息,不继续处理
def __init__(self, message: MessageRecv, plugin_config: dict = None):
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
"""初始化Command组件
Args:

View File

@@ -66,7 +66,7 @@ class BasePlugin(ABC):
config_section_descriptions: Dict[str, str] = {}
def __init__(self, plugin_dir: str = None):
def __init__(self, plugin_dir: str):
"""初始化插件
Args:
@@ -526,7 +526,7 @@ class BasePlugin(ABC):
# 从配置中更新 enable_plugin
if "plugin" in self.config and "enabled" in self.config["plugin"]:
self.enable_plugin = self.config["plugin"]["enabled"]
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
else:
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")

View File

@@ -81,7 +81,9 @@ class ComponentInfo:
class ActionInfo(ComponentInfo):
"""动作组件信息"""
action_parameters: Dict[str, str] = field(default_factory=dict) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
action_parameters: Dict[str, str] = field(
default_factory=dict
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
action_require: List[str] = field(default_factory=list) # 动作需求说明
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
# 激活类型相关

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Any, Pattern, Union
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
import re
from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
@@ -28,25 +28,25 @@ class ComponentRegistry:
ComponentType.ACTION: {},
ComponentType.COMMAND: {},
}
self._component_classes: Dict[str, Union[BaseCommand, BaseAction]] = {} # 组件名 -> 组件类
self._component_classes: Dict[str, Union[Type[BaseCommand], Type[BaseAction]]] = {} # 组件名 -> 组件类
# 插件注册表
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
# Action特定注册表
self._action_registry: Dict[str, BaseAction] = {} # action名 -> action类
# self._action_descriptions: Dict[str, str] = {} # 启用的action名 -> 描述
self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类
self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集即启用的Action集用于重置ActionManager状态
# Command特定注册表
self._command_registry: Dict[str, BaseCommand] = {} # command名 -> command类
self._command_patterns: Dict[Pattern, BaseCommand] = {} # 编译后的正则 -> command类
self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类
self._command_patterns: Dict[Pattern, Type[BaseCommand]] = {} # 编译后的正则 -> command类
logger.info("组件注册中心初始化完成")
# === 通用组件注册方法 ===
def register_component(
self, component_info: ComponentInfo, component_class: Union[BaseCommand, BaseAction]
self, component_info: ComponentInfo, component_class: Union[Type[BaseCommand], Type[BaseAction]]
) -> bool:
"""注册组件
@@ -88,9 +88,9 @@ class ComponentRegistry:
# 根据组件类型进行特定注册(使用原始名称)
if component_type == ComponentType.ACTION:
self._register_action_component(component_info, component_class)
self._register_action_component(component_info, component_class) # type: ignore
elif component_type == ComponentType.COMMAND:
self._register_command_component(component_info, component_class)
self._register_command_component(component_info, component_class) # type: ignore
logger.debug(
f"已注册{component_type.value}组件: '{component_name}' -> '{namespaced_name}' "
@@ -98,7 +98,7 @@ class ComponentRegistry:
)
return True
def _register_action_component(self, action_info: ActionInfo, action_class: BaseAction):
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]):
# -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR -------------------------------------
"""注册Action组件到Action特定注册表"""
@@ -106,11 +106,10 @@ class ComponentRegistry:
self._action_registry[action_name] = action_class
# 如果启用,添加到默认动作集
# ---- HERE ----
# if action_info.enabled:
# self._action_descriptions[action_name] = action_info.description
if action_info.enabled:
self._default_actions[action_name] = action_info
def _register_command_component(self, command_info: CommandInfo, command_class: BaseCommand):
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]):
"""注册Command组件到Command特定注册表"""
command_name = command_info.name
self._command_registry[command_name] = command_class
@@ -122,7 +121,7 @@ class ComponentRegistry:
# === 组件查询方法 ===
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]:
def get_component_info(self, component_name: str, component_type: ComponentType = None) -> Optional[ComponentInfo]: # type: ignore
# sourcery skip: class-extract-method
"""获取组件信息,支持自动命名空间解析
@@ -170,8 +169,10 @@ class ComponentRegistry:
return None
def get_component_class(
self, component_name: str, component_type: ComponentType = None
) -> Optional[Union[BaseCommand, BaseAction]]:
self,
component_name: str,
component_type: ComponentType = None, # type: ignore
) -> Optional[Union[Type[BaseCommand], Type[BaseAction]]]:
"""获取组件类,支持自动命名空间解析
Args:
@@ -230,7 +231,7 @@ class ComponentRegistry:
# === Action特定查询方法 ===
def get_action_registry(self) -> Dict[str, BaseAction]:
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
"""获取Action注册表用于兼容现有系统"""
return self._action_registry.copy()
@@ -239,13 +240,17 @@ class ComponentRegistry:
info = self.get_component_info(action_name, ComponentType.ACTION)
return info if isinstance(info, ActionInfo) else None
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
# === Command特定查询方法 ===
def get_command_registry(self) -> Dict[str, BaseCommand]:
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
"""获取Command注册表用于兼容现有系统"""
return self._command_registry.copy()
def get_command_patterns(self) -> Dict[Pattern, BaseCommand]:
def get_command_patterns(self) -> Dict[Pattern, Type[BaseCommand]]:
"""获取Command模式注册表用于兼容现有系统"""
return self._command_patterns.copy()
@@ -254,7 +259,7 @@ class ComponentRegistry:
info = self.get_component_info(command_name, ComponentType.COMMAND)
return info if isinstance(info, CommandInfo) else None
def find_command_by_text(self, text: str) -> Optional[tuple[BaseCommand, dict, bool, str]]:
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
# sourcery skip: use-named-expression, use-next
"""根据文本查找匹配的命令
@@ -262,7 +267,7 @@ class ComponentRegistry:
text: 输入文本
Returns:
Optional[tuple[BaseCommand, dict, bool, str]]: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
"""
for pattern, command_class in self._command_patterns.items():