From 2229f98993a72e2a247649d9031104c2cce7ecc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 16 Jul 2025 19:58:19 +0800 Subject: [PATCH 1/5] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8DLPMM=E5=AD=A6?= =?UTF-8?q?=E4=B9=A0=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 151 +++++++++++++------------- scripts/import_openie.py | 42 +++++++ scripts/info_extraction.py | 42 +++++++ src/chat/knowledge/embedding_store.py | 26 ++++- src/chat/knowledge/ie_process.py | 111 ++++++++++++++++--- src/chat/knowledge/prompt_template.py | 33 ++++-- src/llm_models/utils_model.py | 22 ++-- 7 files changed, 313 insertions(+), 114 deletions(-) diff --git a/bot.py b/bot.py index 1a5e6694b..5548c1725 100644 --- a/bot.py +++ b/bot.py @@ -8,6 +8,7 @@ if os.path.exists(".env"): print("成功加载环境变量配置") else: print("未找到.env文件,请确保程序所需的环境变量被正确设置") + raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") import sys import time import platform @@ -140,81 +141,85 @@ async def graceful_shutdown(): logger.error(f"麦麦关闭失败: {e}", exc_info=True) +def _calculate_file_hash(file_path: Path, file_type: str) -> str: + """计算文件的MD5哈希值""" + if not file_path.exists(): + logger.error(f"{file_type} 文件不存在") + raise FileNotFoundError(f"{file_type} 文件不存在") + + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + return hashlib.md5(content.encode("utf-8")).hexdigest() + + +def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]: + """检查协议确认状态 + + Returns: + tuple[bool, bool]: (已确认, 未更新) + """ + # 检查环境变量确认 + if file_hash == os.getenv(env_var): + return True, False + + # 检查确认文件 + if confirm_file.exists(): + with open(confirm_file, "r", encoding="utf-8") as f: + confirmed_content = f.read() + if file_hash == confirmed_content: + return True, False + + return False, True + + +def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None: + """提示用户确认协议""" + confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议") + confirm_logger.critical( + f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_hash}"和"PRIVACY_AGREE={privacy_hash}"继续运行' + ) + + while True: + user_input = input().strip().lower() + if user_input in ["同意", "confirmed"]: + return + confirm_logger.critical('请输入"同意"或"confirmed"以继续运行') + + +def _save_confirmations(eula_updated: bool, privacy_updated: bool, + eula_hash: str, privacy_hash: str) -> None: + """保存用户确认结果""" + if eula_updated: + logger.info(f"更新EULA确认文件{eula_hash}") + Path("eula.confirmed").write_text(eula_hash, encoding="utf-8") + + if privacy_updated: + logger.info(f"更新隐私条款确认文件{privacy_hash}") + Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8") + + def check_eula(): - eula_confirm_file = Path("eula.confirmed") - privacy_confirm_file = Path("privacy.confirmed") - eula_file = Path("EULA.md") - privacy_file = Path("PRIVACY.md") - - eula_updated = True - privacy_updated = True - - eula_confirmed = False - privacy_confirmed = False - - # 首先计算当前EULA文件的哈希值 - if eula_file.exists(): - with open(eula_file, "r", encoding="utf-8") as f: - eula_content = f.read() - eula_new_hash = hashlib.md5(eula_content.encode("utf-8")).hexdigest() - else: - logger.error("EULA.md 文件不存在") - raise FileNotFoundError("EULA.md 文件不存在") - - # 首先计算当前隐私条款文件的哈希值 - if privacy_file.exists(): - with open(privacy_file, "r", encoding="utf-8") as f: - privacy_content = f.read() - privacy_new_hash = hashlib.md5(privacy_content.encode("utf-8")).hexdigest() - else: - logger.error("PRIVACY.md 文件不存在") - raise FileNotFoundError("PRIVACY.md 文件不存在") - - # 检查EULA确认文件是否存在 - if eula_confirm_file.exists(): - with open(eula_confirm_file, "r", encoding="utf-8") as f: - confirmed_content = f.read() - if eula_new_hash == confirmed_content: - eula_confirmed = True - eula_updated = False - if eula_new_hash == os.getenv("EULA_AGREE"): - eula_confirmed = True - eula_updated = False - - # 检查隐私条款确认文件是否存在 - if privacy_confirm_file.exists(): - with open(privacy_confirm_file, "r", encoding="utf-8") as f: - confirmed_content = f.read() - if privacy_new_hash == confirmed_content: - privacy_confirmed = True - privacy_updated = False - if privacy_new_hash == os.getenv("PRIVACY_AGREE"): - privacy_confirmed = True - privacy_updated = False - - # 如果EULA或隐私条款有更新,提示用户重新确认 + """检查EULA和隐私条款确认状态""" + # 计算文件哈希值 + eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md") + privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md") + + # 检查确认状态 + eula_confirmed, eula_updated = _check_agreement_status( + eula_hash, Path("eula.confirmed"), "EULA_AGREE" + ) + privacy_confirmed, privacy_updated = _check_agreement_status( + privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE" + ) + + # 早期返回:如果都已确认且未更新 + if eula_confirmed and privacy_confirmed: + return + + # 如果有更新,需要重新确认 if eula_updated or privacy_updated: - confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议") - confirm_logger.critical( - f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_new_hash}"和"PRIVACY_AGREE={privacy_new_hash}"继续运行' - ) - while True: - user_input = input().strip().lower() - if user_input in ["同意", "confirmed"]: - # print("确认成功,继续运行") - # print(f"确认成功,继续运行{eula_updated} {privacy_updated}") - if eula_updated: - logger.info(f"更新EULA确认文件{eula_new_hash}") - eula_confirm_file.write_text(eula_new_hash, encoding="utf-8") - if privacy_updated: - logger.info(f"更新隐私条款确认文件{privacy_new_hash}") - privacy_confirm_file.write_text(privacy_new_hash, encoding="utf-8") - break - else: - confirm_logger.critical('请输入"同意"或"confirmed"以继续运行') - return - elif eula_confirmed and privacy_confirmed: - return + _prompt_user_confirmation(eula_hash, privacy_hash) + _save_confirmations(eula_updated, privacy_updated, eula_hash, privacy_hash) def raw_main(): diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 791c64672..63a4d9852 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -15,6 +15,7 @@ from src.chat.knowledge.kg_manager import KGManager from src.common.logger import get_logger from src.chat.knowledge.utils.hash import get_sha256 from src.manager.local_store_manager import local_storage +from dotenv import load_dotenv # 添加项目根目录到 sys.path @@ -23,6 +24,45 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") logger = get_logger("OpenIE导入") +ENV_FILE = os.path.join(ROOT_PATH, ".env") + +if os.path.exists(".env"): + load_dotenv(".env", override=True) + print("成功加载环境变量配置") +else: + print("未找到.env文件,请确保程序所需的环境变量被正确设置") + raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") + +env_mask = {key: os.getenv(key) for key in os.environ} +def scan_provider(env_config: dict): + provider = {} + + # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 + # 避免 GPG_KEY 这样的变量干扰检查 + env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) + + # 遍历 env_config 的所有键 + for key in env_config: + # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 + if key.endswith("_BASE_URL") or key.endswith("_KEY"): + # 提取 provider 名称 + provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 + + # 初始化 provider 的字典(如果尚未初始化) + if provider_name not in provider: + provider[provider_name] = {"url": None, "key": None} + + # 根据键的类型填充 url 或 key + if key.endswith("_BASE_URL"): + provider[provider_name]["url"] = env_config[key] + elif key.endswith("_KEY"): + provider[provider_name]["key"] = env_config[key] + + # 检查每个 provider 是否同时存在 url 和 key + for provider_name, config in provider.items(): + if config["url"] is None or config["key"] is None: + logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") + raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") def ensure_openie_dir(): """确保OpenIE数据目录存在""" @@ -174,6 +214,8 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k def main(): # sourcery skip: dict-comprehension # 新增确认提示 + env_config = {key: os.getenv(key) for key in os.environ} + scan_provider(env_config) print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 90f0c80ea..c36a77892 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -27,6 +27,7 @@ from rich.progress import ( from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data from src.config.config import global_config from src.llm_models.utils_model import LLMRequest +from dotenv import load_dotenv logger = get_logger("LPMM知识库-信息提取") @@ -35,6 +36,45 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") +ENV_FILE = os.path.join(ROOT_PATH, ".env") + +if os.path.exists(".env"): + load_dotenv(".env", override=True) + print("成功加载环境变量配置") +else: + print("未找到.env文件,请确保程序所需的环境变量被正确设置") + raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") + +env_mask = {key: os.getenv(key) for key in os.environ} +def scan_provider(env_config: dict): + provider = {} + + # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 + # 避免 GPG_KEY 这样的变量干扰检查 + env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) + + # 遍历 env_config 的所有键 + for key in env_config: + # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 + if key.endswith("_BASE_URL") or key.endswith("_KEY"): + # 提取 provider 名称 + provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 + + # 初始化 provider 的字典(如果尚未初始化) + if provider_name not in provider: + provider[provider_name] = {"url": None, "key": None} + + # 根据键的类型填充 url 或 key + if key.endswith("_BASE_URL"): + provider[provider_name]["url"] = env_config[key] + elif key.endswith("_KEY"): + provider[provider_name]["key"] = env_config[key] + + # 检查每个 provider 是否同时存在 url 和 key + for provider_name, config in provider.items(): + if config["url"] is None or config["key"] is None: + logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") + raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") def ensure_dirs(): """确保临时目录和输出目录存在""" @@ -118,6 +158,8 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method # 设置信号处理器 signal.signal(signal.SIGINT, signal_handler) ensure_dirs() # 确保目录存在 + env_config = {key: os.getenv(key) for key in os.environ} + scan_provider(env_config) # 新增用户确认提示 print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 3eb466d21..808b8013b 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import json import os import math +import asyncio from typing import Dict, List, Tuple import numpy as np @@ -99,7 +100,30 @@ class EmbeddingStore: self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - return get_embedding(s) + """获取字符串的嵌入向量,处理异步调用""" + try: + # 尝试获取当前事件循环 + asyncio.get_running_loop() + # 如果在事件循环中,使用线程池执行 + import concurrent.futures + + def run_in_thread(): + return asyncio.run(get_embedding(s)) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + result = future.result() + if result is None: + logger.error(f"获取嵌入失败: {s}") + return [] + return result + except RuntimeError: + # 没有运行的事件循环,直接运行 + result = asyncio.run(get_embedding(s)) + if result is None: + logger.error(f"获取嵌入失败: {s}") + return [] + return result def get_test_file_path(self): return EMBEDDING_TEST_FILE diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index bd0e17684..16d4e0804 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -1,3 +1,4 @@ +import asyncio import json import time from typing import List, Union @@ -7,8 +8,12 @@ from . import prompt_template from .knowledge_lib import INVALID_ENTITY from src.llm_models.utils_model import LLMRequest from json_repair import repair_json -def _extract_json_from_text(text: str) -> dict: +def _extract_json_from_text(text: str): """从文本中提取JSON数据的高容错方法""" + if text is None: + logger.error("输入文本为None") + return [] + try: fixed_json = repair_json(text) if isinstance(fixed_json, str): @@ -16,23 +21,66 @@ def _extract_json_from_text(text: str) -> dict: else: parsed_json = fixed_json - if isinstance(parsed_json, list) and parsed_json: - parsed_json = parsed_json[0] - - if isinstance(parsed_json, dict): + # 如果是列表,直接返回 + if isinstance(parsed_json, list): return parsed_json + + # 如果是字典且只有一个项目,可能包装了列表 + if isinstance(parsed_json, dict): + # 如果字典只有一个键,并且值是列表,返回那个列表 + if len(parsed_json) == 1: + value = list(parsed_json.values())[0] + if isinstance(value, list): + return value + return parsed_json + + # 其他情况,尝试转换为列表 + logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}") + return [] except Exception as e: - logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...") + logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...") + return [] def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_entity_extract_context(paragraph) - response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context) + + # 使用 asyncio.run 来运行异步方法 + try: + # 如果当前已有事件循环在运行,使用它 + loop = asyncio.get_running_loop() + future = asyncio.run_coroutine_threadsafe( + llm_req.generate_response_async(entity_extract_context), loop + ) + response, (reasoning_content, model_name) = future.result() + except RuntimeError: + # 如果没有运行中的事件循环,直接使用 asyncio.run + response, (reasoning_content, model_name) = asyncio.run( + llm_req.generate_response_async(entity_extract_context) + ) + # 添加调试日志 + logger.debug(f"LLM返回的原始响应: {response}") + entity_extract_result = _extract_json_from_text(response) - # 尝试load JSON数据 - json.loads(entity_extract_result) + + # 检查返回的是否为有效的实体列表 + if not isinstance(entity_extract_result, list): + # 如果不是列表,可能是字典格式,尝试从中提取列表 + if isinstance(entity_extract_result, dict): + # 尝试常见的键名 + for key in ['entities', 'result', 'data', 'items']: + if key in entity_extract_result and isinstance(entity_extract_result[key], list): + entity_extract_result = entity_extract_result[key] + break + else: + # 如果找不到合适的列表,抛出异常 + raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") + else: + raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") + + # 过滤无效实体 entity_extract_result = [ entity for entity in entity_extract_result @@ -50,16 +98,47 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) - response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context) + + # 使用 asyncio.run 来运行异步方法 + try: + # 如果当前已有事件循环在运行,使用它 + loop = asyncio.get_running_loop() + future = asyncio.run_coroutine_threadsafe( + llm_req.generate_response_async(rdf_extract_context), loop + ) + response, (reasoning_content, model_name) = future.result() + except RuntimeError: + # 如果没有运行中的事件循环,直接使用 asyncio.run + response, (reasoning_content, model_name) = asyncio.run( + llm_req.generate_response_async(rdf_extract_context) + ) - entity_extract_result = _extract_json_from_text(response) - # 尝试load JSON数据 - json.loads(entity_extract_result) - for triple in entity_extract_result: - if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: + # 添加调试日志 + logger.debug(f"RDF LLM返回的原始响应: {response}") + + rdf_triple_result = _extract_json_from_text(response) + + # 检查返回的是否为有效的三元组列表 + if not isinstance(rdf_triple_result, list): + # 如果不是列表,可能是字典格式,尝试从中提取列表 + if isinstance(rdf_triple_result, dict): + # 尝试常见的键名 + for key in ['triples', 'result', 'data', 'items']: + if key in rdf_triple_result and isinstance(rdf_triple_result[key], list): + rdf_triple_result = rdf_triple_result[key] + break + else: + # 如果找不到合适的列表,抛出异常 + raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}") + else: + raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}") + + # 验证三元组格式 + for triple in rdf_triple_result: + if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: raise Exception("RDF提取结果格式错误") - return entity_extract_result + return rdf_triple_result def info_extract_from_str( diff --git a/src/chat/knowledge/prompt_template.py b/src/chat/knowledge/prompt_template.py index 14a360083..fe5a293c0 100644 --- a/src/chat/knowledge/prompt_template.py +++ b/src/chat/knowledge/prompt_template.py @@ -11,12 +11,14 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统 """ -def build_entity_extract_context(paragraph: str) -> list[LLMMessage]: - messages = [ - LLMMessage("system", entity_extract_system_prompt).to_dict(), - LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(), - ] - return messages +def build_entity_extract_context(paragraph: str) -> str: + """构建实体提取的完整提示文本""" + return f"""{entity_extract_system_prompt} + +段落: +``` +{paragraph} +```""" rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。 @@ -36,12 +38,19 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描 """ -def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]: - messages = [ - LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(), - LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(), - ] - return messages +def build_rdf_triple_extract_context(paragraph: str, entities: str) -> str: + """构建RDF三元组提取的完整提示文本""" + return f"""{rdf_triple_extract_system_prompt} + +段落: +``` +{paragraph} +``` + +实体列表: +``` +{entities} +```""" qa_system_prompt = """ diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 1077cfa09..b9a419c33 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -255,12 +255,11 @@ class LLMRequest: if self.temp != 0.7: payload["temperature"] = self.temp - # 添加enable_thinking参数(如果不是默认值False) - if not self.enable_thinking: - payload["enable_thinking"] = False - - if self.thinking_budget != 4096: - payload["thinking_budget"] = self.thinking_budget + # 添加enable_thinking参数(仅在启用时添加) + if self.enable_thinking: + payload["enable_thinking"] = True + if self.thinking_budget != 4096: + payload["thinking_budget"] = self.thinking_budget if self.max_tokens: payload["max_tokens"] = self.max_tokens @@ -670,12 +669,11 @@ class LLMRequest: if self.temp != 0.7: payload["temperature"] = self.temp - # 添加enable_thinking参数(如果不是默认值False) - if not self.enable_thinking: - payload["enable_thinking"] = False - - if self.thinking_budget != 4096: - payload["thinking_budget"] = self.thinking_budget + # 添加enable_thinking参数(仅在启用时添加) + if self.enable_thinking: + payload["enable_thinking"] = True + if self.thinking_budget != 4096: + payload["thinking_budget"] = self.thinking_budget if self.max_tokens: payload["max_tokens"] = self.max_tokens From eb716f1e469dc4f32680c7755fd6fd5265674572 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 16 Jul 2025 21:02:01 +0800 Subject: [PATCH 2/5] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8D=E5=AE=9E?= =?UTF-8?q?=E4=BD=93=E5=92=8C=E6=AE=B5=E8=90=BD=E8=8A=82=E7=82=B9=E4=B8=8D?= =?UTF-8?q?=E5=AD=98=E5=9C=A8=E6=97=B6=E7=9A=84=E5=A4=84=E7=90=86=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/knowledge/kg_manager.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index 38f883d0e..e18a7da80 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -184,10 +184,10 @@ class KGManager: progress.update(task, advance=1) continue ent = embedding_manager.entities_embedding_store.store.get(ent_hash) - assert isinstance(ent, EmbeddingStoreItem) if ent is None: progress.update(task, advance=1) continue + assert isinstance(ent, EmbeddingStoreItem) # 查询相似实体 similar_ents = embedding_manager.entities_embedding_store.search_top_k( ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] @@ -265,7 +265,10 @@ class KGManager: if node_hash not in existed_nodes: if node_hash.startswith(local_storage['ent_namespace']): # 新增实体节点 - node = embedding_manager.entities_embedding_store.store[node_hash] + node = embedding_manager.entities_embedding_store.store.get(node_hash) + if node is None: + logger.warning(f"实体节点 {node_hash} 在嵌入库中不存在,跳过") + continue assert isinstance(node, EmbeddingStoreItem) node_item = self.graph[node_hash] node_item["content"] = node.str @@ -274,7 +277,10 @@ class KGManager: self.graph.update_node(node_item) elif node_hash.startswith(local_storage['pg_namespace']): # 新增文段节点 - node = embedding_manager.paragraphs_embedding_store.store[node_hash] + node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) + if node is None: + logger.warning(f"段落节点 {node_hash} 在嵌入库中不存在,跳过") + continue assert isinstance(node, EmbeddingStoreItem) content = node.str.replace("\n", " ") node_item = self.graph[node_hash] From 1aa2734d62d2eebb25fefbe96e56d41a8dcfc216 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 17 Jul 2025 00:10:41 +0800 Subject: [PATCH 3/5] typing fix --- bot.py | 27 ++--- src/chat/express/expression_learner.py | 71 ++++++----- src/chat/memory_system/instant_memory.py | 92 +++++++------- src/chat/message_receive/chat_stream.py | 2 +- src/chat/message_receive/message.py | 21 ++-- src/chat/planner_actions/action_manager.py | 2 +- src/chat/utils/chat_message_builder.py | 6 +- src/chat/utils/statistic.py | 37 +++--- src/chat/willing/willing_manager.py | 5 +- src/common/database/database_model.py | 39 +++--- src/config/auto_update.py | 16 +-- src/config/config.py | 37 +++--- src/individuality/not_using/offline_llm.py | 4 +- src/individuality/not_using/per_bf_gen.py | 7 +- src/main.py | 6 +- src/mood/mood_manager.py | 3 + src/person_info/relationship_builder.py | 2 +- src/plugin_system/__init__.py | 6 +- src/plugin_system/apis/send_api.py | 39 +++--- src/plugin_system/base/base_action.py | 24 ++-- src/plugin_system/base/base_command.py | 2 +- src/plugin_system/core/component_registry.py | 112 +++++++++--------- src/plugin_system/core/dependency_manager.py | 4 +- src/plugin_system/core/plugin_manager.py | 44 +++---- .../tool_can_use/compare_numbers_tool.py | 6 +- src/tools/tool_can_use/rename_person_tool.py | 8 +- 26 files changed, 329 insertions(+), 293 deletions(-) diff --git a/bot.py b/bot.py index 5548c1725..72ea65d29 100644 --- a/bot.py +++ b/bot.py @@ -146,7 +146,7 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str: if not file_path.exists(): logger.error(f"{file_type} 文件不存在") raise FileNotFoundError(f"{file_type} 文件不存在") - + with open(file_path, "r", encoding="utf-8") as f: content = f.read() return hashlib.md5(content.encode("utf-8")).hexdigest() @@ -154,21 +154,21 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str: def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]: """检查协议确认状态 - + Returns: tuple[bool, bool]: (已确认, 未更新) """ # 检查环境变量确认 if file_hash == os.getenv(env_var): return True, False - + # 检查确认文件 if confirm_file.exists(): with open(confirm_file, "r", encoding="utf-8") as f: confirmed_content = f.read() if file_hash == confirmed_content: return True, False - + return False, True @@ -178,7 +178,7 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None: confirm_logger.critical( f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_hash}"和"PRIVACY_AGREE={privacy_hash}"继续运行' ) - + while True: user_input = input().strip().lower() if user_input in ["同意", "confirmed"]: @@ -186,13 +186,12 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None: confirm_logger.critical('请输入"同意"或"confirmed"以继续运行') -def _save_confirmations(eula_updated: bool, privacy_updated: bool, - eula_hash: str, privacy_hash: str) -> None: +def _save_confirmations(eula_updated: bool, privacy_updated: bool, eula_hash: str, privacy_hash: str) -> None: """保存用户确认结果""" if eula_updated: logger.info(f"更新EULA确认文件{eula_hash}") Path("eula.confirmed").write_text(eula_hash, encoding="utf-8") - + if privacy_updated: logger.info(f"更新隐私条款确认文件{privacy_hash}") Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8") @@ -203,19 +202,17 @@ def check_eula(): # 计算文件哈希值 eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md") privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md") - + # 检查确认状态 - eula_confirmed, eula_updated = _check_agreement_status( - eula_hash, Path("eula.confirmed"), "EULA_AGREE" - ) + eula_confirmed, eula_updated = _check_agreement_status(eula_hash, Path("eula.confirmed"), "EULA_AGREE") privacy_confirmed, privacy_updated = _check_agreement_status( privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE" ) - + # 早期返回:如果都已确认且未更新 if eula_confirmed and privacy_confirmed: return - + # 如果有更新,需要重新确认 if eula_updated or privacy_updated: _prompt_user_confirmation(eula_hash, privacy_hash) @@ -225,7 +222,7 @@ def check_eula(): def raw_main(): # 利用 TZ 环境变量设定程序工作的时区 if platform.system().lower() != "windows": - time.tzset() + time.tzset() # type: ignore check_eula() logger.info("检查EULA和隐私条款完成") diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 4139c65a5..e02ff7311 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -107,11 +107,12 @@ class ExpressionLearner: last_active_time = expr.get("last_active_time", time.time()) # 查重:同chat_id+type+situation+style from src.common.database.database_model import Expression + query = Expression.select().where( - (Expression.chat_id == chat_id) & - (Expression.type == type_str) & - (Expression.situation == situation) & - (Expression.style == style_val) + (Expression.chat_id == chat_id) + & (Expression.type == type_str) + & (Expression.situation == situation) + & (Expression.style == style_val) ) if query.exists(): expr_obj = query.get() @@ -125,7 +126,7 @@ class ExpressionLearner: count=count, last_active_time=last_active_time, chat_id=chat_id, - type=type_str + type=type_str, ) logger.info(f"已迁移 {expr_file} 到数据库") except Exception as e: @@ -149,24 +150,28 @@ class ExpressionLearner: # 直接从数据库查询 style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style")) for expr in style_query: - learnt_style_expressions.append({ - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": chat_id, - "type": "style" - }) + learnt_style_expressions.append( + { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": chat_id, + "type": "style", + } + ) grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar")) for expr in grammar_query: - learnt_grammar_expressions.append({ - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": chat_id, - "type": "grammar" - }) + learnt_grammar_expressions.append( + { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": chat_id, + "type": "grammar", + } + ) return learnt_style_expressions, learnt_grammar_expressions def is_similar(self, s1: str, s2: str) -> bool: @@ -213,14 +218,16 @@ class ExpressionLearner: logger.error(f"全局衰减{type}表达方式失败: {e}") continue + learnt_style: Optional[List[Tuple[str, str, str]]] = [] + learnt_grammar: Optional[List[Tuple[str, str, str]]] = [] # 学习新的表达方式(这里会进行局部衰减) for _ in range(3): - learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25) + learnt_style = await self.learn_and_store(type="style", num=25) if not learnt_style: return [], [] for _ in range(1): - learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10) + learnt_grammar = await self.learn_and_store(type="grammar", num=10) if not learnt_grammar: return [], [] @@ -321,10 +328,10 @@ class ExpressionLearner: for new_expr in expr_list: # 查找是否已存在相似表达方式 query = Expression.select().where( - (Expression.chat_id == chat_id) & - (Expression.type == type) & - (Expression.situation == new_expr["situation"]) & - (Expression.style == new_expr["style"]) + (Expression.chat_id == chat_id) + & (Expression.type == type) + & (Expression.situation == new_expr["situation"]) + & (Expression.style == new_expr["style"]) ) if query.exists(): expr_obj = query.get() @@ -342,13 +349,17 @@ class ExpressionLearner: count=1, last_active_time=current_time, chat_id=chat_id, - type=type + type=type, ) # 限制最大数量 - exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc())) + exprs = list( + Expression.select() + .where((Expression.chat_id == chat_id) & (Expression.type == type)) + .order_by(Expression.count.asc()) + ) if len(exprs) > MAX_EXPRESSION_COUNT: # 删除count最小的多余表达方式 - for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]: + for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: expr.delete_instance() return learnt_expressions diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index 5b38bbb0b..f7e54f8e9 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -9,51 +9,49 @@ from src.common.logger import get_logger import traceback from src.config.config import global_config -from src.common.database.database_model import Memory # Peewee Models导入 +from src.common.database.database_model import Memory # Peewee Models导入 logger = get_logger(__name__) + class MemoryItem: - def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]): + def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]): self.memory_id = memory_id self.chat_id = chat_id - self.memory_text:str = memory_text - self.keywords:list[str] = keywords - self.create_time:float = time.time() - self.last_view_time:float = time.time() - + self.memory_text: str = memory_text + self.keywords: list[str] = keywords + self.create_time: float = time.time() + self.last_view_time: float = time.time() + + class MemoryManager: def __init__(self): # self.memory_items:list[MemoryItem] = [] pass - - - class InstantMemory: - def __init__(self,chat_id): - self.chat_id = chat_id + def __init__(self, chat_id): + self.chat_id = chat_id self.last_view_time = time.time() self.summary_model = LLMRequest( model=global_config.model.memory, temperature=0.5, request_type="memory.summary", ) - - async def if_need_build(self,text): + + async def if_need_build(self, text): prompt = f""" 请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0 {text} 请只输出1或0就好 """ - + try: - response,_ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt) print(prompt) print(response) - - + if "1" in response: return True else: @@ -61,8 +59,8 @@ class InstantMemory: except Exception as e: logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}") return False - - async def build_memory(self,text): + + async def build_memory(self, text): prompt = f""" 以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出 {text} @@ -73,7 +71,7 @@ class InstantMemory: }} """ try: - response,_ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt) print(prompt) print(response) if not response: @@ -81,53 +79,53 @@ class InstantMemory: try: repaired = repair_json(response) result = json.loads(repaired) - memory_text = result.get('memory_text', '') - keywords = result.get('keywords', '') + memory_text = result.get("memory_text", "") + keywords = result.get("keywords", "") if isinstance(keywords, str): - keywords_list = [k.strip() for k in keywords.split('/') if k.strip()] + keywords_list = [k.strip() for k in keywords.split("/") if k.strip()] elif isinstance(keywords, list): keywords_list = keywords else: keywords_list = [] - return {'memory_text': memory_text, 'keywords': keywords_list} + return {"memory_text": memory_text, "keywords": keywords_list} except Exception as parse_e: logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}") return None except Exception as e: logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}") return None - - async def create_and_store_memory(self,text): + async def create_and_store_memory(self, text): if_need = await self.if_need_build(text) if if_need: logger.info(f"需要记忆:{text}") - memory = await self.build_memory(text) - if memory and memory.get('memory_text'): + memory = await self.build_memory(text) + if memory and memory.get("memory_text"): memory_id = f"{self.chat_id}_{time.time()}" memory_item = MemoryItem( memory_id=memory_id, chat_id=self.chat_id, - memory_text=memory['memory_text'], - keywords=memory.get('keywords', []) + memory_text=memory["memory_text"], + keywords=memory.get("keywords", []), ) await self.store_memory(memory_item) else: logger.info(f"不需要记忆:{text}") - - async def store_memory(self,memory_item:MemoryItem): + + async def store_memory(self, memory_item: MemoryItem): memory = Memory( memory_id=memory_item.memory_id, chat_id=memory_item.chat_id, memory_text=memory_item.memory_text, keywords=memory_item.keywords, create_time=memory_item.create_time, - last_view_time=memory_item.last_view_time + last_view_time=memory_item.last_view_time, ) memory.save() - - async def get_memory(self,target:str): + + async def get_memory(self, target: str): from json_repair import repair_json + prompt = f""" 请根据以下发言内容,判断是否需要提取记忆 {target} @@ -144,7 +142,7 @@ class InstantMemory: 请只输出json格式,不要输出其他多余内容 """ try: - response,_ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt) print(prompt) print(response) if not response: @@ -153,15 +151,15 @@ class InstantMemory: repaired = repair_json(response) result = json.loads(repaired) # 解析keywords - keywords = result.get('keywords', '') + keywords = result.get("keywords", "") if isinstance(keywords, str): - keywords_list = [k.strip() for k in keywords.split('/') if k.strip()] + keywords_list = [k.strip() for k in keywords.split("/") if k.strip()] elif isinstance(keywords, list): keywords_list = keywords else: keywords_list = [] # 解析time为时间段 - time_str = result.get('time', '').strip() + time_str = result.get("time", "").strip() start_time, end_time = self._parse_time_range(time_str) logger.info(f"start_time: {start_time}, end_time: {end_time}") # 检索包含关键词的记忆 @@ -170,16 +168,15 @@ class InstantMemory: start_ts = start_time.timestamp() end_ts = end_time.timestamp() query = Memory.select().where( - (Memory.chat_id == self.chat_id) & - (Memory.create_time >= start_ts) & - (Memory.create_time < end_ts) + (Memory.chat_id == self.chat_id) + & (Memory.create_time >= start_ts) # type: ignore + & (Memory.create_time < end_ts) # type: ignore ) else: query = Memory.select().where(Memory.chat_id == self.chat_id) - for mem in query: - #对每条记忆 + # 对每条记忆 mem_keywords = mem.keywords or [] parsed = ast.literal_eval(mem_keywords) if isinstance(parsed, list): @@ -212,6 +209,7 @@ class InstantMemory: - 空字符串:返回(None, None) """ from datetime import datetime, timedelta + now = datetime.now() if not time_str: return 0, now @@ -251,8 +249,8 @@ class InstantMemory: if m: months = int(m.group(1)) # 近似每月30天 - start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0) + start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0) end = start + timedelta(days=1) return start, end # 其他无法解析 - return 0, now \ No newline at end of file + return 0, now diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 8b71314a6..e4a61900e 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -30,7 +30,7 @@ class ChatMessageContext: def get_template_name(self) -> Optional[str]: """获取模板名称""" if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: - return self.message.message_info.template_info.template_name + return self.message.message_info.template_info.template_name # type: ignore return None def get_last_message(self) -> "MessageRecv": diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index e6b6741f0..487c7d036 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -107,9 +107,9 @@ class MessageRecv(Message): self.is_picid = False self.has_picid = False self.is_mentioned = None - + self.is_command = False - + self.priority_mode = "interest" self.priority_info = None self.interest_value: float = None # type: ignore @@ -181,6 +181,7 @@ class MessageRecv(Message): logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" + @dataclass class MessageRecvS4U(MessageRecv): def __init__(self, message_dict: dict[str, Any]): @@ -194,10 +195,10 @@ class MessageRecvS4U(MessageRecv): self.superchat_price = None self.superchat_message_text = None self.is_screen = False - + async def process(self) -> None: self.processed_plain_text = await self._process_message_segments(self.message_segment) - + async def _process_single_segment(self, segment: Seg) -> str: """处理单个消息段 @@ -252,7 +253,7 @@ class MessageRecvS4U(MessageRecv): elif segment.type == "gift": self.is_gift = True # 解析gift_info,格式为"名称:数量" - name, count = segment.data.split(":", 1) + name, count = segment.data.split(":", 1) # type: ignore self.gift_info = segment.data self.gift_name = name.strip() self.gift_count = int(count.strip()) @@ -260,13 +261,15 @@ class MessageRecvS4U(MessageRecv): elif segment.type == "superchat": self.is_superchat = True self.superchat_info = segment.data - price,message_text = segment.data.split(":", 1) + price, message_text = segment.data.split(":", 1) # type: ignore self.superchat_price = price.strip() self.superchat_message_text = message_text.strip() - + self.processed_plain_text = str(self.superchat_message_text) - self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)" - + self.processed_plain_text += ( + f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)" + ) + return self.processed_plain_text elif segment.type == "screen": self.is_screen = True diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 6c82625b3..a4876a46d 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -80,7 +80,7 @@ class ActionManager: chat_stream: ChatStream, log_prefix: str, shutting_down: bool = False, - action_message: dict = None, + action_message: Optional[dict] = None, ) -> Optional[BaseAction]: """ 创建动作处理器实例 diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index bb32e63a2..3a08ca72b 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -252,7 +252,7 @@ def _build_readable_messages_internal( pic_id_mapping: Optional[Dict[str, str]] = None, pic_counter: int = 1, show_pic: bool = True, - message_id_list: List[Dict[str, Any]] = None, + message_id_list: Optional[List[Dict[str, Any]]] = None, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str: for action in actions: action_time = action.get("time", current_time) action_name = action.get("action_name", "未知动作") - if action_name == "no_action" or action_name == "no_reply": + if action_name in ["no_action", "no_reply"]: continue action_prompt_display = action.get("action_prompt_display", "无具体内容") @@ -697,7 +697,7 @@ def build_readable_messages( truncate: bool = False, show_actions: bool = False, show_pic: bool = True, - message_id_list: List[Dict[str, Any]] = None, + message_id_list: Optional[List[Dict[str, Any]]] = None, ) -> str: # sourcery skip: extract-method """ 将消息列表转换为可读的文本格式。 diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 4e0edd31f..0aff5102e 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask): f.write(html_template) def _generate_focus_tab(self, stat: dict[str, Any]) -> str: - # sourcery skip: for-append-to-extend, list-comprehension, use-any + # sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next """生成Focus统计独立分页的HTML内容""" # 为每个时间段准备Focus数据 @@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask): """ def _generate_versions_tab(self, stat: dict[str, Any]) -> str: + # sourcery skip: use-named-expression, use-next """生成版本对比独立分页的HTML内容""" # 为每个时间段准备版本对比数据 @@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask): # 复用 StatisticOutputTask 的所有方法 def _collect_all_statistics(self, now: datetime): - return StatisticOutputTask._collect_all_statistics(self, now) + return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): - return StatisticOutputTask._statistic_console_output(self, stats, now) + return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore def _generate_html_report(self, stats: dict[str, Any], now: datetime): - return StatisticOutputTask._generate_html_report(self, stats, now) + return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore # 其他需要的方法也可以类似复用... @staticmethod @@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask): return StatisticOutputTask._collect_online_time_for_period(collect_period, now) def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_message_count_for_period(self, collect_period) + return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) + return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore def _process_focus_file_data( self, @@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask): collect_period: List[Tuple[str, datetime]], file_time: datetime, ): - return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) + return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore def _calculate_focus_averages(self, stats: Dict[str, Any]): - return StatisticOutputTask._calculate_focus_averages(self, stats) + return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore @staticmethod def _format_total_stat(stats: Dict[str, Any]) -> str: @@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask): @staticmethod def _format_model_classified_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_model_classified_stat(stats) + return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore def _format_chat_stat(self, stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_chat_stat(self, stats) + return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore def _format_focus_stat(self, stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_focus_stat(self, stats) + return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore def _generate_chart_data(self, stat: dict[str, Any]) -> dict: - return StatisticOutputTask._generate_chart_data(self, stat) + return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: - return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) + return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore def _generate_chart_tab(self, chart_data: dict) -> str: - return StatisticOutputTask._generate_chart_tab(self, chart_data) + return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore def _get_chat_display_name_from_id(self, chat_id: str) -> str: - return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) + return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore def _generate_focus_tab(self, stat: dict[str, Any]) -> str: - return StatisticOutputTask._generate_focus_tab(self, stat) + return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore def _generate_versions_tab(self, stat: dict[str, Any]) -> str: - return StatisticOutputTask._generate_versions_tab(self, stat) + return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore def _convert_defaultdict_to_dict(self, data): - return StatisticOutputTask._convert_defaultdict_to_dict(self, data) + return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore diff --git a/src/chat/willing/willing_manager.py b/src/chat/willing/willing_manager.py index 29110ef94..6c53273f5 100644 --- a/src/chat/willing/willing_manager.py +++ b/src/chat/willing/willing_manager.py @@ -2,14 +2,13 @@ import importlib import asyncio from abc import ABC, abstractmethod -from typing import Dict, Optional +from typing import Dict, Optional, Any from rich.traceback import install from dataclasses import dataclass from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_stream import ChatStream, GroupInfo -from src.chat.message_receive.message import MessageRecv from src.person_info.person_info import PersonInfoManager, get_person_info_manager install(extra_lines=3) @@ -54,7 +53,7 @@ class WillingInfo: interested_rate (float): 兴趣度 """ - message: MessageRecv + message: Dict[str, Any] # 原始消息数据 chat: ChatStream person_info_manager: PersonInfoManager chat_id: str diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 8258ac9fb..4b60dfa10 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -65,7 +65,7 @@ class ChatStreams(BaseModel): # user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。 user_cardname = TextField(null=True) - class Meta: + class Meta: # type: ignore # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它, # 请取消注释并在下面设置数据库实例: @@ -89,7 +89,7 @@ class LLMUsage(BaseModel): status = TextField() timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 - class Meta: + class Meta: # type: ignore # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # database = db table_name = "llm_usage" @@ -112,7 +112,7 @@ class Emoji(BaseModel): usage_count = IntegerField(default=0) # 使用次数(被使用的次数) last_used_time = FloatField(null=True) # 上次使用时间 - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "emoji" @@ -162,7 +162,8 @@ class Messages(BaseModel): is_emoji = BooleanField(default=False) is_picid = BooleanField(default=False) is_command = BooleanField(default=False) - class Meta: + + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "messages" @@ -186,7 +187,7 @@ class ActionRecords(BaseModel): chat_info_stream_id = TextField() chat_info_platform = TextField() - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "action_records" @@ -206,7 +207,7 @@ class Images(BaseModel): type = TextField() # 图像类型,例如 "emoji" vlm_processed = BooleanField(default=False) # 是否已经过VLM处理 - class Meta: + class Meta: # type: ignore table_name = "images" @@ -220,7 +221,7 @@ class ImageDescriptions(BaseModel): description = TextField() # 图像的描述 timestamp = FloatField() # 时间戳 - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "image_descriptions" @@ -236,7 +237,7 @@ class OnlineTime(BaseModel): start_timestamp = DateTimeField(default=datetime.datetime.now) end_timestamp = DateTimeField(index=True) - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "online_time" @@ -263,10 +264,11 @@ class PersonInfo(BaseModel): last_know = FloatField(null=True) # 最后一次印象总结时间 attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢 - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "person_info" + class Memory(BaseModel): memory_id = TextField(index=True) chat_id = TextField(null=True) @@ -274,10 +276,11 @@ class Memory(BaseModel): keywords = TextField(null=True) create_time = FloatField(null=True) last_view_time = FloatField(null=True) - - class Meta: + + class Meta: # type: ignore table_name = "memory" + class Knowledges(BaseModel): """ 用于存储知识库条目的模型。 @@ -287,10 +290,11 @@ class Knowledges(BaseModel): embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 # 可以添加其他元数据字段,如 source, create_time 等 - class Meta: + class Meta: # type: ignore # database = db # 继承自 BaseModel table_name = "knowledges" + class Expression(BaseModel): """ 用于存储表达风格的模型。 @@ -302,10 +306,11 @@ class Expression(BaseModel): last_active_time = FloatField() chat_id = TextField(index=True) type = TextField() - - class Meta: + + class Meta: # type: ignore table_name = "expression" + class ThinkingLog(BaseModel): chat_id = TextField(index=True) trigger_text = TextField(null=True) @@ -326,7 +331,7 @@ class ThinkingLog(BaseModel): # And: import datetime created_at = DateTimeField(default=datetime.datetime.now) - class Meta: + class Meta: # type: ignore table_name = "thinking_logs" @@ -341,7 +346,7 @@ class GraphNodes(BaseModel): created_time = FloatField() # 创建时间戳 last_modified = FloatField() # 最后修改时间戳 - class Meta: + class Meta: # type: ignore table_name = "graph_nodes" @@ -357,7 +362,7 @@ class GraphEdges(BaseModel): created_time = FloatField() # 创建时间戳 last_modified = FloatField() # 最后修改时间戳 - class Meta: + class Meta: # type: ignore table_name = "graph_edges" diff --git a/src/config/auto_update.py b/src/config/auto_update.py index 355ebc55a..8d097ec49 100644 --- a/src/config/auto_update.py +++ b/src/config/auto_update.py @@ -7,13 +7,13 @@ from datetime import datetime def get_key_comment(toml_table, key): # 获取key的注释(如果有) - if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'): + if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): return toml_table.trivia.comment - if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict): + if hasattr(toml_table, "value") and isinstance(toml_table.value, dict): item = toml_table.value.get(key) - if item is not None and hasattr(item, 'trivia'): + if item is not None and hasattr(item, "trivia"): return item.trivia.comment - if hasattr(toml_table, 'keys'): + if hasattr(toml_table, "keys"): for k in toml_table.keys(): if isinstance(k, KeyType) and k.key == key: return k.trivia.comment @@ -36,16 +36,16 @@ def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, log continue if key not in old: comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): - compare_dicts(new[key], old[key], path+[str(key)], new_comments, old_comments, logs) + compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs) # 删减项 for key in old: if key == "version": continue if key not in new: comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") return logs @@ -95,7 +95,7 @@ def update_config(): if old_version and new_version and old_version == new_version: print(f"检测到版本号相同 (v{old_version}),跳过更新") # 如果version相同,恢复旧配置文件并返回 - shutil.move(old_backup_path, old_config_path) + shutil.move(old_backup_path, old_config_path) # type: ignore return else: print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") diff --git a/src/config/config.py b/src/config/config.py index ed433dfd1..fcbde9871 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -53,13 +53,13 @@ MMC_VERSION = "0.9.0-snapshot.2" def get_key_comment(toml_table, key): # 获取key的注释(如果有) - if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'): + if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): return toml_table.trivia.comment - if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict): + if hasattr(toml_table, "value") and isinstance(toml_table.value, dict): item = toml_table.value.get(key) - if item is not None and hasattr(item, 'trivia'): + if item is not None and hasattr(item, "trivia"): return item.trivia.comment - if hasattr(toml_table, 'keys'): + if hasattr(toml_table, "keys"): for k in toml_table.keys(): if isinstance(k, KeyType) and k.key == key: return k.trivia.comment @@ -78,16 +78,16 @@ def compare_dicts(new, old, path=None, logs=None): continue if key not in old: comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): - compare_dicts(new[key], old[key], path+[str(key)], logs) + compare_dicts(new[key], old[key], path + [str(key)], logs) # 删减项 for key in old: if key == "version": continue if key not in new: comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") return logs @@ -99,6 +99,7 @@ def get_value_by_path(d, path): return None return d + def set_value_by_path(d, path, value): for k in path[:-1]: if k not in d or not isinstance(d[k], dict): @@ -106,6 +107,7 @@ def set_value_by_path(d, path, value): d = d[k] d[path[-1]] = value + def compare_default_values(new, old, path=None, logs=None, changes=None): # 递归比较两个dict,找出默认值变化项 if path is None: @@ -119,12 +121,14 @@ def compare_default_values(new, old, path=None, logs=None, changes=None): continue if key in old: if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)): - compare_default_values(new[key], old[key], path+[str(key)], logs, changes) + compare_default_values(new[key], old[key], path + [str(key)], logs, changes) else: # 只要值发生变化就记录 if new[key] != old[key]: - logs.append(f"默认值变化: {'.'.join(path+[str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}") - changes.append((path+[str(key)], old[key], new[key])) + logs.append( + f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}" + ) + changes.append((path + [str(key)], old[key], new[key])) return logs, changes @@ -148,8 +152,8 @@ def update_config(): return None with open(toml_path, "r", encoding="utf-8") as f: doc = tomlkit.load(f) - if "inner" in doc and "version" in doc["inner"]: - return doc["inner"]["version"] + if "inner" in doc and "version" in doc["inner"]: # type: ignore + return doc["inner"]["version"] # type: ignore return None template_version = get_version_from_toml(template_path) @@ -186,7 +190,9 @@ def update_config(): old_value = get_value_by_path(old_config, path) if old_value == old_default: set_value_by_path(old_config, path, new_default) - logger.info(f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}") + logger.info( + f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" + ) else: logger.info("未检测到模板默认值变动") # 保存旧配置的变更(后续合并逻辑会用到 old_config) @@ -229,7 +235,9 @@ def update_config(): logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") return else: - logger.info(f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------") + logger.info( + f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" + ) else: logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") @@ -321,6 +329,7 @@ class Config(ConfigBase): debug: DebugConfig custom_prompt: CustomPromptConfig + def load_config(config_path: str) -> Config: """ 加载配置文件 diff --git a/src/individuality/not_using/offline_llm.py b/src/individuality/not_using/offline_llm.py index 83cb263c7..2bafb69aa 100644 --- a/src/individuality/not_using/offline_llm.py +++ b/src/individuality/not_using/offline_llm.py @@ -39,7 +39,7 @@ class LLMRequestOff: } # 发送请求到完整的 chat/completions 端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" + api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore logger.info(f"Request URL: {api_url}") # 记录请求的 URL max_retries = 3 @@ -89,7 +89,7 @@ class LLMRequestOff: } # 发送请求到完整的 chat/completions 端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" + api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore logger.info(f"Request URL: {api_url}") # 记录请求的 URL max_retries = 3 diff --git a/src/individuality/not_using/per_bf_gen.py b/src/individuality/not_using/per_bf_gen.py index 3b66d0551..aedbe00ee 100644 --- a/src/individuality/not_using/per_bf_gen.py +++ b/src/individuality/not_using/per_bf_gen.py @@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect: def __init__(self): self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.scenarios = [] - self.final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} - self.dimension_counts = {trait: 0 for trait in self.final_scores.keys()} + self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} + self.dimension_counts = {trait: 0 for trait in self.final_scores} # 为每个人格特质获取对应的场景 for trait in PERSONALITY_SCENES: @@ -119,8 +119,7 @@ class PersonalityEvaluatorDirect: # 构建维度描述 dimension_descriptions = [] for dim in dimensions: - desc = FACTOR_DESCRIPTIONS.get(dim, "") - if desc: + if desc := FACTOR_DESCRIPTIONS.get(dim, ""): dimension_descriptions.append(f"- {dim}:{desc}") dimensions_text = "\n".join(dimension_descriptions) diff --git a/src/main.py b/src/main.py index 3dc8c4c9a..dbd12f1a4 100644 --- a/src/main.py +++ b/src/main.py @@ -153,14 +153,14 @@ class MainSystem: while True: await asyncio.sleep(global_config.memory.memory_build_interval) logger.info("正在进行记忆构建") - await self.hippocampus_manager.build_memory() + await self.hippocampus_manager.build_memory() # type: ignore async def forget_memory_task(self): """记忆遗忘任务""" while True: await asyncio.sleep(global_config.memory.forget_memory_interval) logger.info("[记忆遗忘] 开始遗忘记忆...") - await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) + await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore logger.info("[记忆遗忘] 记忆遗忘完成") async def consolidate_memory_task(self): @@ -168,7 +168,7 @@ class MainSystem: while True: await asyncio.sleep(global_config.memory.consolidate_memory_interval) logger.info("[记忆整合] 开始整合记忆...") - await self.hippocampus_manager.consolidate_memory() + await self.hippocampus_manager.consolidate_memory() # type: ignore logger.info("[记忆整合] 记忆整合完成") @staticmethod diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index b47785401..398b1f372 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -49,6 +49,9 @@ class ChatMood: chat_manager = get_chat_manager() self.chat_stream = chat_manager.get_stream(self.chat_id) + + if not self.chat_stream: + raise ValueError(f"Chat stream for chat_id {chat_id} not found") self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]" diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index a489a34d5..69b9e84d2 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -26,7 +26,7 @@ SEGMENT_CLEANUP_CONFIG = { "cleanup_interval_hours": 0.5, # 清理间隔(小时) } -MAX_MESSAGE_COUNT = 80 / global_config.relationship.relation_frequency +MAX_MESSAGE_COUNT = int(80 / global_config.relationship.relation_frequency) class RelationshipBuilder: diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index b8701839d..59e240811 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -61,7 +61,7 @@ __all__ = [ "ConfigField", # 工具函数 "ManifestValidator", - "ManifestGenerator", - "validate_plugin_manifest", - "generate_plugin_manifest", + # "ManifestGenerator", + # "validate_plugin_manifest", + # "generate_plugin_manifest", ] diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 97bee9908..c8b03a0a6 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -111,7 +111,7 @@ async def _send_to_target( is_head=True, is_emoji=(message_type == "emoji"), thinking_start_time=current_time, - reply_to = reply_to_platform_id + reply_to=reply_to_platform_id, ) # 发送消息 @@ -137,6 +137,7 @@ async def _send_to_target( async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]: + # sourcery skip: inline-variable, use-named-expression """查找要回复的消息 Args: @@ -184,14 +185,11 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR # 检查是否有 回复 字段 reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" - match = re.search(reply_pattern, translate_text) - if match: + if match := re.search(reply_pattern, translate_text): aaa = match.group(1) bbb = match.group(2) reply_person_id = get_person_info_manager().get_person_id(platform, bbb) - reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") - if not reply_person_name: - reply_person_name = aaa + reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa # 在内容前加上回复信息 translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1) @@ -206,9 +204,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR aaa = m.group(1) bbb = m.group(2) at_person_id = get_person_info_manager().get_person_id(platform, bbb) - at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") - if not at_person_name: - at_person_name = aaa + at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa new_content += f"@{at_person_name}" last_end = m.end() new_content += translate_text[last_end:] @@ -370,7 +366,14 @@ async def custom_to_stream( bool: 是否发送成功 """ return await _send_to_target( - message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log + message_type, + content, + stream_id, + display_message, + typing, + reply_to, + storage_message=storage_message, + show_log=show_log, ) @@ -396,7 +399,7 @@ async def text_to_group( """ stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message) + return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message) async def text_to_user( @@ -420,7 +423,7 @@ async def text_to_user( bool: 是否发送成功 """ stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message) + return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message) async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool: @@ -543,7 +546,9 @@ async def custom_to_group( bool: 是否发送成功 """ stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) + return await _send_to_target( + message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message + ) async def custom_to_user( @@ -571,7 +576,9 @@ async def custom_to_user( bool: 是否发送成功 """ stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) + return await _send_to_target( + message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message + ) async def custom_message( @@ -611,4 +618,6 @@ async def custom_message( await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好") """ stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group) - return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) + return await _send_to_target( + message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message + ) diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 2c559a2c7..74ab22e67 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -38,7 +38,7 @@ class BaseAction(ABC): chat_stream: ChatStream, log_prefix: str = "", plugin_config: Optional[dict] = None, - action_message: dict = None, + action_message: Optional[dict] = None, **kwargs, ): """初始化Action组件 @@ -63,7 +63,7 @@ class BaseAction(ABC): self.cycle_timers = cycle_timers self.thinking_id = thinking_id self.log_prefix = log_prefix - + # 保存插件配置 self.plugin_config = plugin_config or {} @@ -92,10 +92,10 @@ class BaseAction(ABC): self.chat_stream = chat_stream or kwargs.get("chat_stream") self.chat_id = self.chat_stream.stream_id self.platform = getattr(self.chat_stream, "platform", None) - + # 初始化基础信息(带类型注解) self.action_message = action_message - + self.group_id = None self.group_name = None self.user_id = None @@ -103,15 +103,17 @@ class BaseAction(ABC): self.is_group = False self.target_id = None self.has_action_message = False - + if self.action_message: self.has_action_message = True - + else: + self.action_message = {} + if self.has_action_message: if self.action_name != "no_reply": self.group_id = str(self.action_message.get("chat_info_group_id", None)) self.group_name = self.action_message.get("chat_info_group_name", None) - + self.user_id = str(self.action_message.get("user_id", None)) self.user_nickname = self.action_message.get("user_nickname", None) if self.group_id: @@ -132,8 +134,6 @@ class BaseAction(ABC): self.is_group = False self.target_id = self.user_id - - logger.debug(f"{self.log_prefix} Action组件初始化完成") logger.info( f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" @@ -199,7 +199,9 @@ class BaseAction(ABC): logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}") return False, f"等待新消息失败: {str(e)}" - async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool: + async def send_text( + self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False + ) -> bool: """发送文本消息 Args: @@ -299,7 +301,7 @@ class BaseAction(ABC): ) async def send_command( - self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True + self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True ) -> bool: """发送命令消息 diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 2c2ddf81e..caf68567b 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -135,7 +135,7 @@ class BaseCommand(ABC): ) async def send_command( - self, command_name: str, args: dict = None, display_message: str = "", storage_message: bool = True + self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True ) -> bool: """发送命令消息 diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index b152a1abc..917069e11 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -346,67 +346,67 @@ class ComponentRegistry: # === 状态管理方法 === - def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool: - # -------------------------------- NEED REFACTORING -------------------------------- - # -------------------------------- LOGIC ERROR ------------------------------------- - """启用组件,支持命名空间解析""" - # 首先尝试找到正确的命名空间化名称 - component_info = self.get_component_info(component_name, component_type) - if not component_info: - return False + # def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool: + # # -------------------------------- NEED REFACTORING -------------------------------- + # # -------------------------------- LOGIC ERROR ------------------------------------- + # """启用组件,支持命名空间解析""" + # # 首先尝试找到正确的命名空间化名称 + # component_info = self.get_component_info(component_name, component_type) + # if not component_info: + # return False - # 根据组件类型构造正确的命名空间化名称 - if component_info.component_type == ComponentType.ACTION: - namespaced_name = f"action.{component_name}" if "." not in component_name else component_name - elif component_info.component_type == ComponentType.COMMAND: - namespaced_name = f"command.{component_name}" if "." not in component_name else component_name - else: - namespaced_name = ( - f"{component_info.component_type.value}.{component_name}" - if "." not in component_name - else component_name - ) + # # 根据组件类型构造正确的命名空间化名称 + # if component_info.component_type == ComponentType.ACTION: + # namespaced_name = f"action.{component_name}" if "." not in component_name else component_name + # elif component_info.component_type == ComponentType.COMMAND: + # namespaced_name = f"command.{component_name}" if "." not in component_name else component_name + # else: + # namespaced_name = ( + # f"{component_info.component_type.value}.{component_name}" + # if "." not in component_name + # else component_name + # ) - if namespaced_name in self._components: - self._components[namespaced_name].enabled = True - # 如果是Action,更新默认动作集 - # ---- HERE ---- - # if isinstance(component_info, ActionInfo): - # self._action_descriptions[component_name] = component_info.description - logger.debug(f"已启用组件: {component_name} -> {namespaced_name}") - return True - return False + # if namespaced_name in self._components: + # self._components[namespaced_name].enabled = True + # # 如果是Action,更新默认动作集 + # # ---- HERE ---- + # # if isinstance(component_info, ActionInfo): + # # self._action_descriptions[component_name] = component_info.description + # logger.debug(f"已启用组件: {component_name} -> {namespaced_name}") + # return True + # return False - def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool: - # -------------------------------- NEED REFACTORING -------------------------------- - # -------------------------------- LOGIC ERROR ------------------------------------- - """禁用组件,支持命名空间解析""" - # 首先尝试找到正确的命名空间化名称 - component_info = self.get_component_info(component_name, component_type) - if not component_info: - return False + # def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool: + # # -------------------------------- NEED REFACTORING -------------------------------- + # # -------------------------------- LOGIC ERROR ------------------------------------- + # """禁用组件,支持命名空间解析""" + # # 首先尝试找到正确的命名空间化名称 + # component_info = self.get_component_info(component_name, component_type) + # if not component_info: + # return False - # 根据组件类型构造正确的命名空间化名称 - if component_info.component_type == ComponentType.ACTION: - namespaced_name = f"action.{component_name}" if "." not in component_name else component_name - elif component_info.component_type == ComponentType.COMMAND: - namespaced_name = f"command.{component_name}" if "." not in component_name else component_name - else: - namespaced_name = ( - f"{component_info.component_type.value}.{component_name}" - if "." not in component_name - else component_name - ) + # # 根据组件类型构造正确的命名空间化名称 + # if component_info.component_type == ComponentType.ACTION: + # namespaced_name = f"action.{component_name}" if "." not in component_name else component_name + # elif component_info.component_type == ComponentType.COMMAND: + # namespaced_name = f"command.{component_name}" if "." not in component_name else component_name + # else: + # namespaced_name = ( + # f"{component_info.component_type.value}.{component_name}" + # if "." not in component_name + # else component_name + # ) - if namespaced_name in self._components: - self._components[namespaced_name].enabled = False - # 如果是Action,从默认动作集中移除 - # ---- HERE ---- - # if component_name in self._action_descriptions: - # del self._action_descriptions[component_name] - logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}") - return True - return False + # if namespaced_name in self._components: + # self._components[namespaced_name].enabled = False + # # 如果是Action,从默认动作集中移除 + # # ---- HERE ---- + # # if component_name in self._action_descriptions: + # # del self._action_descriptions[component_name] + # logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}") + # return True + # return False def get_registry_stats(self) -> Dict[str, Any]: """获取注册中心统计信息""" diff --git a/src/plugin_system/core/dependency_manager.py b/src/plugin_system/core/dependency_manager.py index 4a995e028..266254e72 100644 --- a/src/plugin_system/core/dependency_manager.py +++ b/src/plugin_system/core/dependency_manager.py @@ -7,7 +7,7 @@ import subprocess import sys import importlib -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Any from src.common.logger import get_logger from src.plugin_system.base.component_types import PythonDependency @@ -176,7 +176,7 @@ class DependencyManager: logger.error(f"生成requirements文件失败: {str(e)}") return False - def get_install_summary(self) -> Dict[str, any]: + def get_install_summary(self) -> Dict[str, Any]: """获取安装摘要""" return { "install_log": self.install_log.copy(), diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index cff28cb99..b4050794f 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -197,29 +197,29 @@ class PluginManager: """获取所有启用的插件信息""" return list(component_registry.get_enabled_plugins().values()) - def enable_plugin(self, plugin_name: str) -> bool: - # -------------------------------- NEED REFACTORING -------------------------------- - """启用插件""" - if plugin_info := component_registry.get_plugin_info(plugin_name): - plugin_info.enabled = True - # 启用插件的所有组件 - for component in plugin_info.components: - component_registry.enable_component(component.name) - logger.debug(f"已启用插件: {plugin_name}") - return True - return False + # def enable_plugin(self, plugin_name: str) -> bool: + # # -------------------------------- NEED REFACTORING -------------------------------- + # """启用插件""" + # if plugin_info := component_registry.get_plugin_info(plugin_name): + # plugin_info.enabled = True + # # 启用插件的所有组件 + # for component in plugin_info.components: + # component_registry.enable_component(component.name) + # logger.debug(f"已启用插件: {plugin_name}") + # return True + # return False - def disable_plugin(self, plugin_name: str) -> bool: - # -------------------------------- NEED REFACTORING -------------------------------- - """禁用插件""" - if plugin_info := component_registry.get_plugin_info(plugin_name): - plugin_info.enabled = False - # 禁用插件的所有组件 - for component in plugin_info.components: - component_registry.disable_component(component.name) - logger.debug(f"已禁用插件: {plugin_name}") - return True - return False + # def disable_plugin(self, plugin_name: str) -> bool: + # # -------------------------------- NEED REFACTORING -------------------------------- + # """禁用插件""" + # if plugin_info := component_registry.get_plugin_info(plugin_name): + # plugin_info.enabled = False + # # 禁用插件的所有组件 + # for component in plugin_info.components: + # component_registry.disable_component(component.name) + # logger.debug(f"已禁用插件: {plugin_name}") + # return True + # return False def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]: """获取插件实例 diff --git a/src/tools/tool_can_use/compare_numbers_tool.py b/src/tools/tool_can_use/compare_numbers_tool.py index e73f6e79f..2930f8f4b 100644 --- a/src/tools/tool_can_use/compare_numbers_tool.py +++ b/src/tools/tool_can_use/compare_numbers_tool.py @@ -28,10 +28,10 @@ class CompareNumbersTool(BaseTool): Returns: dict: 工具执行结果 """ - try: - num1 = function_args.get("num1") - num2 = function_args.get("num2") + num1: int | float = function_args.get("num1") # type: ignore + num2: int | float = function_args.get("num2") # type: ignore + try: if num1 > num2: result = f"{num1} 大于 {num2}" elif num1 < num2: diff --git a/src/tools/tool_can_use/rename_person_tool.py b/src/tools/tool_can_use/rename_person_tool.py index 0651e0c2c..cfc6ef4b0 100644 --- a/src/tools/tool_can_use/rename_person_tool.py +++ b/src/tools/tool_can_use/rename_person_tool.py @@ -68,10 +68,10 @@ class RenamePersonTool(BaseTool): ) result = await person_info_manager.qv_person_name( person_id=person_id, - user_nickname=user_nickname, - user_cardname=user_cardname, - user_avatar=user_avatar, - request=request_context, + user_nickname=user_nickname, # type: ignore + user_cardname=user_cardname, # type: ignore + user_avatar=user_avatar, # type: ignore + request=request_context, # type: ignore ) # 3. 处理结果 From a83f8948e9f1edc60034276bfabc9f3731ec8fac Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 17 Jul 2025 00:15:58 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E5=9B=9E=E9=80=80utils=5Fmodel.py=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E6=9B=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/utils_model.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index b9a419c33..1077cfa09 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -255,11 +255,12 @@ class LLMRequest: if self.temp != 0.7: payload["temperature"] = self.temp - # 添加enable_thinking参数(仅在启用时添加) - if self.enable_thinking: - payload["enable_thinking"] = True - if self.thinking_budget != 4096: - payload["thinking_budget"] = self.thinking_budget + # 添加enable_thinking参数(如果不是默认值False) + if not self.enable_thinking: + payload["enable_thinking"] = False + + if self.thinking_budget != 4096: + payload["thinking_budget"] = self.thinking_budget if self.max_tokens: payload["max_tokens"] = self.max_tokens @@ -669,11 +670,12 @@ class LLMRequest: if self.temp != 0.7: payload["temperature"] = self.temp - # 添加enable_thinking参数(仅在启用时添加) - if self.enable_thinking: - payload["enable_thinking"] = True - if self.thinking_budget != 4096: - payload["thinking_budget"] = self.thinking_budget + # 添加enable_thinking参数(如果不是默认值False) + if not self.enable_thinking: + payload["enable_thinking"] = False + + if self.thinking_budget != 4096: + payload["thinking_budget"] = self.thinking_budget if self.max_tokens: payload["max_tokens"] = self.max_tokens From 696325cb576dfa3bb3ac8cfc7dfbc5a45fa64340 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 17 Jul 2025 00:28:14 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E7=BB=A7=E6=89=BF=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E6=80=BB=E5=9F=BA=E7=B1=BB=EF=BC=8C=E6=B3=A8=E9=87=8A=E6=9B=B4?= =?UTF-8?q?=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- changes.md | 5 +++- src/plugin_system/apis/plugin_register_api.py | 3 -- src/plugin_system/base/base_event_plugin.py | 28 ++++++++----------- src/plugin_system/base/base_plugin.py | 13 +++++++-- src/plugin_system/base/plugin_base.py | 10 ++----- 5 files changed, 30 insertions(+), 29 deletions(-) diff --git a/changes.md b/changes.md index 4986e4d60..1f53d7e50 100644 --- a/changes.md +++ b/changes.md @@ -20,4 +20,7 @@ - `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。 - `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。 - `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。 -4. 现在增加了参数类型检查,完善了对应注释 \ No newline at end of file +4. 现在增加了参数类型检查,完善了对应注释 +5. 现在插件抽象出了总基类 `PluginBase` + - 基于`Action`和`Command`的插件基类现在为`BasePlugin`,它继承自`PluginBase`,由`register_plugin`装饰器注册。 + - 基于`Event`的插件基类现在为`BaseEventPlugin`,它也继承自`PluginBase`,由`register_event_plugin`装饰器注册。 \ No newline at end of file diff --git a/src/plugin_system/apis/plugin_register_api.py b/src/plugin_system/apis/plugin_register_api.py index b3cc58450..7970f3421 100644 --- a/src/plugin_system/apis/plugin_register_api.py +++ b/src/plugin_system/apis/plugin_register_api.py @@ -34,7 +34,4 @@ def register_event_plugin(cls, *args, **kwargs): 用法: @register_event_plugin - class MyEventPlugin: - event_type = EventType.MESSAGE_RECEIVED - ... """ \ No newline at end of file diff --git a/src/plugin_system/base/base_event_plugin.py b/src/plugin_system/base/base_event_plugin.py index 2261fee26..859d43f06 100644 --- a/src/plugin_system/base/base_event_plugin.py +++ b/src/plugin_system/base/base_event_plugin.py @@ -1,18 +1,14 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod -class BaseEventsPlugin(ABC): +from .plugin_base import PluginBase +from src.common.logger import get_logger + + +class BaseEventPlugin(PluginBase): + """基于事件的插件基类 + + 所有事件类型的插件都应该继承这个基类 """ - 事件触发型插件基类 - - 所有事件触发型插件都应该继承这个基类而不是 BasePlugin - """ - - @property - @abstractmethod - def plugin_name(self) -> str: - return "" # 插件内部标识符(如 "hello_world_plugin") - - @property - @abstractmethod - def enable_plugin(self) -> bool: - return False \ No newline at end of file + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index fe79d8e9a..a93de5fab 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -7,10 +7,19 @@ from src.plugin_system.base.component_types import ComponentInfo logger = get_logger("base_plugin") + class BasePlugin(PluginBase): + """基于Action和Command的插件基类 + + 所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件: + - Action组件:处理聊天中的动作 + - Command组件:处理命令请求 + - 未来可扩展:Scheduler、Listener等 + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + @abstractmethod def get_plugin_components(self) -> List[tuple[ComponentInfo, Type]]: """获取插件包含的组件列表 @@ -21,7 +30,7 @@ class BasePlugin(PluginBase): List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...] """ raise NotImplementedError("Subclasses must implement this method") - + def register_plugin(self) -> bool: """注册插件及其所有组件""" from src.plugin_system.core.component_registry import component_registry diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py index ceb8dcb61..0b7f15d17 100644 --- a/src/plugin_system/base/plugin_base.py +++ b/src/plugin_system/base/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Type, Any, Union +from typing import Dict, List, Any, Union import os import inspect import toml @@ -10,7 +10,6 @@ import datetime from src.common.logger import get_logger from src.plugin_system.base.component_types import ( PluginInfo, - ComponentInfo, PythonDependency, ) from src.plugin_system.base.config_types import ConfigField @@ -20,12 +19,9 @@ logger = get_logger("plugin_base") class PluginBase(ABC): - """插件基类 + """插件总基类 - 所有插件都应该继承这个基类,一个插件可以包含多种组件: - - Action组件:处理聊天中的动作 - - Command组件:处理命令请求 - - 未来可扩展:Scheduler、Listener等 + 所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。 """ # 插件基本信息(子类必须定义)