This commit is contained in:
SengokuCola
2025-07-17 00:57:07 +08:00
37 changed files with 644 additions and 406 deletions

122
bot.py
View File

@@ -8,6 +8,7 @@ if os.path.exists(".env"):
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
import sys
import time
import platform
@@ -140,87 +141,88 @@ async def graceful_shutdown():
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
def check_eula():
eula_confirm_file = Path("eula.confirmed")
privacy_confirm_file = Path("privacy.confirmed")
eula_file = Path("EULA.md")
privacy_file = Path("PRIVACY.md")
def _calculate_file_hash(file_path: Path, file_type: str) -> str:
"""计算文件的MD5哈希值"""
if not file_path.exists():
logger.error(f"{file_type} 文件不存在")
raise FileNotFoundError(f"{file_type} 文件不存在")
eula_updated = True
privacy_updated = True
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return hashlib.md5(content.encode("utf-8")).hexdigest()
eula_confirmed = False
privacy_confirmed = False
# 首先计算当前EULA文件的哈希值
if eula_file.exists():
with open(eula_file, "r", encoding="utf-8") as f:
eula_content = f.read()
eula_new_hash = hashlib.md5(eula_content.encode("utf-8")).hexdigest()
else:
logger.error("EULA.md 文件不存在")
raise FileNotFoundError("EULA.md 文件不存在")
def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]:
"""检查协议确认状态
# 首先计算当前隐私条款文件的哈希值
if privacy_file.exists():
with open(privacy_file, "r", encoding="utf-8") as f:
privacy_content = f.read()
privacy_new_hash = hashlib.md5(privacy_content.encode("utf-8")).hexdigest()
else:
logger.error("PRIVACY.md 文件不存在")
raise FileNotFoundError("PRIVACY.md 文件不存在")
Returns:
tuple[bool, bool]: (已确认, 未更新)
"""
# 检查环境变量确认
if file_hash == os.getenv(env_var):
return True, False
# 检查EULA确认文件是否存在
if eula_confirm_file.exists():
with open(eula_confirm_file, "r", encoding="utf-8") as f:
# 检查确认文件
if confirm_file.exists():
with open(confirm_file, "r", encoding="utf-8") as f:
confirmed_content = f.read()
if eula_new_hash == confirmed_content:
eula_confirmed = True
eula_updated = False
if eula_new_hash == os.getenv("EULA_AGREE"):
eula_confirmed = True
eula_updated = False
if file_hash == confirmed_content:
return True, False
# 检查隐私条款确认文件是否存在
if privacy_confirm_file.exists():
with open(privacy_confirm_file, "r", encoding="utf-8") as f:
confirmed_content = f.read()
if privacy_new_hash == confirmed_content:
privacy_confirmed = True
privacy_updated = False
if privacy_new_hash == os.getenv("PRIVACY_AGREE"):
privacy_confirmed = True
privacy_updated = False
return False, True
# 如果EULA或隐私条款有更新提示用户重新确认
if eula_updated or privacy_updated:
def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
"""提示用户确认协议"""
confirm_logger.critical("EULA或隐私条款内容已更新请在阅读后重新确认继续运行视为同意更新后的以上两款协议")
confirm_logger.critical(
f'输入"同意""confirmed"或设置环境变量"EULA_AGREE={eula_new_hash}""PRIVACY_AGREE={privacy_new_hash}"继续运行'
f'输入"同意""confirmed"或设置环境变量"EULA_AGREE={eula_hash}""PRIVACY_AGREE={privacy_hash}"继续运行'
)
while True:
user_input = input().strip().lower()
if user_input in ["同意", "confirmed"]:
# print("确认成功,继续运行")
# print(f"确认成功,继续运行{eula_updated} {privacy_updated}")
if eula_updated:
logger.info(f"更新EULA确认文件{eula_new_hash}")
eula_confirm_file.write_text(eula_new_hash, encoding="utf-8")
if privacy_updated:
logger.info(f"更新隐私条款确认文件{privacy_new_hash}")
privacy_confirm_file.write_text(privacy_new_hash, encoding="utf-8")
break
else:
return
confirm_logger.critical('请输入"同意""confirmed"以继续运行')
def _save_confirmations(eula_updated: bool, privacy_updated: bool, eula_hash: str, privacy_hash: str) -> None:
"""保存用户确认结果"""
if eula_updated:
logger.info(f"更新EULA确认文件{eula_hash}")
Path("eula.confirmed").write_text(eula_hash, encoding="utf-8")
if privacy_updated:
logger.info(f"更新隐私条款确认文件{privacy_hash}")
Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8")
def check_eula():
"""检查EULA和隐私条款确认状态"""
# 计算文件哈希值
eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md")
privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md")
# 检查确认状态
eula_confirmed, eula_updated = _check_agreement_status(eula_hash, Path("eula.confirmed"), "EULA_AGREE")
privacy_confirmed, privacy_updated = _check_agreement_status(
privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE"
)
# 早期返回:如果都已确认且未更新
if eula_confirmed and privacy_confirmed:
return
elif eula_confirmed and privacy_confirmed:
return
# 如果有更新,需要重新确认
if eula_updated or privacy_updated:
_prompt_user_confirmation(eula_hash, privacy_hash)
_save_confirmations(eula_updated, privacy_updated, eula_hash, privacy_hash)
def raw_main():
# 利用 TZ 环境变量设定程序工作的时区
if platform.system().lower() != "windows":
time.tzset()
time.tzset() # type: ignore
check_eula()
logger.info("检查EULA和隐私条款完成")

View File

@@ -21,3 +21,6 @@
- `config_api.py`中的`get_global_config``get_plugin_config`方法现在支持嵌套访问的配置键名。
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时保证了typing正确`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
4. 现在增加了参数类型检查,完善了对应注释
5. 现在插件抽象出了总基类 `PluginBase`
- 基于`Action``Command`的插件基类现在为`BasePlugin`,它继承自`PluginBase`,由`register_plugin`装饰器注册。
- 基于`Event`的插件基类现在为`BaseEventPlugin`,它也继承自`PluginBase`,由`register_event_plugin`装饰器注册。

View File

@@ -15,6 +15,7 @@ from src.chat.knowledge.kg_manager import KGManager
from src.common.logger import get_logger
from src.chat.knowledge.utils.hash import get_sha256
from src.manager.local_store_manager import local_storage
from dotenv import load_dotenv
# 添加项目根目录到 sys.path
@@ -23,6 +24,45 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入")
ENV_FILE = os.path.join(ROOT_PATH, ".env")
if os.path.exists(".env"):
load_dotenv(".env", override=True)
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
env_mask = {key: os.getenv(key) for key in os.environ}
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
def ensure_openie_dir():
"""确保OpenIE数据目录存在"""
@@ -174,6 +214,8 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
def main(): # sourcery skip: dict-comprehension
# 新增确认提示
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
print("同之前样例在本地模型下在70分钟内我们发送了约8万条请求在网络允许下速度会更快")

View File

@@ -27,6 +27,7 @@ from rich.progress import (
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from dotenv import load_dotenv
logger = get_logger("LPMM知识库-信息提取")
@@ -35,6 +36,45 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
ENV_FILE = os.path.join(ROOT_PATH, ".env")
if os.path.exists(".env"):
load_dotenv(".env", override=True)
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
env_mask = {key: os.getenv(key) for key in os.environ}
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
def ensure_dirs():
"""确保临时目录和输出目录存在"""
@@ -118,6 +158,8 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
# 设置信号处理器
signal.signal(signal.SIGINT, signal_handler)
ensure_dirs() # 确保目录存在
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
# 新增用户确认提示
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")

View File

@@ -107,11 +107,12 @@ class ExpressionLearner:
last_active_time = expr.get("last_active_time", time.time())
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == type_str) &
(Expression.situation == situation) &
(Expression.style == style_val)
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)
if query.exists():
expr_obj = query.get()
@@ -125,7 +126,7 @@ class ExpressionLearner:
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str
type=type_str,
)
logger.info(f"已迁移 {expr_file} 到数据库")
except Exception as e:
@@ -149,24 +150,28 @@ class ExpressionLearner:
# 直接从数据库查询
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
for expr in style_query:
learnt_style_expressions.append({
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style"
})
"type": "style",
}
)
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
for expr in grammar_query:
learnt_grammar_expressions.append({
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar"
})
"type": "grammar",
}
)
return learnt_style_expressions, learnt_grammar_expressions
def is_similar(self, s1: str, s2: str) -> bool:
@@ -213,14 +218,16 @@ class ExpressionLearner:
logger.error(f"全局衰减{type}表达方式失败: {e}")
continue
learnt_style: Optional[List[Tuple[str, str, str]]] = []
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
# 学习新的表达方式(这里会进行局部衰减)
for _ in range(3):
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
learnt_style = await self.learn_and_store(type="style", num=25)
if not learnt_style:
return [], []
for _ in range(1):
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar:
return [], []
@@ -321,10 +328,10 @@ class ExpressionLearner:
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == type) &
(Expression.situation == new_expr["situation"]) &
(Expression.style == new_expr["style"])
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
@@ -342,10 +349,14 @@ class ExpressionLearner:
count=1,
last_active_time=current_time,
chat_id=chat_id,
type=type
type=type,
)
# 限制最大数量
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass
import json
import os
import math
import asyncio
from typing import Dict, List, Tuple
import numpy as np
@@ -99,7 +100,30 @@ class EmbeddingStore:
self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]:
return get_embedding(s)
"""获取字符串的嵌入向量,处理异步调用"""
try:
# 尝试获取当前事件循环
asyncio.get_running_loop()
# 如果在事件循环中,使用线程池执行
import concurrent.futures
def run_in_thread():
return asyncio.run(get_embedding(s))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
result = future.result()
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except RuntimeError:
# 没有运行的事件循环,直接运行
result = asyncio.run(get_embedding(s))
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
def get_test_file_path(self):
return EMBEDDING_TEST_FILE

View File

@@ -1,3 +1,4 @@
import asyncio
import json
import time
from typing import List, Union
@@ -7,8 +8,12 @@ from . import prompt_template
from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json
def _extract_json_from_text(text: str) -> dict:
def _extract_json_from_text(text: str):
"""从文本中提取JSON数据的高容错方法"""
if text is None:
logger.error("输入文本为None")
return []
try:
fixed_json = repair_json(text)
if isinstance(fixed_json, str):
@@ -16,23 +21,66 @@ def _extract_json_from_text(text: str) -> dict:
else:
parsed_json = fixed_json
if isinstance(parsed_json, list) and parsed_json:
parsed_json = parsed_json[0]
if isinstance(parsed_json, dict):
# 如果是列表,直接返回
if isinstance(parsed_json, list):
return parsed_json
# 如果是字典且只有一个项目,可能包装了列表
if isinstance(parsed_json, dict):
# 如果字典只有一个键,并且值是列表,返回那个列表
if len(parsed_json) == 1:
value = list(parsed_json.values())[0]
if isinstance(value, list):
return value
return parsed_json
# 其他情况,尝试转换为列表
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
return []
except Exception as e:
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
return []
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
# 使用 asyncio.run 来运行异步方法
try:
# 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(
llm_req.generate_response_async(entity_extract_context), loop
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run(
llm_req.generate_response_async(entity_extract_context)
)
# 添加调试日志
logger.debug(f"LLM返回的原始响应: {response}")
entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据
json.loads(entity_extract_result)
# 检查返回的是否为有效的实体列表
if not isinstance(entity_extract_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表
if isinstance(entity_extract_result, dict):
# 尝试常见的键名
for key in ['entities', 'result', 'data', 'items']:
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
entity_extract_result = entity_extract_result[key]
break
else:
# 如果找不到合适的列表,抛出异常
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
else:
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 过滤无效实体
entity_extract_result = [
entity
for entity in entity_extract_result
@@ -50,16 +98,47 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False)
)
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据
json.loads(entity_extract_result)
for triple in entity_extract_result:
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
# 使用 asyncio.run 来运行异步方法
try:
# 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(
llm_req.generate_response_async(rdf_extract_context), loop
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run(
llm_req.generate_response_async(rdf_extract_context)
)
# 添加调试日志
logger.debug(f"RDF LLM返回的原始响应: {response}")
rdf_triple_result = _extract_json_from_text(response)
# 检查返回的是否为有效的三元组列表
if not isinstance(rdf_triple_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表
if isinstance(rdf_triple_result, dict):
# 尝试常见的键名
for key in ['triples', 'result', 'data', 'items']:
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
rdf_triple_result = rdf_triple_result[key]
break
else:
# 如果找不到合适的列表,抛出异常
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
else:
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
# 验证三元组格式
for triple in rdf_triple_result:
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
raise Exception("RDF提取结果格式错误")
return entity_extract_result
return rdf_triple_result
def info_extract_from_str(

View File

@@ -184,10 +184,10 @@ class KGManager:
progress.update(task, advance=1)
continue
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
assert isinstance(ent, EmbeddingStoreItem)
if ent is None:
progress.update(task, advance=1)
continue
assert isinstance(ent, EmbeddingStoreItem)
# 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
@@ -265,7 +265,10 @@ class KGManager:
if node_hash not in existed_nodes:
if node_hash.startswith(local_storage['ent_namespace']):
# 新增实体节点
node = embedding_manager.entities_embedding_store.store[node_hash]
node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None:
logger.warning(f"实体节点 {node_hash} 在嵌入库中不存在,跳过")
continue
assert isinstance(node, EmbeddingStoreItem)
node_item = self.graph[node_hash]
node_item["content"] = node.str
@@ -274,7 +277,10 @@ class KGManager:
self.graph.update_node(node_item)
elif node_hash.startswith(local_storage['pg_namespace']):
# 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None:
logger.warning(f"段落节点 {node_hash} 在嵌入库中不存在,跳过")
continue
assert isinstance(node, EmbeddingStoreItem)
content = node.str.replace("\n", " ")
node_item = self.graph[node_hash]

View File

@@ -11,12 +11,14 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
"""
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
messages = [
LLMMessage("system", entity_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
]
return messages
def build_entity_extract_context(paragraph: str) -> str:
"""构建实体提取的完整提示文本"""
return f"""{entity_extract_system_prompt}
段落:
```
{paragraph}
```"""
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF资源描述框架由节点和边组成节点表示实体/资源、属性边则表示了实体和实体之间的关系以及实体和属性的关系。构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
@@ -36,12 +38,19 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF资源描
"""
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
messages = [
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
]
return messages
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> str:
"""构建RDF三元组提取的完整提示文本"""
return f"""{rdf_triple_extract_system_prompt}
段落:
```
{paragraph}
```
实体列表:
```
{entities}
```"""
qa_system_prompt = """

View File

@@ -13,6 +13,7 @@ from src.common.database.database_model import Memory # Peewee Models导入
logger = get_logger(__name__)
class MemoryItem:
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
self.memory_id = memory_id
@@ -22,15 +23,13 @@ class MemoryItem:
self.create_time: float = time.time()
self.last_view_time: float = time.time()
class MemoryManager:
def __init__(self):
# self.memory_items:list[MemoryItem] = []
pass
class InstantMemory:
def __init__(self, chat_id):
self.chat_id = chat_id
@@ -53,7 +52,6 @@ class InstantMemory:
print(prompt)
print(response)
if "1" in response:
return True
else:
@@ -81,15 +79,15 @@ class InstantMemory:
try:
repaired = repair_json(response)
result = json.loads(repaired)
memory_text = result.get('memory_text', '')
keywords = result.get('keywords', '')
memory_text = result.get("memory_text", "")
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
return {'memory_text': memory_text, 'keywords': keywords_list}
return {"memory_text": memory_text, "keywords": keywords_list}
except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None
@@ -97,19 +95,18 @@ class InstantMemory:
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
async def create_and_store_memory(self, text):
if_need = await self.if_need_build(text)
if if_need:
logger.info(f"需要记忆:{text}")
memory = await self.build_memory(text)
if memory and memory.get('memory_text'):
if memory and memory.get("memory_text"):
memory_id = f"{self.chat_id}_{time.time()}"
memory_item = MemoryItem(
memory_id=memory_id,
chat_id=self.chat_id,
memory_text=memory['memory_text'],
keywords=memory.get('keywords', [])
memory_text=memory["memory_text"],
keywords=memory.get("keywords", []),
)
await self.store_memory(memory_item)
else:
@@ -122,12 +119,13 @@ class InstantMemory:
memory_text=memory_item.memory_text,
keywords=memory_item.keywords,
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time
last_view_time=memory_item.last_view_time,
)
memory.save()
async def get_memory(self, target: str):
from json_repair import repair_json
prompt = f"""
请根据以下发言内容,判断是否需要提取记忆
{target}
@@ -153,15 +151,15 @@ class InstantMemory:
repaired = repair_json(response)
result = json.loads(repaired)
# 解析keywords
keywords = result.get('keywords', '')
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
# 解析time为时间段
time_str = result.get('time', '').strip()
time_str = result.get("time", "").strip()
start_time, end_time = self._parse_time_range(time_str)
logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆
@@ -170,14 +168,13 @@ class InstantMemory:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = Memory.select().where(
(Memory.chat_id == self.chat_id) &
(Memory.create_time >= start_ts) &
(Memory.create_time < end_ts)
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) # type: ignore
& (Memory.create_time < end_ts) # type: ignore
)
else:
query = Memory.select().where(Memory.chat_id == self.chat_id)
for mem in query:
# 对每条记忆
mem_keywords = mem.keywords or []
@@ -212,6 +209,7 @@ class InstantMemory:
- 空字符串:返回(None, None)
"""
from datetime import datetime, timedelta
now = datetime.now()
if not time_str:
return 0, now

View File

@@ -30,7 +30,7 @@ class ChatMessageContext:
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
return self.message.message_info.template_info.template_name # type: ignore
return None
def get_last_message(self) -> "MessageRecv":

View File

@@ -181,6 +181,7 @@ class MessageRecv(Message):
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
@@ -254,7 +255,7 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "gift":
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1)
name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
@@ -267,12 +268,14 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price,message_text = segment.data.split(":", 1)
price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text
elif segment.type == "screen":

View File

@@ -80,7 +80,7 @@ class ActionManager:
chat_stream: ChatStream,
log_prefix: str,
shutting_down: bool = False,
action_message: dict = None,
action_message: Optional[dict] = None,
) -> Optional[BaseAction]:
"""
创建动作处理器实例

View File

@@ -252,7 +252,7 @@ def _build_readable_messages_internal(
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1,
show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
"""
内部辅助函数,构建可读消息字符串和原始消息详情列表。
@@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
for action in actions:
action_time = action.get("time", current_time)
action_name = action.get("action_name", "未知动作")
if action_name == "no_action" or action_name == "no_reply":
if action_name in ["no_action", "no_reply"]:
continue
action_prompt_display = action.get("action_prompt_display", "无具体内容")
@@ -697,7 +697,7 @@ def build_readable_messages(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。

View File

@@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: for-append-to-extend, list-comprehension, use-any
# sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next
"""生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据
@@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask):
"""
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: use-named-expression, use-next
"""生成版本对比独立分页的HTML内容"""
# 为每个时间段准备版本对比数据
@@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask):
# 复用 StatisticOutputTask 的所有方法
def _collect_all_statistics(self, now: datetime):
return StatisticOutputTask._collect_all_statistics(self, now)
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
return StatisticOutputTask._statistic_console_output(self, stats, now)
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
return StatisticOutputTask._generate_html_report(self, stats, now)
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
# 其他需要的方法也可以类似复用...
@staticmethod
@@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask):
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_message_count_for_period(self, collect_period)
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period)
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore
def _process_focus_file_data(
self,
@@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask):
collect_period: List[Tuple[str, datetime]],
file_time: datetime,
):
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time)
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore
def _calculate_focus_averages(self, stats: Dict[str, Any]):
return StatisticOutputTask._calculate_focus_averages(self, stats)
return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
@@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask):
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_model_classified_stat(stats)
return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_chat_stat(self, stats)
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_focus_stat(self, stats)
return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
return StatisticOutputTask._generate_chart_data(self, stat)
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes)
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore
def _generate_chart_tab(self, chart_data: dict) -> str:
return StatisticOutputTask._generate_chart_tab(self, chart_data)
return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id)
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_focus_tab(self, stat)
return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_versions_tab(self, stat)
return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore
def _convert_defaultdict_to_dict(self, data):
return StatisticOutputTask._convert_defaultdict_to_dict(self, data)
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore

View File

@@ -2,14 +2,13 @@ import importlib
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Any
from rich.traceback import install
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
install(extra_lines=3)
@@ -54,7 +53,7 @@ class WillingInfo:
interested_rate (float): 兴趣度
"""
message: MessageRecv
message: Dict[str, Any] # 原始消息数据
chat: ChatStream
person_info_manager: PersonInfoManager
chat_id: str

View File

@@ -65,7 +65,7 @@ class ChatStreams(BaseModel):
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
user_cardname = TextField(null=True)
class Meta:
class Meta: # type: ignore
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
@@ -89,7 +89,7 @@ class LLMUsage(BaseModel):
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
class Meta: # type: ignore
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = "llm_usage"
@@ -112,7 +112,7 @@ class Emoji(BaseModel):
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
last_used_time = FloatField(null=True) # 上次使用时间
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "emoji"
@@ -162,7 +162,8 @@ class Messages(BaseModel):
is_emoji = BooleanField(default=False)
is_picid = BooleanField(default=False)
is_command = BooleanField(default=False)
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "messages"
@@ -186,7 +187,7 @@ class ActionRecords(BaseModel):
chat_info_stream_id = TextField()
chat_info_platform = TextField()
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "action_records"
@@ -206,7 +207,7 @@ class Images(BaseModel):
type = TextField() # 图像类型,例如 "emoji"
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
class Meta:
class Meta: # type: ignore
table_name = "images"
@@ -220,7 +221,7 @@ class ImageDescriptions(BaseModel):
description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "image_descriptions"
@@ -236,7 +237,7 @@ class OnlineTime(BaseModel):
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "online_time"
@@ -263,10 +264,11 @@ class PersonInfo(BaseModel):
last_know = FloatField(null=True) # 最后一次印象总结时间
attitude = IntegerField(null=True, default=50) # 态度0-100从非常厌恶到十分喜欢
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "person_info"
class Memory(BaseModel):
memory_id = TextField(index=True)
chat_id = TextField(null=True)
@@ -275,9 +277,10 @@ class Memory(BaseModel):
create_time = FloatField(null=True)
last_view_time = FloatField(null=True)
class Meta:
class Meta: # type: ignore
table_name = "memory"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
@@ -287,10 +290,11 @@ class Knowledges(BaseModel):
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "knowledges"
class Expression(BaseModel):
"""
用于存储表达风格的模型。
@@ -303,9 +307,10 @@ class Expression(BaseModel):
chat_id = TextField(index=True)
type = TextField()
class Meta:
class Meta: # type: ignore
table_name = "expression"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
@@ -326,7 +331,7 @@ class ThinkingLog(BaseModel):
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
class Meta: # type: ignore
table_name = "thinking_logs"
@@ -341,7 +346,7 @@ class GraphNodes(BaseModel):
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
class Meta: # type: ignore
table_name = "graph_nodes"
@@ -357,7 +362,7 @@ class GraphEdges(BaseModel):
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
class Meta: # type: ignore
table_name = "graph_edges"

View File

@@ -7,13 +7,13 @@ from datetime import datetime
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, 'trivia'):
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, 'keys'):
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment
@@ -95,7 +95,7 @@ def update_config():
if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同恢复旧配置文件并返回
shutil.move(old_backup_path, old_config_path)
shutil.move(old_backup_path, old_config_path) # type: ignore
return
else:
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")

View File

@@ -53,13 +53,13 @@ MMC_VERSION = "0.9.0-snapshot.2"
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, 'trivia'):
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, 'keys'):
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment
@@ -99,6 +99,7 @@ def get_value_by_path(d, path):
return None
return d
def set_value_by_path(d, path, value):
for k in path[:-1]:
if k not in d or not isinstance(d[k], dict):
@@ -106,6 +107,7 @@ def set_value_by_path(d, path, value):
d = d[k]
d[path[-1]] = value
def compare_default_values(new, old, path=None, logs=None, changes=None):
# 递归比较两个dict找出默认值变化项
if path is None:
@@ -123,7 +125,9 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
else:
# 只要值发生变化就记录
if new[key] != old[key]:
logs.append(f"默认值变化: {'.'.join(path+[str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
logs.append(
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
)
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes
@@ -148,8 +152,8 @@ def update_config():
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]:
return doc["inner"]["version"]
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
template_version = get_version_from_toml(template_path)
@@ -186,7 +190,9 @@ def update_config():
old_value = get_value_by_path(old_config, path)
if old_value == old_default:
set_value_by_path(old_config, path, new_default)
logger.info(f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}")
logger.info(
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
else:
logger.info("未检测到模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config
@@ -229,7 +235,9 @@ def update_config():
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------")
logger.info(
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
)
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
@@ -321,6 +329,7 @@ class Config(ConfigBase):
debug: DebugConfig
custom_prompt: CustomPromptConfig
def load_config(config_path: str) -> Config:
"""
加载配置文件

View File

@@ -39,7 +39,7 @@ class LLMRequestOff:
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3
@@ -89,7 +89,7 @@ class LLMRequestOff:
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3

View File

@@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect:
def __init__(self):
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.scenarios = []
self.final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.dimension_counts = {trait: 0 for trait in self.final_scores.keys()}
self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.dimension_counts = {trait: 0 for trait in self.final_scores}
# 为每个人格特质获取对应的场景
for trait in PERSONALITY_SCENES:
@@ -119,8 +119,7 @@ class PersonalityEvaluatorDirect:
# 构建维度描述
dimension_descriptions = []
for dim in dimensions:
desc = FACTOR_DESCRIPTIONS.get(dim, "")
if desc:
if desc := FACTOR_DESCRIPTIONS.get(dim, ""):
dimension_descriptions.append(f"- {dim}{desc}")
dimensions_text = "\n".join(dimension_descriptions)

View File

@@ -153,14 +153,14 @@ class MainSystem:
while True:
await asyncio.sleep(global_config.memory.memory_build_interval)
logger.info("正在进行记忆构建")
await self.hippocampus_manager.build_memory()
await self.hippocampus_manager.build_memory() # type: ignore
async def forget_memory_task(self):
"""记忆遗忘任务"""
while True:
await asyncio.sleep(global_config.memory.forget_memory_interval)
logger.info("[记忆遗忘] 开始遗忘记忆...")
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage)
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
logger.info("[记忆遗忘] 记忆遗忘完成")
async def consolidate_memory_task(self):
@@ -168,7 +168,7 @@ class MainSystem:
while True:
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
logger.info("[记忆整合] 开始整合记忆...")
await self.hippocampus_manager.consolidate_memory()
await self.hippocampus_manager.consolidate_memory() # type: ignore
logger.info("[记忆整合] 记忆整合完成")
@staticmethod

View File

@@ -50,6 +50,9 @@ class ChatMood:
chat_manager = get_chat_manager()
self.chat_stream = chat_manager.get_stream(self.chat_id)
if not self.chat_stream:
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
self.mood_state: str = "感觉很平静"

View File

@@ -26,7 +26,7 @@ SEGMENT_CLEANUP_CONFIG = {
"cleanup_interval_hours": 0.5, # 清理间隔(小时)
}
MAX_MESSAGE_COUNT = 80 / global_config.relationship.relation_frequency
MAX_MESSAGE_COUNT = int(80 / global_config.relationship.relation_frequency)
class RelationshipBuilder:

View File

@@ -61,7 +61,7 @@ __all__ = [
"ConfigField",
# 工具函数
"ManifestValidator",
"ManifestGenerator",
"validate_plugin_manifest",
"generate_plugin_manifest",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
]

View File

@@ -34,7 +34,4 @@ def register_event_plugin(cls, *args, **kwargs):
用法:
@register_event_plugin
class MyEventPlugin:
event_type = EventType.MESSAGE_RECEIVED
...
"""

View File

@@ -111,7 +111,7 @@ async def _send_to_target(
is_head=True,
is_emoji=(message_type == "emoji"),
thinking_start_time=current_time,
reply_to = reply_to_platform_id
reply_to=reply_to_platform_id,
)
# 发送消息
@@ -137,6 +137,7 @@ async def _send_to_target(
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
# sourcery skip: inline-variable, use-named-expression
"""查找要回复的消息
Args:
@@ -184,14 +185,11 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
# 检查是否有 回复<aaa:bbb> 字段
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, translate_text)
if match:
if match := re.search(reply_pattern, translate_text):
aaa = match.group(1)
bbb = match.group(2)
reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name")
if not reply_person_name:
reply_person_name = aaa
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa
# 在内容前加上回复信息
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
@@ -206,9 +204,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
aaa = m.group(1)
bbb = m.group(2)
at_person_id = get_person_info_manager().get_person_id(platform, bbb)
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name")
if not at_person_name:
at_person_name = aaa
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa
new_content += f"@{at_person_name}"
last_end = m.end()
new_content += translate_text[last_end:]
@@ -403,7 +399,7 @@ async def text_to_group(
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def text_to_user(
@@ -427,7 +423,7 @@ async def text_to_user(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
@@ -550,7 +546,9 @@ async def custom_to_group(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_to_user(
@@ -578,7 +576,9 @@ async def custom_to_user(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_message(
@@ -618,4 +618,6 @@ async def custom_message(
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
"""
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)

View File

@@ -38,7 +38,7 @@ class BaseAction(ABC):
chat_stream: ChatStream,
log_prefix: str = "",
plugin_config: Optional[dict] = None,
action_message: dict = None,
action_message: Optional[dict] = None,
**kwargs,
):
"""初始化Action组件
@@ -106,6 +106,8 @@ class BaseAction(ABC):
if self.action_message:
self.has_action_message = True
else:
self.action_message = {}
if self.has_action_message:
if self.action_name != "no_reply":
@@ -132,8 +134,6 @@ class BaseAction(ABC):
self.is_group = False
self.target_id = self.user_id
logger.debug(f"{self.log_prefix} Action组件初始化完成")
logger.info(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
@@ -199,7 +199,9 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool:
async def send_text(
self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
) -> bool:
"""发送文本消息
Args:
@@ -299,7 +301,7 @@ class BaseAction(ABC):
)
async def send_command(
self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息

View File

@@ -135,7 +135,7 @@ class BaseCommand(ABC):
)
async def send_command(
self, command_name: str, args: dict = None, display_message: str = "", storage_message: bool = True
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息

View File

@@ -1,18 +1,14 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
class BaseEventsPlugin(ABC):
"""
事件触发型插件基类
from .plugin_base import PluginBase
from src.common.logger import get_logger
所有事件触发型插件都应该继承这个基类而不是 BasePlugin
class BaseEventPlugin(PluginBase):
"""基于事件的插件基类
所有事件类型的插件都应该继承这个基类
"""
@property
@abstractmethod
def plugin_name(self) -> str:
return "" # 插件内部标识符(如 "hello_world_plugin"
@property
@abstractmethod
def enable_plugin(self) -> bool:
return False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -7,7 +7,16 @@ from src.plugin_system.base.component_types import ComponentInfo
logger = get_logger("base_plugin")
class BasePlugin(PluginBase):
"""基于Action和Command的插件基类
所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件:
- Action组件处理聊天中的动作
- Command组件处理命令请求
- 未来可扩展Scheduler、Listener等
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -19,12 +19,9 @@ logger = get_logger("plugin_base")
class PluginBase(ABC):
"""插件基类
"""插件基类
所有插件都应该继承这个基类,一个插件可以包含多种组件:
- Action组件处理聊天中的动作
- Command组件处理命令请求
- 未来可扩展Scheduler、Listener等
所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。
"""
# 插件基本信息(子类必须定义)

View File

@@ -346,67 +346,67 @@ class ComponentRegistry:
# === 状态管理方法 ===
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR -------------------------------------
"""启用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type)
if not component_info:
return False
# def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# # -------------------------------- LOGIC ERROR -------------------------------------
# """启用组件,支持命名空间解析"""
# # 首先尝试找到正确的命名空间化名称
# component_info = self.get_component_info(component_name, component_type)
# if not component_info:
# return False
# 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else:
namespaced_name = (
f"{component_info.component_type.value}.{component_name}"
if "." not in component_name
else component_name
)
# # 根据组件类型构造正确的命名空间化名称
# if component_info.component_type == ComponentType.ACTION:
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
# elif component_info.component_type == ComponentType.COMMAND:
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
# else:
# namespaced_name = (
# f"{component_info.component_type.value}.{component_name}"
# if "." not in component_name
# else component_name
# )
if namespaced_name in self._components:
self._components[namespaced_name].enabled = True
# 如果是Action更新默认动作集
# ---- HERE ----
# if isinstance(component_info, ActionInfo):
# self._action_descriptions[component_name] = component_info.description
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
return True
return False
# if namespaced_name in self._components:
# self._components[namespaced_name].enabled = True
# # 如果是Action更新默认动作集
# # ---- HERE ----
# # if isinstance(component_info, ActionInfo):
# # self._action_descriptions[component_name] = component_info.description
# logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
# return True
# return False
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR -------------------------------------
"""禁用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type)
if not component_info:
return False
# def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# # -------------------------------- LOGIC ERROR -------------------------------------
# """禁用组件,支持命名空间解析"""
# # 首先尝试找到正确的命名空间化名称
# component_info = self.get_component_info(component_name, component_type)
# if not component_info:
# return False
# 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else:
namespaced_name = (
f"{component_info.component_type.value}.{component_name}"
if "." not in component_name
else component_name
)
# # 根据组件类型构造正确的命名空间化名称
# if component_info.component_type == ComponentType.ACTION:
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
# elif component_info.component_type == ComponentType.COMMAND:
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
# else:
# namespaced_name = (
# f"{component_info.component_type.value}.{component_name}"
# if "." not in component_name
# else component_name
# )
if namespaced_name in self._components:
self._components[namespaced_name].enabled = False
# 如果是Action从默认动作集中移除
# ---- HERE ----
# if component_name in self._action_descriptions:
# del self._action_descriptions[component_name]
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
return True
return False
# if namespaced_name in self._components:
# self._components[namespaced_name].enabled = False
# # 如果是Action从默认动作集中移除
# # ---- HERE ----
# # if component_name in self._action_descriptions:
# # del self._action_descriptions[component_name]
# logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
# return True
# return False
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""

View File

@@ -7,7 +7,7 @@
import subprocess
import sys
import importlib
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Any
from src.common.logger import get_logger
from src.plugin_system.base.component_types import PythonDependency
@@ -176,7 +176,7 @@ class DependencyManager:
logger.error(f"生成requirements文件失败: {str(e)}")
return False
def get_install_summary(self) -> Dict[str, any]:
def get_install_summary(self) -> Dict[str, Any]:
"""获取安装摘要"""
return {
"install_log": self.install_log.copy(),

View File

@@ -197,29 +197,29 @@ class PluginManager:
"""获取所有启用的插件信息"""
return list(component_registry.get_enabled_plugins().values())
def enable_plugin(self, plugin_name: str) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
"""启用插件"""
if plugin_info := component_registry.get_plugin_info(plugin_name):
plugin_info.enabled = True
# 启用插件的所有组件
for component in plugin_info.components:
component_registry.enable_component(component.name)
logger.debug(f"已启用插件: {plugin_name}")
return True
return False
# def enable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """启用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = True
# # 启用插件的所有组件
# for component in plugin_info.components:
# component_registry.enable_component(component.name)
# logger.debug(f"已启用插件: {plugin_name}")
# return True
# return False
def disable_plugin(self, plugin_name: str) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
"""禁用插件"""
if plugin_info := component_registry.get_plugin_info(plugin_name):
plugin_info.enabled = False
# 禁用插件的所有组件
for component in plugin_info.components:
component_registry.disable_component(component.name)
logger.debug(f"已禁用插件: {plugin_name}")
return True
return False
# def disable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """禁用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = False
# # 禁用插件的所有组件
# for component in plugin_info.components:
# component_registry.disable_component(component.name)
# logger.debug(f"已禁用插件: {plugin_name}")
# return True
# return False
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例

View File

@@ -28,10 +28,10 @@ class CompareNumbersTool(BaseTool):
Returns:
dict: 工具执行结果
"""
try:
num1 = function_args.get("num1")
num2 = function_args.get("num2")
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
try:
if num1 > num2:
result = f"{num1} 大于 {num2}"
elif num1 < num2:

View File

@@ -68,10 +68,10 @@ class RenamePersonTool(BaseTool):
)
result = await person_info_manager.qv_person_name(
person_id=person_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_avatar=user_avatar,
request=request_context,
user_nickname=user_nickname, # type: ignore
user_cardname=user_cardname, # type: ignore
user_avatar=user_avatar, # type: ignore
request=request_context, # type: ignore
)
# 3. 处理结果