34
.github/workflows/create-prerelease.yml
vendored
34
.github/workflows/create-prerelease.yml
vendored
@@ -1,34 +0,0 @@
|
||||
# 当代码推送到 master 分支时,自动创建一个 pre-release
|
||||
|
||||
name: Create Pre-release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
create-prerelease:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
# 获取所有提交历史,以便生成 release notes
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Generate tag name
|
||||
id: generate_tag
|
||||
run: echo "TAG_NAME=MoFox-prerelease-$(date -u +'%Y%m%d%H%M%S')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create Pre-release
|
||||
env:
|
||||
# 使用仓库自带的 GITHUB_TOKEN 进行认证
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
gh release create ${{ steps.generate_tag.outputs.TAG_NAME }} \
|
||||
--title "Pre-release ${{ steps.generate_tag.outputs.TAG_NAME }}" \
|
||||
--prerelease \
|
||||
--generate-notes
|
||||
94
EULA.md
94
EULA.md
@@ -1,94 +0,0 @@
|
||||
# **欢迎使用 MoFox_Bot (第三方修改版)!**
|
||||
|
||||
**版本:V2.1**
|
||||
**更新日期:2025年8月30日**
|
||||
|
||||
---
|
||||
|
||||
你好!感谢你选择 MoFox_Bot。在开始之前,请花几分钟时间阅读这份协议。我们用问答的形式,帮助你快速了解使用这个**第三方修改版**软件时,你的权利和责任。
|
||||
|
||||
**简单来说,你需要同意这份协议才能使用我们的软件。** 如果你是未成年人,请确保你的监护人也阅读并同意了哦。
|
||||
|
||||
---
|
||||
|
||||
### **1. 这个软件和原版有什么关系?**
|
||||
|
||||
这是一个非常重要的问题!
|
||||
|
||||
* **第三方修改版**:首先,你需要清楚地知道,MoFox_Bot 是一个基于[MaiCore](https://mai-mai.org/)开源项目的**第三方修改版**。我们(MoFox_Bot 团队)与原始项目的开发者**没有任何关联**。
|
||||
* **独立维护**:我们独立负责这个修改版的维护、开发和更新。因此,原始项目的开发者**不会**为 MoFox_Bot 提供任何技术支持,也**不会**对因使用本修改版产生的任何问题负责。
|
||||
* **责任划分**:我们只对我们修改、添加的功能负责。对于原始代码中可能存在的任何问题,我们不承担责任。
|
||||
|
||||
---
|
||||
|
||||
### **2. 这个软件是免费和开源的吗?**
|
||||
|
||||
**是的,核心代码是开源的!**
|
||||
|
||||
* **遵循开源协议**:本项目继承了原始项目的 **GPLv3 开源协议**。这意味着你可以自由地使用、复制、研究、修改和重新分发它。
|
||||
* **你的义务**:当你修改或分发这个软件时,你同样需要遵守 GPLv3 协议的要求,确保你的衍生作品也是开源的。你可以在项目根目录找到 `LICENSE` 文件来了解更多细节。
|
||||
* **包含第三方代码**:请注意,项目中可能还包含了其他第三方库或组件,它们各自有独立的开源许可证。你在使用时需要同时遵守这些许可证。
|
||||
|
||||
---
|
||||
|
||||
### **3. 我的个人数据是如何被处理的?**
|
||||
|
||||
你的隐私对我们很重要。了解数据如何流动,能帮助你更好地保护自己。
|
||||
|
||||
* **工作流程**:当你与机器人互动时,你的**输入内容**(比如文字、指令)、**配置信息**以及机器人**生成的回复**,会被发送给第三方的 API 服务(例如 OpenAI、Google 等大语言模型提供商)以获得智能回复。
|
||||
* **你的明确授权**:一旦你开始使用,即表示你授权我们利用你的数据进行以下操作:
|
||||
1. **调用外部 API**:这是机器人能与你对话的核心。
|
||||
2. **建立本地知识库与记忆**:为了让机器人更个性化、更懂你,软件会在**你自己的设备上**创建和存储知识库、记忆库和对话日志。**这些数据存储在本地,我们无法访问。**
|
||||
3. **记录本地日志**:为了方便排查可能出现的技术问题,软件会在你的设备上记录运行日志。
|
||||
* **第三方服务的风险**:我们无法控制第三方 API 提供商的服务质量、数据处理政策、稳定性或安全性。使用这些服务时,你同样受到该第三方服务条款和隐私政策的约束。我们建议你自行了解这些条款。
|
||||
|
||||
---
|
||||
|
||||
### **4. 关于强大的插件系统,我需要了解什么?**
|
||||
|
||||
MoFox_Bot 通过插件系统实现功能扩展,但这需要你承担相应的责任。
|
||||
|
||||
* **谁开发的插件?**:绝大多数插件是由**社区里的第三方开发者**创建和维护的,他们并**不是 MoFox_Bot 核心团队的成员**。
|
||||
* **责任完全自负**:插件的功能、质量、安全性和合法性**完全由其各自的开发者负责**。我们只提供了一个能让插件运行的技术平台,但**不对任何第三方插件的内容、行为或造成的后果承担任何责任**。
|
||||
* **你的使用风险**:使用任何第三方插件的风险**完全由你自行承担**。在安装和使用插件前,我们强烈建议你:
|
||||
* 仔细阅读并理解插件开发者提供的许可协议和说明文档。
|
||||
* **只从你完全信任的来源获取和安装插件**。
|
||||
* 自行评估插件的安全性、合法性及其对你数据隐私的影响。
|
||||
|
||||
---
|
||||
|
||||
### **5. 我在使用时,有哪些行为准则?**
|
||||
|
||||
请务必合法、合规地使用本软件。
|
||||
|
||||
* **禁止内容**:严禁输入、处理或传播任何违反你所在地法律法规的内容,包括但不限于:涉及国家秘密、商业机密、侵犯他人知识产权、个人隐私的内容,以及任何形式的非法、骚扰、诽谤、淫秽信息。
|
||||
* **合法用途**:你承诺不会将本项目用于任何非法目的或活动,例如网络攻击、诈骗等。
|
||||
* **数据安全**:你对自己存储在本地知识库、记忆库和日志中的所有内容的合法性负全部责任。
|
||||
* **插件规范**:不要使用任何已知包含恶意代码、安全漏洞或违法内容的插件。
|
||||
|
||||
**你将对自己使用本项目(包括所有第三方插件)的全部行为及其产生的一切后果,承担完全的法律责任。**
|
||||
|
||||
---
|
||||
|
||||
### **6. 免责声明(非常重要!)**
|
||||
|
||||
* **“按原样”提供**:本项目是“按原样”提供的,我们**不提供任何形式的明示或暗示的担保**,包括但不限于对适销性、特定用途适用性和不侵权的保证。
|
||||
* **AI 回复的立场**:机器人的所有回复均由第三方大语言模型生成,其观点和信息**不代表 MoFox_Bot 团队的立场**。我们不对其准确性、完整性或可靠性负责。
|
||||
* **无责任声明**:在任何情况下,MoFox_Bot 团队均不对因使用或无法使用本项目(特别是第三方插件)而导致的任何直接、间接、偶然、特殊或后果性的损害(包括但不限于数据丢失、利润损失、业务中断)承担责任。
|
||||
* **插件支持**:所有第三方插件的技术支持、功能更新和 bug 修复,都应**直接联系相应的插件开发者**。
|
||||
|
||||
---
|
||||
|
||||
### **7. 其他条款**
|
||||
|
||||
* **协议的修改**:我们保留随时修改本协议的权利。修改后的协议将在新版本发布时生效。我们建议你定期检查以获取最新版本。继续使用本项目即表示你接受修订后的协议。
|
||||
* **最终解释权**:在法律允许的范围内,MoFox_Bot 团队保留对本协议的最终解释权。
|
||||
* **适用法律**:本协议的订立、执行和解释及争议的解决均应适用中国法律。
|
||||
|
||||
---
|
||||
|
||||
### **风险提示(请再次确认你已理解!)**
|
||||
|
||||
* **隐私风险**:你的对话数据会被发送到不受我们控制的第三方 API。请**绝对不要**在对话中包含任何个人身份信息、财务信息、密码或其他敏感数据。
|
||||
* **精神健康风险**:AI 机器人只是一个程序,无法提供真正的情感支持或专业的心理建议。如果遇到任何心理困扰,请务必寻求专业人士的帮助(例如,全国心理援助热线:12355)。
|
||||
* **插件风险**:这是最大的风险之一。第三方插件可能带来严重的安全漏洞、系统不稳定、性能下降甚至隐私数据泄露的风险。请务必谨慎选择和使用,并为自己的选择承担全部后果。
|
||||
4
bot.py
4
bot.py
@@ -229,10 +229,10 @@ if __name__ == "__main__":
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 执行初始化和任务调度
|
||||
loop.run_until_complete(main_system.initialize())
|
||||
# 异步初始化数据库表结构
|
||||
loop.run_until_complete(maibot.initialize_database_async())
|
||||
# 执行初始化和任务调度
|
||||
loop.run_until_complete(main_system.initialize())
|
||||
initialize_lpmm_knowledge()
|
||||
# Schedule tasks returns a future that runs forever.
|
||||
# We can run console_input_loop concurrently.
|
||||
|
||||
@@ -10,7 +10,6 @@ from src.plugin_system import (
|
||||
ConfigField,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
from .qq_emoji_list import qq_face
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
@@ -125,31 +124,25 @@ class SetEmojiLikeAction(BaseAction):
|
||||
|
||||
try:
|
||||
# 使用适配器API发送贴表情命令
|
||||
response = await send_api.adapter_command_to_stream(
|
||||
action="set_msg_emoji_like",
|
||||
params={"message_id": message_id, "emoji_id": emoji_id, "set": set_like},
|
||||
stream_id=self.chat_stream.stream_id if self.chat_stream else None,
|
||||
timeout=30.0,
|
||||
storage_message=False,
|
||||
success = await self.send_command(
|
||||
command_name="set_emoji_like", args={"message_id": message_id, "emoji_id": emoji_id, "set": set_like}, storage_message=False
|
||||
)
|
||||
|
||||
if response["status"] == "ok":
|
||||
logger.info(f"设置表情回应成功: {response}")
|
||||
if success:
|
||||
logger.info("设置表情回应成功")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了set_emoji_like动作,{emoji_input},设置表情回应: {emoji_id}, 是否设置: {set_like}",
|
||||
action_done=True,
|
||||
)
|
||||
return True, f"成功设置表情回应: {response.get('message', '成功')}"
|
||||
return True, "成功设置表情回应"
|
||||
else:
|
||||
error_msg = response.get("message", "未知错误")
|
||||
logger.error(f"设置表情回应失败: {error_msg}")
|
||||
logger.error("设置表情回应失败")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败: {error_msg}",
|
||||
action_prompt_display=f"执行了set_emoji_like动作:{self.action_name},失败",
|
||||
action_done=False,
|
||||
)
|
||||
return False, f"设置表情回应失败: {error_msg}"
|
||||
return False, "设置表情回应失败"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置表情回应失败: {e}")
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import glob
|
||||
import orjson
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
from json_repair import repair_json
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
@@ -37,13 +38,30 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
||||
file_lock = Lock()
|
||||
|
||||
# --- 缓存清理 ---
|
||||
|
||||
def clear_cache():
|
||||
"""清理 lpmm_learning_tool.py 生成的缓存文件"""
|
||||
logger.info("--- 开始清理缓存 ---")
|
||||
if os.path.exists(TEMP_DIR):
|
||||
try:
|
||||
shutil.rmtree(TEMP_DIR)
|
||||
logger.info(f"成功删除缓存目录: {TEMP_DIR}")
|
||||
except OSError as e:
|
||||
logger.error(f"删除缓存时出错: {e}")
|
||||
else:
|
||||
logger.info("缓存目录不存在,无需清理。")
|
||||
logger.info("--- 缓存清理完成 ---")
|
||||
|
||||
# --- 模块一:数据预处理 ---
|
||||
|
||||
|
||||
def process_text_file(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
return [p.strip() for p in raw.split("\n\n") if p.strip()]
|
||||
|
||||
|
||||
def preprocess_raw_data():
|
||||
logger.info("--- 步骤 1: 开始数据预处理 ---")
|
||||
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
||||
@@ -51,7 +69,7 @@ def preprocess_raw_data():
|
||||
if not raw_files:
|
||||
logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件")
|
||||
return []
|
||||
|
||||
|
||||
all_paragraphs = []
|
||||
for file in raw_files:
|
||||
logger.info(f"正在处理文件: {file.name}")
|
||||
@@ -62,8 +80,57 @@ def preprocess_raw_data():
|
||||
logger.info("--- 数据预处理完成 ---")
|
||||
return unique_paragraphs
|
||||
|
||||
|
||||
# --- 模块二:信息提取 ---
|
||||
|
||||
|
||||
def _parse_and_repair_json(json_string: str) -> Optional[dict]:
|
||||
"""
|
||||
尝试解析JSON字符串,如果失败则尝试修复并重新解析。
|
||||
|
||||
该函数首先会清理字符串,去除常见的Markdown代码块标记,
|
||||
然后尝试直接解析。如果解析失败,它会调用 `repair_json`
|
||||
进行修复,并再次尝试解析。
|
||||
|
||||
Args:
|
||||
json_string: 从LLM获取的、可能格式不正确的JSON字符串。
|
||||
|
||||
Returns:
|
||||
解析后的字典。如果最终无法解析,则返回 None,并记录详细错误日志。
|
||||
"""
|
||||
if not isinstance(json_string, str):
|
||||
logger.error(f"输入内容非字符串,无法解析: {type(json_string)}")
|
||||
return None
|
||||
|
||||
# 1. 预处理:去除常见的多余字符,如Markdown代码块标记
|
||||
cleaned_string = json_string.strip()
|
||||
if cleaned_string.startswith("```json"):
|
||||
cleaned_string = cleaned_string[7:].strip()
|
||||
elif cleaned_string.startswith("```"):
|
||||
cleaned_string = cleaned_string[3:].strip()
|
||||
|
||||
if cleaned_string.endswith("```"):
|
||||
cleaned_string = cleaned_string[:-3].strip()
|
||||
|
||||
# 2. 性能优化:乐观地尝试直接解析
|
||||
try:
|
||||
return orjson.loads(cleaned_string)
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning("直接解析JSON失败,将尝试修复...")
|
||||
|
||||
# 3. 修复与最终解析
|
||||
repaired_json_str = ""
|
||||
try:
|
||||
repaired_json_str = repair_json(cleaned_string)
|
||||
return orjson.loads(repaired_json_str)
|
||||
except Exception as e:
|
||||
# 4. 增强错误处理:记录详细的失败信息
|
||||
logger.error(f"修复并解析JSON后依然失败: {e}")
|
||||
logger.error(f"原始字符串 (清理后): {cleaned_string}")
|
||||
logger.error(f"修复后尝试解析的字符串: {repaired_json_str}")
|
||||
return None
|
||||
|
||||
|
||||
def get_extraction_prompt(paragraph: str) -> str:
|
||||
return f"""
|
||||
请从以下段落中提取关键信息。你需要提取两种类型的信息:
|
||||
@@ -82,6 +149,7 @@ def get_extraction_prompt(paragraph: str) -> str:
|
||||
---
|
||||
"""
|
||||
|
||||
|
||||
async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
||||
with file_lock:
|
||||
@@ -93,11 +161,20 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
prompt = get_extraction_prompt(paragraph)
|
||||
content = None
|
||||
try:
|
||||
content, (_, _, _) = await llm_api.generate_response_async(prompt)
|
||||
extracted_data = orjson.loads(content)
|
||||
|
||||
# 改进点:调用封装好的函数处理JSON解析和修复
|
||||
extracted_data = _parse_and_repair_json(content)
|
||||
|
||||
if extracted_data is None:
|
||||
# 如果解析失败,抛出异常以触发统一的错误处理逻辑
|
||||
raise ValueError("无法从LLM输出中解析有效的JSON数据")
|
||||
|
||||
doc_item = {
|
||||
"idx": pg_hash, "passage": paragraph,
|
||||
"idx": pg_hash,
|
||||
"passage": paragraph,
|
||||
"extracted_entities": extracted_data.get("entities", []),
|
||||
"extracted_triples": extracted_data.get("triples", []),
|
||||
}
|
||||
@@ -107,27 +184,45 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
return doc_item, None
|
||||
except Exception as e:
|
||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||
if content:
|
||||
logger.error(f"导致解析失败的原始输出: {content}")
|
||||
return None, pg_hash
|
||||
|
||||
|
||||
def extract_info_sync(pg_hash, paragraph, llm_api):
|
||||
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
||||
|
||||
|
||||
def extract_information(paragraphs_dict, model_set):
|
||||
logger.info("--- 步骤 2: 开始信息提取 ---")
|
||||
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
|
||||
llm_api = LLMRequest(model_set=model_set)
|
||||
failed_hashes, open_ie_docs = [], []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()}
|
||||
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress:
|
||||
f_to_hash = {
|
||||
executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()
|
||||
}
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
||||
for future in as_completed(f_to_hash):
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash: failed_hashes.append(failed_hash)
|
||||
elif doc_item: open_ie_docs.append(doc_item)
|
||||
if failed_hash:
|
||||
failed_hashes.append(failed_hash)
|
||||
elif doc_item:
|
||||
open_ie_docs.append(doc_item)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
if open_ie_docs:
|
||||
@@ -136,19 +231,22 @@ def extract_information(paragraphs_dict, model_set):
|
||||
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words)
|
||||
|
||||
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json")
|
||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(orjson.dumps(openie_obj._to_dict()))
|
||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||
|
||||
if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
|
||||
|
||||
if failed_hashes:
|
||||
logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
|
||||
logger.info("--- 信息提取完成 ---")
|
||||
|
||||
|
||||
# --- 模块三:数据导入 ---
|
||||
|
||||
|
||||
async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||
"""
|
||||
将OpenIE数据导入知识库(Embedding Store 和 KG)
|
||||
@@ -160,15 +258,19 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||
"""
|
||||
logger.info("--- 步骤 3: 开始数据导入 ---")
|
||||
embed_manager, kg_manager = EmbeddingManager(), KGManager()
|
||||
|
||||
|
||||
logger.info("正在加载现有的 Embedding 库...")
|
||||
try: embed_manager.load_from_file()
|
||||
except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"加载 Embedding 库失败: {e}。")
|
||||
|
||||
logger.info("正在加载现有的 KG...")
|
||||
try: kg_manager.load_from_file()
|
||||
except Exception as e: logger.warning(f"加载 KG 失败: {e}。")
|
||||
|
||||
try:
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"加载 KG 失败: {e}。")
|
||||
|
||||
try:
|
||||
if openie_obj:
|
||||
openie_data = openie_obj
|
||||
@@ -181,7 +283,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
|
||||
|
||||
new_raw_paragraphs, new_triple_list_data = {}, {}
|
||||
stored_embeds = embed_manager.stored_pg_hashes
|
||||
stored_kgs = kg_manager.stored_paragraph_hashes
|
||||
@@ -190,7 +292,7 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||
if p_hash not in stored_embeds and p_hash not in stored_kgs:
|
||||
new_raw_paragraphs[p_hash] = raw_p
|
||||
new_triple_list_data[p_hash] = triple_list_data.get(p_hash, [])
|
||||
|
||||
|
||||
if not new_raw_paragraphs:
|
||||
logger.info("没有新的段落需要处理。")
|
||||
else:
|
||||
@@ -208,60 +310,68 @@ async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||
|
||||
logger.info("--- 数据导入完成 ---")
|
||||
|
||||
|
||||
def import_from_specific_file():
|
||||
"""从用户指定的 openie.json 文件导入数据"""
|
||||
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
||||
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"文件路径不存在: {file_path}")
|
||||
return
|
||||
|
||||
|
||||
if not file_path.endswith(".json"):
|
||||
logger.error("请输入一个有效的 .json 文件路径。")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
||||
openie_obj = OpenIE.load(filepath=file_path)
|
||||
openie_obj = OpenIE.load()
|
||||
asyncio.run(import_data(openie_obj=openie_obj))
|
||||
except Exception as e:
|
||||
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
||||
|
||||
|
||||
# --- 主函数 ---
|
||||
|
||||
|
||||
def main():
|
||||
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
||||
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
||||
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
|
||||
|
||||
|
||||
print("=== LPMM 知识库学习工具 ===")
|
||||
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
||||
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
|
||||
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
|
||||
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
||||
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
||||
print("6. [清理缓存] -> 删除所有已提取信息的缓存")
|
||||
print("0. [退出]")
|
||||
print("-" * 30)
|
||||
choice = input("请输入你的选择 (0-5): ").strip()
|
||||
|
||||
if choice == '1':
|
||||
if choice == "1":
|
||||
preprocess_raw_data()
|
||||
elif choice == '2':
|
||||
elif choice == "2":
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
elif choice == '3':
|
||||
if paragraphs:
|
||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
elif choice == "3":
|
||||
asyncio.run(import_data())
|
||||
elif choice == '4':
|
||||
elif choice == "4":
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs:
|
||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
asyncio.run(import_data())
|
||||
elif choice == '5':
|
||||
elif choice == "5":
|
||||
import_from_specific_file()
|
||||
elif choice == '0':
|
||||
elif choice == "6":
|
||||
clear_cache()
|
||||
elif choice == "0":
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("无效输入,请重新运行脚本。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -233,6 +233,5 @@ class MessageShield:
|
||||
|
||||
def create_default_shield() -> MessageShield:
|
||||
"""创建默认的消息加盾器"""
|
||||
from .config import default_config
|
||||
|
||||
return MessageShield()
|
||||
|
||||
@@ -206,13 +206,6 @@ class CycleProcessor:
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成")
|
||||
with Timer("规划器", cycle_timers):
|
||||
actions, _ = await self.action_planner.plan(mode=mode)
|
||||
|
||||
# 在这里添加日志,清晰地显示最终选择的动作
|
||||
if actions:
|
||||
chosen_actions = [a.get("action_type", "unknown") for a in actions]
|
||||
logger.info(f"{self.log_prefix} LLM最终选择的动作: {chosen_actions}")
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} LLM最终没有选择任何动作")
|
||||
|
||||
async def execute_action(action_info):
|
||||
"""执行单个动作的通用函数"""
|
||||
@@ -267,6 +260,7 @@ class CycleProcessor:
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type="chat.replyer",
|
||||
from_plugin=False,
|
||||
read_mark=action_info.get("action_message", {}).get("time", 0.0),
|
||||
)
|
||||
if not success or not response_set:
|
||||
logger.info(
|
||||
|
||||
@@ -10,7 +10,6 @@ from src.config.config import global_config
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager, SleepState
|
||||
from src.plugin_system.apis import message_api
|
||||
|
||||
from .hfc_context import HfcContext
|
||||
from .energy_manager import EnergyManager
|
||||
@@ -40,6 +39,7 @@ class HeartFChatting:
|
||||
"""
|
||||
self.context = HfcContext(chat_id)
|
||||
self.context.new_message_queue = asyncio.Queue()
|
||||
self._processing_lock = asyncio.Lock()
|
||||
|
||||
self.cycle_tracker = CycleTracker(self.context)
|
||||
self.response_handler = ResponseHandler(self.context)
|
||||
@@ -358,130 +358,130 @@ class HeartFChatting:
|
||||
- FOCUS模式:直接处理所有消息并检查退出条件
|
||||
- NORMAL模式:检查进入FOCUS模式的条件,并通过normal_mode_handler处理消息
|
||||
"""
|
||||
# --- 核心状态更新 ---
|
||||
await self.sleep_manager.update_sleep_state(self.wakeup_manager)
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
is_sleeping = current_sleep_state == SleepState.SLEEPING
|
||||
is_in_insomnia = current_sleep_state == SleepState.INSOMNIA
|
||||
async with self._processing_lock:
|
||||
# --- 核心状态更新 ---
|
||||
await self.sleep_manager.update_sleep_state(self.wakeup_manager)
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
is_sleeping = current_sleep_state == SleepState.SLEEPING
|
||||
is_in_insomnia = current_sleep_state == SleepState.INSOMNIA
|
||||
|
||||
# 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收
|
||||
filter_command_flag = not (is_sleeping or is_in_insomnia)
|
||||
# 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收
|
||||
filter_command_flag = not (is_sleeping or is_in_insomnia)
|
||||
|
||||
# 从队列中获取所有待处理的新消息
|
||||
recent_messages = []
|
||||
while not self.context.new_message_queue.empty():
|
||||
recent_messages.append(await self.context.new_message_queue.get())
|
||||
# 从队列中获取所有待处理的新消息
|
||||
recent_messages = []
|
||||
while not self.context.new_message_queue.empty():
|
||||
recent_messages.append(await self.context.new_message_queue.get())
|
||||
|
||||
has_new_messages = bool(recent_messages)
|
||||
new_message_count = len(recent_messages)
|
||||
has_new_messages = bool(recent_messages)
|
||||
new_message_count = len(recent_messages)
|
||||
|
||||
# 只有在有新消息时才进行思考循环处理
|
||||
if has_new_messages:
|
||||
self.context.last_message_time = time.time()
|
||||
self.context.last_read_time = time.time()
|
||||
# 只有在有新消息时才进行思考循环处理
|
||||
if has_new_messages:
|
||||
self.context.last_message_time = time.time()
|
||||
self.context.last_read_time = time.time()
|
||||
|
||||
# --- 专注模式安静群组检查 ---
|
||||
quiet_groups = global_config.chat.focus_mode_quiet_groups
|
||||
if quiet_groups and self.context.chat_stream:
|
||||
is_group_chat = self.context.chat_stream.group_info is not None
|
||||
if is_group_chat:
|
||||
try:
|
||||
platform = self.context.chat_stream.platform
|
||||
group_id = self.context.chat_stream.group_info.group_id
|
||||
|
||||
# 兼容不同QQ适配器的平台名称
|
||||
is_qq_platform = platform in ["qq", "napcat"]
|
||||
|
||||
current_chat_identifier = f"{platform}:{group_id}"
|
||||
config_identifier_for_qq = f"qq:{group_id}"
|
||||
|
||||
is_in_quiet_list = (current_chat_identifier in quiet_groups or
|
||||
(is_qq_platform and config_identifier_for_qq in quiet_groups))
|
||||
|
||||
if is_in_quiet_list:
|
||||
is_mentioned_in_batch = False
|
||||
for msg in recent_messages:
|
||||
if msg.get("is_mentioned"):
|
||||
is_mentioned_in_batch = True
|
||||
break
|
||||
# --- 专注模式安静群组检查 ---
|
||||
quiet_groups = global_config.chat.focus_mode_quiet_groups
|
||||
if quiet_groups and self.context.chat_stream:
|
||||
is_group_chat = self.context.chat_stream.group_info is not None
|
||||
if is_group_chat:
|
||||
try:
|
||||
platform = self.context.chat_stream.platform
|
||||
group_id = self.context.chat_stream.group_info.group_id
|
||||
|
||||
if not is_mentioned_in_batch:
|
||||
logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。")
|
||||
return True # 消耗消息但不做回复
|
||||
# 兼容不同QQ适配器的平台名称
|
||||
is_qq_platform = platform in ["qq", "napcat"]
|
||||
|
||||
current_chat_identifier = f"{platform}:{group_id}"
|
||||
config_identifier_for_qq = f"qq:{group_id}"
|
||||
|
||||
is_in_quiet_list = (current_chat_identifier in quiet_groups or
|
||||
(is_qq_platform and config_identifier_for_qq in quiet_groups))
|
||||
|
||||
if is_in_quiet_list:
|
||||
is_mentioned_in_batch = False
|
||||
for msg in recent_messages:
|
||||
if msg.get("is_mentioned"):
|
||||
is_mentioned_in_batch = True
|
||||
break
|
||||
|
||||
if not is_mentioned_in_batch:
|
||||
logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。")
|
||||
return True # 消耗消息但不做回复
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}")
|
||||
|
||||
# 处理唤醒度逻辑
|
||||
if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]:
|
||||
self._handle_wakeup_messages(recent_messages)
|
||||
|
||||
# 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
|
||||
if current_sleep_state == SleepState.SLEEPING:
|
||||
# 只有在纯粹的 SLEEPING 状态下才跳过消息处理
|
||||
return True
|
||||
|
||||
if current_sleep_state == SleepState.WOKEN_UP:
|
||||
logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。")
|
||||
|
||||
# 根据聊天模式处理新消息
|
||||
should_process, interest_value = await self._should_process_messages(recent_messages)
|
||||
if not should_process:
|
||||
# 消息数量不足或兴趣不够,等待
|
||||
await asyncio.sleep(0.5)
|
||||
return True # Skip rest of the logic for this iteration
|
||||
|
||||
# Messages should be processed
|
||||
action_type = await self.cycle_processor.observe(interest_value=interest_value)
|
||||
|
||||
# 尝试触发表达学习
|
||||
if self.context.expression_learner:
|
||||
try:
|
||||
await self.context.expression_learner.trigger_learning_for_chat()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}")
|
||||
logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}")
|
||||
|
||||
# 处理唤醒度逻辑
|
||||
if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]:
|
||||
self._handle_wakeup_messages(recent_messages)
|
||||
# 管理no_reply计数器
|
||||
if action_type != "no_reply":
|
||||
self.recent_interest_records.clear()
|
||||
self.context.no_reply_consecutive = 0
|
||||
logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
||||
else: # action_type == "no_reply"
|
||||
self.context.no_reply_consecutive += 1
|
||||
self._determine_form_type()
|
||||
|
||||
# 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP
|
||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||
# 在一轮动作执行完毕后,增加睡眠压力
|
||||
if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system:
|
||||
if action_type not in ["no_reply", "no_action"]:
|
||||
self.context.energy_manager.increase_sleep_pressure()
|
||||
|
||||
if current_sleep_state == SleepState.SLEEPING:
|
||||
# 只有在纯粹的 SLEEPING 状态下才跳过消息处理
|
||||
return True
|
||||
|
||||
if current_sleep_state == SleepState.WOKEN_UP:
|
||||
logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。")
|
||||
|
||||
# 根据聊天模式处理新消息
|
||||
should_process, interest_value = await self._should_process_messages(recent_messages)
|
||||
if not should_process:
|
||||
# 消息数量不足或兴趣不够,等待
|
||||
await asyncio.sleep(0.5)
|
||||
return True # Skip rest of the logic for this iteration
|
||||
|
||||
# Messages should be processed
|
||||
action_type = await self.cycle_processor.observe(interest_value=interest_value)
|
||||
|
||||
# 尝试触发表达学习
|
||||
if self.context.expression_learner:
|
||||
try:
|
||||
await self.context.expression_learner.trigger_learning_for_chat()
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}")
|
||||
|
||||
# 管理no_reply计数器
|
||||
if action_type != "no_reply":
|
||||
self.recent_interest_records.clear()
|
||||
self.context.no_reply_consecutive = 0
|
||||
logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
||||
else: # action_type == "no_reply"
|
||||
self.context.no_reply_consecutive += 1
|
||||
self._determine_form_type()
|
||||
|
||||
# 在一轮动作执行完毕后,增加睡眠压力
|
||||
if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system:
|
||||
if action_type not in ["no_reply", "no_action"]:
|
||||
self.context.energy_manager.increase_sleep_pressure()
|
||||
|
||||
# 如果成功观察,增加能量值并重置累积兴趣值
|
||||
self.context.energy_value += 1 / global_config.chat.focus_value
|
||||
# 重置累积兴趣值,因为消息已经被成功处理
|
||||
self.context.breaking_accumulated_interest = 0.0
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值"
|
||||
)
|
||||
|
||||
# 更新上一帧的睡眠状态
|
||||
self.context.was_sleeping = is_sleeping
|
||||
|
||||
# --- 重新入睡逻辑 ---
|
||||
# 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡
|
||||
if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages:
|
||||
re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60
|
||||
# 使用 last_message_time 来判断空闲时间
|
||||
if time.time() - self.context.last_message_time > re_sleep_delay:
|
||||
# 如果成功观察,增加能量值并重置累积兴趣值
|
||||
self.context.energy_value += 1 / global_config.chat.focus_value
|
||||
# 重置累积兴趣值,因为消息已经被成功处理
|
||||
self.context.breaking_accumulated_interest = 0.0
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。"
|
||||
f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值"
|
||||
)
|
||||
self.sleep_manager.reset_sleep_state_after_wakeup()
|
||||
|
||||
# 保存HFC上下文状态
|
||||
self.context.save_context_state()
|
||||
# 更新上一帧的睡眠状态
|
||||
self.context.was_sleeping = is_sleeping
|
||||
|
||||
return has_new_messages
|
||||
# --- 重新入睡逻辑 ---
|
||||
# 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡
|
||||
if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages:
|
||||
re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60
|
||||
# 使用 last_message_time 来判断空闲时间
|
||||
if time.time() - self.context.last_message_time > re_sleep_delay:
|
||||
logger.info(
|
||||
f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。"
|
||||
)
|
||||
self.sleep_manager.reset_sleep_state_after_wakeup()
|
||||
|
||||
# 保存HFC上下文状态
|
||||
self.context.save_context_state()
|
||||
return has_new_messages
|
||||
|
||||
def _handle_wakeup_messages(self, messages):
|
||||
"""
|
||||
|
||||
@@ -162,107 +162,33 @@ class ProactiveThinker:
|
||||
|
||||
news_block = "暂时没有获取到最新资讯。"
|
||||
if trigger_event.source != "reminder_system":
|
||||
# 升级决策模型
|
||||
should_search_prompt = f"""
|
||||
# 搜索决策
|
||||
|
||||
## 任务
|
||||
分析话题“{topic}”,判断它的展开更依赖于“外部信息”还是“内部信息”,并决定是否需要进行网络搜索。
|
||||
|
||||
## 判断原则
|
||||
- **需要搜索 (SEARCH)**:当话题的有效讨论**必须**依赖于现实世界的、客观的、可被检索的外部信息时。这包括但不限于:
|
||||
- 新闻时事、公共事件
|
||||
- 专业知识、科学概念
|
||||
- 天气、股价等实时数据
|
||||
- 对具体实体(如电影、书籍、地点)的客观描述查询
|
||||
|
||||
- **无需搜索 (SKIP)**:当话题的展开主要依赖于**已有的对话上下文、个人情感、主观体验或社交互动**时。这包括但不限于:
|
||||
- 延续之前的对话、追问细节
|
||||
- 表达关心、问候或个人感受
|
||||
- 分享主观看法或经历
|
||||
- 纯粹的社交性互动
|
||||
|
||||
## 你的决策
|
||||
根据以上原则,对“{topic}”这个话题进行分析,并严格输出`SEARCH`或`SKIP`。
|
||||
"""
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
decision_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner,
|
||||
request_type="planner"
|
||||
)
|
||||
|
||||
decision, _ = await decision_llm.generate_response_async(prompt=should_search_prompt)
|
||||
|
||||
if "SEARCH" in decision:
|
||||
try:
|
||||
if topic and topic.strip():
|
||||
web_search_tool = tool_api.get_tool_instance("web_search")
|
||||
if web_search_tool:
|
||||
try:
|
||||
search_result_dict = await web_search_tool.execute(
|
||||
function_args={"query": topic, "max_results": 10}
|
||||
)
|
||||
if search_result_dict and not search_result_dict.get("error"):
|
||||
news_block = search_result_dict.get("content", "未能提取有效资讯。")
|
||||
elif search_result_dict:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 网络搜索执行失败: {e}")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 主题为空,跳过网络搜索。")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}")
|
||||
message_list = await get_raw_msg_before_timestamp_with_chat(
|
||||
try:
|
||||
web_search_tool = tool_api.get_tool_instance("web_search")
|
||||
if web_search_tool:
|
||||
try:
|
||||
search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10})
|
||||
except TypeError:
|
||||
try:
|
||||
search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10})
|
||||
except TypeError:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索工具参数不匹配,跳过搜索")
|
||||
news_block = "跳过网络搜索。"
|
||||
search_result_dict = None
|
||||
|
||||
if search_result_dict and not search_result_dict.get("error"):
|
||||
news_block = search_result_dict.get("content", "未能提取有效资讯。")
|
||||
elif search_result_dict:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}")
|
||||
message_list = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.context.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.3),
|
||||
)
|
||||
chat_context_block, _ = await build_readable_messages_with_id(messages=message_list)
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
confirmation_prompt = f"""# 主动回复二次确认
|
||||
|
||||
## 基本信息
|
||||
你的名字是{bot_name},准备主动发起关于"{topic}"的话题。
|
||||
|
||||
## 最近的聊天内容
|
||||
{chat_context_block}
|
||||
|
||||
## 合理判断标准
|
||||
请检查以下条件,如果**所有条件都合理**就可以回复:
|
||||
|
||||
1. **回应检查**:检查你({bot_name})发送的最后一条消息之后,是否有其他人发言。如果没有,则大概率应该保持沉默。
|
||||
2. **话题补充**:只有当你认为准备发起的话题是对上一条无人回应消息的**有价值的补充**时,才可以在上一条消息无人回应的情况下继续发言。
|
||||
3. **时间合理性**:当前时间是否在深夜(凌晨2点-6点)这种不适合主动聊天的时段?
|
||||
4. **内容价值**:这个话题"{topic}"是否有意义,不是完全无关紧要的内容?
|
||||
5. **重复避免**:你准备说的话题是否与你自己的上一条消息明显重复?
|
||||
6. **自然性**:在当前上下文中主动提起这个话题是否自然合理?
|
||||
|
||||
## 输出要求
|
||||
如果判断应该跳过(比如上一条消息无人回应、深夜时段、无意义话题、重复内容),输出:SKIP_PROACTIVE_REPLY
|
||||
其他情况都应该输出:PROCEED_TO_REPLY
|
||||
|
||||
请严格按照上述格式输出,不要添加任何解释。"""
|
||||
|
||||
planner_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner,
|
||||
request_type="planner"
|
||||
)
|
||||
|
||||
confirmation_result, _ = await planner_llm.generate_response_async(prompt=confirmation_prompt)
|
||||
|
||||
if not confirmation_result or "SKIP_PROACTIVE_REPLY" in confirmation_result:
|
||||
logger.info(f"{self.context.log_prefix} 决策模型二次确认决定跳过主动回复")
|
||||
return
|
||||
|
||||
chat_context_block, _ = await build_readable_messages_with_id(messages=message_list)
|
||||
bot_name = global_config.bot.nickname
|
||||
personality = global_config.personality
|
||||
identity_block = (
|
||||
|
||||
@@ -4,7 +4,7 @@ import orjson
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple, Coroutine
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
|
||||
@@ -24,7 +24,7 @@ class SubHeartflow:
|
||||
self.subheartflow_id = subheartflow_id
|
||||
self.chat_id = subheartflow_id
|
||||
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
self.is_group_chat, self.chat_target_info = (None, None)
|
||||
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
|
||||
# focus模式退出冷却时间管理
|
||||
@@ -38,4 +38,5 @@ class SubHeartflow:
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化方法,创建兴趣流并确定聊天类型"""
|
||||
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_id)
|
||||
await self.heart_fc_instance.start()
|
||||
|
||||
@@ -415,7 +415,6 @@ class ChatBot:
|
||||
return
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
@@ -427,11 +426,11 @@ class ChatBot:
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
message.raw_message, # type: ignore
|
||||
chat,
|
||||
user_info, # type: ignore
|
||||
# 过滤检查 (在消息处理之后进行)
|
||||
if _check_ban_words(
|
||||
message.processed_plain_text, chat, user_info # type: ignore
|
||||
) or _check_ban_regex(
|
||||
message.processed_plain_text, chat, user_info # type: ignore
|
||||
):
|
||||
return
|
||||
|
||||
|
||||
@@ -254,7 +254,7 @@ class ChatManager:
|
||||
model_instance = await _db_find_stream_async(stream_id)
|
||||
|
||||
if model_instance:
|
||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
@@ -382,7 +382,7 @@ class ChatManager:
|
||||
await _db_save_stream_async(stream_data_dict)
|
||||
stream.saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}", exc_info=True)
|
||||
|
||||
async def _save_all_streams(self):
|
||||
"""保存所有聊天流"""
|
||||
@@ -435,7 +435,7 @@ class ChatManager:
|
||||
if stream.stream_id in self.last_messages:
|
||||
stream.set_context(self.last_messages[stream.stream_id])
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||||
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
|
||||
|
||||
|
||||
chat_manager = None
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
import time
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any, TYPE_CHECKING
|
||||
from typing import Optional, Any
|
||||
|
||||
import urllib3
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
|
||||
@@ -72,7 +72,7 @@ class ActionModifier:
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
|
||||
# 获取聊天类型
|
||||
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
|
||||
is_group_chat, _ = await get_chat_type_and_target_info(self.chat_id)
|
||||
all_registered_actions = component_registry.get_components_by_type(ComponentType.ACTION)
|
||||
|
||||
chat_type_removals = []
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from . import planner_prompts
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
@@ -47,12 +46,12 @@ class PlanFilter:
|
||||
try:
|
||||
prompt, used_message_id_list = await self._build_prompt(plan)
|
||||
plan.llm_prompt = prompt
|
||||
logger.debug(f"墨墨在这里加了日志 -> LLM prompt: {prompt}")
|
||||
logger.info(f"规划器原始提示词: {prompt}")
|
||||
|
||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if llm_content:
|
||||
logger.debug(f"墨墨在这里加了日志 -> LLM a原始返回: {llm_content}")
|
||||
logger.info(f"规划器原始返回: {llm_content}")
|
||||
parsed_json = orjson.loads(repair_json(llm_content))
|
||||
logger.debug(f"墨墨在这里加了日志 -> 解析后的 JSON: {parsed_json}")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class PlanGenerator:
|
||||
Returns:
|
||||
Plan: 一个填充了初始上下文信息的 Plan 对象。
|
||||
"""
|
||||
_is_group_chat, chat_target_info_dict = get_chat_type_and_target_info(self.chat_id)
|
||||
_is_group_chat, chat_target_info_dict = await get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
target_info = None
|
||||
if chat_target_info_dict:
|
||||
|
||||
@@ -10,7 +10,7 @@ from src.chat.planner_actions.plan_filter import PlanFilter
|
||||
from src.chat.planner_actions.plan_generator import PlanGenerator
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
|
||||
import src.chat.planner_actions.planner_prompts #noga # noqa: F401
|
||||
# 导入提示词模块以确保其被初始化
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
@@ -83,12 +83,12 @@ def init_prompt():
|
||||
- {schedule_block}
|
||||
|
||||
## 历史记录
|
||||
### 当前群聊中的所有人的聊天记录:
|
||||
### {chat_context_type}中的所有人的聊天记录:
|
||||
{background_dialogue_prompt}
|
||||
|
||||
{cross_context_block}
|
||||
|
||||
### 当前群聊中正在与你对话的聊天记录
|
||||
### {chat_context_type}中正在与你对话的聊天记录
|
||||
{core_dialogue_prompt}
|
||||
|
||||
## 表达方式
|
||||
@@ -110,12 +110,11 @@ def init_prompt():
|
||||
|
||||
## 任务
|
||||
|
||||
*你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。*
|
||||
*你正在一个{chat_context_type}里聊天,你需要理解整个{chat_context_type}的聊天动态和话题走向,并做出自然的回应。*
|
||||
|
||||
### 核心任务
|
||||
- 你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与聊天,你可以参考他们的回复内容,但是你现在想回复{sender_name}的发言。
|
||||
|
||||
- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。
|
||||
- 你现在的主要任务是和 {sender_name} 聊天。
|
||||
- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。
|
||||
|
||||
## 规则
|
||||
{safety_guidelines_block}
|
||||
@@ -203,7 +202,9 @@ class DefaultReplyer:
|
||||
):
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self.is_group_chat: Optional[bool] = None
|
||||
self.chat_target_info: Optional[Dict[str, Any]] = None
|
||||
self._initialized = False
|
||||
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
self.memory_activator = MemoryActivator()
|
||||
@@ -236,6 +237,7 @@ class DefaultReplyer:
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
@@ -268,6 +270,7 @@ class DefaultReplyer:
|
||||
available_actions=available_actions,
|
||||
enable_tool=enable_tool,
|
||||
reply_message=reply_message,
|
||||
read_mark=read_mark,
|
||||
)
|
||||
|
||||
if not prompt:
|
||||
@@ -723,10 +726,8 @@ class DefaultReplyer:
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
core_dialogue_prompt = f"""--------------------------------
|
||||
这是你和{sender}的对话,你们正在交流中:
|
||||
core_dialogue_prompt = f"""
|
||||
{core_dialogue_prompt_str}
|
||||
--------------------------------
|
||||
"""
|
||||
|
||||
return core_dialogue_prompt, all_dialogue_prompt
|
||||
@@ -776,6 +777,12 @@ class DefaultReplyer:
|
||||
mai_think.target = target
|
||||
return mai_think
|
||||
|
||||
async def _async_init(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self._initialized = True
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_to: str,
|
||||
@@ -783,6 +790,7 @@ class DefaultReplyer:
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> str:
|
||||
"""
|
||||
构建回复器上下文
|
||||
@@ -800,10 +808,11 @@ class DefaultReplyer:
|
||||
"""
|
||||
if available_actions is None:
|
||||
available_actions = {}
|
||||
await self._async_init()
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
person_info_manager = get_person_info_manager()
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
is_group_chat = self.is_group_chat
|
||||
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
@@ -859,7 +868,7 @@ class DefaultReplyer:
|
||||
target = "(无消息内容)"
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
||||
person_id = person_info_manager.get_person_id(platform, reply_message.get("user_id")) if reply_message else None
|
||||
platform = chat_stream.platform
|
||||
|
||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||
@@ -891,7 +900,7 @@ class DefaultReplyer:
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
read_mark=read_mark,
|
||||
show_actions=True,
|
||||
)
|
||||
# 获取目标用户信息,用于s4u模式
|
||||
@@ -1081,6 +1090,7 @@ class DefaultReplyer:
|
||||
reply_target_block=reply_target_block,
|
||||
mood_prompt=mood_prompt,
|
||||
action_descriptions=action_descriptions,
|
||||
read_mark=read_mark,
|
||||
)
|
||||
|
||||
# 使用新的统一Prompt系统 - 使用正确的模板名称
|
||||
@@ -1127,9 +1137,10 @@ class DefaultReplyer:
|
||||
reply_to: str,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
await self._async_init()
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
is_group_chat = self.is_group_chat
|
||||
|
||||
if reply_message:
|
||||
sender = reply_message.get("sender")
|
||||
@@ -1167,7 +1178,7 @@ class DefaultReplyer:
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
read_mark=read_mark,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -520,6 +520,7 @@ async def _build_readable_messages_internal(
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
@@ -642,6 +643,10 @@ async def _build_readable_messages_internal(
|
||||
else:
|
||||
person_name = "某人"
|
||||
|
||||
# 在用户名后面添加 QQ 号, 但机器人本体不用
|
||||
if user_id != global_config.bot.qq_account:
|
||||
person_name = f"{person_name}({user_id})"
|
||||
|
||||
# 使用独立函数处理用户引用格式
|
||||
content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name)
|
||||
|
||||
@@ -726,11 +731,10 @@ async def _build_readable_messages_internal(
|
||||
"is_action": is_action,
|
||||
}
|
||||
continue
|
||||
|
||||
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
|
||||
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
|
||||
current_merge["content"].append(content)
|
||||
current_merge["end_time"] = timestamp # 更新最后消息时间
|
||||
current_merge["end_time"] = timestamp
|
||||
else:
|
||||
# 保存上一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
@@ -758,8 +762,14 @@ async def _build_readable_messages_internal(
|
||||
|
||||
# 4 & 5: 格式化为字符串
|
||||
output_lines = []
|
||||
read_mark_inserted = False
|
||||
|
||||
for _i, merged in enumerate(merged_messages):
|
||||
# 检查是否需要插入已读标记
|
||||
if read_mark > 0 and not read_mark_inserted and merged["start_time"] >= read_mark:
|
||||
output_lines.append("\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n")
|
||||
read_mark_inserted = True
|
||||
|
||||
# 使用指定的 timestamp_mode 格式化时间
|
||||
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ class PromptParameters:
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
read_mark: float = 0.0
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""参数验证"""
|
||||
@@ -449,7 +450,8 @@ class Prompt:
|
||||
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
|
||||
self.parameters.message_list_before_now_long,
|
||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||
self.parameters.sender
|
||||
self.parameters.sender,
|
||||
read_mark=self.parameters.read_mark,
|
||||
)
|
||||
|
||||
context_data["core_dialogue_prompt"] = core_dialogue
|
||||
@@ -465,7 +467,7 @@ class Prompt:
|
||||
|
||||
@staticmethod
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||
message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, read_mark: float = 0.0
|
||||
) -> Tuple[str, str]:
|
||||
"""构建S4U风格的分离对话prompt"""
|
||||
# 实现逻辑与原有SmartPromptBuilder相同
|
||||
@@ -491,6 +493,7 @@ class Prompt:
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True,
|
||||
read_mark=read_mark,
|
||||
)
|
||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
||||
|
||||
@@ -510,7 +513,7 @@ class Prompt:
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
read_mark=read_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
@@ -764,6 +767,7 @@ class Prompt:
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
||||
"chat_context_type": "群聊" if self.parameters.is_group_chat else "私聊",
|
||||
}
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -341,9 +341,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
split_sentences = [s.strip() for s in split_sentences_raw if s.strip()]
|
||||
else:
|
||||
if split_mode == "llm":
|
||||
logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。")
|
||||
split_sentences = [cleaned_text]
|
||||
else: # mode == "punctuation"
|
||||
logger.debug("未检测到 [SPLIT] 标记,回退到基于标点的传统模式进行分割。")
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
else: # mode == "punctuation"
|
||||
logger.debug("使用基于标点的传统模式进行分割。")
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
else:
|
||||
@@ -619,7 +619,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
"""
|
||||
获取聊天类型(是否群聊)和私聊对象信息。
|
||||
|
||||
@@ -663,7 +663,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
if person_id:
|
||||
# get_value is async, so await it directly
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = person_info_manager.get_value(person_id, "person_name")
|
||||
person_data = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
person_name = person_data.get("person_name")
|
||||
|
||||
target_info["person_id"] = person_id
|
||||
target_info["person_name"] = person_name
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
from pymongo.database import Database
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -10,8 +8,6 @@ from src.common.database.sqlalchemy_models import get_engine, get_db_session
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
_client = None
|
||||
_db = None
|
||||
_sql_engine = None
|
||||
|
||||
logger = get_logger("database")
|
||||
@@ -64,43 +60,6 @@ class SQLAlchemyTransaction:
|
||||
db = DatabaseProxy()
|
||||
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
host = os.getenv("MONGODB_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("MONGODB_PORT", "27017"))
|
||||
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
|
||||
username = os.getenv("MONGODB_USERNAME")
|
||||
password = os.getenv("MONGODB_PASSWORD")
|
||||
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
|
||||
|
||||
if uri:
|
||||
# 支持标准mongodb://和mongodb+srv://连接字符串
|
||||
if uri.startswith(("mongodb://", "mongodb+srv://")):
|
||||
return MongoClient(uri)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. "
|
||||
"For MongoDB Atlas, use 'mongodb+srv://' format. "
|
||||
"See: https://www.mongodb.com/docs/manual/reference/connection-string/"
|
||||
)
|
||||
|
||||
if username and password:
|
||||
# 如果有用户名和密码,使用认证连接
|
||||
return MongoClient(host, port, username=username, password=password, authSource=auth_source)
|
||||
|
||||
# 否则使用无认证连接
|
||||
return MongoClient(host, port)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取MongoDB连接实例,延迟初始化。"""
|
||||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
|
||||
return _db
|
||||
|
||||
|
||||
async def initialize_sql_database(database_config):
|
||||
"""
|
||||
根据配置初始化SQL数据库连接(SQLAlchemy版本)
|
||||
@@ -141,17 +100,3 @@ async def initialize_sql_database(database_config):
|
||||
except Exception as e:
|
||||
logger.error(f"初始化SQL数据库失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(get_db(), name)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return get_db()[key] # type: ignore
|
||||
|
||||
|
||||
# 全局MongoDB数据库访问点
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# mmc/src/common/database/db_migration.py
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.schema import CreateIndex
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from src.common.database.sqlalchemy_models import Base, get_engine
|
||||
@@ -70,24 +69,32 @@ async def check_and_migrate_database():
|
||||
|
||||
def add_columns_sync(conn):
|
||||
dialect = conn.dialect
|
||||
compiler = dialect.ddl_compiler(dialect, None)
|
||||
|
||||
for column_name in missing_columns:
|
||||
column = table.c[column_name]
|
||||
|
||||
# 使用DDLCompiler为特定方言编译列
|
||||
compiler = dialect.ddl_compiler(dialect, None)
|
||||
|
||||
# 编译列的数据类型
|
||||
column_type = compiler.get_column_specification(column)
|
||||
|
||||
# 构建原生SQL
|
||||
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}"
|
||||
|
||||
# 添加默认值(如果存在)
|
||||
|
||||
if column.default:
|
||||
default_value = compiler.render_literal_value(column.default.arg, column.type)
|
||||
# 手动处理不同方言的默认值
|
||||
default_arg = column.default.arg
|
||||
if dialect.name == "sqlite" and isinstance(default_arg, bool):
|
||||
# SQLite 将布尔值存储为 0 或 1
|
||||
default_value = "1" if default_arg else "0"
|
||||
elif hasattr(compiler, 'render_literal_value'):
|
||||
try:
|
||||
# 尝试使用 render_literal_value
|
||||
default_value = compiler.render_literal_value(default_arg, column.type)
|
||||
except AttributeError:
|
||||
# 如果失败,则回退到简单的字符串转换
|
||||
default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg)
|
||||
else:
|
||||
# 对于没有 render_literal_value 的旧版或特定方言
|
||||
default_value = f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg)
|
||||
|
||||
sql += f" DEFAULT {default_value}"
|
||||
|
||||
# 添加非空约束(如果存在)
|
||||
|
||||
if not column.nullable:
|
||||
sql += " NOT NULL"
|
||||
|
||||
@@ -109,12 +116,11 @@ async def check_and_migrate_database():
|
||||
logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}")
|
||||
|
||||
def add_indexes_sync(conn):
|
||||
with conn.begin():
|
||||
for index_name in missing_indexes:
|
||||
index_obj = next((idx for idx in table.indexes if idx.name == index_name), None)
|
||||
if index_obj is not None:
|
||||
conn.execute(CreateIndex(index_obj))
|
||||
logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。")
|
||||
for index_name in missing_indexes:
|
||||
index_obj = next((idx for idx in table.indexes if idx.name == index_name), None)
|
||||
if index_obj is not None:
|
||||
index_obj.create(conn)
|
||||
logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。")
|
||||
|
||||
await connection.run_sync(add_indexes_sync)
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import List, Dict, Any, Literal, Union
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field
|
||||
from threading import Lock
|
||||
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
|
||||
@@ -8,7 +8,7 @@ from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table, KeyType
|
||||
from rich.traceback import install
|
||||
from typing import List, Optional
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
@@ -164,6 +164,18 @@ def _version_tuple(v):
|
||||
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
|
||||
|
||||
|
||||
def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDocument | dict | Table):
|
||||
"""
|
||||
递归地从目标字典中移除所有不存在于参考字典中的键。
|
||||
"""
|
||||
# 使用 list() 创建键的副本,以便在迭代期间安全地修改字典
|
||||
for key in list(target.keys()):
|
||||
if key not in reference:
|
||||
del target[key]
|
||||
elif isinstance(target.get(key), (dict, Table)) and isinstance(reference.get(key), (dict, Table)):
|
||||
_remove_obsolete_keys(target[key], reference[key])
|
||||
|
||||
|
||||
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中
|
||||
@@ -334,6 +346,13 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
logger.info(f"开始合并{config_name}新旧配置...")
|
||||
_update_dict(new_config, old_config)
|
||||
|
||||
# 移除在新模板中已不存在的旧配置项
|
||||
logger.info(f"开始移除{config_name}中已废弃的配置项...")
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
template_doc = tomlkit.load(f)
|
||||
_remove_obsolete_keys(new_config, template_doc)
|
||||
logger.info(f"已移除{config_name}中已废弃的配置项")
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
|
||||
@@ -20,6 +20,26 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||
logger = get_logger("AioHTTP-Gemini客户端")
|
||||
|
||||
|
||||
# gemini_thinking参数(默认范围)
|
||||
# 不同模型的思考预算范围配置
|
||||
THINKING_BUDGET_LIMITS = {
|
||||
"gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True},
|
||||
"gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True},
|
||||
"gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False},
|
||||
}
|
||||
# 思维预算特殊值
|
||||
THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定
|
||||
THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用)
|
||||
|
||||
gemini_safe_settings = [
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
||||
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
||||
]
|
||||
|
||||
|
||||
def _format_to_mime_type(image_format: str) -> str:
|
||||
"""
|
||||
将图片格式转换为正确的MIME类型
|
||||
@@ -130,7 +150,11 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
|
||||
|
||||
|
||||
def _build_generation_config(
|
||||
max_tokens: int, temperature: float, response_format: RespFormat | None = None, extra_params: dict | None = None
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
thinking_budget: int,
|
||||
response_format: RespFormat | None = None,
|
||||
extra_params: dict | None = None,
|
||||
) -> dict:
|
||||
"""构建生成配置"""
|
||||
config = {
|
||||
@@ -138,6 +162,8 @@ def _build_generation_config(
|
||||
"temperature": temperature,
|
||||
"topK": 1,
|
||||
"topP": 1,
|
||||
"safetySettings": gemini_safe_settings,
|
||||
"thinkingConfig": {"includeThoughts": True, "thinkingBudget": thinking_budget},
|
||||
}
|
||||
|
||||
# 处理响应格式
|
||||
@@ -150,7 +176,11 @@ def _build_generation_config(
|
||||
|
||||
# 合并额外参数
|
||||
if extra_params:
|
||||
config.update(extra_params)
|
||||
# 拷贝一份以防修改原始字典
|
||||
safe_extra_params = extra_params.copy()
|
||||
# 移除已单独处理的 thinking_budget
|
||||
safe_extra_params.pop("thinking_budget", None)
|
||||
config.update(safe_extra_params)
|
||||
|
||||
return config
|
||||
|
||||
@@ -317,6 +347,41 @@ class AiohttpGeminiClient(BaseClient):
|
||||
if api_provider.base_url:
|
||||
self.base_url = api_provider.base_url.rstrip("/")
|
||||
|
||||
@staticmethod
|
||||
def clamp_thinking_budget(tb: int, model_id: str) -> int:
|
||||
"""
|
||||
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
|
||||
"""
|
||||
limits = None
|
||||
|
||||
# 优先尝试精确匹配
|
||||
if model_id in THINKING_BUDGET_LIMITS:
|
||||
limits = THINKING_BUDGET_LIMITS[model_id]
|
||||
else:
|
||||
# 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先
|
||||
sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True)
|
||||
for key in sorted_keys:
|
||||
# 必须满足:完全等于 或者 前缀匹配(带 "-" 边界)
|
||||
if model_id == key or model_id.startswith(f"{key}-"):
|
||||
limits = THINKING_BUDGET_LIMITS[key]
|
||||
break
|
||||
|
||||
# 特殊值处理
|
||||
if tb == THINKING_BUDGET_AUTO:
|
||||
return THINKING_BUDGET_AUTO
|
||||
if tb == THINKING_BUDGET_DISABLED:
|
||||
if limits and limits.get("can_disable", False):
|
||||
return THINKING_BUDGET_DISABLED
|
||||
return limits["min"] if limits else THINKING_BUDGET_AUTO
|
||||
|
||||
# 已知模型裁剪到范围
|
||||
if limits:
|
||||
return max(limits["min"], min(tb, limits["max"]))
|
||||
|
||||
# 未知模型,返回动态模式
|
||||
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
|
||||
return tb
|
||||
|
||||
# 移除全局 session,全部请求都用 with aiohttp.ClientSession() as session:
|
||||
|
||||
async def _make_request(
|
||||
@@ -376,10 +441,21 @@ class AiohttpGeminiClient(BaseClient):
|
||||
# 转换消息格式
|
||||
contents, system_instructions = _convert_messages(message_list)
|
||||
|
||||
# 处理思考预算
|
||||
tb = THINKING_BUDGET_AUTO
|
||||
if extra_params and "thinking_budget" in extra_params:
|
||||
try:
|
||||
tb = int(extra_params["thinking_budget"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
|
||||
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
||||
|
||||
# 构建请求体
|
||||
request_data = {
|
||||
"contents": contents,
|
||||
"generationConfig": _build_generation_config(max_tokens, temperature, response_format, extra_params),
|
||||
"generationConfig": _build_generation_config(
|
||||
max_tokens, temperature, tb, response_format, extra_params
|
||||
),
|
||||
}
|
||||
|
||||
# 添加系统指令
|
||||
@@ -475,7 +551,7 @@ class AiohttpGeminiClient(BaseClient):
|
||||
|
||||
request_data = {
|
||||
"contents": contents,
|
||||
"generationConfig": _build_generation_config(2048, 0.1, None, extra_params),
|
||||
"generationConfig": _build_generation_config(2048, 0.1, THINKING_BUDGET_AUTO, None, extra_params),
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -84,6 +84,7 @@ async def generate_reply(
|
||||
return_prompt: bool = False,
|
||||
request_type: str = "generator_api",
|
||||
from_plugin: bool = True,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""生成回复
|
||||
|
||||
@@ -129,6 +130,7 @@ async def generate_reply(
|
||||
from_plugin=from_plugin,
|
||||
stream_id=chat_stream.stream_id if chat_stream else chat_id,
|
||||
reply_message=reply_message,
|
||||
read_mark=read_mark,
|
||||
)
|
||||
if not success:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
|
||||
@@ -5,6 +5,7 @@ import toml
|
||||
import orjson
|
||||
import shutil
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
@@ -268,100 +269,64 @@ class PluginBase(ABC):
|
||||
except IOError as e:
|
||||
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
|
||||
|
||||
def _get_expected_config_version(self) -> str:
|
||||
"""获取插件期望的配置版本号"""
|
||||
# 从config_schema的plugin.config_version字段获取
|
||||
if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict):
|
||||
config_version_field = self.config_schema["plugin"].get("config_version")
|
||||
if isinstance(config_version_field, ConfigField):
|
||||
return config_version_field.default
|
||||
return "1.0.0"
|
||||
|
||||
@staticmethod
|
||||
def _get_current_config_version(config: Dict[str, Any]) -> str:
|
||||
"""从配置文件中获取当前版本号"""
|
||||
if "plugin" in config and "config_version" in config["plugin"]:
|
||||
return str(config["plugin"]["config_version"])
|
||||
# 如果没有config_version字段,视为最早的版本
|
||||
return "0.0.0"
|
||||
|
||||
def _backup_config_file(self, config_file_path: str) -> str:
|
||||
"""备份配置文件"""
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = f"{config_file_path}.backup_{timestamp}"
|
||||
|
||||
"""备份配置文件到指定的 backup 子目录"""
|
||||
try:
|
||||
config_path = Path(config_file_path)
|
||||
backup_dir = config_path.parent / "backup"
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_filename = f"{config_path.name}.backup_{timestamp}"
|
||||
backup_path = backup_dir / backup_filename
|
||||
|
||||
shutil.copy2(config_file_path, backup_path)
|
||||
logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}")
|
||||
return backup_path
|
||||
return str(backup_path)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 备份配置文件失败: {e}")
|
||||
logger.error(f"{self.log_prefix} 备份配置文件失败: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""将旧配置值迁移到新配置结构中
|
||||
def _synchronize_config(
|
||||
self, schema_config: Dict[str, Any], user_config: Dict[str, Any]
|
||||
) -> tuple[Dict[str, Any], bool]:
|
||||
"""递归地将用户配置与 schema 同步,返回同步后的配置和是否发生变化的标志"""
|
||||
changed = False
|
||||
|
||||
Args:
|
||||
old_config: 旧配置数据
|
||||
new_config: 基于新schema生成的默认配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 迁移后的配置
|
||||
"""
|
||||
|
||||
def migrate_section(
|
||||
old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str
|
||||
# 内部递归函数
|
||||
def _sync_dicts(
|
||||
schema_dict: Dict[str, Any], user_dict: Dict[str, Any], parent_key: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""迁移单个配置节"""
|
||||
result = new_section.copy()
|
||||
nonlocal changed
|
||||
synced_dict = schema_dict.copy()
|
||||
|
||||
for key, value in old_section.items():
|
||||
if key in new_section:
|
||||
# 特殊处理:config_version字段总是使用新版本
|
||||
if section_name == "plugin" and key == "config_version":
|
||||
# 保持新的版本号,不迁移旧值
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})"
|
||||
)
|
||||
continue
|
||||
# 检查并记录用户配置中多余的、在 schema 中不存在的键
|
||||
for key in user_dict:
|
||||
if key not in schema_dict:
|
||||
logger.warning(f"{self.log_prefix} 发现废弃配置项 '{parent_key}{key}',将被移除。")
|
||||
changed = True
|
||||
|
||||
# 键存在于新配置中,复制值
|
||||
if isinstance(value, dict) and isinstance(new_section[key], dict):
|
||||
# 递归处理嵌套字典
|
||||
result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}")
|
||||
# 以 schema 为基准进行遍历,保留用户的值,补全缺失的项
|
||||
for key, schema_value in schema_dict.items():
|
||||
full_key = f"{parent_key}{key}"
|
||||
if key in user_dict:
|
||||
user_value = user_dict[key]
|
||||
if isinstance(schema_value, dict) and isinstance(user_value, dict):
|
||||
# 递归同步嵌套的字典
|
||||
synced_dict[key] = _sync_dicts(schema_value, user_value, f"{full_key}.")
|
||||
else:
|
||||
result[key] = value
|
||||
logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}")
|
||||
# 键存在,保留用户的值
|
||||
synced_dict[key] = user_value
|
||||
else:
|
||||
# 键在新配置中不存在,记录警告
|
||||
logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除")
|
||||
# 键在用户配置中缺失,补全
|
||||
logger.info(f"{self.log_prefix} 补全缺失的配置项: '{full_key}' = {schema_value}")
|
||||
changed = True
|
||||
# synced_dict[key] 已经包含了来自 schema_dict.copy() 的默认值
|
||||
|
||||
return result
|
||||
return synced_dict
|
||||
|
||||
migrated_config = {}
|
||||
|
||||
# 迁移每个配置节
|
||||
for section_name, new_section_data in new_config.items():
|
||||
if (
|
||||
section_name in old_config
|
||||
and isinstance(old_config[section_name], dict)
|
||||
and isinstance(new_section_data, dict)
|
||||
):
|
||||
migrated_config[section_name] = migrate_section(
|
||||
old_config[section_name], new_section_data, section_name
|
||||
)
|
||||
else:
|
||||
# 新增的节或类型不匹配,使用默认值
|
||||
migrated_config[section_name] = new_section_data
|
||||
if section_name in old_config:
|
||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
|
||||
|
||||
# 检查旧配置中是否有新配置没有的节
|
||||
for section_name in old_config:
|
||||
if section_name not in migrated_config:
|
||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
|
||||
|
||||
return migrated_config
|
||||
final_config = _sync_dicts(schema_config, user_config)
|
||||
return final_config, changed
|
||||
|
||||
def _generate_config_from_schema(self) -> Dict[str, Any]:
|
||||
# sourcery skip: dict-comprehension
|
||||
@@ -393,11 +358,7 @@ class PluginBase(ABC):
|
||||
|
||||
toml_str = f"# {self.plugin_name} - 配置文件\n"
|
||||
plugin_description = self.get_manifest_info("description", "插件配置文件")
|
||||
toml_str += f"# {plugin_description}\n"
|
||||
|
||||
# 获取当前期望的配置版本
|
||||
expected_version = self._get_expected_config_version()
|
||||
toml_str += f"# 配置版本: {expected_version}\n\n"
|
||||
toml_str += f"# {plugin_description}\n\n"
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
@@ -456,77 +417,74 @@ class PluginBase(ABC):
|
||||
|
||||
def _load_plugin_config(self): # sourcery skip: extract-method
|
||||
"""
|
||||
加载插件配置文件,实现集中化管理和自动迁移。
|
||||
加载并同步插件配置文件。
|
||||
|
||||
处理逻辑:
|
||||
1. 确定用户配置文件路径(位于 `config/plugins/` 目录下)。
|
||||
2. 如果用户配置文件不存在,则根据 config_schema 直接在中央目录生成一份。
|
||||
3. 加载用户配置文件,并进行版本检查和自动迁移(如果需要)。
|
||||
4. 最终加载的配置是用户配置文件。
|
||||
1. 确定用户配置文件路径和插件自带的配置文件路径。
|
||||
2. 如果用户配置文件不存在,尝试从插件目录迁移(移动)一份。
|
||||
3. 如果迁移后(或原本)用户配置文件仍不存在,则根据 schema 生成一份。
|
||||
4. 加载用户配置文件。
|
||||
5. 以 schema 为基准,与用户配置进行同步,补全缺失项并移除废弃项。
|
||||
6. 如果同步过程发现不一致,则先备份原始文件,然后将同步后的完整配置写回用户目录。
|
||||
7. 将最终同步后的配置加载到 self.config。
|
||||
"""
|
||||
if not self.config_file_name:
|
||||
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
|
||||
return
|
||||
|
||||
# 1. 确定并确保用户配置文件路径存在
|
||||
user_config_path = os.path.join(CONFIG_DIR, "plugins", self.plugin_name, self.config_file_name)
|
||||
plugin_config_path = os.path.join(self.plugin_dir, self.config_file_name)
|
||||
os.makedirs(os.path.dirname(user_config_path), exist_ok=True)
|
||||
|
||||
# 2. 如果用户配置文件不存在,直接在中央目录生成
|
||||
# 首次加载迁移:如果用户配置不存在,但插件目录中存在,则移动过来
|
||||
if not os.path.exists(user_config_path) and os.path.exists(plugin_config_path):
|
||||
try:
|
||||
shutil.move(plugin_config_path, user_config_path)
|
||||
logger.info(f"{self.log_prefix} 已将配置文件从 {plugin_config_path} 迁移到 {user_config_path}")
|
||||
except OSError as e:
|
||||
logger.error(f"{self.log_prefix} 迁移配置文件失败: {e}", exc_info=True)
|
||||
|
||||
# 如果用户配置文件仍然不存在,生成默认的
|
||||
if not os.path.exists(user_config_path):
|
||||
logger.info(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在,将生成默认配置。")
|
||||
self._generate_and_save_default_config(user_config_path)
|
||||
|
||||
# 检查最终的用户配置文件是否存在
|
||||
if not os.path.exists(user_config_path):
|
||||
# 如果插件没有定义config_schema,那么不创建文件是正常行为
|
||||
if not self.config_schema:
|
||||
logger.debug(f"{self.log_prefix} 插件未定义config_schema,使用空的配置.")
|
||||
logger.debug(f"{self.log_prefix} 插件未定义 config_schema,使用空配置。")
|
||||
self.config = {}
|
||||
return
|
||||
|
||||
logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 用户配置文件 {user_config_path} 不存在且无法创建。")
|
||||
return
|
||||
|
||||
# 3. 加载、检查和迁移用户配置文件
|
||||
_, file_ext = os.path.splitext(self.config_file_name)
|
||||
if file_ext.lower() != ".toml":
|
||||
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
||||
self.config = {}
|
||||
return
|
||||
try:
|
||||
with open(user_config_path, "r", encoding="utf-8") as f:
|
||||
existing_config = toml.load(f) or {}
|
||||
user_config = toml.load(f) or {}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
|
||||
self.config = {}
|
||||
self.config = self._generate_config_from_schema() # 加载失败时使用默认 schema
|
||||
return
|
||||
|
||||
current_version = self._get_current_config_version(existing_config)
|
||||
expected_version = self._get_expected_config_version()
|
||||
# 生成基于 schema 的理想配置结构
|
||||
schema_config = self._generate_config_from_schema()
|
||||
|
||||
if current_version == "0.0.0":
|
||||
logger.debug(f"{self.log_prefix} 用户配置文件无版本信息,跳过版本检查")
|
||||
self.config = existing_config
|
||||
elif current_version != expected_version:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 检测到用户配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
|
||||
)
|
||||
new_config_structure = self._generate_config_from_schema()
|
||||
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
|
||||
self._save_config_to_file(migrated_config, user_config_path)
|
||||
logger.info(f"{self.log_prefix} 用户配置文件已从 v{current_version} 更新到 v{expected_version}")
|
||||
self.config = migrated_config
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 用户配置版本匹配 (v{current_version}),直接加载")
|
||||
self.config = existing_config
|
||||
# 将用户配置与 schema 同步
|
||||
synced_config, was_changed = self._synchronize_config(schema_config, user_config)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载")
|
||||
# 如果配置发生了变化(补全或移除),则备份并重写配置文件
|
||||
if was_changed:
|
||||
logger.info(f"{self.log_prefix} 检测到配置结构不匹配,将自动同步并更新配置文件。")
|
||||
self._backup_config_file(user_config_path)
|
||||
self._save_config_to_file(synced_config, user_config_path)
|
||||
logger.info(f"{self.log_prefix} 配置文件已同步更新。")
|
||||
|
||||
# 从配置中更新 enable_plugin 状态
|
||||
self.config = synced_config
|
||||
logger.debug(f"{self.log_prefix} 配置已从 {user_config_path} 加载并同步。")
|
||||
|
||||
# 从最终配置中更新插件启用状态
|
||||
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
||||
self._is_enabled = self.config["plugin"]["enabled"]
|
||||
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}")
|
||||
logger.info(f"{self.log_prefix} 从配置更新插件启用状态: {self._is_enabled}")
|
||||
|
||||
def _check_dependencies(self) -> bool:
|
||||
"""检查插件依赖"""
|
||||
|
||||
@@ -39,76 +39,6 @@ class PluginManager:
|
||||
self._ensure_plugin_directories()
|
||||
logger.info("插件管理器初始化完成")
|
||||
|
||||
def _synchronize_plugin_config(self, plugin_name: str, plugin_dir: str):
|
||||
"""
|
||||
同步单个插件的配置。
|
||||
|
||||
此过程确保中央配置与插件本地配置保持同步,包含两个主要步骤:
|
||||
1. 如果中央配置不存在,则从插件目录复制默认配置到中央配置目录。
|
||||
2. 使用中央配置覆盖插件的本地配置,以确保插件运行时使用的是最新的用户配置。
|
||||
"""
|
||||
try:
|
||||
plugin_path = Path(plugin_dir)
|
||||
# 修正:插件的配置文件路径应为 config.toml 文件,而不是目录
|
||||
plugin_config_file = plugin_path / "config.toml"
|
||||
central_config_dir = Path("config") / "plugins" / plugin_name
|
||||
|
||||
# 确保中央配置目录存在
|
||||
central_config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 步骤 1: 从插件目录复制默认配置到中央目录
|
||||
self._copy_default_config_to_central(plugin_name, plugin_config_file, central_config_dir)
|
||||
|
||||
# 步骤 2: 从中央目录同步配置到插件目录
|
||||
self._sync_central_config_to_plugin(plugin_name, plugin_config_file, central_config_dir)
|
||||
|
||||
except OSError as e:
|
||||
logger.error(f"处理插件 '{plugin_name}' 的配置时发生文件操作错误: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"同步插件 '{plugin_name}' 配置时发生未知错误: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _copy_default_config_to_central(plugin_name: str, plugin_config_file: Path, central_config_dir: Path):
|
||||
"""
|
||||
如果中央配置不存在,则将插件的默认 config.toml 复制到中央目录。
|
||||
"""
|
||||
if not plugin_config_file.is_file():
|
||||
return # 插件没有提供默认配置文件,直接跳过
|
||||
|
||||
central_config_file = central_config_dir / plugin_config_file.name
|
||||
if not central_config_file.exists():
|
||||
shutil.copy2(plugin_config_file, central_config_file)
|
||||
logger.info(f"为插件 '{plugin_name}' 从模板复制了默认配置: {plugin_config_file.name}")
|
||||
|
||||
def _sync_central_config_to_plugin(self, plugin_name: str, plugin_config_file: Path, central_config_dir: Path):
|
||||
"""
|
||||
将中央配置同步(覆盖)到插件的本地配置。
|
||||
"""
|
||||
# 遍历中央配置目录中的所有文件
|
||||
for central_file in central_config_dir.iterdir():
|
||||
if not central_file.is_file():
|
||||
continue
|
||||
|
||||
# 目标文件应与中央配置文件同名,这里我们强制它为 config.toml
|
||||
target_plugin_file = plugin_config_file
|
||||
|
||||
# 仅在文件内容不同时才执行复制,以减少不必要的IO操作
|
||||
if not self._is_file_content_identical(central_file, target_plugin_file):
|
||||
shutil.copy2(central_file, target_plugin_file)
|
||||
logger.info(f"已将中央配置 '{central_file.name}' 同步到插件 '{plugin_name}'")
|
||||
|
||||
@staticmethod
|
||||
def _is_file_content_identical(file1: Path, file2: Path) -> bool:
|
||||
"""
|
||||
通过比较 MD5 哈希值检查两个文件的内容是否相同。
|
||||
"""
|
||||
if not file2.exists():
|
||||
return False # 目标文件不存在,视为不同
|
||||
|
||||
# 使用 'rb' 模式以二进制方式读取文件,确保哈希值计算的一致性
|
||||
with open(file1, "rb") as f1, open(file2, "rb") as f2:
|
||||
return hashlib.md5(f1.read()).hexdigest() == hashlib.md5(f2.read()).hexdigest()
|
||||
|
||||
# === 插件目录管理 ===
|
||||
|
||||
def add_plugin_directory(self, directory: str) -> bool:
|
||||
@@ -176,8 +106,6 @@ class PluginManager:
|
||||
if not plugin_dir:
|
||||
return False, 1
|
||||
|
||||
# 同步插件配置
|
||||
self._synchronize_plugin_config(plugin_name, plugin_dir)
|
||||
|
||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||
if not plugin_instance:
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
from inspect import iscoroutinefunction
|
||||
import inspect
|
||||
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.send_api import text_to_stream
|
||||
|
||||
@@ -255,7 +255,7 @@ class EmojiAction(BaseAction):
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix} 表情包发送失败")
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False)
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包,但失败了",action_done= False)
|
||||
return False, "表情包发送失败"
|
||||
|
||||
# 发送成功后,记录到历史
|
||||
@@ -264,7 +264,7 @@ class EmojiAction(BaseAction):
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
|
||||
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True)
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包",action_done= True)
|
||||
|
||||
return True, f"发送表情包: {emoji_description}"
|
||||
|
||||
|
||||
@@ -53,6 +53,6 @@ class SendFeedCommand(PlusCommand):
|
||||
return False, result.get("message", "未知错误"), True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行发送说说命令时发生未知异常: {e}", exc_info=True)
|
||||
logger.error(f"执行发送说说命令时发生未知异常: {e},它的类型是:{type(e)}", exc_info=True)
|
||||
await self.send_text("呜... 发送过程中好像出了点问题。")
|
||||
return False, "命令执行异常", True
|
||||
|
||||
@@ -119,12 +119,10 @@ class ContentService:
|
||||
logger.error(f"生成说说内容时发生异常: {e}")
|
||||
return ""
|
||||
|
||||
async def generate_comment(self, content: str, target_name: str, rt_con: str = "", images=None) -> str:
|
||||
async def generate_comment(self, content: str, target_name: str, rt_con: str = "", images: list = []) -> str:
|
||||
"""
|
||||
针对一条具体的说说内容生成评论。
|
||||
"""
|
||||
if images is None:
|
||||
images = []
|
||||
for i in range(3): # 重试3次
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
@@ -182,8 +180,7 @@ class ContentService:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def generate_comment_reply(story_content: str, comment_content: str, commenter_name: str) -> str:
|
||||
async def generate_comment_reply(self, story_content: str, comment_content: str, commenter_name: str) -> str:
|
||||
"""
|
||||
针对自己说说的评论,生成回复。
|
||||
"""
|
||||
|
||||
@@ -50,8 +50,7 @@ class CookieService:
|
||||
logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _get_cookies_from_adapter(stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||
async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||
"""通过Adapter API获取Cookie"""
|
||||
try:
|
||||
params = {"domain": "user.qzone.qq.com"}
|
||||
|
||||
@@ -59,8 +59,7 @@ class ImageService:
|
||||
logger.error(f"处理AI配图时发生异常: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _call_siliconflow_api(api_key: str, story: str, image_dir: str, batch_size: int) -> bool:
|
||||
async def _call_siliconflow_api(self, api_key: str, story: str, image_dir: str, batch_size: int) -> bool:
|
||||
"""
|
||||
调用硅基流动(SiliconFlow)的API来生成图片。
|
||||
|
||||
|
||||
@@ -187,8 +187,7 @@ class QZoneService:
|
||||
|
||||
# --- Internal Helper Methods ---
|
||||
|
||||
@staticmethod
|
||||
async def _get_intercom_context(stream_id: str) -> Optional[str]:
|
||||
async def _get_intercom_context(self, stream_id: str) -> Optional[str]:
|
||||
"""
|
||||
根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。
|
||||
|
||||
@@ -399,8 +398,7 @@ class QZoneService:
|
||||
logger.error(f"加载本地图片失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _generate_gtk(skey: str) -> str:
|
||||
def _generate_gtk(self, skey: str) -> str:
|
||||
hash_val = 5381
|
||||
for char in skey:
|
||||
hash_val += (hash_val << 5) + ord(char)
|
||||
@@ -437,8 +435,7 @@ class QZoneService:
|
||||
logger.error(f"更新或加载Cookie时发生异常: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_cookies_http(host: str, port: str, napcat_token: str) -> Optional[Dict]:
|
||||
async def _fetch_cookies_http(self, host: str, port: str, napcat_token: str) -> Optional[Dict]:
|
||||
"""通过HTTP服务器获取Cookie"""
|
||||
url = f"http://{host}:{port}/get_cookies"
|
||||
max_retries = 5
|
||||
@@ -657,20 +654,30 @@ class QZoneService:
|
||||
end_idx = resp_text.rfind("}") + 1
|
||||
if start_idx != -1 and end_idx != -1:
|
||||
json_str = resp_text[start_idx:end_idx]
|
||||
upload_result = eval(json_str) # 与原版保持一致使用eval
|
||||
try:
|
||||
upload_result = orjson.loads(json_str)
|
||||
except orjson.JSONDecodeError:
|
||||
logger.error(f"图片上传响应JSON解析失败,原始响应: {resp_text}")
|
||||
return None
|
||||
|
||||
logger.info(f"图片上传解析结果: {upload_result}")
|
||||
logger.debug(f"图片上传解析结果: {upload_result}")
|
||||
|
||||
if upload_result.get("ret") == 0:
|
||||
# 使用原版的参数提取逻辑
|
||||
picbo, richval = _get_picbo_and_richval(upload_result)
|
||||
logger.info(f"图片 {index + 1} 上传成功: picbo={picbo}")
|
||||
return {"pic_bo": picbo, "richval": richval}
|
||||
try:
|
||||
# 使用原版的参数提取逻辑
|
||||
picbo, richval = _get_picbo_and_richval(upload_result)
|
||||
logger.info(f"图片 {index + 1} 上传成功: picbo={picbo}")
|
||||
return {"pic_bo": picbo, "richval": richval}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}", exc_info=True
|
||||
)
|
||||
return None
|
||||
else:
|
||||
logger.error(f"图片 {index + 1} 上传失败: {upload_result}")
|
||||
return None
|
||||
else:
|
||||
logger.error("无法解析上传响应")
|
||||
logger.error(f"无法从响应中提取JSON内容: {resp_text}")
|
||||
return None
|
||||
else:
|
||||
error_text = await response.text()
|
||||
|
||||
@@ -36,8 +36,7 @@ class ReplyTrackerService:
|
||||
self._load_data()
|
||||
logger.debug(f"ReplyTrackerService initialized with data file: {self.reply_record_file}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_data(data: Any) -> bool:
|
||||
def _validate_data(self, data: Any) -> bool:
|
||||
"""验证加载的数据格式是否正确"""
|
||||
if not isinstance(data, dict):
|
||||
logger.error("加载的数据不是字典格式")
|
||||
|
||||
@@ -129,8 +129,7 @@ class SchedulerService:
|
||||
logger.error(f"定时任务循环中发生未知错误: {e}\n{traceback.format_exc()}")
|
||||
await asyncio.sleep(300) # 发生错误后,等待一段时间再重试
|
||||
|
||||
@staticmethod
|
||||
async def _is_processed(hour_str: str, activity: str) -> bool:
|
||||
async def _is_processed(self, hour_str: str, activity: str) -> bool:
|
||||
"""
|
||||
检查指定的任务(某个小时的某个活动)是否已经被成功处理过。
|
||||
|
||||
@@ -153,8 +152,7 @@ class SchedulerService:
|
||||
logger.error(f"检查日程处理状态时发生数据库错误: {e}")
|
||||
return False # 数据库异常时,默认为未处理,允许重试
|
||||
|
||||
@staticmethod
|
||||
async def _mark_as_processed(hour_str: str, activity: str, success: bool, content: str):
|
||||
async def _mark_as_processed(self, hour_str: str, activity: str, success: bool, content: str):
|
||||
"""
|
||||
将任务的处理状态和结果写入数据库。
|
||||
|
||||
@@ -187,7 +185,7 @@ class SchedulerService:
|
||||
send_success=success,
|
||||
)
|
||||
session.add(new_record)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
logger.info(f"已更新日程处理状态: {hour_str} - {activity} - 成功: {success}")
|
||||
except Exception as e:
|
||||
logger.error(f"更新日程处理状态时发生数据库错误: {e}")
|
||||
|
||||
@@ -49,8 +49,7 @@ class _SimpleQZoneAPI:
|
||||
if p_skey:
|
||||
self.gtk2 = self._generate_gtk(p_skey)
|
||||
|
||||
@staticmethod
|
||||
def _generate_gtk(skey: str) -> str:
|
||||
def _generate_gtk(self, skey: str) -> str:
|
||||
hash_val = 5381
|
||||
for char in skey:
|
||||
hash_val += (hash_val << 5) + ord(char)
|
||||
|
||||
@@ -26,7 +26,7 @@ import json
|
||||
import websockets as Server
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any, Coroutine
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import uuid
|
||||
|
||||
from maim_message import (
|
||||
|
||||
@@ -53,14 +53,14 @@ price_out = 8.0 # 输出价格(用于API调用统计,单
|
||||
#use_anti_truncation = true # [可选] 启用反截断功能。当模型输出不完整时,系统会自动重试。建议只为有需要的模型(如Gemini)开启。
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Pro/deepseek-ai/DeepSeek-V3"
|
||||
model_identifier = "deepseek-ai/DeepSeek-V3"
|
||||
name = "siliconflow-deepseek-v3"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 2.0
|
||||
price_out = 8.0
|
||||
|
||||
[[models]]
|
||||
model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||
model_identifier = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||
name = "deepseek-r1-distill-qwen-32b"
|
||||
api_provider = "SiliconFlow"
|
||||
price_in = 4.0
|
||||
|
||||
Reference in New Issue
Block a user