Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
144
bot.py
144
bot.py
@@ -8,6 +8,7 @@ if os.path.exists(".env"):
|
|||||||
print("成功加载环境变量配置")
|
print("成功加载环境变量配置")
|
||||||
else:
|
else:
|
||||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||||
|
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import platform
|
import platform
|
||||||
@@ -140,87 +141,88 @@ async def graceful_shutdown():
|
|||||||
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_file_hash(file_path: Path, file_type: str) -> str:
|
||||||
|
"""计算文件的MD5哈希值"""
|
||||||
|
if not file_path.exists():
|
||||||
|
logger.error(f"{file_type} 文件不存在")
|
||||||
|
raise FileNotFoundError(f"{file_type} 文件不存在")
|
||||||
|
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]:
|
||||||
|
"""检查协议确认状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, bool]: (已确认, 未更新)
|
||||||
|
"""
|
||||||
|
# 检查环境变量确认
|
||||||
|
if file_hash == os.getenv(env_var):
|
||||||
|
return True, False
|
||||||
|
|
||||||
|
# 检查确认文件
|
||||||
|
if confirm_file.exists():
|
||||||
|
with open(confirm_file, "r", encoding="utf-8") as f:
|
||||||
|
confirmed_content = f.read()
|
||||||
|
if file_hash == confirmed_content:
|
||||||
|
return True, False
|
||||||
|
|
||||||
|
return False, True
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
|
||||||
|
"""提示用户确认协议"""
|
||||||
|
confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议")
|
||||||
|
confirm_logger.critical(
|
||||||
|
f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_hash}"和"PRIVACY_AGREE={privacy_hash}"继续运行'
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input().strip().lower()
|
||||||
|
if user_input in ["同意", "confirmed"]:
|
||||||
|
return
|
||||||
|
confirm_logger.critical('请输入"同意"或"confirmed"以继续运行')
|
||||||
|
|
||||||
|
|
||||||
|
def _save_confirmations(eula_updated: bool, privacy_updated: bool, eula_hash: str, privacy_hash: str) -> None:
|
||||||
|
"""保存用户确认结果"""
|
||||||
|
if eula_updated:
|
||||||
|
logger.info(f"更新EULA确认文件{eula_hash}")
|
||||||
|
Path("eula.confirmed").write_text(eula_hash, encoding="utf-8")
|
||||||
|
|
||||||
|
if privacy_updated:
|
||||||
|
logger.info(f"更新隐私条款确认文件{privacy_hash}")
|
||||||
|
Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
def check_eula():
|
def check_eula():
|
||||||
eula_confirm_file = Path("eula.confirmed")
|
"""检查EULA和隐私条款确认状态"""
|
||||||
privacy_confirm_file = Path("privacy.confirmed")
|
# 计算文件哈希值
|
||||||
eula_file = Path("EULA.md")
|
eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md")
|
||||||
privacy_file = Path("PRIVACY.md")
|
privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md")
|
||||||
|
|
||||||
eula_updated = True
|
# 检查确认状态
|
||||||
privacy_updated = True
|
eula_confirmed, eula_updated = _check_agreement_status(eula_hash, Path("eula.confirmed"), "EULA_AGREE")
|
||||||
|
privacy_confirmed, privacy_updated = _check_agreement_status(
|
||||||
|
privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE"
|
||||||
|
)
|
||||||
|
|
||||||
eula_confirmed = False
|
# 早期返回:如果都已确认且未更新
|
||||||
privacy_confirmed = False
|
if eula_confirmed and privacy_confirmed:
|
||||||
|
return
|
||||||
|
|
||||||
# 首先计算当前EULA文件的哈希值
|
# 如果有更新,需要重新确认
|
||||||
if eula_file.exists():
|
|
||||||
with open(eula_file, "r", encoding="utf-8") as f:
|
|
||||||
eula_content = f.read()
|
|
||||||
eula_new_hash = hashlib.md5(eula_content.encode("utf-8")).hexdigest()
|
|
||||||
else:
|
|
||||||
logger.error("EULA.md 文件不存在")
|
|
||||||
raise FileNotFoundError("EULA.md 文件不存在")
|
|
||||||
|
|
||||||
# 首先计算当前隐私条款文件的哈希值
|
|
||||||
if privacy_file.exists():
|
|
||||||
with open(privacy_file, "r", encoding="utf-8") as f:
|
|
||||||
privacy_content = f.read()
|
|
||||||
privacy_new_hash = hashlib.md5(privacy_content.encode("utf-8")).hexdigest()
|
|
||||||
else:
|
|
||||||
logger.error("PRIVACY.md 文件不存在")
|
|
||||||
raise FileNotFoundError("PRIVACY.md 文件不存在")
|
|
||||||
|
|
||||||
# 检查EULA确认文件是否存在
|
|
||||||
if eula_confirm_file.exists():
|
|
||||||
with open(eula_confirm_file, "r", encoding="utf-8") as f:
|
|
||||||
confirmed_content = f.read()
|
|
||||||
if eula_new_hash == confirmed_content:
|
|
||||||
eula_confirmed = True
|
|
||||||
eula_updated = False
|
|
||||||
if eula_new_hash == os.getenv("EULA_AGREE"):
|
|
||||||
eula_confirmed = True
|
|
||||||
eula_updated = False
|
|
||||||
|
|
||||||
# 检查隐私条款确认文件是否存在
|
|
||||||
if privacy_confirm_file.exists():
|
|
||||||
with open(privacy_confirm_file, "r", encoding="utf-8") as f:
|
|
||||||
confirmed_content = f.read()
|
|
||||||
if privacy_new_hash == confirmed_content:
|
|
||||||
privacy_confirmed = True
|
|
||||||
privacy_updated = False
|
|
||||||
if privacy_new_hash == os.getenv("PRIVACY_AGREE"):
|
|
||||||
privacy_confirmed = True
|
|
||||||
privacy_updated = False
|
|
||||||
|
|
||||||
# 如果EULA或隐私条款有更新,提示用户重新确认
|
|
||||||
if eula_updated or privacy_updated:
|
if eula_updated or privacy_updated:
|
||||||
confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议")
|
_prompt_user_confirmation(eula_hash, privacy_hash)
|
||||||
confirm_logger.critical(
|
_save_confirmations(eula_updated, privacy_updated, eula_hash, privacy_hash)
|
||||||
f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_new_hash}"和"PRIVACY_AGREE={privacy_new_hash}"继续运行'
|
|
||||||
)
|
|
||||||
while True:
|
|
||||||
user_input = input().strip().lower()
|
|
||||||
if user_input in ["同意", "confirmed"]:
|
|
||||||
# print("确认成功,继续运行")
|
|
||||||
# print(f"确认成功,继续运行{eula_updated} {privacy_updated}")
|
|
||||||
if eula_updated:
|
|
||||||
logger.info(f"更新EULA确认文件{eula_new_hash}")
|
|
||||||
eula_confirm_file.write_text(eula_new_hash, encoding="utf-8")
|
|
||||||
if privacy_updated:
|
|
||||||
logger.info(f"更新隐私条款确认文件{privacy_new_hash}")
|
|
||||||
privacy_confirm_file.write_text(privacy_new_hash, encoding="utf-8")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
confirm_logger.critical('请输入"同意"或"confirmed"以继续运行')
|
|
||||||
return
|
|
||||||
elif eula_confirmed and privacy_confirmed:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def raw_main():
|
def raw_main():
|
||||||
# 利用 TZ 环境变量设定程序工作的时区
|
# 利用 TZ 环境变量设定程序工作的时区
|
||||||
if platform.system().lower() != "windows":
|
if platform.system().lower() != "windows":
|
||||||
time.tzset()
|
time.tzset() # type: ignore
|
||||||
|
|
||||||
check_eula()
|
check_eula()
|
||||||
logger.info("检查EULA和隐私条款完成")
|
logger.info("检查EULA和隐私条款完成")
|
||||||
|
|||||||
@@ -21,3 +21,6 @@
|
|||||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||||
4. 现在增加了参数类型检查,完善了对应注释
|
4. 现在增加了参数类型检查,完善了对应注释
|
||||||
|
5. 现在插件抽象出了总基类 `PluginBase`
|
||||||
|
- 基于`Action`和`Command`的插件基类现在为`BasePlugin`,它继承自`PluginBase`,由`register_plugin`装饰器注册。
|
||||||
|
- 基于`Event`的插件基类现在为`BaseEventPlugin`,它也继承自`PluginBase`,由`register_event_plugin`装饰器注册。
|
||||||
@@ -15,6 +15,7 @@ from src.chat.knowledge.kg_manager import KGManager
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.knowledge.utils.hash import get_sha256
|
from src.chat.knowledge.utils.hash import get_sha256
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
@@ -23,6 +24,45 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
|||||||
|
|
||||||
logger = get_logger("OpenIE导入")
|
logger = get_logger("OpenIE导入")
|
||||||
|
|
||||||
|
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
||||||
|
|
||||||
|
if os.path.exists(".env"):
|
||||||
|
load_dotenv(".env", override=True)
|
||||||
|
print("成功加载环境变量配置")
|
||||||
|
else:
|
||||||
|
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||||
|
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||||
|
|
||||||
|
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||||
|
def scan_provider(env_config: dict):
|
||||||
|
provider = {}
|
||||||
|
|
||||||
|
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
||||||
|
# 避免 GPG_KEY 这样的变量干扰检查
|
||||||
|
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
||||||
|
|
||||||
|
# 遍历 env_config 的所有键
|
||||||
|
for key in env_config:
|
||||||
|
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
||||||
|
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
||||||
|
# 提取 provider 名称
|
||||||
|
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
||||||
|
|
||||||
|
# 初始化 provider 的字典(如果尚未初始化)
|
||||||
|
if provider_name not in provider:
|
||||||
|
provider[provider_name] = {"url": None, "key": None}
|
||||||
|
|
||||||
|
# 根据键的类型填充 url 或 key
|
||||||
|
if key.endswith("_BASE_URL"):
|
||||||
|
provider[provider_name]["url"] = env_config[key]
|
||||||
|
elif key.endswith("_KEY"):
|
||||||
|
provider[provider_name]["key"] = env_config[key]
|
||||||
|
|
||||||
|
# 检查每个 provider 是否同时存在 url 和 key
|
||||||
|
for provider_name, config in provider.items():
|
||||||
|
if config["url"] is None or config["key"] is None:
|
||||||
|
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
||||||
|
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||||
|
|
||||||
def ensure_openie_dir():
|
def ensure_openie_dir():
|
||||||
"""确保OpenIE数据目录存在"""
|
"""确保OpenIE数据目录存在"""
|
||||||
@@ -174,6 +214,8 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
|
|
||||||
def main(): # sourcery skip: dict-comprehension
|
def main(): # sourcery skip: dict-comprehension
|
||||||
# 新增确认提示
|
# 新增确认提示
|
||||||
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
|
scan_provider(env_config)
|
||||||
print("=== 重要操作确认 ===")
|
print("=== 重要操作确认 ===")
|
||||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from rich.progress import (
|
|||||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
logger = get_logger("LPMM知识库-信息提取")
|
logger = get_logger("LPMM知识库-信息提取")
|
||||||
|
|
||||||
@@ -35,6 +36,45 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
|
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
||||||
|
|
||||||
|
if os.path.exists(".env"):
|
||||||
|
load_dotenv(".env", override=True)
|
||||||
|
print("成功加载环境变量配置")
|
||||||
|
else:
|
||||||
|
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||||
|
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||||
|
|
||||||
|
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||||
|
def scan_provider(env_config: dict):
|
||||||
|
provider = {}
|
||||||
|
|
||||||
|
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
||||||
|
# 避免 GPG_KEY 这样的变量干扰检查
|
||||||
|
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
||||||
|
|
||||||
|
# 遍历 env_config 的所有键
|
||||||
|
for key in env_config:
|
||||||
|
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
||||||
|
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
||||||
|
# 提取 provider 名称
|
||||||
|
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
||||||
|
|
||||||
|
# 初始化 provider 的字典(如果尚未初始化)
|
||||||
|
if provider_name not in provider:
|
||||||
|
provider[provider_name] = {"url": None, "key": None}
|
||||||
|
|
||||||
|
# 根据键的类型填充 url 或 key
|
||||||
|
if key.endswith("_BASE_URL"):
|
||||||
|
provider[provider_name]["url"] = env_config[key]
|
||||||
|
elif key.endswith("_KEY"):
|
||||||
|
provider[provider_name]["key"] = env_config[key]
|
||||||
|
|
||||||
|
# 检查每个 provider 是否同时存在 url 和 key
|
||||||
|
for provider_name, config in provider.items():
|
||||||
|
if config["url"] is None or config["key"] is None:
|
||||||
|
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
||||||
|
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||||
|
|
||||||
def ensure_dirs():
|
def ensure_dirs():
|
||||||
"""确保临时目录和输出目录存在"""
|
"""确保临时目录和输出目录存在"""
|
||||||
@@ -118,6 +158,8 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
|||||||
# 设置信号处理器
|
# 设置信号处理器
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
ensure_dirs() # 确保目录存在
|
ensure_dirs() # 确保目录存在
|
||||||
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
|
scan_provider(env_config)
|
||||||
# 新增用户确认提示
|
# 新增用户确认提示
|
||||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||||
|
|||||||
@@ -107,11 +107,12 @@ class ExpressionLearner:
|
|||||||
last_active_time = expr.get("last_active_time", time.time())
|
last_active_time = expr.get("last_active_time", time.time())
|
||||||
# 查重:同chat_id+type+situation+style
|
# 查重:同chat_id+type+situation+style
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
|
|
||||||
query = Expression.select().where(
|
query = Expression.select().where(
|
||||||
(Expression.chat_id == chat_id) &
|
(Expression.chat_id == chat_id)
|
||||||
(Expression.type == type_str) &
|
& (Expression.type == type_str)
|
||||||
(Expression.situation == situation) &
|
& (Expression.situation == situation)
|
||||||
(Expression.style == style_val)
|
& (Expression.style == style_val)
|
||||||
)
|
)
|
||||||
if query.exists():
|
if query.exists():
|
||||||
expr_obj = query.get()
|
expr_obj = query.get()
|
||||||
@@ -125,7 +126,7 @@ class ExpressionLearner:
|
|||||||
count=count,
|
count=count,
|
||||||
last_active_time=last_active_time,
|
last_active_time=last_active_time,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
type=type_str
|
type=type_str,
|
||||||
)
|
)
|
||||||
logger.info(f"已迁移 {expr_file} 到数据库")
|
logger.info(f"已迁移 {expr_file} 到数据库")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -149,24 +150,28 @@ class ExpressionLearner:
|
|||||||
# 直接从数据库查询
|
# 直接从数据库查询
|
||||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||||
for expr in style_query:
|
for expr in style_query:
|
||||||
learnt_style_expressions.append({
|
learnt_style_expressions.append(
|
||||||
"situation": expr.situation,
|
{
|
||||||
"style": expr.style,
|
"situation": expr.situation,
|
||||||
"count": expr.count,
|
"style": expr.style,
|
||||||
"last_active_time": expr.last_active_time,
|
"count": expr.count,
|
||||||
"source_id": chat_id,
|
"last_active_time": expr.last_active_time,
|
||||||
"type": "style"
|
"source_id": chat_id,
|
||||||
})
|
"type": "style",
|
||||||
|
}
|
||||||
|
)
|
||||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||||
for expr in grammar_query:
|
for expr in grammar_query:
|
||||||
learnt_grammar_expressions.append({
|
learnt_grammar_expressions.append(
|
||||||
"situation": expr.situation,
|
{
|
||||||
"style": expr.style,
|
"situation": expr.situation,
|
||||||
"count": expr.count,
|
"style": expr.style,
|
||||||
"last_active_time": expr.last_active_time,
|
"count": expr.count,
|
||||||
"source_id": chat_id,
|
"last_active_time": expr.last_active_time,
|
||||||
"type": "grammar"
|
"source_id": chat_id,
|
||||||
})
|
"type": "grammar",
|
||||||
|
}
|
||||||
|
)
|
||||||
return learnt_style_expressions, learnt_grammar_expressions
|
return learnt_style_expressions, learnt_grammar_expressions
|
||||||
|
|
||||||
def is_similar(self, s1: str, s2: str) -> bool:
|
def is_similar(self, s1: str, s2: str) -> bool:
|
||||||
@@ -213,14 +218,16 @@ class ExpressionLearner:
|
|||||||
logger.error(f"全局衰减{type}表达方式失败: {e}")
|
logger.error(f"全局衰减{type}表达方式失败: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
||||||
|
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
|
||||||
# 学习新的表达方式(这里会进行局部衰减)
|
# 学习新的表达方式(这里会进行局部衰减)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
learnt_style = await self.learn_and_store(type="style", num=25)
|
||||||
if not learnt_style:
|
if not learnt_style:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
for _ in range(1):
|
for _ in range(1):
|
||||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
|
||||||
if not learnt_grammar:
|
if not learnt_grammar:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
@@ -321,10 +328,10 @@ class ExpressionLearner:
|
|||||||
for new_expr in expr_list:
|
for new_expr in expr_list:
|
||||||
# 查找是否已存在相似表达方式
|
# 查找是否已存在相似表达方式
|
||||||
query = Expression.select().where(
|
query = Expression.select().where(
|
||||||
(Expression.chat_id == chat_id) &
|
(Expression.chat_id == chat_id)
|
||||||
(Expression.type == type) &
|
& (Expression.type == type)
|
||||||
(Expression.situation == new_expr["situation"]) &
|
& (Expression.situation == new_expr["situation"])
|
||||||
(Expression.style == new_expr["style"])
|
& (Expression.style == new_expr["style"])
|
||||||
)
|
)
|
||||||
if query.exists():
|
if query.exists():
|
||||||
expr_obj = query.get()
|
expr_obj = query.get()
|
||||||
@@ -342,13 +349,17 @@ class ExpressionLearner:
|
|||||||
count=1,
|
count=1,
|
||||||
last_active_time=current_time,
|
last_active_time=current_time,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
type=type
|
type=type,
|
||||||
)
|
)
|
||||||
# 限制最大数量
|
# 限制最大数量
|
||||||
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
|
exprs = list(
|
||||||
|
Expression.select()
|
||||||
|
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||||
|
.order_by(Expression.count.asc())
|
||||||
|
)
|
||||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||||
# 删除count最小的多余表达方式
|
# 删除count最小的多余表达方式
|
||||||
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]:
|
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||||
expr.delete_instance()
|
expr.delete_instance()
|
||||||
return learnt_expressions
|
return learnt_expressions
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
|
import asyncio
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -99,7 +100,30 @@ class EmbeddingStore:
|
|||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
return get_embedding(s)
|
"""获取字符串的嵌入向量,处理异步调用"""
|
||||||
|
try:
|
||||||
|
# 尝试获取当前事件循环
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
# 如果在事件循环中,使用线程池执行
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
return asyncio.run(get_embedding(s))
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(run_in_thread)
|
||||||
|
result = future.result()
|
||||||
|
if result is None:
|
||||||
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
|
return []
|
||||||
|
return result
|
||||||
|
except RuntimeError:
|
||||||
|
# 没有运行的事件循环,直接运行
|
||||||
|
result = asyncio.run(get_embedding(s))
|
||||||
|
if result is None:
|
||||||
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
|
return []
|
||||||
|
return result
|
||||||
|
|
||||||
def get_test_file_path(self):
|
def get_test_file_path(self):
|
||||||
return EMBEDDING_TEST_FILE
|
return EMBEDDING_TEST_FILE
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
@@ -7,8 +8,12 @@ from . import prompt_template
|
|||||||
from .knowledge_lib import INVALID_ENTITY
|
from .knowledge_lib import INVALID_ENTITY
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
def _extract_json_from_text(text: str) -> dict:
|
def _extract_json_from_text(text: str):
|
||||||
"""从文本中提取JSON数据的高容错方法"""
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
|
if text is None:
|
||||||
|
logger.error("输入文本为None")
|
||||||
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fixed_json = repair_json(text)
|
fixed_json = repair_json(text)
|
||||||
if isinstance(fixed_json, str):
|
if isinstance(fixed_json, str):
|
||||||
@@ -16,23 +21,66 @@ def _extract_json_from_text(text: str) -> dict:
|
|||||||
else:
|
else:
|
||||||
parsed_json = fixed_json
|
parsed_json = fixed_json
|
||||||
|
|
||||||
if isinstance(parsed_json, list) and parsed_json:
|
# 如果是列表,直接返回
|
||||||
parsed_json = parsed_json[0]
|
if isinstance(parsed_json, list):
|
||||||
|
|
||||||
if isinstance(parsed_json, dict):
|
|
||||||
return parsed_json
|
return parsed_json
|
||||||
|
|
||||||
|
# 如果是字典且只有一个项目,可能包装了列表
|
||||||
|
if isinstance(parsed_json, dict):
|
||||||
|
# 如果字典只有一个键,并且值是列表,返回那个列表
|
||||||
|
if len(parsed_json) == 1:
|
||||||
|
value = list(parsed_json.values())[0]
|
||||||
|
if isinstance(value, list):
|
||||||
|
return value
|
||||||
|
return parsed_json
|
||||||
|
|
||||||
|
# 其他情况,尝试转换为列表
|
||||||
|
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
|
||||||
|
return []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
|
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
|
||||||
|
return []
|
||||||
|
|
||||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
|
|
||||||
|
# 使用 asyncio.run 来运行异步方法
|
||||||
|
try:
|
||||||
|
# 如果当前已有事件循环在运行,使用它
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
llm_req.generate_response_async(entity_extract_context), loop
|
||||||
|
)
|
||||||
|
response, (reasoning_content, model_name) = future.result()
|
||||||
|
except RuntimeError:
|
||||||
|
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||||
|
response, (reasoning_content, model_name) = asyncio.run(
|
||||||
|
llm_req.generate_response_async(entity_extract_context)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加调试日志
|
||||||
|
logger.debug(f"LLM返回的原始响应: {response}")
|
||||||
|
|
||||||
entity_extract_result = _extract_json_from_text(response)
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
# 尝试load JSON数据
|
|
||||||
json.loads(entity_extract_result)
|
# 检查返回的是否为有效的实体列表
|
||||||
|
if not isinstance(entity_extract_result, list):
|
||||||
|
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
||||||
|
if isinstance(entity_extract_result, dict):
|
||||||
|
# 尝试常见的键名
|
||||||
|
for key in ['entities', 'result', 'data', 'items']:
|
||||||
|
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
|
||||||
|
entity_extract_result = entity_extract_result[key]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 如果找不到合适的列表,抛出异常
|
||||||
|
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||||
|
else:
|
||||||
|
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||||
|
|
||||||
|
# 过滤无效实体
|
||||||
entity_extract_result = [
|
entity_extract_result = [
|
||||||
entity
|
entity
|
||||||
for entity in entity_extract_result
|
for entity in entity_extract_result
|
||||||
@@ -50,16 +98,47 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
|||||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
|
|
||||||
|
|
||||||
entity_extract_result = _extract_json_from_text(response)
|
# 使用 asyncio.run 来运行异步方法
|
||||||
# 尝试load JSON数据
|
try:
|
||||||
json.loads(entity_extract_result)
|
# 如果当前已有事件循环在运行,使用它
|
||||||
for triple in entity_extract_result:
|
loop = asyncio.get_running_loop()
|
||||||
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
llm_req.generate_response_async(rdf_extract_context), loop
|
||||||
|
)
|
||||||
|
response, (reasoning_content, model_name) = future.result()
|
||||||
|
except RuntimeError:
|
||||||
|
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||||
|
response, (reasoning_content, model_name) = asyncio.run(
|
||||||
|
llm_req.generate_response_async(rdf_extract_context)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加调试日志
|
||||||
|
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||||
|
|
||||||
|
rdf_triple_result = _extract_json_from_text(response)
|
||||||
|
|
||||||
|
# 检查返回的是否为有效的三元组列表
|
||||||
|
if not isinstance(rdf_triple_result, list):
|
||||||
|
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
||||||
|
if isinstance(rdf_triple_result, dict):
|
||||||
|
# 尝试常见的键名
|
||||||
|
for key in ['triples', 'result', 'data', 'items']:
|
||||||
|
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
|
||||||
|
rdf_triple_result = rdf_triple_result[key]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 如果找不到合适的列表,抛出异常
|
||||||
|
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||||
|
else:
|
||||||
|
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||||
|
|
||||||
|
# 验证三元组格式
|
||||||
|
for triple in rdf_triple_result:
|
||||||
|
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||||
raise Exception("RDF提取结果格式错误")
|
raise Exception("RDF提取结果格式错误")
|
||||||
|
|
||||||
return entity_extract_result
|
return rdf_triple_result
|
||||||
|
|
||||||
|
|
||||||
def info_extract_from_str(
|
def info_extract_from_str(
|
||||||
|
|||||||
@@ -184,10 +184,10 @@ class KGManager:
|
|||||||
progress.update(task, advance=1)
|
progress.update(task, advance=1)
|
||||||
continue
|
continue
|
||||||
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
|
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
|
||||||
assert isinstance(ent, EmbeddingStoreItem)
|
|
||||||
if ent is None:
|
if ent is None:
|
||||||
progress.update(task, advance=1)
|
progress.update(task, advance=1)
|
||||||
continue
|
continue
|
||||||
|
assert isinstance(ent, EmbeddingStoreItem)
|
||||||
# 查询相似实体
|
# 查询相似实体
|
||||||
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
||||||
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
||||||
@@ -265,7 +265,10 @@ class KGManager:
|
|||||||
if node_hash not in existed_nodes:
|
if node_hash not in existed_nodes:
|
||||||
if node_hash.startswith(local_storage['ent_namespace']):
|
if node_hash.startswith(local_storage['ent_namespace']):
|
||||||
# 新增实体节点
|
# 新增实体节点
|
||||||
node = embedding_manager.entities_embedding_store.store[node_hash]
|
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
||||||
|
if node is None:
|
||||||
|
logger.warning(f"实体节点 {node_hash} 在嵌入库中不存在,跳过")
|
||||||
|
continue
|
||||||
assert isinstance(node, EmbeddingStoreItem)
|
assert isinstance(node, EmbeddingStoreItem)
|
||||||
node_item = self.graph[node_hash]
|
node_item = self.graph[node_hash]
|
||||||
node_item["content"] = node.str
|
node_item["content"] = node.str
|
||||||
@@ -274,7 +277,10 @@ class KGManager:
|
|||||||
self.graph.update_node(node_item)
|
self.graph.update_node(node_item)
|
||||||
elif node_hash.startswith(local_storage['pg_namespace']):
|
elif node_hash.startswith(local_storage['pg_namespace']):
|
||||||
# 新增文段节点
|
# 新增文段节点
|
||||||
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
|
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
||||||
|
if node is None:
|
||||||
|
logger.warning(f"段落节点 {node_hash} 在嵌入库中不存在,跳过")
|
||||||
|
continue
|
||||||
assert isinstance(node, EmbeddingStoreItem)
|
assert isinstance(node, EmbeddingStoreItem)
|
||||||
content = node.str.replace("\n", " ")
|
content = node.str.replace("\n", " ")
|
||||||
node_item = self.graph[node_hash]
|
node_item = self.graph[node_hash]
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
|
def build_entity_extract_context(paragraph: str) -> str:
|
||||||
messages = [
|
"""构建实体提取的完整提示文本"""
|
||||||
LLMMessage("system", entity_extract_system_prompt).to_dict(),
|
return f"""{entity_extract_system_prompt}
|
||||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
|
|
||||||
]
|
段落:
|
||||||
return messages
|
```
|
||||||
|
{paragraph}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
|
||||||
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
|
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
|
||||||
@@ -36,12 +38,19 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
|
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> str:
|
||||||
messages = [
|
"""构建RDF三元组提取的完整提示文本"""
|
||||||
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
return f"""{rdf_triple_extract_system_prompt}
|
||||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
|
|
||||||
]
|
段落:
|
||||||
return messages
|
```
|
||||||
|
{paragraph}
|
||||||
|
```
|
||||||
|
|
||||||
|
实体列表:
|
||||||
|
```
|
||||||
|
{entities}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
|
||||||
qa_system_prompt = """
|
qa_system_prompt = """
|
||||||
|
|||||||
@@ -9,18 +9,20 @@ from src.common.logger import get_logger
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.database.database_model import Memory # Peewee Models导入
|
from src.common.database.database_model import Memory # Peewee Models导入
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemoryItem:
|
class MemoryItem:
|
||||||
def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]):
|
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
|
||||||
self.memory_id = memory_id
|
self.memory_id = memory_id
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.memory_text:str = memory_text
|
self.memory_text: str = memory_text
|
||||||
self.keywords:list[str] = keywords
|
self.keywords: list[str] = keywords
|
||||||
self.create_time:float = time.time()
|
self.create_time: float = time.time()
|
||||||
self.last_view_time:float = time.time()
|
self.last_view_time: float = time.time()
|
||||||
|
|
||||||
|
|
||||||
class MemoryManager:
|
class MemoryManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -28,11 +30,8 @@ class MemoryManager:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InstantMemory:
|
class InstantMemory:
|
||||||
def __init__(self,chat_id):
|
def __init__(self, chat_id):
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.last_view_time = time.time()
|
self.last_view_time = time.time()
|
||||||
self.summary_model = LLMRequest(
|
self.summary_model = LLMRequest(
|
||||||
@@ -41,7 +40,7 @@ class InstantMemory:
|
|||||||
request_type="memory.summary",
|
request_type="memory.summary",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def if_need_build(self,text):
|
async def if_need_build(self, text):
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0
|
请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0
|
||||||
{text}
|
{text}
|
||||||
@@ -49,11 +48,10 @@ class InstantMemory:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||||
print(prompt)
|
print(prompt)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
if "1" in response:
|
if "1" in response:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@@ -62,7 +60,7 @@ class InstantMemory:
|
|||||||
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def build_memory(self,text):
|
async def build_memory(self, text):
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
|
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
|
||||||
{text}
|
{text}
|
||||||
@@ -73,7 +71,7 @@ class InstantMemory:
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||||
print(prompt)
|
print(prompt)
|
||||||
print(response)
|
print(response)
|
||||||
if not response:
|
if not response:
|
||||||
@@ -81,15 +79,15 @@ class InstantMemory:
|
|||||||
try:
|
try:
|
||||||
repaired = repair_json(response)
|
repaired = repair_json(response)
|
||||||
result = json.loads(repaired)
|
result = json.loads(repaired)
|
||||||
memory_text = result.get('memory_text', '')
|
memory_text = result.get("memory_text", "")
|
||||||
keywords = result.get('keywords', '')
|
keywords = result.get("keywords", "")
|
||||||
if isinstance(keywords, str):
|
if isinstance(keywords, str):
|
||||||
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
|
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||||
elif isinstance(keywords, list):
|
elif isinstance(keywords, list):
|
||||||
keywords_list = keywords
|
keywords_list = keywords
|
||||||
else:
|
else:
|
||||||
keywords_list = []
|
keywords_list = []
|
||||||
return {'memory_text': memory_text, 'keywords': keywords_list}
|
return {"memory_text": memory_text, "keywords": keywords_list}
|
||||||
except Exception as parse_e:
|
except Exception as parse_e:
|
||||||
logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}")
|
logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
@@ -97,37 +95,37 @@ class InstantMemory:
|
|||||||
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
|
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def create_and_store_memory(self, text):
|
||||||
async def create_and_store_memory(self,text):
|
|
||||||
if_need = await self.if_need_build(text)
|
if_need = await self.if_need_build(text)
|
||||||
if if_need:
|
if if_need:
|
||||||
logger.info(f"需要记忆:{text}")
|
logger.info(f"需要记忆:{text}")
|
||||||
memory = await self.build_memory(text)
|
memory = await self.build_memory(text)
|
||||||
if memory and memory.get('memory_text'):
|
if memory and memory.get("memory_text"):
|
||||||
memory_id = f"{self.chat_id}_{time.time()}"
|
memory_id = f"{self.chat_id}_{time.time()}"
|
||||||
memory_item = MemoryItem(
|
memory_item = MemoryItem(
|
||||||
memory_id=memory_id,
|
memory_id=memory_id,
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
memory_text=memory['memory_text'],
|
memory_text=memory["memory_text"],
|
||||||
keywords=memory.get('keywords', [])
|
keywords=memory.get("keywords", []),
|
||||||
)
|
)
|
||||||
await self.store_memory(memory_item)
|
await self.store_memory(memory_item)
|
||||||
else:
|
else:
|
||||||
logger.info(f"不需要记忆:{text}")
|
logger.info(f"不需要记忆:{text}")
|
||||||
|
|
||||||
async def store_memory(self,memory_item:MemoryItem):
|
async def store_memory(self, memory_item: MemoryItem):
|
||||||
memory = Memory(
|
memory = Memory(
|
||||||
memory_id=memory_item.memory_id,
|
memory_id=memory_item.memory_id,
|
||||||
chat_id=memory_item.chat_id,
|
chat_id=memory_item.chat_id,
|
||||||
memory_text=memory_item.memory_text,
|
memory_text=memory_item.memory_text,
|
||||||
keywords=memory_item.keywords,
|
keywords=memory_item.keywords,
|
||||||
create_time=memory_item.create_time,
|
create_time=memory_item.create_time,
|
||||||
last_view_time=memory_item.last_view_time
|
last_view_time=memory_item.last_view_time,
|
||||||
)
|
)
|
||||||
memory.save()
|
memory.save()
|
||||||
|
|
||||||
async def get_memory(self,target:str):
|
async def get_memory(self, target: str):
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
请根据以下发言内容,判断是否需要提取记忆
|
请根据以下发言内容,判断是否需要提取记忆
|
||||||
{target}
|
{target}
|
||||||
@@ -144,7 +142,7 @@ class InstantMemory:
|
|||||||
请只输出json格式,不要输出其他多余内容
|
请只输出json格式,不要输出其他多余内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||||
print(prompt)
|
print(prompt)
|
||||||
print(response)
|
print(response)
|
||||||
if not response:
|
if not response:
|
||||||
@@ -153,15 +151,15 @@ class InstantMemory:
|
|||||||
repaired = repair_json(response)
|
repaired = repair_json(response)
|
||||||
result = json.loads(repaired)
|
result = json.loads(repaired)
|
||||||
# 解析keywords
|
# 解析keywords
|
||||||
keywords = result.get('keywords', '')
|
keywords = result.get("keywords", "")
|
||||||
if isinstance(keywords, str):
|
if isinstance(keywords, str):
|
||||||
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
|
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||||
elif isinstance(keywords, list):
|
elif isinstance(keywords, list):
|
||||||
keywords_list = keywords
|
keywords_list = keywords
|
||||||
else:
|
else:
|
||||||
keywords_list = []
|
keywords_list = []
|
||||||
# 解析time为时间段
|
# 解析time为时间段
|
||||||
time_str = result.get('time', '').strip()
|
time_str = result.get("time", "").strip()
|
||||||
start_time, end_time = self._parse_time_range(time_str)
|
start_time, end_time = self._parse_time_range(time_str)
|
||||||
logger.info(f"start_time: {start_time}, end_time: {end_time}")
|
logger.info(f"start_time: {start_time}, end_time: {end_time}")
|
||||||
# 检索包含关键词的记忆
|
# 检索包含关键词的记忆
|
||||||
@@ -170,16 +168,15 @@ class InstantMemory:
|
|||||||
start_ts = start_time.timestamp()
|
start_ts = start_time.timestamp()
|
||||||
end_ts = end_time.timestamp()
|
end_ts = end_time.timestamp()
|
||||||
query = Memory.select().where(
|
query = Memory.select().where(
|
||||||
(Memory.chat_id == self.chat_id) &
|
(Memory.chat_id == self.chat_id)
|
||||||
(Memory.create_time >= start_ts) &
|
& (Memory.create_time >= start_ts) # type: ignore
|
||||||
(Memory.create_time < end_ts)
|
& (Memory.create_time < end_ts) # type: ignore
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query = Memory.select().where(Memory.chat_id == self.chat_id)
|
query = Memory.select().where(Memory.chat_id == self.chat_id)
|
||||||
|
|
||||||
|
|
||||||
for mem in query:
|
for mem in query:
|
||||||
#对每条记忆
|
# 对每条记忆
|
||||||
mem_keywords = mem.keywords or []
|
mem_keywords = mem.keywords or []
|
||||||
parsed = ast.literal_eval(mem_keywords)
|
parsed = ast.literal_eval(mem_keywords)
|
||||||
if isinstance(parsed, list):
|
if isinstance(parsed, list):
|
||||||
@@ -212,6 +209,7 @@ class InstantMemory:
|
|||||||
- 空字符串:返回(None, None)
|
- 空字符串:返回(None, None)
|
||||||
"""
|
"""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
if not time_str:
|
if not time_str:
|
||||||
return 0, now
|
return 0, now
|
||||||
@@ -251,7 +249,7 @@ class InstantMemory:
|
|||||||
if m:
|
if m:
|
||||||
months = int(m.group(1))
|
months = int(m.group(1))
|
||||||
# 近似每月30天
|
# 近似每月30天
|
||||||
start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
end = start + timedelta(days=1)
|
end = start + timedelta(days=1)
|
||||||
return start, end
|
return start, end
|
||||||
# 其他无法解析
|
# 其他无法解析
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class ChatMessageContext:
|
|||||||
def get_template_name(self) -> Optional[str]:
|
def get_template_name(self) -> Optional[str]:
|
||||||
"""获取模板名称"""
|
"""获取模板名称"""
|
||||||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||||
return self.message.message_info.template_info.template_name
|
return self.message.message_info.template_info.template_name # type: ignore
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_last_message(self) -> "MessageRecv":
|
def get_last_message(self) -> "MessageRecv":
|
||||||
|
|||||||
@@ -181,6 +181,7 @@ class MessageRecv(Message):
|
|||||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||||
return f"[处理失败的{segment.type}消息]"
|
return f"[处理失败的{segment.type}消息]"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageRecvS4U(MessageRecv):
|
class MessageRecvS4U(MessageRecv):
|
||||||
def __init__(self, message_dict: dict[str, Any]):
|
def __init__(self, message_dict: dict[str, Any]):
|
||||||
@@ -254,7 +255,7 @@ class MessageRecvS4U(MessageRecv):
|
|||||||
elif segment.type == "gift":
|
elif segment.type == "gift":
|
||||||
self.is_gift = True
|
self.is_gift = True
|
||||||
# 解析gift_info,格式为"名称:数量"
|
# 解析gift_info,格式为"名称:数量"
|
||||||
name, count = segment.data.split(":", 1)
|
name, count = segment.data.split(":", 1) # type: ignore
|
||||||
self.gift_info = segment.data
|
self.gift_info = segment.data
|
||||||
self.gift_name = name.strip()
|
self.gift_name = name.strip()
|
||||||
self.gift_count = int(count.strip())
|
self.gift_count = int(count.strip())
|
||||||
@@ -267,12 +268,14 @@ class MessageRecvS4U(MessageRecv):
|
|||||||
elif segment.type == "superchat":
|
elif segment.type == "superchat":
|
||||||
self.is_superchat = True
|
self.is_superchat = True
|
||||||
self.superchat_info = segment.data
|
self.superchat_info = segment.data
|
||||||
price,message_text = segment.data.split(":", 1)
|
price, message_text = segment.data.split(":", 1) # type: ignore
|
||||||
self.superchat_price = price.strip()
|
self.superchat_price = price.strip()
|
||||||
self.superchat_message_text = message_text.strip()
|
self.superchat_message_text = message_text.strip()
|
||||||
|
|
||||||
self.processed_plain_text = str(self.superchat_message_text)
|
self.processed_plain_text = str(self.superchat_message_text)
|
||||||
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
self.processed_plain_text += (
|
||||||
|
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||||
|
)
|
||||||
|
|
||||||
return self.processed_plain_text
|
return self.processed_plain_text
|
||||||
elif segment.type == "screen":
|
elif segment.type == "screen":
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class ActionManager:
|
|||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
log_prefix: str,
|
log_prefix: str,
|
||||||
shutting_down: bool = False,
|
shutting_down: bool = False,
|
||||||
action_message: dict = None,
|
action_message: Optional[dict] = None,
|
||||||
) -> Optional[BaseAction]:
|
) -> Optional[BaseAction]:
|
||||||
"""
|
"""
|
||||||
创建动作处理器实例
|
创建动作处理器实例
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ def _build_readable_messages_internal(
|
|||||||
pic_id_mapping: Optional[Dict[str, str]] = None,
|
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||||
pic_counter: int = 1,
|
pic_counter: int = 1,
|
||||||
show_pic: bool = True,
|
show_pic: bool = True,
|
||||||
message_id_list: List[Dict[str, Any]] = None,
|
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||||
"""
|
"""
|
||||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||||
@@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
|||||||
for action in actions:
|
for action in actions:
|
||||||
action_time = action.get("time", current_time)
|
action_time = action.get("time", current_time)
|
||||||
action_name = action.get("action_name", "未知动作")
|
action_name = action.get("action_name", "未知动作")
|
||||||
if action_name == "no_action" or action_name == "no_reply":
|
if action_name in ["no_action", "no_reply"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
action_prompt_display = action.get("action_prompt_display", "无具体内容")
|
action_prompt_display = action.get("action_prompt_display", "无具体内容")
|
||||||
@@ -697,7 +697,7 @@ def build_readable_messages(
|
|||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
show_actions: bool = False,
|
show_actions: bool = False,
|
||||||
show_pic: bool = True,
|
show_pic: bool = True,
|
||||||
message_id_list: List[Dict[str, Any]] = None,
|
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> str: # sourcery skip: extract-method
|
) -> str: # sourcery skip: extract-method
|
||||||
"""
|
"""
|
||||||
将消息列表转换为可读的文本格式。
|
将消息列表转换为可读的文本格式。
|
||||||
|
|||||||
@@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
f.write(html_template)
|
f.write(html_template)
|
||||||
|
|
||||||
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
||||||
# sourcery skip: for-append-to-extend, list-comprehension, use-any
|
# sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next
|
||||||
"""生成Focus统计独立分页的HTML内容"""
|
"""生成Focus统计独立分页的HTML内容"""
|
||||||
|
|
||||||
# 为每个时间段准备Focus数据
|
# 为每个时间段准备Focus数据
|
||||||
@@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
|
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
|
||||||
|
# sourcery skip: use-named-expression, use-next
|
||||||
"""生成版本对比独立分页的HTML内容"""
|
"""生成版本对比独立分页的HTML内容"""
|
||||||
|
|
||||||
# 为每个时间段准备版本对比数据
|
# 为每个时间段准备版本对比数据
|
||||||
@@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
# 复用 StatisticOutputTask 的所有方法
|
# 复用 StatisticOutputTask 的所有方法
|
||||||
def _collect_all_statistics(self, now: datetime):
|
def _collect_all_statistics(self, now: datetime):
|
||||||
return StatisticOutputTask._collect_all_statistics(self, now)
|
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
|
||||||
|
|
||||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
||||||
return StatisticOutputTask._statistic_console_output(self, stats, now)
|
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
|
||||||
|
|
||||||
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
|
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
|
||||||
return StatisticOutputTask._generate_html_report(self, stats, now)
|
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
|
||||||
|
|
||||||
# 其他需要的方法也可以类似复用...
|
# 其他需要的方法也可以类似复用...
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask):
|
|||||||
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
|
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
|
||||||
|
|
||||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||||
return StatisticOutputTask._collect_message_count_for_period(self, collect_period)
|
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
|
||||||
|
|
||||||
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||||
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period)
|
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore
|
||||||
|
|
||||||
def _process_focus_file_data(
|
def _process_focus_file_data(
|
||||||
self,
|
self,
|
||||||
@@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask):
|
|||||||
collect_period: List[Tuple[str, datetime]],
|
collect_period: List[Tuple[str, datetime]],
|
||||||
file_time: datetime,
|
file_time: datetime,
|
||||||
):
|
):
|
||||||
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time)
|
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore
|
||||||
|
|
||||||
def _calculate_focus_averages(self, stats: Dict[str, Any]):
|
def _calculate_focus_averages(self, stats: Dict[str, Any]):
|
||||||
return StatisticOutputTask._calculate_focus_averages(self, stats)
|
return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
||||||
@@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||||
return StatisticOutputTask._format_model_classified_stat(stats)
|
return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
|
||||||
|
|
||||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||||
return StatisticOutputTask._format_chat_stat(self, stats)
|
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
|
||||||
|
|
||||||
def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
|
def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
|
||||||
return StatisticOutputTask._format_focus_stat(self, stats)
|
return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore
|
||||||
|
|
||||||
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||||
return StatisticOutputTask._generate_chart_data(self, stat)
|
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
|
||||||
|
|
||||||
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||||
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes)
|
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore
|
||||||
|
|
||||||
def _generate_chart_tab(self, chart_data: dict) -> str:
|
def _generate_chart_tab(self, chart_data: dict) -> str:
|
||||||
return StatisticOutputTask._generate_chart_tab(self, chart_data)
|
return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
|
||||||
|
|
||||||
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
|
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
|
||||||
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id)
|
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
|
||||||
|
|
||||||
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
||||||
return StatisticOutputTask._generate_focus_tab(self, stat)
|
return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore
|
||||||
|
|
||||||
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
|
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
|
||||||
return StatisticOutputTask._generate_versions_tab(self, stat)
|
return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore
|
||||||
|
|
||||||
def _convert_defaultdict_to_dict(self, data):
|
def _convert_defaultdict_to_dict(self, data):
|
||||||
return StatisticOutputTask._convert_defaultdict_to_dict(self, data)
|
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ import importlib
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, Any
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||||
from src.chat.message_receive.message import MessageRecv
|
|
||||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -54,7 +53,7 @@ class WillingInfo:
|
|||||||
interested_rate (float): 兴趣度
|
interested_rate (float): 兴趣度
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message: MessageRecv
|
message: Dict[str, Any] # 原始消息数据
|
||||||
chat: ChatStream
|
chat: ChatStream
|
||||||
person_info_manager: PersonInfoManager
|
person_info_manager: PersonInfoManager
|
||||||
chat_id: str
|
chat_id: str
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class ChatStreams(BaseModel):
|
|||||||
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
|
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
|
||||||
user_cardname = TextField(null=True)
|
user_cardname = TextField(null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||||
# 如果不使用带有数据库实例的 BaseModel,或者想覆盖它,
|
# 如果不使用带有数据库实例的 BaseModel,或者想覆盖它,
|
||||||
# 请取消注释并在下面设置数据库实例:
|
# 请取消注释并在下面设置数据库实例:
|
||||||
@@ -89,7 +89,7 @@ class LLMUsage(BaseModel):
|
|||||||
status = TextField()
|
status = TextField()
|
||||||
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||||
# database = db
|
# database = db
|
||||||
table_name = "llm_usage"
|
table_name = "llm_usage"
|
||||||
@@ -112,7 +112,7 @@ class Emoji(BaseModel):
|
|||||||
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
|
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
|
||||||
last_used_time = FloatField(null=True) # 上次使用时间
|
last_used_time = FloatField(null=True) # 上次使用时间
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
table_name = "emoji"
|
table_name = "emoji"
|
||||||
|
|
||||||
@@ -162,7 +162,8 @@ class Messages(BaseModel):
|
|||||||
is_emoji = BooleanField(default=False)
|
is_emoji = BooleanField(default=False)
|
||||||
is_picid = BooleanField(default=False)
|
is_picid = BooleanField(default=False)
|
||||||
is_command = BooleanField(default=False)
|
is_command = BooleanField(default=False)
|
||||||
class Meta:
|
|
||||||
|
class Meta: # type: ignore
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
table_name = "messages"
|
table_name = "messages"
|
||||||
|
|
||||||
@@ -186,7 +187,7 @@ class ActionRecords(BaseModel):
|
|||||||
chat_info_stream_id = TextField()
|
chat_info_stream_id = TextField()
|
||||||
chat_info_platform = TextField()
|
chat_info_platform = TextField()
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
table_name = "action_records"
|
table_name = "action_records"
|
||||||
|
|
||||||
@@ -206,7 +207,7 @@ class Images(BaseModel):
|
|||||||
type = TextField() # 图像类型,例如 "emoji"
|
type = TextField() # 图像类型,例如 "emoji"
|
||||||
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
|
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
table_name = "images"
|
table_name = "images"
|
||||||
|
|
||||||
|
|
||||||
@@ -220,7 +221,7 @@ class ImageDescriptions(BaseModel):
|
|||||||
description = TextField() # 图像的描述
|
description = TextField() # 图像的描述
|
||||||
timestamp = FloatField() # 时间戳
|
timestamp = FloatField() # 时间戳
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
table_name = "image_descriptions"
|
table_name = "image_descriptions"
|
||||||
|
|
||||||
@@ -236,7 +237,7 @@ class OnlineTime(BaseModel):
|
|||||||
start_timestamp = DateTimeField(default=datetime.datetime.now)
|
start_timestamp = DateTimeField(default=datetime.datetime.now)
|
||||||
end_timestamp = DateTimeField(index=True)
|
end_timestamp = DateTimeField(index=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
table_name = "online_time"
|
table_name = "online_time"
|
||||||
|
|
||||||
@@ -263,10 +264,11 @@ class PersonInfo(BaseModel):
|
|||||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||||
attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢
|
attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
table_name = "person_info"
|
table_name = "person_info"
|
||||||
|
|
||||||
|
|
||||||
class Memory(BaseModel):
|
class Memory(BaseModel):
|
||||||
memory_id = TextField(index=True)
|
memory_id = TextField(index=True)
|
||||||
chat_id = TextField(null=True)
|
chat_id = TextField(null=True)
|
||||||
@@ -275,9 +277,10 @@ class Memory(BaseModel):
|
|||||||
create_time = FloatField(null=True)
|
create_time = FloatField(null=True)
|
||||||
last_view_time = FloatField(null=True)
|
last_view_time = FloatField(null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
table_name = "memory"
|
table_name = "memory"
|
||||||
|
|
||||||
|
|
||||||
class Knowledges(BaseModel):
|
class Knowledges(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储知识库条目的模型。
|
用于存储知识库条目的模型。
|
||||||
@@ -287,10 +290,11 @@ class Knowledges(BaseModel):
|
|||||||
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
||||||
# 可以添加其他元数据字段,如 source, create_time 等
|
# 可以添加其他元数据字段,如 source, create_time 等
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
table_name = "knowledges"
|
table_name = "knowledges"
|
||||||
|
|
||||||
|
|
||||||
class Expression(BaseModel):
|
class Expression(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储表达风格的模型。
|
用于存储表达风格的模型。
|
||||||
@@ -303,9 +307,10 @@ class Expression(BaseModel):
|
|||||||
chat_id = TextField(index=True)
|
chat_id = TextField(index=True)
|
||||||
type = TextField()
|
type = TextField()
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
table_name = "expression"
|
table_name = "expression"
|
||||||
|
|
||||||
|
|
||||||
class ThinkingLog(BaseModel):
|
class ThinkingLog(BaseModel):
|
||||||
chat_id = TextField(index=True)
|
chat_id = TextField(index=True)
|
||||||
trigger_text = TextField(null=True)
|
trigger_text = TextField(null=True)
|
||||||
@@ -326,7 +331,7 @@ class ThinkingLog(BaseModel):
|
|||||||
# And: import datetime
|
# And: import datetime
|
||||||
created_at = DateTimeField(default=datetime.datetime.now)
|
created_at = DateTimeField(default=datetime.datetime.now)
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
table_name = "thinking_logs"
|
table_name = "thinking_logs"
|
||||||
|
|
||||||
|
|
||||||
@@ -341,7 +346,7 @@ class GraphNodes(BaseModel):
|
|||||||
created_time = FloatField() # 创建时间戳
|
created_time = FloatField() # 创建时间戳
|
||||||
last_modified = FloatField() # 最后修改时间戳
|
last_modified = FloatField() # 最后修改时间戳
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
table_name = "graph_nodes"
|
table_name = "graph_nodes"
|
||||||
|
|
||||||
|
|
||||||
@@ -357,7 +362,7 @@ class GraphEdges(BaseModel):
|
|||||||
created_time = FloatField() # 创建时间戳
|
created_time = FloatField() # 创建时间戳
|
||||||
last_modified = FloatField() # 最后修改时间戳
|
last_modified = FloatField() # 最后修改时间戳
|
||||||
|
|
||||||
class Meta:
|
class Meta: # type: ignore
|
||||||
table_name = "graph_edges"
|
table_name = "graph_edges"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ from datetime import datetime
|
|||||||
|
|
||||||
def get_key_comment(toml_table, key):
|
def get_key_comment(toml_table, key):
|
||||||
# 获取key的注释(如果有)
|
# 获取key的注释(如果有)
|
||||||
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
|
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||||
return toml_table.trivia.comment
|
return toml_table.trivia.comment
|
||||||
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
|
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
||||||
item = toml_table.value.get(key)
|
item = toml_table.value.get(key)
|
||||||
if item is not None and hasattr(item, 'trivia'):
|
if item is not None and hasattr(item, "trivia"):
|
||||||
return item.trivia.comment
|
return item.trivia.comment
|
||||||
if hasattr(toml_table, 'keys'):
|
if hasattr(toml_table, "keys"):
|
||||||
for k in toml_table.keys():
|
for k in toml_table.keys():
|
||||||
if isinstance(k, KeyType) and k.key == key:
|
if isinstance(k, KeyType) and k.key == key:
|
||||||
return k.trivia.comment
|
return k.trivia.comment
|
||||||
@@ -36,16 +36,16 @@ def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, log
|
|||||||
continue
|
continue
|
||||||
if key not in old:
|
if key not in old:
|
||||||
comment = get_key_comment(new, key)
|
comment = get_key_comment(new, key)
|
||||||
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||||
compare_dicts(new[key], old[key], path+[str(key)], new_comments, old_comments, logs)
|
compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
|
||||||
# 删减项
|
# 删减项
|
||||||
for key in old:
|
for key in old:
|
||||||
if key == "version":
|
if key == "version":
|
||||||
continue
|
continue
|
||||||
if key not in new:
|
if key not in new:
|
||||||
comment = get_key_comment(old, key)
|
comment = get_key_comment(old, key)
|
||||||
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ def update_config():
|
|||||||
if old_version and new_version and old_version == new_version:
|
if old_version and new_version and old_version == new_version:
|
||||||
print(f"检测到版本号相同 (v{old_version}),跳过更新")
|
print(f"检测到版本号相同 (v{old_version}),跳过更新")
|
||||||
# 如果version相同,恢复旧配置文件并返回
|
# 如果version相同,恢复旧配置文件并返回
|
||||||
shutil.move(old_backup_path, old_config_path)
|
shutil.move(old_backup_path, old_config_path) # type: ignore
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||||
|
|||||||
@@ -53,13 +53,13 @@ MMC_VERSION = "0.9.0-snapshot.2"
|
|||||||
|
|
||||||
def get_key_comment(toml_table, key):
|
def get_key_comment(toml_table, key):
|
||||||
# 获取key的注释(如果有)
|
# 获取key的注释(如果有)
|
||||||
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
|
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||||
return toml_table.trivia.comment
|
return toml_table.trivia.comment
|
||||||
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
|
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
||||||
item = toml_table.value.get(key)
|
item = toml_table.value.get(key)
|
||||||
if item is not None and hasattr(item, 'trivia'):
|
if item is not None and hasattr(item, "trivia"):
|
||||||
return item.trivia.comment
|
return item.trivia.comment
|
||||||
if hasattr(toml_table, 'keys'):
|
if hasattr(toml_table, "keys"):
|
||||||
for k in toml_table.keys():
|
for k in toml_table.keys():
|
||||||
if isinstance(k, KeyType) and k.key == key:
|
if isinstance(k, KeyType) and k.key == key:
|
||||||
return k.trivia.comment
|
return k.trivia.comment
|
||||||
@@ -78,16 +78,16 @@ def compare_dicts(new, old, path=None, logs=None):
|
|||||||
continue
|
continue
|
||||||
if key not in old:
|
if key not in old:
|
||||||
comment = get_key_comment(new, key)
|
comment = get_key_comment(new, key)
|
||||||
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||||
compare_dicts(new[key], old[key], path+[str(key)], logs)
|
compare_dicts(new[key], old[key], path + [str(key)], logs)
|
||||||
# 删减项
|
# 删减项
|
||||||
for key in old:
|
for key in old:
|
||||||
if key == "version":
|
if key == "version":
|
||||||
continue
|
continue
|
||||||
if key not in new:
|
if key not in new:
|
||||||
comment = get_key_comment(old, key)
|
comment = get_key_comment(old, key)
|
||||||
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
|
|
||||||
@@ -99,6 +99,7 @@ def get_value_by_path(d, path):
|
|||||||
return None
|
return None
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def set_value_by_path(d, path, value):
|
def set_value_by_path(d, path, value):
|
||||||
for k in path[:-1]:
|
for k in path[:-1]:
|
||||||
if k not in d or not isinstance(d[k], dict):
|
if k not in d or not isinstance(d[k], dict):
|
||||||
@@ -106,6 +107,7 @@ def set_value_by_path(d, path, value):
|
|||||||
d = d[k]
|
d = d[k]
|
||||||
d[path[-1]] = value
|
d[path[-1]] = value
|
||||||
|
|
||||||
|
|
||||||
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||||
# 递归比较两个dict,找出默认值变化项
|
# 递归比较两个dict,找出默认值变化项
|
||||||
if path is None:
|
if path is None:
|
||||||
@@ -119,12 +121,14 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
|
|||||||
continue
|
continue
|
||||||
if key in old:
|
if key in old:
|
||||||
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
||||||
compare_default_values(new[key], old[key], path+[str(key)], logs, changes)
|
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
|
||||||
else:
|
else:
|
||||||
# 只要值发生变化就记录
|
# 只要值发生变化就记录
|
||||||
if new[key] != old[key]:
|
if new[key] != old[key]:
|
||||||
logs.append(f"默认值变化: {'.'.join(path+[str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
|
logs.append(
|
||||||
changes.append((path+[str(key)], old[key], new[key]))
|
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
|
||||||
|
)
|
||||||
|
changes.append((path + [str(key)], old[key], new[key]))
|
||||||
return logs, changes
|
return logs, changes
|
||||||
|
|
||||||
|
|
||||||
@@ -148,8 +152,8 @@ def update_config():
|
|||||||
return None
|
return None
|
||||||
with open(toml_path, "r", encoding="utf-8") as f:
|
with open(toml_path, "r", encoding="utf-8") as f:
|
||||||
doc = tomlkit.load(f)
|
doc = tomlkit.load(f)
|
||||||
if "inner" in doc and "version" in doc["inner"]:
|
if "inner" in doc and "version" in doc["inner"]: # type: ignore
|
||||||
return doc["inner"]["version"]
|
return doc["inner"]["version"] # type: ignore
|
||||||
return None
|
return None
|
||||||
|
|
||||||
template_version = get_version_from_toml(template_path)
|
template_version = get_version_from_toml(template_path)
|
||||||
@@ -186,7 +190,9 @@ def update_config():
|
|||||||
old_value = get_value_by_path(old_config, path)
|
old_value = get_value_by_path(old_config, path)
|
||||||
if old_value == old_default:
|
if old_value == old_default:
|
||||||
set_value_by_path(old_config, path, new_default)
|
set_value_by_path(old_config, path, new_default)
|
||||||
logger.info(f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}")
|
logger.info(
|
||||||
|
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("未检测到模板默认值变动")
|
logger.info("未检测到模板默认值变动")
|
||||||
# 保存旧配置的变更(后续合并逻辑会用到 old_config)
|
# 保存旧配置的变更(后续合并逻辑会用到 old_config)
|
||||||
@@ -229,7 +235,9 @@ def update_config():
|
|||||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.info(f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------")
|
logger.info(
|
||||||
|
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||||
|
|
||||||
@@ -321,6 +329,7 @@ class Config(ConfigBase):
|
|||||||
debug: DebugConfig
|
debug: DebugConfig
|
||||||
custom_prompt: CustomPromptConfig
|
custom_prompt: CustomPromptConfig
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: str) -> Config:
|
def load_config(config_path: str) -> Config:
|
||||||
"""
|
"""
|
||||||
加载配置文件
|
加载配置文件
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class LLMRequestOff:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 发送请求到完整的 chat/completions 端点
|
# 发送请求到完整的 chat/completions 端点
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
|
||||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||||
|
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
@@ -89,7 +89,7 @@ class LLMRequestOff:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 发送请求到完整的 chat/completions 端点
|
# 发送请求到完整的 chat/completions 端点
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
|
||||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||||
|
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
|
|||||||
@@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||||
self.scenarios = []
|
self.scenarios = []
|
||||||
self.final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||||
self.dimension_counts = {trait: 0 for trait in self.final_scores.keys()}
|
self.dimension_counts = {trait: 0 for trait in self.final_scores}
|
||||||
|
|
||||||
# 为每个人格特质获取对应的场景
|
# 为每个人格特质获取对应的场景
|
||||||
for trait in PERSONALITY_SCENES:
|
for trait in PERSONALITY_SCENES:
|
||||||
@@ -119,8 +119,7 @@ class PersonalityEvaluatorDirect:
|
|||||||
# 构建维度描述
|
# 构建维度描述
|
||||||
dimension_descriptions = []
|
dimension_descriptions = []
|
||||||
for dim in dimensions:
|
for dim in dimensions:
|
||||||
desc = FACTOR_DESCRIPTIONS.get(dim, "")
|
if desc := FACTOR_DESCRIPTIONS.get(dim, ""):
|
||||||
if desc:
|
|
||||||
dimension_descriptions.append(f"- {dim}:{desc}")
|
dimension_descriptions.append(f"- {dim}:{desc}")
|
||||||
|
|
||||||
dimensions_text = "\n".join(dimension_descriptions)
|
dimensions_text = "\n".join(dimension_descriptions)
|
||||||
|
|||||||
@@ -153,14 +153,14 @@ class MainSystem:
|
|||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(global_config.memory.memory_build_interval)
|
await asyncio.sleep(global_config.memory.memory_build_interval)
|
||||||
logger.info("正在进行记忆构建")
|
logger.info("正在进行记忆构建")
|
||||||
await self.hippocampus_manager.build_memory()
|
await self.hippocampus_manager.build_memory() # type: ignore
|
||||||
|
|
||||||
async def forget_memory_task(self):
|
async def forget_memory_task(self):
|
||||||
"""记忆遗忘任务"""
|
"""记忆遗忘任务"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(global_config.memory.forget_memory_interval)
|
await asyncio.sleep(global_config.memory.forget_memory_interval)
|
||||||
logger.info("[记忆遗忘] 开始遗忘记忆...")
|
logger.info("[记忆遗忘] 开始遗忘记忆...")
|
||||||
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage)
|
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
|
||||||
logger.info("[记忆遗忘] 记忆遗忘完成")
|
logger.info("[记忆遗忘] 记忆遗忘完成")
|
||||||
|
|
||||||
async def consolidate_memory_task(self):
|
async def consolidate_memory_task(self):
|
||||||
@@ -168,7 +168,7 @@ class MainSystem:
|
|||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
|
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
|
||||||
logger.info("[记忆整合] 开始整合记忆...")
|
logger.info("[记忆整合] 开始整合记忆...")
|
||||||
await self.hippocampus_manager.consolidate_memory()
|
await self.hippocampus_manager.consolidate_memory() # type: ignore
|
||||||
logger.info("[记忆整合] 记忆整合完成")
|
logger.info("[记忆整合] 记忆整合完成")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ class ChatMood:
|
|||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
self.chat_stream = chat_manager.get_stream(self.chat_id)
|
self.chat_stream = chat_manager.get_stream(self.chat_id)
|
||||||
|
|
||||||
|
if not self.chat_stream:
|
||||||
|
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
|
||||||
|
|
||||||
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
|
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
|
||||||
|
|
||||||
self.mood_state: str = "感觉很平静"
|
self.mood_state: str = "感觉很平静"
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ SEGMENT_CLEANUP_CONFIG = {
|
|||||||
"cleanup_interval_hours": 0.5, # 清理间隔(小时)
|
"cleanup_interval_hours": 0.5, # 清理间隔(小时)
|
||||||
}
|
}
|
||||||
|
|
||||||
MAX_MESSAGE_COUNT = 80 / global_config.relationship.relation_frequency
|
MAX_MESSAGE_COUNT = int(80 / global_config.relationship.relation_frequency)
|
||||||
|
|
||||||
|
|
||||||
class RelationshipBuilder:
|
class RelationshipBuilder:
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ __all__ = [
|
|||||||
"ConfigField",
|
"ConfigField",
|
||||||
# 工具函数
|
# 工具函数
|
||||||
"ManifestValidator",
|
"ManifestValidator",
|
||||||
"ManifestGenerator",
|
# "ManifestGenerator",
|
||||||
"validate_plugin_manifest",
|
# "validate_plugin_manifest",
|
||||||
"generate_plugin_manifest",
|
# "generate_plugin_manifest",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -34,7 +34,4 @@ def register_event_plugin(cls, *args, **kwargs):
|
|||||||
|
|
||||||
用法:
|
用法:
|
||||||
@register_event_plugin
|
@register_event_plugin
|
||||||
class MyEventPlugin:
|
|
||||||
event_type = EventType.MESSAGE_RECEIVED
|
|
||||||
...
|
|
||||||
"""
|
"""
|
||||||
@@ -111,7 +111,7 @@ async def _send_to_target(
|
|||||||
is_head=True,
|
is_head=True,
|
||||||
is_emoji=(message_type == "emoji"),
|
is_emoji=(message_type == "emoji"),
|
||||||
thinking_start_time=current_time,
|
thinking_start_time=current_time,
|
||||||
reply_to = reply_to_platform_id
|
reply_to=reply_to_platform_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
@@ -137,6 +137,7 @@ async def _send_to_target(
|
|||||||
|
|
||||||
|
|
||||||
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
|
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
|
||||||
|
# sourcery skip: inline-variable, use-named-expression
|
||||||
"""查找要回复的消息
|
"""查找要回复的消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -184,14 +185,11 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
|
|||||||
|
|
||||||
# 检查是否有 回复<aaa:bbb> 字段
|
# 检查是否有 回复<aaa:bbb> 字段
|
||||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||||
match = re.search(reply_pattern, translate_text)
|
if match := re.search(reply_pattern, translate_text):
|
||||||
if match:
|
|
||||||
aaa = match.group(1)
|
aaa = match.group(1)
|
||||||
bbb = match.group(2)
|
bbb = match.group(2)
|
||||||
reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
|
reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
|
||||||
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name")
|
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa
|
||||||
if not reply_person_name:
|
|
||||||
reply_person_name = aaa
|
|
||||||
# 在内容前加上回复信息
|
# 在内容前加上回复信息
|
||||||
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
|
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
|
||||||
|
|
||||||
@@ -206,9 +204,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
|
|||||||
aaa = m.group(1)
|
aaa = m.group(1)
|
||||||
bbb = m.group(2)
|
bbb = m.group(2)
|
||||||
at_person_id = get_person_info_manager().get_person_id(platform, bbb)
|
at_person_id = get_person_info_manager().get_person_id(platform, bbb)
|
||||||
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name")
|
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa
|
||||||
if not at_person_name:
|
|
||||||
at_person_name = aaa
|
|
||||||
new_content += f"@{at_person_name}"
|
new_content += f"@{at_person_name}"
|
||||||
last_end = m.end()
|
last_end = m.end()
|
||||||
new_content += translate_text[last_end:]
|
new_content += translate_text[last_end:]
|
||||||
@@ -403,7 +399,7 @@ async def text_to_group(
|
|||||||
"""
|
"""
|
||||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||||
|
|
||||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
|
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
|
||||||
|
|
||||||
|
|
||||||
async def text_to_user(
|
async def text_to_user(
|
||||||
@@ -427,7 +423,7 @@ async def text_to_user(
|
|||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
|
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
|
||||||
|
|
||||||
|
|
||||||
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||||
@@ -550,7 +546,9 @@ async def custom_to_group(
|
|||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
return await _send_to_target(
|
||||||
|
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def custom_to_user(
|
async def custom_to_user(
|
||||||
@@ -578,7 +576,9 @@ async def custom_to_user(
|
|||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
return await _send_to_target(
|
||||||
|
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def custom_message(
|
async def custom_message(
|
||||||
@@ -618,4 +618,6 @@ async def custom_message(
|
|||||||
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
|
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
|
||||||
"""
|
"""
|
||||||
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
|
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
|
||||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
return await _send_to_target(
|
||||||
|
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||||
|
)
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class BaseAction(ABC):
|
|||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
log_prefix: str = "",
|
log_prefix: str = "",
|
||||||
plugin_config: Optional[dict] = None,
|
plugin_config: Optional[dict] = None,
|
||||||
action_message: dict = None,
|
action_message: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""初始化Action组件
|
"""初始化Action组件
|
||||||
@@ -106,6 +106,8 @@ class BaseAction(ABC):
|
|||||||
|
|
||||||
if self.action_message:
|
if self.action_message:
|
||||||
self.has_action_message = True
|
self.has_action_message = True
|
||||||
|
else:
|
||||||
|
self.action_message = {}
|
||||||
|
|
||||||
if self.has_action_message:
|
if self.has_action_message:
|
||||||
if self.action_name != "no_reply":
|
if self.action_name != "no_reply":
|
||||||
@@ -132,8 +134,6 @@ class BaseAction(ABC):
|
|||||||
self.is_group = False
|
self.is_group = False
|
||||||
self.target_id = self.user_id
|
self.target_id = self.user_id
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||||
@@ -199,7 +199,9 @@ class BaseAction(ABC):
|
|||||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||||
return False, f"等待新消息失败: {str(e)}"
|
return False, f"等待新消息失败: {str(e)}"
|
||||||
|
|
||||||
async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool:
|
async def send_text(
|
||||||
|
self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
|
||||||
|
) -> bool:
|
||||||
"""发送文本消息
|
"""发送文本消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -299,7 +301,7 @@ class BaseAction(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def send_command(
|
async def send_command(
|
||||||
self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True
|
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送命令消息
|
"""发送命令消息
|
||||||
|
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class BaseCommand(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def send_command(
|
async def send_command(
|
||||||
self, command_name: str, args: dict = None, display_message: str = "", storage_message: bool = True
|
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送命令消息
|
"""发送命令消息
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
class BaseEventsPlugin(ABC):
|
from .plugin_base import PluginBase
|
||||||
"""
|
from src.common.logger import get_logger
|
||||||
事件触发型插件基类
|
|
||||||
|
|
||||||
所有事件触发型插件都应该继承这个基类而不是 BasePlugin
|
|
||||||
|
class BaseEventPlugin(PluginBase):
|
||||||
|
"""基于事件的插件基类
|
||||||
|
|
||||||
|
所有事件类型的插件都应该继承这个基类
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
def __init__(self, *args, **kwargs):
|
||||||
@abstractmethod
|
super().__init__(*args, **kwargs)
|
||||||
def plugin_name(self) -> str:
|
|
||||||
return "" # 插件内部标识符(如 "hello_world_plugin")
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def enable_plugin(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|||||||
@@ -7,7 +7,16 @@ from src.plugin_system.base.component_types import ComponentInfo
|
|||||||
|
|
||||||
logger = get_logger("base_plugin")
|
logger = get_logger("base_plugin")
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin(PluginBase):
|
class BasePlugin(PluginBase):
|
||||||
|
"""基于Action和Command的插件基类
|
||||||
|
|
||||||
|
所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件:
|
||||||
|
- Action组件:处理聊天中的动作
|
||||||
|
- Command组件:处理命令请求
|
||||||
|
- 未来可扩展:Scheduler、Listener等
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -19,12 +19,9 @@ logger = get_logger("plugin_base")
|
|||||||
|
|
||||||
|
|
||||||
class PluginBase(ABC):
|
class PluginBase(ABC):
|
||||||
"""插件基类
|
"""插件总基类
|
||||||
|
|
||||||
所有插件都应该继承这个基类,一个插件可以包含多种组件:
|
所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。
|
||||||
- Action组件:处理聊天中的动作
|
|
||||||
- Command组件:处理命令请求
|
|
||||||
- 未来可扩展:Scheduler、Listener等
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 插件基本信息(子类必须定义)
|
# 插件基本信息(子类必须定义)
|
||||||
|
|||||||
@@ -346,67 +346,67 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# === 状态管理方法 ===
|
# === 状态管理方法 ===
|
||||||
|
|
||||||
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
# def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||||
# -------------------------------- NEED REFACTORING --------------------------------
|
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||||
# -------------------------------- LOGIC ERROR -------------------------------------
|
# # -------------------------------- LOGIC ERROR -------------------------------------
|
||||||
"""启用组件,支持命名空间解析"""
|
# """启用组件,支持命名空间解析"""
|
||||||
# 首先尝试找到正确的命名空间化名称
|
# # 首先尝试找到正确的命名空间化名称
|
||||||
component_info = self.get_component_info(component_name, component_type)
|
# component_info = self.get_component_info(component_name, component_type)
|
||||||
if not component_info:
|
# if not component_info:
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
# 根据组件类型构造正确的命名空间化名称
|
# # 根据组件类型构造正确的命名空间化名称
|
||||||
if component_info.component_type == ComponentType.ACTION:
|
# if component_info.component_type == ComponentType.ACTION:
|
||||||
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
||||||
elif component_info.component_type == ComponentType.COMMAND:
|
# elif component_info.component_type == ComponentType.COMMAND:
|
||||||
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
||||||
else:
|
# else:
|
||||||
namespaced_name = (
|
# namespaced_name = (
|
||||||
f"{component_info.component_type.value}.{component_name}"
|
# f"{component_info.component_type.value}.{component_name}"
|
||||||
if "." not in component_name
|
# if "." not in component_name
|
||||||
else component_name
|
# else component_name
|
||||||
)
|
# )
|
||||||
|
|
||||||
if namespaced_name in self._components:
|
# if namespaced_name in self._components:
|
||||||
self._components[namespaced_name].enabled = True
|
# self._components[namespaced_name].enabled = True
|
||||||
# 如果是Action,更新默认动作集
|
# # 如果是Action,更新默认动作集
|
||||||
# ---- HERE ----
|
# # ---- HERE ----
|
||||||
# if isinstance(component_info, ActionInfo):
|
# # if isinstance(component_info, ActionInfo):
|
||||||
# self._action_descriptions[component_name] = component_info.description
|
# # self._action_descriptions[component_name] = component_info.description
|
||||||
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
|
# logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
|
||||||
return True
|
# return True
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
# def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||||
# -------------------------------- NEED REFACTORING --------------------------------
|
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||||
# -------------------------------- LOGIC ERROR -------------------------------------
|
# # -------------------------------- LOGIC ERROR -------------------------------------
|
||||||
"""禁用组件,支持命名空间解析"""
|
# """禁用组件,支持命名空间解析"""
|
||||||
# 首先尝试找到正确的命名空间化名称
|
# # 首先尝试找到正确的命名空间化名称
|
||||||
component_info = self.get_component_info(component_name, component_type)
|
# component_info = self.get_component_info(component_name, component_type)
|
||||||
if not component_info:
|
# if not component_info:
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
# 根据组件类型构造正确的命名空间化名称
|
# # 根据组件类型构造正确的命名空间化名称
|
||||||
if component_info.component_type == ComponentType.ACTION:
|
# if component_info.component_type == ComponentType.ACTION:
|
||||||
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
||||||
elif component_info.component_type == ComponentType.COMMAND:
|
# elif component_info.component_type == ComponentType.COMMAND:
|
||||||
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
||||||
else:
|
# else:
|
||||||
namespaced_name = (
|
# namespaced_name = (
|
||||||
f"{component_info.component_type.value}.{component_name}"
|
# f"{component_info.component_type.value}.{component_name}"
|
||||||
if "." not in component_name
|
# if "." not in component_name
|
||||||
else component_name
|
# else component_name
|
||||||
)
|
# )
|
||||||
|
|
||||||
if namespaced_name in self._components:
|
# if namespaced_name in self._components:
|
||||||
self._components[namespaced_name].enabled = False
|
# self._components[namespaced_name].enabled = False
|
||||||
# 如果是Action,从默认动作集中移除
|
# # 如果是Action,从默认动作集中移除
|
||||||
# ---- HERE ----
|
# # ---- HERE ----
|
||||||
# if component_name in self._action_descriptions:
|
# # if component_name in self._action_descriptions:
|
||||||
# del self._action_descriptions[component_name]
|
# # del self._action_descriptions[component_name]
|
||||||
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
|
# logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
|
||||||
return True
|
# return True
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
def get_registry_stats(self) -> Dict[str, Any]:
|
def get_registry_stats(self) -> Dict[str, Any]:
|
||||||
"""获取注册中心统计信息"""
|
"""获取注册中心统计信息"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import importlib
|
import importlib
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import PythonDependency
|
from src.plugin_system.base.component_types import PythonDependency
|
||||||
@@ -176,7 +176,7 @@ class DependencyManager:
|
|||||||
logger.error(f"生成requirements文件失败: {str(e)}")
|
logger.error(f"生成requirements文件失败: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_install_summary(self) -> Dict[str, any]:
|
def get_install_summary(self) -> Dict[str, Any]:
|
||||||
"""获取安装摘要"""
|
"""获取安装摘要"""
|
||||||
return {
|
return {
|
||||||
"install_log": self.install_log.copy(),
|
"install_log": self.install_log.copy(),
|
||||||
|
|||||||
@@ -197,29 +197,29 @@ class PluginManager:
|
|||||||
"""获取所有启用的插件信息"""
|
"""获取所有启用的插件信息"""
|
||||||
return list(component_registry.get_enabled_plugins().values())
|
return list(component_registry.get_enabled_plugins().values())
|
||||||
|
|
||||||
def enable_plugin(self, plugin_name: str) -> bool:
|
# def enable_plugin(self, plugin_name: str) -> bool:
|
||||||
# -------------------------------- NEED REFACTORING --------------------------------
|
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||||
"""启用插件"""
|
# """启用插件"""
|
||||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||||
plugin_info.enabled = True
|
# plugin_info.enabled = True
|
||||||
# 启用插件的所有组件
|
# # 启用插件的所有组件
|
||||||
for component in plugin_info.components:
|
# for component in plugin_info.components:
|
||||||
component_registry.enable_component(component.name)
|
# component_registry.enable_component(component.name)
|
||||||
logger.debug(f"已启用插件: {plugin_name}")
|
# logger.debug(f"已启用插件: {plugin_name}")
|
||||||
return True
|
# return True
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
def disable_plugin(self, plugin_name: str) -> bool:
|
# def disable_plugin(self, plugin_name: str) -> bool:
|
||||||
# -------------------------------- NEED REFACTORING --------------------------------
|
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||||
"""禁用插件"""
|
# """禁用插件"""
|
||||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||||
plugin_info.enabled = False
|
# plugin_info.enabled = False
|
||||||
# 禁用插件的所有组件
|
# # 禁用插件的所有组件
|
||||||
for component in plugin_info.components:
|
# for component in plugin_info.components:
|
||||||
component_registry.disable_component(component.name)
|
# component_registry.disable_component(component.name)
|
||||||
logger.debug(f"已禁用插件: {plugin_name}")
|
# logger.debug(f"已禁用插件: {plugin_name}")
|
||||||
return True
|
# return True
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
||||||
"""获取插件实例
|
"""获取插件实例
|
||||||
|
|||||||
@@ -28,10 +28,10 @@ class CompareNumbersTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: 工具执行结果
|
dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
try:
|
num1: int | float = function_args.get("num1") # type: ignore
|
||||||
num1 = function_args.get("num1")
|
num2: int | float = function_args.get("num2") # type: ignore
|
||||||
num2 = function_args.get("num2")
|
|
||||||
|
|
||||||
|
try:
|
||||||
if num1 > num2:
|
if num1 > num2:
|
||||||
result = f"{num1} 大于 {num2}"
|
result = f"{num1} 大于 {num2}"
|
||||||
elif num1 < num2:
|
elif num1 < num2:
|
||||||
|
|||||||
@@ -68,10 +68,10 @@ class RenamePersonTool(BaseTool):
|
|||||||
)
|
)
|
||||||
result = await person_info_manager.qv_person_name(
|
result = await person_info_manager.qv_person_name(
|
||||||
person_id=person_id,
|
person_id=person_id,
|
||||||
user_nickname=user_nickname,
|
user_nickname=user_nickname, # type: ignore
|
||||||
user_cardname=user_cardname,
|
user_cardname=user_cardname, # type: ignore
|
||||||
user_avatar=user_avatar,
|
user_avatar=user_avatar, # type: ignore
|
||||||
request=request_context,
|
request=request_context, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 处理结果
|
# 3. 处理结果
|
||||||
|
|||||||
Reference in New Issue
Block a user