Merge branch 'new-storage' into plugin
This commit is contained in:
24
README.md
24
README.md
@@ -1,18 +1,18 @@
|
|||||||
# 麦麦!MaiCore-MaiMBot (编辑中)
|
# 麦麦!MaiCore-MaiMBot (编辑中)
|
||||||
<br />
|
<br />
|
||||||
<div style="text-align: center">
|
<div align="center">
|
||||||
|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
[](https://deepwiki.com/DrSmoothl/MaiBot)
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<p style="text-align: center">
|
<p align="center">
|
||||||
<a href="https://github.com/MaiM-with-u/MaiBot/">
|
<a href="https://github.com/MaiM-with-u/MaiBot/">
|
||||||
<img src="depends-data/maimai.png" alt="Logo" style="width: 200px">
|
<img src="depends-data/maimai.png" alt="Logo" style="width: 200px">
|
||||||
</a>
|
</a>
|
||||||
@@ -21,8 +21,8 @@
|
|||||||
画师:略nd
|
画师:略nd
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
<h3 style="text-align: center">MaiBot(麦麦)</h3>
|
<h3 align="center">MaiBot(麦麦)</h3>
|
||||||
<p style="text-align: center">
|
<p align="center">
|
||||||
一款专注于<strong> 群组聊天 </strong>的赛博网友
|
一款专注于<strong> 群组聊天 </strong>的赛博网友
|
||||||
<br />
|
<br />
|
||||||
<a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a>
|
<a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.config.config import BotConfig
|
from src.config.config import Config
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -14,8 +14,8 @@ async def reload_config():
|
|||||||
from src.config import config as config_module
|
from src.config import config as config_module
|
||||||
|
|
||||||
logger.debug("正在重载配置文件...")
|
logger.debug("正在重载配置文件...")
|
||||||
bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml")
|
bot_config_path = os.path.join(Config.get_config_dir(), "bot_config.toml")
|
||||||
config_module.global_config = BotConfig.load_config(config_path=bot_config_path)
|
config_module.global_config = Config.load_config(config_path=bot_config_path)
|
||||||
logger.debug("配置文件重载成功")
|
logger.debug("配置文件重载成功")
|
||||||
return {"status": "reloaded"}
|
return {"status": "reloaded"}
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
|
|||||||
@@ -5,12 +5,15 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, List, Any
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from ...common.database import db
|
# from gradio_client import file
|
||||||
|
|
||||||
|
from ...common.database.database_model import Emoji
|
||||||
|
from ...common.database.database import db as peewee_db
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from ..utils.utils_image import image_path_to_base64, image_manager
|
from ..utils.utils_image import image_path_to_base64, image_manager
|
||||||
from ..models.utils_model import LLMRequest
|
from ..models.utils_model import LLMRequest
|
||||||
@@ -51,7 +54,7 @@ class MaiEmoji:
|
|||||||
self.is_deleted = False # 标记是否已被删除
|
self.is_deleted = False # 标记是否已被删除
|
||||||
self.format = ""
|
self.format = ""
|
||||||
|
|
||||||
async def initialize_hash_format(self):
|
async def initialize_hash_format(self) -> Optional[bool]:
|
||||||
"""从文件创建表情包实例, 计算哈希值和格式"""
|
"""从文件创建表情包实例, 计算哈希值和格式"""
|
||||||
try:
|
try:
|
||||||
# 使用 full_path 检查文件是否存在
|
# 使用 full_path 检查文件是否存在
|
||||||
@@ -104,7 +107,7 @@ class MaiEmoji:
|
|||||||
self.is_deleted = True
|
self.is_deleted = True
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_to_db(self):
|
async def register_to_db(self) -> bool:
|
||||||
"""
|
"""
|
||||||
注册表情包
|
注册表情包
|
||||||
将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下
|
将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下
|
||||||
@@ -143,22 +146,22 @@ class MaiEmoji:
|
|||||||
# --- 数据库操作 ---
|
# --- 数据库操作 ---
|
||||||
try:
|
try:
|
||||||
# 准备数据库记录 for emoji collection
|
# 准备数据库记录 for emoji collection
|
||||||
emoji_record = {
|
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||||
"filename": self.filename,
|
|
||||||
"path": self.path, # 存储目录路径
|
|
||||||
"full_path": self.full_path, # 存储完整文件路径
|
|
||||||
"embedding": self.embedding,
|
|
||||||
"description": self.description,
|
|
||||||
"emotion": self.emotion,
|
|
||||||
"hash": self.hash,
|
|
||||||
"format": self.format,
|
|
||||||
"timestamp": int(self.register_time),
|
|
||||||
"usage_count": self.usage_count,
|
|
||||||
"last_used_time": self.last_used_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 使用upsert确保记录存在或被更新
|
Emoji.create(
|
||||||
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True)
|
hash=self.hash,
|
||||||
|
full_path=self.full_path,
|
||||||
|
format=self.format,
|
||||||
|
description=self.description,
|
||||||
|
emotion=emotion_str, # Store as comma-separated string
|
||||||
|
query_count=0, # Default value
|
||||||
|
is_registered=True,
|
||||||
|
is_banned=False, # Default value
|
||||||
|
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
|
||||||
|
register_time=self.register_time,
|
||||||
|
usage_count=self.usage_count,
|
||||||
|
last_used_time=self.last_used_time,
|
||||||
|
)
|
||||||
|
|
||||||
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||||
|
|
||||||
@@ -166,14 +169,6 @@ class MaiEmoji:
|
|||||||
|
|
||||||
except Exception as db_error:
|
except Exception as db_error:
|
||||||
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
|
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
|
||||||
# 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误
|
|
||||||
# 可以考虑在这里尝试删除已移动的文件,避免残留
|
|
||||||
try:
|
|
||||||
if os.path.exists(self.full_path): # full_path 此时是目标路径
|
|
||||||
os.remove(self.full_path)
|
|
||||||
logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}")
|
|
||||||
except Exception as remove_error:
|
|
||||||
logger.error(f"[错误] 回滚删除文件失败: {remove_error}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -181,7 +176,7 @@ class MaiEmoji:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self) -> bool:
|
||||||
"""删除表情包
|
"""删除表情包
|
||||||
|
|
||||||
删除表情包的文件和数据库记录
|
删除表情包的文件和数据库记录
|
||||||
@@ -201,10 +196,14 @@ class MaiEmoji:
|
|||||||
# 文件删除失败,但仍然尝试删除数据库记录
|
# 文件删除失败,但仍然尝试删除数据库记录
|
||||||
|
|
||||||
# 2. 删除数据库记录
|
# 2. 删除数据库记录
|
||||||
result = db.emoji.delete_one({"hash": self.hash})
|
try:
|
||||||
deleted_in_db = result.deleted_count > 0
|
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
||||||
|
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
||||||
|
except Emoji.DoesNotExist:
|
||||||
|
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||||
|
result = 0 # Indicate no DB record was deleted
|
||||||
|
|
||||||
if deleted_in_db:
|
if result > 0:
|
||||||
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
|
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
|
||||||
# 3. 标记对象已被删除
|
# 3. 标记对象已被删除
|
||||||
self.is_deleted = True
|
self.is_deleted = True
|
||||||
@@ -224,7 +223,7 @@ class MaiEmoji:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _emoji_objects_to_readable_list(emoji_objects):
|
def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
|
||||||
"""将表情包对象列表转换为可读的字符串列表
|
"""将表情包对象列表转换为可读的字符串列表
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@@ -243,47 +242,48 @@ def _emoji_objects_to_readable_list(emoji_objects):
|
|||||||
return emoji_info_list
|
return emoji_info_list
|
||||||
|
|
||||||
|
|
||||||
def _to_emoji_objects(data):
|
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||||
emoji_objects = []
|
emoji_objects = []
|
||||||
load_errors = 0
|
load_errors = 0
|
||||||
|
# data is now an iterable of Peewee Emoji model instances
|
||||||
emoji_data_list = list(data)
|
emoji_data_list = list(data)
|
||||||
|
|
||||||
for emoji_data in emoji_data_list:
|
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
|
||||||
full_path = emoji_data.get("full_path")
|
full_path = emoji_data.full_path
|
||||||
if not full_path:
|
if not full_path:
|
||||||
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}")
|
logger.warning(
|
||||||
|
f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}"
|
||||||
|
)
|
||||||
load_errors += 1
|
load_errors += 1
|
||||||
continue # 跳过缺少 full_path 的记录
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 full_path 初始化 MaiEmoji 对象
|
|
||||||
emoji = MaiEmoji(full_path=full_path)
|
emoji = MaiEmoji(full_path=full_path)
|
||||||
|
|
||||||
# 设置从数据库加载的属性
|
emoji.hash = emoji_data.emoji_hash
|
||||||
emoji.hash = emoji_data.get("hash", "")
|
|
||||||
# 如果 hash 为空,也跳过?取决于业务逻辑
|
|
||||||
if not emoji.hash:
|
if not emoji.hash:
|
||||||
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
|
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
|
||||||
load_errors += 1
|
load_errors += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
emoji.description = emoji_data.get("description", "")
|
emoji.description = emoji_data.description
|
||||||
emoji.emotion = emoji_data.get("emotion", [])
|
# Deserialize emotion string from DB to list
|
||||||
emoji.usage_count = emoji_data.get("usage_count", 0)
|
emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
|
||||||
# 优先使用 last_used_time,否则用 timestamp,最后用当前时间
|
emoji.usage_count = emoji_data.usage_count
|
||||||
last_used = emoji_data.get("last_used_time")
|
|
||||||
timestamp = emoji_data.get("timestamp")
|
|
||||||
emoji.last_used_time = (
|
|
||||||
last_used if last_used is not None else (timestamp if timestamp is not None else time.time())
|
|
||||||
)
|
|
||||||
emoji.register_time = timestamp if timestamp is not None else time.time()
|
|
||||||
emoji.format = emoji_data.get("format", "") # 加载格式
|
|
||||||
|
|
||||||
# 不需要再手动设置 path 和 filename,__init__ 会自动处理
|
db_last_used_time = emoji_data.last_used_time
|
||||||
|
db_register_time = emoji_data.register_time
|
||||||
|
|
||||||
|
# If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time
|
||||||
|
emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
|
||||||
|
# If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
|
||||||
|
emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
|
||||||
|
|
||||||
|
emoji.format = emoji_data.format
|
||||||
|
|
||||||
emoji_objects.append(emoji)
|
emoji_objects.append(emoji)
|
||||||
|
|
||||||
except ValueError as ve: # 捕获 __init__ 可能的错误
|
except ValueError as ve:
|
||||||
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
|
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
|
||||||
load_errors += 1
|
load_errors += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -292,13 +292,13 @@ def _to_emoji_objects(data):
|
|||||||
return emoji_objects, load_errors
|
return emoji_objects, load_errors
|
||||||
|
|
||||||
|
|
||||||
def _ensure_emoji_dir():
|
def _ensure_emoji_dir() -> None:
|
||||||
"""确保表情存储目录存在"""
|
"""确保表情存储目录存在"""
|
||||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||||
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
|
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
async def clear_temp_emoji():
|
async def clear_temp_emoji() -> None:
|
||||||
"""清理临时表情包
|
"""清理临时表情包
|
||||||
清理/data/emoji和/data/image目录下的所有文件
|
清理/data/emoji和/data/image目录下的所有文件
|
||||||
当目录中文件数超过100时,会全部删除
|
当目录中文件数超过100时,会全部删除
|
||||||
@@ -320,7 +320,7 @@ async def clear_temp_emoji():
|
|||||||
logger.success("[清理] 完成")
|
logger.success("[清理] 完成")
|
||||||
|
|
||||||
|
|
||||||
async def clean_unused_emojis(emoji_dir, emoji_objects):
|
async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None:
|
||||||
"""清理指定目录中未被 emoji_objects 追踪的表情包文件"""
|
"""清理指定目录中未被 emoji_objects 追踪的表情包文件"""
|
||||||
if not os.path.exists(emoji_dir):
|
if not os.path.exists(emoji_dir):
|
||||||
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
||||||
@@ -360,74 +360,52 @@ async def clean_unused_emojis(emoji_dir, emoji_objects):
|
|||||||
class EmojiManager:
|
class EmojiManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls) -> "EmojiManager":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance._initialized = False
|
cls._instance._initialized = False
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._initialized = None
|
self._initialized = None
|
||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
|
|
||||||
|
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
|
||||||
self.llm_emotion_judge = LLMRequest(
|
self.llm_emotion_judge = LLMRequest(
|
||||||
model=global_config.llm_normal, max_tokens=600, request_type="emoji"
|
model=global_config.model.normal, max_tokens=600, request_type="emoji"
|
||||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||||
|
|
||||||
self.emoji_num = 0
|
self.emoji_num = 0
|
||||||
self.emoji_num_max = global_config.max_emoji_num
|
self.emoji_num_max = global_config.emoji.max_reg_num
|
||||||
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
|
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
|
||||||
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型
|
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型
|
||||||
|
|
||||||
logger.info("启动表情包管理器")
|
logger.info("启动表情包管理器")
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self) -> None:
|
||||||
"""初始化数据库连接和表情目录"""
|
"""初始化数据库连接和表情目录"""
|
||||||
if not self._initialized:
|
peewee_db.connect(reuse_if_open=True)
|
||||||
try:
|
if peewee_db.is_closed():
|
||||||
self._ensure_emoji_collection()
|
raise RuntimeError("数据库连接失败")
|
||||||
_ensure_emoji_dir()
|
_ensure_emoji_dir()
|
||||||
self._initialized = True
|
Emoji.create_table(safe=True) # Ensures table exists
|
||||||
# 更新表情包数量
|
|
||||||
# 启动时执行一次完整性检查
|
|
||||||
# await self.check_emoji_file_integrity()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"初始化表情管理器失败: {e}")
|
|
||||||
|
|
||||||
def _ensure_db(self):
|
def _ensure_db(self) -> None:
|
||||||
"""确保数据库已初始化"""
|
"""确保数据库已初始化"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.initialize()
|
self.initialize()
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("EmojiManager not initialized")
|
raise RuntimeError("EmojiManager not initialized")
|
||||||
|
|
||||||
@staticmethod
|
def record_usage(self, emoji_hash: str) -> None:
|
||||||
def _ensure_emoji_collection():
|
|
||||||
"""确保emoji集合存在并创建索引
|
|
||||||
|
|
||||||
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
|
|
||||||
|
|
||||||
索引的作用是加快数据库查询速度:
|
|
||||||
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
|
|
||||||
- tags字段的普通索引: 加快按标签搜索表情包的速度
|
|
||||||
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
|
|
||||||
|
|
||||||
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
|
|
||||||
"""
|
|
||||||
if "emoji" not in db.list_collection_names():
|
|
||||||
db.create_collection("emoji")
|
|
||||||
db.emoji.create_index([("embedding", "2dsphere")])
|
|
||||||
db.emoji.create_index([("filename", 1)], unique=True)
|
|
||||||
|
|
||||||
def record_usage(self, emoji_hash: str):
|
|
||||||
"""记录表情使用次数"""
|
"""记录表情使用次数"""
|
||||||
try:
|
try:
|
||||||
db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}})
|
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
|
||||||
for emoji in self.emoji_objects:
|
emoji_update.usage_count += 1
|
||||||
if emoji.hash == emoji_hash:
|
emoji_update.last_used_time = time.time() # Update last used time
|
||||||
emoji.usage_count += 1
|
emoji_update.save() # Persist changes to DB
|
||||||
break
|
except Emoji.DoesNotExist:
|
||||||
|
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录表情使用失败: {str(e)}")
|
logger.error(f"记录表情使用失败: {str(e)}")
|
||||||
|
|
||||||
@@ -447,7 +425,6 @@ class EmojiManager:
|
|||||||
|
|
||||||
if not all_emojis:
|
if not all_emojis:
|
||||||
logger.warning("内存中没有任何表情包对象")
|
logger.warning("内存中没有任何表情包对象")
|
||||||
# 可以考虑再查一次数据库?或者依赖定期任务更新
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 计算每个表情包与输入文本的最大情感相似度
|
# 计算每个表情包与输入文本的最大情感相似度
|
||||||
@@ -463,40 +440,38 @@ class EmojiManager:
|
|||||||
|
|
||||||
# 计算与每个emotion标签的相似度,取最大值
|
# 计算与每个emotion标签的相似度,取最大值
|
||||||
max_similarity = 0
|
max_similarity = 0
|
||||||
best_matching_emotion = "" # 记录最匹配的 emotion 喵~
|
best_matching_emotion = ""
|
||||||
for emotion in emotions:
|
for emotion in emotions:
|
||||||
# 使用编辑距离计算相似度
|
# 使用编辑距离计算相似度
|
||||||
distance = self._levenshtein_distance(text_emotion, emotion)
|
distance = self._levenshtein_distance(text_emotion, emotion)
|
||||||
max_len = max(len(text_emotion), len(emotion))
|
max_len = max(len(text_emotion), len(emotion))
|
||||||
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
||||||
if similarity > max_similarity: # 如果找到更相似的喵~
|
if similarity > max_similarity:
|
||||||
max_similarity = similarity
|
max_similarity = similarity
|
||||||
best_matching_emotion = emotion # 就记下这个 emotion 喵~
|
best_matching_emotion = emotion
|
||||||
|
|
||||||
if best_matching_emotion: # 确保有匹配的情感才添加喵~
|
if best_matching_emotion:
|
||||||
emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~
|
emoji_similarities.append((emoji, max_similarity, best_matching_emotion))
|
||||||
|
|
||||||
# 按相似度降序排序
|
# 按相似度降序排序
|
||||||
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
|
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
# 获取前10个最相似的表情包
|
# 获取前10个最相似的表情包
|
||||||
top_emojis = (
|
top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
|
||||||
emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
|
|
||||||
) # 改个名字,更清晰喵~
|
|
||||||
|
|
||||||
if not top_emojis:
|
if not top_emojis:
|
||||||
logger.warning("未找到匹配的表情包")
|
logger.warning("未找到匹配的表情包")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 从前几个中随机选择一个
|
# 从前几个中随机选择一个
|
||||||
selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~
|
selected_emoji, similarity, matched_emotion = random.choice(top_emojis)
|
||||||
|
|
||||||
# 更新使用次数
|
# 更新使用次数
|
||||||
self.record_usage(selected_emoji.hash)
|
self.record_usage(selected_emoji.emoji_hash)
|
||||||
|
|
||||||
_time_end = time.time()
|
_time_end = time.time()
|
||||||
|
|
||||||
logger.info( # 使用匹配到的 emotion 记录日志喵~
|
logger.info(
|
||||||
f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}"
|
f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}"
|
||||||
)
|
)
|
||||||
# 返回完整文件路径和描述
|
# 返回完整文件路径和描述
|
||||||
@@ -534,7 +509,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
return previous_row[-1]
|
return previous_row[-1]
|
||||||
|
|
||||||
async def check_emoji_file_integrity(self):
|
async def check_emoji_file_integrity(self) -> None:
|
||||||
"""检查表情包文件完整性
|
"""检查表情包文件完整性
|
||||||
遍历self.emoji_objects中的所有对象,检查文件是否存在
|
遍历self.emoji_objects中的所有对象,检查文件是否存在
|
||||||
如果文件已被删除,则执行对象的删除方法并从列表中移除
|
如果文件已被删除,则执行对象的删除方法并从列表中移除
|
||||||
@@ -599,7 +574,7 @@ class EmojiManager:
|
|||||||
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
|
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def start_periodic_check_register(self):
|
async def start_periodic_check_register(self) -> None:
|
||||||
"""定期检查表情包完整性和数量"""
|
"""定期检查表情包完整性和数量"""
|
||||||
await self.get_all_emoji_from_db()
|
await self.get_all_emoji_from_db()
|
||||||
while True:
|
while True:
|
||||||
@@ -613,18 +588,18 @@ class EmojiManager:
|
|||||||
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
|
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
|
||||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||||
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
|
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
|
||||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
await asyncio.sleep(global_config.emoji.check_interval * 60)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查目录是否为空
|
# 检查目录是否为空
|
||||||
files = os.listdir(EMOJI_DIR)
|
files = os.listdir(EMOJI_DIR)
|
||||||
if not files:
|
if not files:
|
||||||
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
|
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
|
||||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
await asyncio.sleep(global_config.emoji.check_interval * 60)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查是否需要处理表情包(数量超过最大值或不足)
|
# 检查是否需要处理表情包(数量超过最大值或不足)
|
||||||
if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or (
|
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
|
||||||
self.emoji_num < self.emoji_num_max
|
self.emoji_num < self.emoji_num_max
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@@ -651,15 +626,16 @@ class EmojiManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
|
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
|
||||||
|
|
||||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
await asyncio.sleep(global_config.emoji.check_interval * 60)
|
||||||
|
|
||||||
async def get_all_emoji_from_db(self):
|
async def get_all_emoji_from_db(self) -> None:
|
||||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||||
try:
|
try:
|
||||||
self._ensure_db()
|
self._ensure_db()
|
||||||
logger.info("[数据库] 开始加载所有表情包记录...")
|
logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...")
|
||||||
|
|
||||||
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find())
|
emoji_peewee_instances = Emoji.select()
|
||||||
|
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||||
|
|
||||||
# 更新内存中的列表和数量
|
# 更新内存中的列表和数量
|
||||||
self.emoji_objects = emoji_objects
|
self.emoji_objects = emoji_objects
|
||||||
@@ -674,7 +650,7 @@ class EmojiManager:
|
|||||||
self.emoji_objects = [] # 加载失败则清空列表
|
self.emoji_objects = [] # 加载失败则清空列表
|
||||||
self.emoji_num = 0
|
self.emoji_num = 0
|
||||||
|
|
||||||
async def get_emoji_from_db(self, emoji_hash=None):
|
async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
|
||||||
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
|
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@@ -686,15 +662,16 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
self._ensure_db()
|
self._ensure_db()
|
||||||
|
|
||||||
query = {}
|
|
||||||
if emoji_hash:
|
if emoji_hash:
|
||||||
query = {"hash": emoji_hash}
|
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||||
)
|
)
|
||||||
|
query = Emoji.select()
|
||||||
|
|
||||||
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query))
|
emoji_peewee_instances = query
|
||||||
|
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||||
|
|
||||||
if load_errors > 0:
|
if load_errors > 0:
|
||||||
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
|
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
|
||||||
@@ -705,7 +682,7 @@ class EmojiManager:
|
|||||||
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
|
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]:
|
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
|
||||||
"""从内存中的 emoji_objects 列表获取表情包
|
"""从内存中的 emoji_objects 列表获取表情包
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@@ -758,7 +735,7 @@ class EmojiManager:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def replace_a_emoji(self, new_emoji: MaiEmoji):
|
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
|
||||||
"""替换一个表情包
|
"""替换一个表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -788,7 +765,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
# 构建提示词
|
# 构建提示词
|
||||||
prompt = (
|
prompt = (
|
||||||
f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max}),"
|
f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max}),"
|
||||||
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
|
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
|
||||||
f"新表情包信息:\n"
|
f"新表情包信息:\n"
|
||||||
f"描述: {new_emoji.description}\n\n"
|
f"描述: {new_emoji.description}\n\n"
|
||||||
@@ -819,7 +796,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
# 删除选定的表情包
|
# 删除选定的表情包
|
||||||
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
|
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
|
||||||
delete_success = await self.delete_emoji(emoji_to_delete.hash)
|
delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash)
|
||||||
|
|
||||||
if delete_success:
|
if delete_success:
|
||||||
# 修复:等待异步注册完成
|
# 修复:等待异步注册完成
|
||||||
@@ -847,7 +824,7 @@ class EmojiManager:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]:
|
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
|
||||||
"""获取表情包描述和情感列表
|
"""获取表情包描述和情感列表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -871,10 +848,10 @@ class EmojiManager:
|
|||||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||||
|
|
||||||
# 审核表情包
|
# 审核表情包
|
||||||
if global_config.EMOJI_CHECK:
|
if global_config.emoji.content_filtration:
|
||||||
prompt = f'''
|
prompt = f'''
|
||||||
这是一个表情包,请对这个表情包进行审核,标准如下:
|
这是一个表情包,请对这个表情包进行审核,标准如下:
|
||||||
1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
|
1. 必须符合"{global_config.emoji.filtration_prompt}"的要求
|
||||||
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
|
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
|
||||||
3. 不能是任何形式的截图,聊天记录或视频截图
|
3. 不能是任何形式的截图,聊天记录或视频截图
|
||||||
4. 不要出现5个以上文字
|
4. 不要出现5个以上文字
|
||||||
|
|||||||
@@ -76,9 +76,10 @@ def init_prompt():
|
|||||||
class DefaultExpressor:
|
class DefaultExpressor:
|
||||||
def __init__(self, chat_id: str):
|
def __init__(self, chat_id: str):
|
||||||
self.log_prefix = "expressor"
|
self.log_prefix = "expressor"
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.express_model = LLMRequest(
|
self.express_model = LLMRequest(
|
||||||
model=global_config.llm_normal,
|
model=global_config.model.normal,
|
||||||
temperature=global_config.llm_normal["temp"],
|
temperature=global_config.model.normal["temp"],
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
request_type="response_heartflow",
|
request_type="response_heartflow",
|
||||||
)
|
)
|
||||||
@@ -102,8 +103,8 @@ class DefaultExpressor:
|
|||||||
messageinfo = anchor_message.message_info
|
messageinfo = anchor_message.message_info
|
||||||
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
|
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
|
||||||
bot_user_info = UserInfo(
|
bot_user_info = UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.bot.qq_account,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.bot.nickname,
|
||||||
platform=messageinfo.platform,
|
platform=messageinfo.platform,
|
||||||
)
|
)
|
||||||
# logger.debug(f"创建思考消息:{anchor_message}")
|
# logger.debug(f"创建思考消息:{anchor_message}")
|
||||||
@@ -192,7 +193,7 @@ class DefaultExpressor:
|
|||||||
try:
|
try:
|
||||||
# 1. 获取情绪影响因子并调整模型温度
|
# 1. 获取情绪影响因子并调整模型温度
|
||||||
arousal_multiplier = mood_manager.get_arousal_multiplier()
|
arousal_multiplier = mood_manager.get_arousal_multiplier()
|
||||||
current_temp = float(global_config.llm_normal["temp"]) * arousal_multiplier
|
current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
|
||||||
self.express_model.params["temperature"] = current_temp # 动态调整温度
|
self.express_model.params["temperature"] = current_temp # 动态调整温度
|
||||||
|
|
||||||
# 2. 获取信息捕捉器
|
# 2. 获取信息捕捉器
|
||||||
@@ -231,6 +232,7 @@ class DefaultExpressor:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
|
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
|
||||||
content, reasoning_content, model_name = await self.express_model.generate_response(prompt)
|
content, reasoning_content, model_name = await self.express_model.generate_response(prompt)
|
||||||
|
|
||||||
@@ -482,8 +484,8 @@ class DefaultExpressor:
|
|||||||
"""构建单个发送消息"""
|
"""构建单个发送消息"""
|
||||||
|
|
||||||
bot_user_info = UserInfo(
|
bot_user_info = UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.bot.qq_account,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.bot.nickname,
|
||||||
platform=self.chat_stream.platform,
|
platform=self.chat_stream.platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -77,8 +77,9 @@ def init_prompt() -> None:
|
|||||||
|
|
||||||
class ExpressionLearner:
|
class ExpressionLearner:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.express_learn_model: LLMRequest = LLMRequest(
|
self.express_learn_model: LLMRequest = LLMRequest(
|
||||||
model=global_config.llm_normal,
|
model=global_config.model.normal,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
request_type="response_heartflow",
|
request_type="response_heartflow",
|
||||||
@@ -289,7 +290,7 @@ class ExpressionLearner:
|
|||||||
# 构建prompt
|
# 构建prompt
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"personality_expression_prompt",
|
"personality_expression_prompt",
|
||||||
personality=global_config.expression_style,
|
personality=global_config.personality.expression_style,
|
||||||
)
|
)
|
||||||
# logger.info(f"个性表达方式提取prompt: {prompt}")
|
# logger.info(f"个性表达方式提取prompt: {prompt}")
|
||||||
|
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否包含过滤词
|
bool: 是否包含过滤词
|
||||||
"""
|
"""
|
||||||
for word in global_config.ban_words:
|
for word in global_config.chat.ban_words:
|
||||||
if word in text:
|
if word in text:
|
||||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||||
@@ -132,7 +132,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否匹配过滤正则
|
bool: 是否匹配过滤正则
|
||||||
"""
|
"""
|
||||||
for pattern in global_config.ban_msgs_regex:
|
for pattern in global_config.chat.ban_msgs_regex:
|
||||||
if pattern.search(text):
|
if pattern.search(text):
|
||||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ from src.manager.mood_manager import mood_manager
|
|||||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||||
import random
|
import random
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from src.common.database.database_model import Knowledges
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("prompt")
|
logger = get_logger("prompt")
|
||||||
@@ -45,7 +48,7 @@ def init_prompt():
|
|||||||
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1},
|
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1},
|
||||||
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger}
|
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger}
|
||||||
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
|
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""",
|
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""",
|
||||||
"reasoning_prompt_main",
|
"reasoning_prompt_main",
|
||||||
@@ -110,7 +113,7 @@ class PromptBuilder:
|
|||||||
who_chat_in_group = get_recent_group_speaker(
|
who_chat_in_group = get_recent_group_speaker(
|
||||||
chat_stream.stream_id,
|
chat_stream.stream_id,
|
||||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
||||||
limit=global_config.observation_context_size,
|
limit=global_config.chat.observation_context_size,
|
||||||
)
|
)
|
||||||
elif chat_stream.user_info:
|
elif chat_stream.user_info:
|
||||||
who_chat_in_group.append(
|
who_chat_in_group.append(
|
||||||
@@ -158,7 +161,7 @@ class PromptBuilder:
|
|||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=global_config.observation_context_size,
|
limit=global_config.chat.observation_context_size,
|
||||||
)
|
)
|
||||||
chat_talking_prompt = await build_readable_messages(
|
chat_talking_prompt = await build_readable_messages(
|
||||||
message_list_before_now,
|
message_list_before_now,
|
||||||
@@ -170,18 +173,15 @@ class PromptBuilder:
|
|||||||
|
|
||||||
# 关键词检测与反应
|
# 关键词检测与反应
|
||||||
keywords_reaction_prompt = ""
|
keywords_reaction_prompt = ""
|
||||||
for rule in global_config.keywords_reaction_rules:
|
for rule in global_config.keyword_reaction.rules:
|
||||||
if rule.get("enable", False):
|
if rule.enable:
|
||||||
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
|
if any(keyword in message_txt for keyword in rule.keywords):
|
||||||
logger.info(
|
logger.info(f"检测到以下关键词之一:{rule.keywords},触发反应:{rule.reaction}")
|
||||||
f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}"
|
keywords_reaction_prompt += f"{rule.reaction},"
|
||||||
)
|
|
||||||
keywords_reaction_prompt += rule.get("reaction", "") + ","
|
|
||||||
else:
|
else:
|
||||||
for pattern in rule.get("regex", []):
|
for pattern in rule.regex:
|
||||||
result = pattern.search(message_txt)
|
if result := pattern.search(message_txt):
|
||||||
if result:
|
reaction = rule.reaction
|
||||||
reaction = rule.get("reaction", "")
|
|
||||||
for name, content in result.groupdict().items():
|
for name, content in result.groupdict().items():
|
||||||
reaction = reaction.replace(f"[{name}]", content)
|
reaction = reaction.replace(f"[{name}]", content)
|
||||||
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
|
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
|
||||||
@@ -227,8 +227,8 @@ class PromptBuilder:
|
|||||||
chat_target_2=chat_target_2,
|
chat_target_2=chat_target_2,
|
||||||
chat_talking_prompt=chat_talking_prompt,
|
chat_talking_prompt=chat_talking_prompt,
|
||||||
message_txt=message_txt,
|
message_txt=message_txt,
|
||||||
bot_name=global_config.BOT_NICKNAME,
|
bot_name=global_config.bot.nickname,
|
||||||
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
|
bot_other_names="/".join(global_config.bot.alias_names),
|
||||||
prompt_personality=prompt_personality,
|
prompt_personality=prompt_personality,
|
||||||
mood_prompt=mood_prompt,
|
mood_prompt=mood_prompt,
|
||||||
reply_style1=reply_style1_chosen,
|
reply_style1=reply_style1_chosen,
|
||||||
@@ -249,8 +249,8 @@ class PromptBuilder:
|
|||||||
prompt_info=prompt_info,
|
prompt_info=prompt_info,
|
||||||
chat_talking_prompt=chat_talking_prompt,
|
chat_talking_prompt=chat_talking_prompt,
|
||||||
message_txt=message_txt,
|
message_txt=message_txt,
|
||||||
bot_name=global_config.BOT_NICKNAME,
|
bot_name=global_config.bot.nickname,
|
||||||
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
|
bot_other_names="/".join(global_config.bot.alias_names),
|
||||||
prompt_personality=prompt_personality,
|
prompt_personality=prompt_personality,
|
||||||
mood_prompt=mood_prompt,
|
mood_prompt=mood_prompt,
|
||||||
reply_style1=reply_style1_chosen,
|
reply_style1=reply_style1_chosen,
|
||||||
@@ -269,30 +269,6 @@ class PromptBuilder:
|
|||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
||||||
topics = []
|
topics = []
|
||||||
# try:
|
|
||||||
# # 先尝试使用记忆系统的方法获取主题
|
|
||||||
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
|
||||||
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
|
||||||
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
|
||||||
|
|
||||||
# # 提取关键词
|
|
||||||
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
|
||||||
# if not topics:
|
|
||||||
# topics = []
|
|
||||||
# else:
|
|
||||||
# topics = [
|
|
||||||
# topic.strip()
|
|
||||||
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
|
||||||
# if topic.strip()
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
|
||||||
# # 如果LLM提取失败,使用jieba分词提取关键词作为备选
|
|
||||||
# words = jieba.cut(message)
|
|
||||||
# topics = [word for word in words if len(word) > 1][:5]
|
|
||||||
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
|
||||||
|
|
||||||
# 如果无法提取到主题,直接使用整个消息
|
# 如果无法提取到主题,直接使用整个消息
|
||||||
if not topics:
|
if not topics:
|
||||||
@@ -402,8 +378,6 @@ class PromptBuilder:
|
|||||||
for _i, result in enumerate(results, 1):
|
for _i, result in enumerate(results, 1):
|
||||||
_similarity = result["similarity"]
|
_similarity = result["similarity"]
|
||||||
content = result["content"].strip()
|
content = result["content"].strip()
|
||||||
# 调试:为内容添加序号和相似度信息
|
|
||||||
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
|
||||||
related_info += f"{content}\n"
|
related_info += f"{content}\n"
|
||||||
related_info += "\n"
|
related_info += "\n"
|
||||||
|
|
||||||
@@ -432,14 +406,14 @@ class PromptBuilder:
|
|||||||
return related_info
|
return related_info
|
||||||
else:
|
else:
|
||||||
logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索")
|
logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索")
|
||||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
|
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
|
||||||
related_info += knowledge_from_old
|
related_info += knowledge_from_old
|
||||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||||
return related_info
|
return related_info
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||||
try:
|
try:
|
||||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
|
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
|
||||||
related_info += knowledge_from_old
|
related_info += knowledge_from_old
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
|
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
|
||||||
@@ -455,69 +429,69 @@ class PromptBuilder:
|
|||||||
) -> Union[str, list]:
|
) -> Union[str, list]:
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return "" if not return_raw else []
|
return "" if not return_raw else []
|
||||||
# 使用余弦相似度计算
|
|
||||||
pipeline = [
|
|
||||||
{
|
|
||||||
"$addFields": {
|
|
||||||
"dotProduct": {
|
|
||||||
"$reduce": {
|
|
||||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
|
||||||
"initialValue": 0,
|
|
||||||
"in": {
|
|
||||||
"$add": [
|
|
||||||
"$$value",
|
|
||||||
{
|
|
||||||
"$multiply": [
|
|
||||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
|
||||||
{"$arrayElemAt": [query_embedding, "$$this"]},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"magnitude1": {
|
|
||||||
"$sqrt": {
|
|
||||||
"$reduce": {
|
|
||||||
"input": "$embedding",
|
|
||||||
"initialValue": 0,
|
|
||||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"magnitude2": {
|
|
||||||
"$sqrt": {
|
|
||||||
"$reduce": {
|
|
||||||
"input": query_embedding,
|
|
||||||
"initialValue": 0,
|
|
||||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
|
||||||
{
|
|
||||||
"$match": {
|
|
||||||
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$sort": {"similarity": -1}},
|
|
||||||
{"$limit": limit},
|
|
||||||
{"$project": {"content": 1, "similarity": 1}},
|
|
||||||
]
|
|
||||||
|
|
||||||
results = list(db.knowledges.aggregate(pipeline))
|
results_with_similarity = []
|
||||||
logger.debug(f"知识库查询结果数量: {len(results)}")
|
try:
|
||||||
|
# Fetch all knowledge entries
|
||||||
|
# This might be inefficient for very large databases.
|
||||||
|
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
|
||||||
|
all_knowledges = Knowledges.select()
|
||||||
|
|
||||||
if not results:
|
if not all_knowledges:
|
||||||
|
return [] if return_raw else ""
|
||||||
|
|
||||||
|
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
|
||||||
|
if query_embedding_magnitude == 0: # Avoid division by zero
|
||||||
|
return "" if not return_raw else []
|
||||||
|
|
||||||
|
for knowledge_item in all_knowledges:
|
||||||
|
try:
|
||||||
|
db_embedding_str = knowledge_item.embedding
|
||||||
|
db_embedding = json.loads(db_embedding_str)
|
||||||
|
|
||||||
|
if len(db_embedding) != len(query_embedding):
|
||||||
|
logger.warning(
|
||||||
|
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate Cosine Similarity
|
||||||
|
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
|
||||||
|
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
|
||||||
|
|
||||||
|
if db_embedding_magnitude == 0: # Avoid division by zero
|
||||||
|
similarity = 0.0
|
||||||
|
else:
|
||||||
|
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
|
||||||
|
|
||||||
|
if similarity >= threshold:
|
||||||
|
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing knowledge item: {e}")
|
||||||
|
|
||||||
|
# Sort by similarity in descending order
|
||||||
|
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
|
||||||
|
|
||||||
|
# Limit results
|
||||||
|
limited_results = results_with_similarity[:limit]
|
||||||
|
|
||||||
|
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
|
||||||
|
|
||||||
|
if not limited_results:
|
||||||
return "" if not return_raw else []
|
return "" if not return_raw else []
|
||||||
|
|
||||||
if return_raw:
|
if return_raw:
|
||||||
return results
|
return limited_results
|
||||||
else:
|
else:
|
||||||
# 返回所有找到的内容,用换行分隔
|
return "\n".join(str(result["content"]) for result in limited_results)
|
||||||
return "\n".join(str(result["content"]) for result in results)
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error querying Knowledges with Peewee: {e}")
|
||||||
|
return "" if not return_raw else []
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|||||||
@@ -26,8 +26,9 @@ class ChattingInfoProcessor(BaseProcessor):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化观察处理器"""
|
"""初始化观察处理器"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.llm_summary = LLMRequest(
|
self.llm_summary = LLMRequest(
|
||||||
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
|
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def process_info(
|
async def process_info(
|
||||||
@@ -110,12 +111,12 @@ class ChattingInfoProcessor(BaseProcessor):
|
|||||||
"created_at": datetime.now().timestamp(),
|
"created_at": datetime.now().timestamp(),
|
||||||
}
|
}
|
||||||
|
|
||||||
obs.mid_memorys.append(mid_memory)
|
obs.mid_memories.append(mid_memory)
|
||||||
if len(obs.mid_memorys) > obs.max_mid_memory_len:
|
if len(obs.mid_memories) > obs.max_mid_memory_len:
|
||||||
obs.mid_memorys.pop(0) # 移除最旧的
|
obs.mid_memories.pop(0) # 移除最旧的
|
||||||
|
|
||||||
mid_memory_str = "之前聊天的内容概述是:\n"
|
mid_memory_str = "之前聊天的内容概述是:\n"
|
||||||
for mid_memory_item in obs.mid_memorys: # 重命名循环变量以示区分
|
for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分
|
||||||
time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60)
|
time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60)
|
||||||
mid_memory_str += (
|
mid_memory_str += (
|
||||||
f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}):{mid_memory_item['theme']}\n"
|
f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}):{mid_memory_item['theme']}\n"
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ class MindProcessor(BaseProcessor):
|
|||||||
self.subheartflow_id = subheartflow_id
|
self.subheartflow_id = subheartflow_id
|
||||||
|
|
||||||
self.llm_model = LLMRequest(
|
self.llm_model = LLMRequest(
|
||||||
model=global_config.llm_sub_heartflow,
|
model=global_config.model.sub_heartflow,
|
||||||
temperature=global_config.llm_sub_heartflow["temp"],
|
temperature=global_config.model.sub_heartflow["temp"],
|
||||||
max_tokens=800,
|
max_tokens=800,
|
||||||
request_type="sub_heart_flow",
|
request_type="sub_heart_flow",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class ToolProcessor(BaseProcessor):
|
|||||||
self.subheartflow_id = subheartflow_id
|
self.subheartflow_id = subheartflow_id
|
||||||
self.log_prefix = f"[{subheartflow_id}:ToolExecutor] "
|
self.log_prefix = f"[{subheartflow_id}:ToolExecutor] "
|
||||||
self.llm_model = LLMRequest(
|
self.llm_model = LLMRequest(
|
||||||
model=global_config.llm_tool_use,
|
model=global_config.model.tool_use,
|
||||||
max_tokens=500,
|
max_tokens=500,
|
||||||
request_type="tool_execution",
|
request_type="tool_execution",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,8 +34,9 @@ def init_prompt():
|
|||||||
|
|
||||||
class MemoryActivator:
|
class MemoryActivator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.summary_model = LLMRequest(
|
self.summary_model = LLMRequest(
|
||||||
model=global_config.llm_summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
|
model=global_config.model.summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
|
||||||
)
|
)
|
||||||
self.running_memory = []
|
self.running_memory = []
|
||||||
|
|
||||||
|
|||||||
@@ -35,8 +35,9 @@ class Heartflow:
|
|||||||
self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state)
|
self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state)
|
||||||
|
|
||||||
# LLM模型配置
|
# LLM模型配置
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.llm_model = LLMRequest(
|
self.llm_model = LLMRequest(
|
||||||
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
|
model=global_config.model.heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 外部依赖模块
|
# 外部依赖模块
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ MAX_REPLY_PROBABILITY = 1
|
|||||||
class InterestChatting:
|
class InterestChatting:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
decay_rate=global_config.default_decay_rate_per_second,
|
decay_rate=global_config.focus_chat.default_decay_rate_per_second,
|
||||||
max_interest=MAX_INTEREST,
|
max_interest=MAX_INTEREST,
|
||||||
trigger_threshold=global_config.reply_trigger_threshold,
|
trigger_threshold=global_config.focus_chat.reply_trigger_threshold,
|
||||||
max_probability=MAX_REPLY_PROBABILITY,
|
max_probability=MAX_REPLY_PROBABILITY,
|
||||||
):
|
):
|
||||||
# 基础属性初始化
|
# 基础属性初始化
|
||||||
|
|||||||
@@ -18,19 +18,14 @@ enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
|
|||||||
prevent_offline_state = True
|
prevent_offline_state = True
|
||||||
# 目前默认不启用OFFLINE状态
|
# 目前默认不启用OFFLINE状态
|
||||||
|
|
||||||
# 不同状态下普通聊天的最大消息数
|
MAX_NORMAL_CHAT_NUM_PEEKING = int(global_config.chat.base_normal_chat_num / 2)
|
||||||
base_normal_chat_num = global_config.base_normal_chat_num
|
MAX_NORMAL_CHAT_NUM_NORMAL = global_config.chat.base_normal_chat_num
|
||||||
base_focused_chat_num = global_config.base_focused_chat_num
|
MAX_NORMAL_CHAT_NUM_FOCUSED = global_config.chat.base_normal_chat_num + 1
|
||||||
|
|
||||||
|
|
||||||
MAX_NORMAL_CHAT_NUM_PEEKING = int(base_normal_chat_num / 2)
|
|
||||||
MAX_NORMAL_CHAT_NUM_NORMAL = base_normal_chat_num
|
|
||||||
MAX_NORMAL_CHAT_NUM_FOCUSED = base_normal_chat_num + 1
|
|
||||||
|
|
||||||
# 不同状态下专注聊天的最大消息数
|
# 不同状态下专注聊天的最大消息数
|
||||||
MAX_FOCUSED_CHAT_NUM_PEEKING = int(base_focused_chat_num / 2)
|
MAX_FOCUSED_CHAT_NUM_PEEKING = int(global_config.chat.base_focused_chat_num / 2)
|
||||||
MAX_FOCUSED_CHAT_NUM_NORMAL = base_focused_chat_num
|
MAX_FOCUSED_CHAT_NUM_NORMAL = global_config.chat.base_focused_chat_num
|
||||||
MAX_FOCUSED_CHAT_NUM_FOCUSED = base_focused_chat_num + 2
|
MAX_FOCUSED_CHAT_NUM_FOCUSED = global_config.chat.base_focused_chat_num + 2
|
||||||
|
|
||||||
# -- 状态定义 --
|
# -- 状态定义 --
|
||||||
|
|
||||||
|
|||||||
@@ -55,19 +55,20 @@ class ChattingObservation(Observation):
|
|||||||
self.talking_message = []
|
self.talking_message = []
|
||||||
self.talking_message_str = ""
|
self.talking_message_str = ""
|
||||||
self.talking_message_str_truncate = ""
|
self.talking_message_str_truncate = ""
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.bot.nickname
|
||||||
self.nick_name = global_config.BOT_ALIAS_NAMES
|
self.nick_name = global_config.bot.alias_names
|
||||||
self.max_now_obs_len = global_config.observation_context_size
|
self.max_now_obs_len = global_config.chat.observation_context_size
|
||||||
self.overlap_len = global_config.compressed_length
|
self.overlap_len = global_config.focus_chat.compressed_length
|
||||||
self.mid_memorys = []
|
self.mid_memories = []
|
||||||
self.max_mid_memory_len = global_config.compress_length_limit
|
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
|
||||||
self.mid_memory_info = ""
|
self.mid_memory_info = ""
|
||||||
self.person_list = []
|
self.person_list = []
|
||||||
self.oldest_messages = []
|
self.oldest_messages = []
|
||||||
self.oldest_messages_str = ""
|
self.oldest_messages_str = ""
|
||||||
self.compressor_prompt = ""
|
self.compressor_prompt = ""
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.llm_summary = LLMRequest(
|
self.llm_summary = LLMRequest(
|
||||||
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
|
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
@@ -85,7 +86,7 @@ class ChattingObservation(Observation):
|
|||||||
for id in ids:
|
for id in ids:
|
||||||
print(f"id:{id}")
|
print(f"id:{id}")
|
||||||
try:
|
try:
|
||||||
for mid_memory in self.mid_memorys:
|
for mid_memory in self.mid_memories:
|
||||||
if mid_memory["id"] == id:
|
if mid_memory["id"] == id:
|
||||||
mid_memory_by_id = mid_memory
|
mid_memory_by_id = mid_memory
|
||||||
msg_str = ""
|
msg_str = ""
|
||||||
@@ -103,7 +104,7 @@ class ChattingObservation(Observation):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
mid_memory_str = "之前的聊天内容:\n"
|
mid_memory_str = "之前的聊天内容:\n"
|
||||||
for mid_memory in self.mid_memorys:
|
for mid_memory in self.mid_memories:
|
||||||
mid_memory_str += f"{mid_memory['theme']}\n"
|
mid_memory_str += f"{mid_memory['theme']}\n"
|
||||||
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
|
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
|
||||||
|
|
||||||
|
|||||||
@@ -76,8 +76,9 @@ class SubHeartflowManager:
|
|||||||
|
|
||||||
# 为 LLM 状态评估创建一个 LLMRequest 实例
|
# 为 LLM 状态评估创建一个 LLMRequest 实例
|
||||||
# 使用与 Heartflow 相同的模型和参数
|
# 使用与 Heartflow 相同的模型和参数
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.llm_state_evaluator = LLMRequest(
|
self.llm_state_evaluator = LLMRequest(
|
||||||
model=global_config.llm_heartflow, # 与 Heartflow 一致
|
model=global_config.model.heartflow, # 与 Heartflow 一致
|
||||||
temperature=0.6, # 与 Heartflow 一致
|
temperature=0.6, # 与 Heartflow 一致
|
||||||
max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多)
|
max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多)
|
||||||
request_type="subheartflow_state_eval", # 保留特定的请求类型
|
request_type="subheartflow_state_eval", # 保留特定的请求类型
|
||||||
@@ -278,7 +279,7 @@ class SubHeartflowManager:
|
|||||||
focused_limit = current_state.get_focused_chat_max_num()
|
focused_limit = current_state.get_focused_chat_max_num()
|
||||||
|
|
||||||
# --- 新增:检查是否允许进入 FOCUS 模式 --- #
|
# --- 新增:检查是否允许进入 FOCUS 模式 --- #
|
||||||
if not global_config.allow_focus_mode:
|
if not global_config.chat.allow_focus_mode:
|
||||||
if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏
|
if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏
|
||||||
logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)")
|
logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)")
|
||||||
return # 如果不允许,直接返回
|
return # 如果不允许,直接返回
|
||||||
@@ -766,7 +767,7 @@ class SubHeartflowManager:
|
|||||||
focused_limit = current_mai_state.get_focused_chat_max_num()
|
focused_limit = current_mai_state.get_focused_chat_max_num()
|
||||||
|
|
||||||
# --- 检查是否允许 FOCUS 模式 --- #
|
# --- 检查是否允许 FOCUS 模式 --- #
|
||||||
if not global_config.allow_focus_mode:
|
if not global_config.chat.allow_focus_mode:
|
||||||
# Log less frequently to avoid spam
|
# Log less frequently to avoid spam
|
||||||
# if int(time.time()) % 60 == 0:
|
# if int(time.time()) % 60 == 0:
|
||||||
# logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态")
|
# logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态")
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import jieba
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from ...common.database import db
|
from ...common.database.database import memory_db as db
|
||||||
from ...chat.models.utils_model import LLMRequest
|
from ...chat.models.utils_model import LLMRequest
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||||
@@ -19,9 +19,10 @@ from ..utils.chat_message_builder import (
|
|||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
) # 导入 build_readable_messages
|
) # 导入 build_readable_messages
|
||||||
from ..utils.utils import translate_timestamp_to_human_readable
|
from ..utils.utils import translate_timestamp_to_human_readable
|
||||||
from .memory_config import MemoryConfig
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from ...config.config import global_config
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
@@ -195,18 +196,16 @@ class Hippocampus:
|
|||||||
self.llm_summary = None
|
self.llm_summary = None
|
||||||
self.entorhinal_cortex = None
|
self.entorhinal_cortex = None
|
||||||
self.parahippocampal_gyrus = None
|
self.parahippocampal_gyrus = None
|
||||||
self.config = None
|
|
||||||
|
|
||||||
def initialize(self, global_config):
|
def initialize(self):
|
||||||
# 使用导入的 MemoryConfig dataclass 和其 from_global_config 方法
|
|
||||||
self.config = MemoryConfig.from_global_config(global_config)
|
|
||||||
# 初始化子组件
|
# 初始化子组件
|
||||||
self.entorhinal_cortex = EntorhinalCortex(self)
|
self.entorhinal_cortex = EntorhinalCortex(self)
|
||||||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||||
# 从数据库加载记忆图
|
# 从数据库加载记忆图
|
||||||
self.entorhinal_cortex.sync_memory_from_db()
|
self.entorhinal_cortex.sync_memory_from_db()
|
||||||
self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory")
|
# TODO: API-Adapter修改标记
|
||||||
self.llm_summary = LLMRequest(self.config.llm_summary, request_type="memory")
|
self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory")
|
||||||
|
self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory")
|
||||||
|
|
||||||
def get_all_node_names(self) -> list:
|
def get_all_node_names(self) -> list:
|
||||||
"""获取记忆图中所有节点的名字列表"""
|
"""获取记忆图中所有节点的名字列表"""
|
||||||
@@ -792,7 +791,6 @@ class EntorhinalCortex:
|
|||||||
def __init__(self, hippocampus: Hippocampus):
|
def __init__(self, hippocampus: Hippocampus):
|
||||||
self.hippocampus = hippocampus
|
self.hippocampus = hippocampus
|
||||||
self.memory_graph = hippocampus.memory_graph
|
self.memory_graph = hippocampus.memory_graph
|
||||||
self.config = hippocampus.config
|
|
||||||
|
|
||||||
def get_memory_sample(self):
|
def get_memory_sample(self):
|
||||||
"""从数据库获取记忆样本"""
|
"""从数据库获取记忆样本"""
|
||||||
@@ -801,13 +799,13 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
# 创建双峰分布的记忆调度器
|
# 创建双峰分布的记忆调度器
|
||||||
sample_scheduler = MemoryBuildScheduler(
|
sample_scheduler = MemoryBuildScheduler(
|
||||||
n_hours1=self.config.memory_build_distribution[0],
|
n_hours1=global_config.memory.memory_build_distribution[0],
|
||||||
std_hours1=self.config.memory_build_distribution[1],
|
std_hours1=global_config.memory.memory_build_distribution[1],
|
||||||
weight1=self.config.memory_build_distribution[2],
|
weight1=global_config.memory.memory_build_distribution[2],
|
||||||
n_hours2=self.config.memory_build_distribution[3],
|
n_hours2=global_config.memory.memory_build_distribution[3],
|
||||||
std_hours2=self.config.memory_build_distribution[4],
|
std_hours2=global_config.memory.memory_build_distribution[4],
|
||||||
weight2=self.config.memory_build_distribution[5],
|
weight2=global_config.memory.memory_build_distribution[5],
|
||||||
total_samples=self.config.build_memory_sample_num,
|
total_samples=global_config.memory.memory_build_sample_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
timestamps = sample_scheduler.get_timestamp_array()
|
timestamps = sample_scheduler.get_timestamp_array()
|
||||||
@@ -818,7 +816,7 @@ class EntorhinalCortex:
|
|||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
# 调用修改后的 random_get_msg_snippet
|
# 调用修改后的 random_get_msg_snippet
|
||||||
messages = self.random_get_msg_snippet(
|
messages = self.random_get_msg_snippet(
|
||||||
timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg
|
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
|
||||||
)
|
)
|
||||||
if messages:
|
if messages:
|
||||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||||
@@ -1099,7 +1097,6 @@ class ParahippocampalGyrus:
|
|||||||
def __init__(self, hippocampus: Hippocampus):
|
def __init__(self, hippocampus: Hippocampus):
|
||||||
self.hippocampus = hippocampus
|
self.hippocampus = hippocampus
|
||||||
self.memory_graph = hippocampus.memory_graph
|
self.memory_graph = hippocampus.memory_graph
|
||||||
self.config = hippocampus.config
|
|
||||||
|
|
||||||
async def memory_compress(self, messages: list, compress_rate=0.1):
|
async def memory_compress(self, messages: list, compress_rate=0.1):
|
||||||
"""压缩和总结消息内容,生成记忆主题和摘要。
|
"""压缩和总结消息内容,生成记忆主题和摘要。
|
||||||
@@ -1159,7 +1156,7 @@ class ParahippocampalGyrus:
|
|||||||
|
|
||||||
# 3. 过滤掉包含禁用关键词的topic
|
# 3. 过滤掉包含禁用关键词的topic
|
||||||
filtered_topics = [
|
filtered_topics = [
|
||||||
topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words)
|
topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.debug(f"过滤后话题: {filtered_topics}")
|
logger.debug(f"过滤后话题: {filtered_topics}")
|
||||||
@@ -1222,7 +1219,7 @@ class ParahippocampalGyrus:
|
|||||||
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
bar = "█" * filled_length + "-" * (bar_length - filled_length)
|
||||||
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
||||||
|
|
||||||
compress_rate = self.config.memory_compress_rate
|
compress_rate = global_config.memory.memory_compress_rate
|
||||||
try:
|
try:
|
||||||
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1322,7 +1319,7 @@ class ParahippocampalGyrus:
|
|||||||
edge_data = self.memory_graph.G[source][target]
|
edge_data = self.memory_graph.G[source][target]
|
||||||
last_modified = edge_data.get("last_modified")
|
last_modified = edge_data.get("last_modified")
|
||||||
|
|
||||||
if current_time - last_modified > 3600 * self.config.memory_forget_time:
|
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
|
||||||
current_strength = edge_data.get("strength", 1)
|
current_strength = edge_data.get("strength", 1)
|
||||||
new_strength = current_strength - 1
|
new_strength = current_strength - 1
|
||||||
|
|
||||||
@@ -1430,8 +1427,8 @@ class ParahippocampalGyrus:
|
|||||||
async def operation_consolidate_memory(self):
|
async def operation_consolidate_memory(self):
|
||||||
"""整合记忆:合并节点内相似的记忆项"""
|
"""整合记忆:合并节点内相似的记忆项"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
percentage = self.config.consolidate_memory_percentage
|
percentage = global_config.memory.consolidate_memory_percentage
|
||||||
similarity_threshold = self.config.consolidation_similarity_threshold
|
similarity_threshold = global_config.memory.consolidation_similarity_threshold
|
||||||
logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}")
|
logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}")
|
||||||
|
|
||||||
# 获取所有至少有2条记忆项的节点
|
# 获取所有至少有2条记忆项的节点
|
||||||
@@ -1544,7 +1541,6 @@ class ParahippocampalGyrus:
|
|||||||
class HippocampusManager:
|
class HippocampusManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
_hippocampus = None
|
_hippocampus = None
|
||||||
_global_config = None
|
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1559,19 +1555,15 @@ class HippocampusManager:
|
|||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
return cls._hippocampus
|
return cls._hippocampus
|
||||||
|
|
||||||
def initialize(self, global_config):
|
def initialize(self):
|
||||||
"""初始化海马体实例"""
|
"""初始化海马体实例"""
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return self._hippocampus
|
return self._hippocampus
|
||||||
|
|
||||||
self._global_config = global_config
|
|
||||||
self._hippocampus = Hippocampus()
|
self._hippocampus = Hippocampus()
|
||||||
self._hippocampus.initialize(global_config)
|
self._hippocampus.initialize()
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
# 输出记忆系统参数信息
|
|
||||||
config = self._hippocampus.config
|
|
||||||
|
|
||||||
# 输出记忆图统计信息
|
# 输出记忆图统计信息
|
||||||
memory_graph = self._hippocampus.memory_graph.G
|
memory_graph = self._hippocampus.memory_graph.G
|
||||||
node_count = len(memory_graph.nodes())
|
node_count = len(memory_graph.nodes())
|
||||||
@@ -1579,9 +1571,9 @@ class HippocampusManager:
|
|||||||
|
|
||||||
logger.success(f"""--------------------------------
|
logger.success(f"""--------------------------------
|
||||||
记忆系统参数配置:
|
记忆系统参数配置:
|
||||||
构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate}
|
构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
|
||||||
记忆构建分布: {config.memory_build_distribution}
|
记忆构建分布: {global_config.memory.memory_build_distribution}
|
||||||
遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后
|
遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
|
||||||
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
|
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
|
||||||
--------------------------------""") # noqa: E501
|
--------------------------------""") # noqa: E501
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import os
|
|||||||
# 添加项目根目录到系统路径
|
# 添加项目根目录到系统路径
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
||||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.config.config import global_config
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -19,7 +18,7 @@ async def test_memory_system():
|
|||||||
# 初始化记忆系统
|
# 初始化记忆系统
|
||||||
print("开始初始化记忆系统...")
|
print("开始初始化记忆系统...")
|
||||||
hippocampus_manager = HippocampusManager.get_instance()
|
hippocampus_manager = HippocampusManager.get_instance()
|
||||||
hippocampus_manager.initialize(global_config=global_config)
|
hippocampus_manager.initialize()
|
||||||
print("记忆系统初始化完成")
|
print("记忆系统初始化完成")
|
||||||
|
|
||||||
# 测试记忆构建
|
# 测试记忆构建
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
|||||||
sys.path.append(root_path)
|
sys.path.append(root_path)
|
||||||
|
|
||||||
from src.common.logger import get_module_logger # noqa E402
|
from src.common.logger import get_module_logger # noqa E402
|
||||||
from src.common.database import db # noqa E402
|
from common.database.database import db # noqa E402
|
||||||
|
|
||||||
logger = get_module_logger("mem_alter")
|
logger = get_module_logger("mem_alter")
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MemoryConfig:
|
|
||||||
"""记忆系统配置类"""
|
|
||||||
|
|
||||||
# 记忆构建相关配置
|
|
||||||
memory_build_distribution: List[float] # 记忆构建的时间分布参数
|
|
||||||
build_memory_sample_num: int # 每次构建记忆的样本数量
|
|
||||||
build_memory_sample_length: int # 每个样本的消息长度
|
|
||||||
memory_compress_rate: float # 记忆压缩率
|
|
||||||
|
|
||||||
# 记忆遗忘相关配置
|
|
||||||
memory_forget_time: int # 记忆遗忘时间(小时)
|
|
||||||
|
|
||||||
# 记忆过滤相关配置
|
|
||||||
memory_ban_words: List[str] # 记忆过滤词列表
|
|
||||||
|
|
||||||
# 新增:记忆整合相关配置
|
|
||||||
consolidation_similarity_threshold: float # 相似度阈值
|
|
||||||
consolidate_memory_percentage: float # 检查节点比例
|
|
||||||
consolidate_memory_interval: int # 记忆整合间隔
|
|
||||||
|
|
||||||
llm_topic_judge: str # 话题判断模型
|
|
||||||
llm_summary: str # 话题总结模型
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_global_config(cls, global_config):
|
|
||||||
"""从全局配置创建记忆系统配置"""
|
|
||||||
# 使用 getattr 提供默认值,防止全局配置缺少这些项
|
|
||||||
return cls(
|
|
||||||
memory_build_distribution=getattr(
|
|
||||||
global_config, "memory_build_distribution", (24, 12, 0.5, 168, 72, 0.5)
|
|
||||||
), # 添加默认值
|
|
||||||
build_memory_sample_num=getattr(global_config, "build_memory_sample_num", 5),
|
|
||||||
build_memory_sample_length=getattr(global_config, "build_memory_sample_length", 30),
|
|
||||||
memory_compress_rate=getattr(global_config, "memory_compress_rate", 0.1),
|
|
||||||
memory_forget_time=getattr(global_config, "memory_forget_time", 24 * 7),
|
|
||||||
memory_ban_words=getattr(global_config, "memory_ban_words", []),
|
|
||||||
# 新增加载整合配置,并提供默认值
|
|
||||||
consolidation_similarity_threshold=getattr(global_config, "consolidation_similarity_threshold", 0.7),
|
|
||||||
consolidate_memory_percentage=getattr(global_config, "consolidate_memory_percentage", 0.01),
|
|
||||||
consolidate_memory_interval=getattr(global_config, "consolidate_memory_interval", 1000),
|
|
||||||
llm_topic_judge=getattr(global_config, "llm_topic_judge", "default_judge_model"), # 添加默认模型名
|
|
||||||
llm_summary=getattr(global_config, "llm_summary", "default_summary_model"), # 添加默认模型名
|
|
||||||
)
|
|
||||||
@@ -41,7 +41,7 @@ class ChatBot:
|
|||||||
chat_id = str(message.chat_stream.stream_id)
|
chat_id = str(message.chat_stream.stream_id)
|
||||||
private_name = str(message.message_info.user_info.user_nickname)
|
private_name = str(message.message_info.user_info.user_nickname)
|
||||||
|
|
||||||
if global_config.enable_pfc_chatting:
|
if global_config.experimental.enable_pfc_chatting:
|
||||||
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -78,19 +78,19 @@ class ChatBot:
|
|||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
|
|
||||||
# 用户黑名单拦截
|
# 用户黑名单拦截
|
||||||
if userinfo.user_id in global_config.ban_user_id:
|
if userinfo.user_id in global_config.chat_target.ban_user_id:
|
||||||
logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
||||||
return
|
return
|
||||||
|
|
||||||
if groupinfo is None:
|
if groupinfo is None:
|
||||||
logger.trace("检测到私聊消息,检查")
|
logger.trace("检测到私聊消息,检查")
|
||||||
# 好友黑名单拦截
|
# 好友黑名单拦截
|
||||||
if userinfo.user_id not in global_config.talk_allowed_private:
|
if userinfo.user_id not in global_config.experimental.talk_allowed_private:
|
||||||
logger.debug(f"用户{userinfo.user_id}没有私聊权限")
|
logger.debug(f"用户{userinfo.user_id}没有私聊权限")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 群聊黑名单拦截
|
# 群聊黑名单拦截
|
||||||
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
|
if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups:
|
||||||
logger.trace(f"群{groupinfo.group_id}被禁止回复")
|
logger.trace(f"群{groupinfo.group_id}被禁止回复")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -112,7 +112,7 @@ class ChatBot:
|
|||||||
if groupinfo is None:
|
if groupinfo is None:
|
||||||
logger.trace("检测到私聊消息")
|
logger.trace("检测到私聊消息")
|
||||||
# 是否在配置信息中开启私聊模式
|
# 是否在配置信息中开启私聊模式
|
||||||
if global_config.enable_friend_chat:
|
if global_config.experimental.enable_friend_chat:
|
||||||
logger.trace("私聊模式已启用")
|
logger.trace("私聊模式已启用")
|
||||||
# 是否进入PFC
|
# 是否进入PFC
|
||||||
if global_config.enable_pfc_chatting:
|
if global_config.enable_pfc_chatting:
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ import copy
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database.database import db
|
||||||
|
from ...common.database.database_model import ChatStreams # 新增导入
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
@@ -82,7 +83,13 @@ class ChatManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||||
self._ensure_collection()
|
try:
|
||||||
|
db.connect(reuse_if_open=True)
|
||||||
|
# 确保 ChatStreams 表存在
|
||||||
|
db.create_tables([ChatStreams], safe=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
# 在事件循环中启动初始化
|
# 在事件循环中启动初始化
|
||||||
# asyncio.create_task(self._initialize())
|
# asyncio.create_task(self._initialize())
|
||||||
@@ -107,15 +114,6 @@ class ChatManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _ensure_collection():
|
|
||||||
"""确保数据库集合存在并创建索引"""
|
|
||||||
if "chat_streams" not in db.list_collection_names():
|
|
||||||
db.create_collection("chat_streams")
|
|
||||||
# 创建索引
|
|
||||||
db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
|
||||||
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||||
"""生成聊天流唯一ID"""
|
"""生成聊天流唯一ID"""
|
||||||
@@ -151,16 +149,43 @@ class ChatManager:
|
|||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
# 更新用户信息和群组信息
|
# 更新用户信息和群组信息
|
||||||
stream.update_active_time()
|
stream.update_active_time()
|
||||||
stream = copy.deepcopy(stream)
|
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||||||
stream.user_info = user_info
|
stream.user_info = user_info
|
||||||
if group_info:
|
if group_info:
|
||||||
stream.group_info = group_info
|
stream.group_info = group_info
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
# 检查数据库中是否存在
|
# 检查数据库中是否存在
|
||||||
data = db.chat_streams.find_one({"stream_id": stream_id})
|
def _db_find_stream_sync(s_id: str):
|
||||||
if data:
|
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
|
||||||
stream = ChatStream.from_dict(data)
|
|
||||||
|
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||||
|
|
||||||
|
if model_instance:
|
||||||
|
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||||
|
user_info_data = {
|
||||||
|
"platform": model_instance.user_platform,
|
||||||
|
"user_id": model_instance.user_id,
|
||||||
|
"user_nickname": model_instance.user_nickname,
|
||||||
|
"user_cardname": model_instance.user_cardname or "",
|
||||||
|
}
|
||||||
|
group_info_data = None
|
||||||
|
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
|
||||||
|
group_info_data = {
|
||||||
|
"platform": model_instance.group_platform,
|
||||||
|
"group_id": model_instance.group_id,
|
||||||
|
"group_name": model_instance.group_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
data_for_from_dict = {
|
||||||
|
"stream_id": model_instance.stream_id,
|
||||||
|
"platform": model_instance.platform,
|
||||||
|
"user_info": user_info_data,
|
||||||
|
"group_info": group_info_data,
|
||||||
|
"create_time": model_instance.create_time,
|
||||||
|
"last_active_time": model_instance.last_active_time,
|
||||||
|
}
|
||||||
|
stream = ChatStream.from_dict(data_for_from_dict)
|
||||||
# 更新用户信息和群组信息
|
# 更新用户信息和群组信息
|
||||||
stream.user_info = user_info
|
stream.user_info = user_info
|
||||||
if group_info:
|
if group_info:
|
||||||
@@ -175,7 +200,7 @@ class ChatManager:
|
|||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建聊天流失败: {e}")
|
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# 保存到内存和数据库
|
# 保存到内存和数据库
|
||||||
@@ -205,15 +230,38 @@ class ChatManager:
|
|||||||
elif stream.user_info and stream.user_info.user_nickname:
|
elif stream.user_info and stream.user_info.user_nickname:
|
||||||
return f"{stream.user_info.user_nickname}的私聊"
|
return f"{stream.user_info.user_nickname}的私聊"
|
||||||
else:
|
else:
|
||||||
# 如果没有群名或用户昵称,返回 None 或其他默认值
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _save_stream(stream: ChatStream):
|
async def _save_stream(stream: ChatStream):
|
||||||
"""保存聊天流到数据库"""
|
"""保存聊天流到数据库"""
|
||||||
if not stream.saved:
|
if not stream.saved:
|
||||||
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)
|
stream_data_dict = stream.to_dict()
|
||||||
|
|
||||||
|
def _db_save_stream_sync(s_data_dict: dict):
|
||||||
|
user_info_d = s_data_dict.get("user_info")
|
||||||
|
group_info_d = s_data_dict.get("group_info")
|
||||||
|
|
||||||
|
fields_to_save = {
|
||||||
|
"platform": s_data_dict["platform"],
|
||||||
|
"create_time": s_data_dict["create_time"],
|
||||||
|
"last_active_time": s_data_dict["last_active_time"],
|
||||||
|
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||||
|
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||||
|
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||||
|
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||||
|
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||||
|
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||||
|
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||||||
|
|
||||||
async def _save_all_streams(self):
|
async def _save_all_streams(self):
|
||||||
"""保存所有聊天流"""
|
"""保存所有聊天流"""
|
||||||
@@ -222,10 +270,44 @@ class ChatManager:
|
|||||||
|
|
||||||
async def load_all_streams(self):
|
async def load_all_streams(self):
|
||||||
"""从数据库加载所有聊天流"""
|
"""从数据库加载所有聊天流"""
|
||||||
all_streams = db.chat_streams.find({})
|
|
||||||
for data in all_streams:
|
def _db_load_all_streams_sync():
|
||||||
|
loaded_streams_data = []
|
||||||
|
for model_instance in ChatStreams.select():
|
||||||
|
user_info_data = {
|
||||||
|
"platform": model_instance.user_platform,
|
||||||
|
"user_id": model_instance.user_id,
|
||||||
|
"user_nickname": model_instance.user_nickname,
|
||||||
|
"user_cardname": model_instance.user_cardname or "",
|
||||||
|
}
|
||||||
|
group_info_data = None
|
||||||
|
if model_instance.group_id:
|
||||||
|
group_info_data = {
|
||||||
|
"platform": model_instance.group_platform,
|
||||||
|
"group_id": model_instance.group_id,
|
||||||
|
"group_name": model_instance.group_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
data_for_from_dict = {
|
||||||
|
"stream_id": model_instance.stream_id,
|
||||||
|
"platform": model_instance.platform,
|
||||||
|
"user_info": user_info_data,
|
||||||
|
"group_info": group_info_data,
|
||||||
|
"create_time": model_instance.create_time,
|
||||||
|
"last_active_time": model_instance.last_active_time,
|
||||||
|
}
|
||||||
|
loaded_streams_data.append(data_for_from_dict)
|
||||||
|
return loaded_streams_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
|
||||||
|
self.streams.clear()
|
||||||
|
for data in all_streams_data_list:
|
||||||
stream = ChatStream.from_dict(data)
|
stream = ChatStream.from_dict(data)
|
||||||
|
stream.saved = True
|
||||||
self.streams[stream.stream_id] = stream
|
self.streams[stream.stream_id] = stream
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局单例
|
# 创建全局单例
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class MessageBuffer:
|
|||||||
|
|
||||||
async def start_caching_messages(self, message: MessageRecv):
|
async def start_caching_messages(self, message: MessageRecv):
|
||||||
"""添加消息,启动缓冲"""
|
"""添加消息,启动缓冲"""
|
||||||
if not global_config.message_buffer:
|
if not global_config.chat.message_buffer:
|
||||||
person_id = person_info_manager.get_person_id(
|
person_id = person_info_manager.get_person_id(
|
||||||
message.message_info.user_info.platform, message.message_info.user_info.user_id
|
message.message_info.user_info.platform, message.message_info.user_info.user_id
|
||||||
)
|
)
|
||||||
@@ -107,7 +107,7 @@ class MessageBuffer:
|
|||||||
|
|
||||||
async def query_buffer_result(self, message: MessageRecv) -> bool:
|
async def query_buffer_result(self, message: MessageRecv) -> bool:
|
||||||
"""查询缓冲结果,并清理"""
|
"""查询缓冲结果,并清理"""
|
||||||
if not global_config.message_buffer:
|
if not global_config.chat.message_buffer:
|
||||||
return True
|
return True
|
||||||
person_id_ = self.get_person_id_(
|
person_id_ = self.get_person_id_(
|
||||||
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
|
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ class MessageManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否超时
|
# 检查是否超时
|
||||||
if thinking_time > global_config.thinking_timeout:
|
if thinking_time > global_config.normal_chat.thinking_timeout:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}"
|
f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from ...common.database import db
|
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
|
||||||
from .message import MessageSending, MessageRecv
|
from .message import MessageSending, MessageRecv
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
|
from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
logger = get_module_logger("message_storage")
|
logger = get_module_logger("message_storage")
|
||||||
@@ -29,34 +30,56 @@ class MessageStorage:
|
|||||||
else:
|
else:
|
||||||
filtered_detailed_plain_text = ""
|
filtered_detailed_plain_text = ""
|
||||||
|
|
||||||
message_data = {
|
chat_info_dict = chat_stream.to_dict()
|
||||||
"message_id": message.message_info.message_id,
|
user_info_dict = message.message_info.user_info.to_dict()
|
||||||
"time": message.message_info.time,
|
|
||||||
"chat_id": chat_stream.stream_id,
|
# message_id 现在是 TextField,直接使用字符串值
|
||||||
"chat_info": chat_stream.to_dict(),
|
msg_id = message.message_info.message_id
|
||||||
"user_info": message.message_info.user_info.to_dict(),
|
|
||||||
# 使用过滤后的文本
|
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||||
"processed_plain_text": filtered_processed_plain_text,
|
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
||||||
"detailed_plain_text": filtered_detailed_plain_text,
|
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||||
"memorized_times": message.memorized_times,
|
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||||
}
|
|
||||||
db.messages.insert_one(message_data)
|
Messages.create(
|
||||||
|
message_id=msg_id,
|
||||||
|
time=float(message.message_info.time),
|
||||||
|
chat_id=chat_stream.stream_id,
|
||||||
|
# Flattened chat_info
|
||||||
|
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
||||||
|
chat_info_platform=chat_info_dict.get("platform"),
|
||||||
|
chat_info_user_platform=user_info_from_chat.get("platform"),
|
||||||
|
chat_info_user_id=user_info_from_chat.get("user_id"),
|
||||||
|
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
|
||||||
|
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
|
||||||
|
chat_info_group_platform=group_info_from_chat.get("platform"),
|
||||||
|
chat_info_group_id=group_info_from_chat.get("group_id"),
|
||||||
|
chat_info_group_name=group_info_from_chat.get("group_name"),
|
||||||
|
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
|
||||||
|
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
|
||||||
|
# Flattened user_info (message sender)
|
||||||
|
user_platform=user_info_dict.get("platform"),
|
||||||
|
user_id=user_info_dict.get("user_id"),
|
||||||
|
user_nickname=user_info_dict.get("user_nickname"),
|
||||||
|
user_cardname=user_info_dict.get("user_cardname"),
|
||||||
|
# Text content
|
||||||
|
processed_plain_text=filtered_processed_plain_text,
|
||||||
|
detailed_plain_text=filtered_detailed_plain_text,
|
||||||
|
memorized_times=message.memorized_times,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储消息失败")
|
logger.exception("存储消息失败")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
|
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
|
||||||
"""存储撤回消息到数据库"""
|
"""存储撤回消息到数据库"""
|
||||||
if "recalled_messages" not in db.list_collection_names():
|
# Table creation is handled by initialize_database in database_model.py
|
||||||
db.create_collection("recalled_messages")
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
message_data = {
|
RecalledMessages.create(
|
||||||
"message_id": message_id,
|
message_id=message_id,
|
||||||
"time": time,
|
time=float(time), # Assuming time is a string representing a float timestamp
|
||||||
"stream_id": chat_stream.stream_id,
|
stream_id=chat_stream.stream_id,
|
||||||
}
|
)
|
||||||
db.recalled_messages.insert_one(message_data)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储撤回消息失败")
|
logger.exception("存储撤回消息失败")
|
||||||
|
|
||||||
@@ -64,7 +87,9 @@ class MessageStorage:
|
|||||||
async def remove_recalled_message(time: str) -> None:
|
async def remove_recalled_message(time: str) -> None:
|
||||||
"""删除撤回消息"""
|
"""删除撤回消息"""
|
||||||
try:
|
try:
|
||||||
db.recalled_messages.delete_many({"time": {"$lt": time - 300}})
|
# Assuming input 'time' is a string timestamp that can be converted to float
|
||||||
|
current_time_float = float(time)
|
||||||
|
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("删除撤回消息失败")
|
logger.exception("删除撤回消息失败")
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import base64
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
from ...common.database import db
|
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||||
|
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
@@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
|
|||||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||||
)
|
)
|
||||||
# if isinstance(content, str) and len(content) > 100:
|
|
||||||
# payload["messages"][0]["content"] = content[:100]
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
@@ -111,8 +110,8 @@ class LLMRequest:
|
|||||||
def __init__(self, model: dict, **kwargs):
|
def __init__(self, model: dict, **kwargs):
|
||||||
# 将大写的配置键转换为小写并从config中获取实际值
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
try:
|
try:
|
||||||
self.api_key = os.environ[model["key"]]
|
self.api_key = os.environ[f"{model['provider']}_KEY"]
|
||||||
self.base_url = os.environ[model["base_url"]]
|
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
logger.error(f"原始 model dict 信息:{model}")
|
logger.error(f"原始 model dict 信息:{model}")
|
||||||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||||
@@ -134,13 +133,11 @@ class LLMRequest:
|
|||||||
def _init_database():
|
def _init_database():
|
||||||
"""初始化数据库集合"""
|
"""初始化数据库集合"""
|
||||||
try:
|
try:
|
||||||
# 创建llm_usage集合的索引
|
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||||
db.llm_usage.create_index([("timestamp", 1)])
|
db.create_tables([LLMUsage], safe=True)
|
||||||
db.llm_usage.create_index([("model_name", 1)])
|
logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||||
db.llm_usage.create_index([("user_id", 1)])
|
|
||||||
db.llm_usage.create_index([("request_type", 1)])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建数据库索引失败: {str(e)}")
|
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||||
|
|
||||||
def _record_usage(
|
def _record_usage(
|
||||||
self,
|
self,
|
||||||
@@ -165,19 +162,19 @@ class LLMRequest:
|
|||||||
request_type = self.request_type
|
request_type = self.request_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
usage_data = {
|
# 使用 Peewee 模型创建记录
|
||||||
"model_name": self.model_name,
|
LLMUsage.create(
|
||||||
"user_id": user_id,
|
model_name=self.model_name,
|
||||||
"request_type": request_type,
|
user_id=user_id,
|
||||||
"endpoint": endpoint,
|
request_type=request_type,
|
||||||
"prompt_tokens": prompt_tokens,
|
endpoint=endpoint,
|
||||||
"completion_tokens": completion_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
"total_tokens": total_tokens,
|
completion_tokens=completion_tokens,
|
||||||
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
|
total_tokens=total_tokens,
|
||||||
"status": "success",
|
cost=self._calculate_cost(prompt_tokens, completion_tokens),
|
||||||
"timestamp": datetime.now(),
|
status="success",
|
||||||
}
|
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||||
db.llm_usage.insert_one(usage_data)
|
)
|
||||||
logger.trace(
|
logger.trace(
|
||||||
f"Token使用情况 - 模型: {self.model_name}, "
|
f"Token使用情况 - 模型: {self.model_name}, "
|
||||||
f"用户: {user_id}, 类型: {request_type}, "
|
f"用户: {user_id}, 类型: {request_type}, "
|
||||||
@@ -500,11 +497,11 @@ class LLMRequest:
|
|||||||
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
|
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
|
||||||
|
|
||||||
# 对全局配置进行更新
|
# 对全局配置进行更新
|
||||||
if global_config.llm_normal.get("name") == old_model_name:
|
if global_config.model.normal.get("name") == old_model_name:
|
||||||
global_config.llm_normal["name"] = self.model_name
|
global_config.model.normal["name"] = self.model_name
|
||||||
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
|
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
|
||||||
if global_config.llm_reasoning.get("name") == old_model_name:
|
if global_config.model.reasoning.get("name") == old_model_name:
|
||||||
global_config.llm_reasoning["name"] = self.model_name
|
global_config.model.reasoning["name"] = self.model_name
|
||||||
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
|
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
|
||||||
|
|
||||||
if payload and "model" in payload:
|
if payload and "model" in payload:
|
||||||
@@ -636,7 +633,7 @@ class LLMRequest:
|
|||||||
**params_copy,
|
**params_copy,
|
||||||
}
|
}
|
||||||
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||||||
payload["max_tokens"] = global_config.model_max_output_length
|
payload["max_tokens"] = global_config.model.model_max_output_length
|
||||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
|
|||||||
@@ -73,8 +73,8 @@ class NormalChat:
|
|||||||
messageinfo = message.message_info
|
messageinfo = message.message_info
|
||||||
|
|
||||||
bot_user_info = UserInfo(
|
bot_user_info = UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.bot.qq_account,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.bot.nickname,
|
||||||
platform=messageinfo.platform,
|
platform=messageinfo.platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -121,8 +121,8 @@ class NormalChat:
|
|||||||
message_id=thinking_id,
|
message_id=thinking_id,
|
||||||
chat_stream=self.chat_stream, # 使用 self.chat_stream
|
chat_stream=self.chat_stream, # 使用 self.chat_stream
|
||||||
bot_user_info=UserInfo(
|
bot_user_info=UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.bot.qq_account,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.bot.nickname,
|
||||||
platform=message.message_info.platform,
|
platform=message.message_info.platform,
|
||||||
),
|
),
|
||||||
sender_info=message.message_info.user_info,
|
sender_info=message.message_info.user_info,
|
||||||
@@ -147,7 +147,7 @@ class NormalChat:
|
|||||||
# 改为实例方法
|
# 改为实例方法
|
||||||
async def _handle_emoji(self, message: MessageRecv, response: str):
|
async def _handle_emoji(self, message: MessageRecv, response: str):
|
||||||
"""处理表情包"""
|
"""处理表情包"""
|
||||||
if random() < global_config.emoji_chance:
|
if random() < global_config.normal_chat.emoji_chance:
|
||||||
emoji_raw = await emoji_manager.get_emoji_for_text(response)
|
emoji_raw = await emoji_manager.get_emoji_for_text(response)
|
||||||
if emoji_raw:
|
if emoji_raw:
|
||||||
emoji_path, description = emoji_raw
|
emoji_path, description = emoji_raw
|
||||||
@@ -160,8 +160,8 @@ class NormalChat:
|
|||||||
message_id="mt" + str(thinking_time_point),
|
message_id="mt" + str(thinking_time_point),
|
||||||
chat_stream=self.chat_stream, # 使用 self.chat_stream
|
chat_stream=self.chat_stream, # 使用 self.chat_stream
|
||||||
bot_user_info=UserInfo(
|
bot_user_info=UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.bot.qq_account,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.bot.nickname,
|
||||||
platform=message.message_info.platform,
|
platform=message.message_info.platform,
|
||||||
),
|
),
|
||||||
sender_info=message.message_info.user_info,
|
sender_info=message.message_info.user_info,
|
||||||
@@ -186,7 +186,7 @@ class NormalChat:
|
|||||||
label=emotion,
|
label=emotion,
|
||||||
stance=stance, # 使用 self.chat_stream
|
stance=stance, # 使用 self.chat_stream
|
||||||
)
|
)
|
||||||
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
|
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
|
||||||
|
|
||||||
async def _reply_interested_message(self) -> None:
|
async def _reply_interested_message(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -430,7 +430,7 @@ class NormalChat:
|
|||||||
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||||
"""检查消息中是否包含过滤词"""
|
"""检查消息中是否包含过滤词"""
|
||||||
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
|
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
|
||||||
for word in global_config.ban_words:
|
for word in global_config.chat.ban_words:
|
||||||
if word in text:
|
if word in text:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
|
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
|
||||||
@@ -445,7 +445,7 @@ class NormalChat:
|
|||||||
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
|
||||||
"""检查消息是否匹配过滤正则表达式"""
|
"""检查消息是否匹配过滤正则表达式"""
|
||||||
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
|
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
|
||||||
for pattern in global_config.ban_msgs_regex:
|
for pattern in global_config.chat.ban_msgs_regex:
|
||||||
if pattern.search(text):
|
if pattern.search(text):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
|
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
|
||||||
|
|||||||
@@ -15,21 +15,22 @@ logger = get_logger("llm")
|
|||||||
|
|
||||||
class NormalChatGenerator:
|
class NormalChatGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.model_reasoning = LLMRequest(
|
self.model_reasoning = LLMRequest(
|
||||||
model=global_config.llm_reasoning,
|
model=global_config.model.reasoning,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=3000,
|
max_tokens=3000,
|
||||||
request_type="response_reasoning",
|
request_type="response_reasoning",
|
||||||
)
|
)
|
||||||
self.model_normal = LLMRequest(
|
self.model_normal = LLMRequest(
|
||||||
model=global_config.llm_normal,
|
model=global_config.model.normal,
|
||||||
temperature=global_config.llm_normal["temp"],
|
temperature=global_config.model.normal["temp"],
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
request_type="response_reasoning",
|
request_type="response_reasoning",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_sum = LLMRequest(
|
self.model_sum = LLMRequest(
|
||||||
model=global_config.llm_summary, temperature=0.7, max_tokens=3000, request_type="relation"
|
model=global_config.model.summary, temperature=0.7, max_tokens=3000, request_type="relation"
|
||||||
)
|
)
|
||||||
self.current_model_type = "r1" # 默认使用 R1
|
self.current_model_type = "r1" # 默认使用 R1
|
||||||
self.current_model_name = "unknown model"
|
self.current_model_name = "unknown model"
|
||||||
@@ -37,7 +38,7 @@ class NormalChatGenerator:
|
|||||||
async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
|
async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
# 从global_config中获取模型概率值并选择模型
|
# 从global_config中获取模型概率值并选择模型
|
||||||
if random.random() < global_config.model_reasoning_probability:
|
if random.random() < global_config.normal_chat.reasoning_model_probability:
|
||||||
self.current_model_type = "深深地"
|
self.current_model_type = "深深地"
|
||||||
current_model = self.model_reasoning
|
current_model = self.model_reasoning
|
||||||
else:
|
else:
|
||||||
@@ -51,7 +52,7 @@ class NormalChatGenerator:
|
|||||||
model_response = await self._generate_response_with_model(message, current_model, thinking_id)
|
model_response = await self._generate_response_with_model(message, current_model, thinking_id)
|
||||||
|
|
||||||
if model_response:
|
if model_response:
|
||||||
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
|
logger.info(f"{global_config.bot.nickname}的回复是:{model_response}")
|
||||||
model_response = await self._process_response(model_response)
|
model_response = await self._process_response(model_response)
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
@@ -113,7 +114,7 @@ class NormalChatGenerator:
|
|||||||
- "中立":不表达明确立场或无关回应
|
- "中立":不表达明确立场或无关回应
|
||||||
2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签
|
2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签
|
||||||
3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒"
|
3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒"
|
||||||
4. 考虑回复者的人格设定为{global_config.personality_core}
|
4. 考虑回复者的人格设定为{global_config.personality.personality_core}
|
||||||
|
|
||||||
对话示例:
|
对话示例:
|
||||||
被回复:「A就是笨」
|
被回复:「A就是笨」
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
from .willing_manager import BaseWillingManager
|
from .willing_manager import BaseWillingManager
|
||||||
|
|
||||||
|
|
||||||
class ClassicalWillingManager(BaseWillingManager):
|
class ClassicalWillingManager(BaseWillingManager):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._decay_task: asyncio.Task = None
|
self._decay_task: asyncio.Task | None = None
|
||||||
|
|
||||||
async def _decay_reply_willing(self):
|
async def _decay_reply_willing(self):
|
||||||
"""定期衰减回复意愿"""
|
"""定期衰减回复意愿"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
for chat_id in self.chat_reply_willing:
|
for chat_id in self.chat_reply_willing:
|
||||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.9)
|
self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9)
|
||||||
|
|
||||||
async def async_task_starter(self):
|
async def async_task_starter(self):
|
||||||
if self._decay_task is None:
|
if self._decay_task is None:
|
||||||
@@ -23,35 +25,33 @@ class ClassicalWillingManager(BaseWillingManager):
|
|||||||
chat_id = willing_info.chat_id
|
chat_id = willing_info.chat_id
|
||||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||||
|
|
||||||
interested_rate = willing_info.interested_rate * self.global_config.response_interested_rate_amplifier
|
interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
|
||||||
|
|
||||||
if interested_rate > 0.4:
|
if interested_rate > 0.4:
|
||||||
current_willing += interested_rate - 0.3
|
current_willing += interested_rate - 0.3
|
||||||
|
|
||||||
if willing_info.is_mentioned_bot and current_willing < 1.0:
|
if willing_info.is_mentioned_bot:
|
||||||
current_willing += 1
|
current_willing += 1 if current_willing < 1.0 else 0.05
|
||||||
elif willing_info.is_mentioned_bot:
|
|
||||||
current_willing += 0.05
|
|
||||||
|
|
||||||
is_emoji_not_reply = False
|
is_emoji_not_reply = False
|
||||||
if willing_info.is_emoji:
|
if willing_info.is_emoji:
|
||||||
if self.global_config.emoji_response_penalty != 0:
|
if global_config.normal_chat.emoji_response_penalty != 0:
|
||||||
current_willing *= self.global_config.emoji_response_penalty
|
current_willing *= global_config.normal_chat.emoji_response_penalty
|
||||||
else:
|
else:
|
||||||
is_emoji_not_reply = True
|
is_emoji_not_reply = True
|
||||||
|
|
||||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||||
|
|
||||||
reply_probability = min(
|
reply_probability = min(
|
||||||
max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1
|
max((current_willing - 0.5), 0.01) * global_config.normal_chat.response_willing_amplifier * 2, 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查群组权限(如果是群聊)
|
# 检查群组权限(如果是群聊)
|
||||||
if (
|
if (
|
||||||
willing_info.group_info
|
willing_info.group_info
|
||||||
and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups
|
and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups
|
||||||
):
|
):
|
||||||
reply_probability = reply_probability / self.global_config.down_frequency_rate
|
reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate
|
||||||
|
|
||||||
if is_emoji_not_reply:
|
if is_emoji_not_reply:
|
||||||
reply_probability = 0
|
reply_probability = 0
|
||||||
@@ -61,7 +61,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
|||||||
async def before_generate_reply_handle(self, message_id):
|
async def before_generate_reply_handle(self, message_id):
|
||||||
chat_id = self.ongoing_messages[message_id].chat_id
|
chat_id = self.ongoing_messages[message_id].chat_id
|
||||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||||
self.chat_reply_willing[chat_id] = max(0, current_willing - 1.8)
|
self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8)
|
||||||
|
|
||||||
async def after_generate_reply_handle(self, message_id):
|
async def after_generate_reply_handle(self, message_id):
|
||||||
if message_id not in self.ongoing_messages:
|
if message_id not in self.ongoing_messages:
|
||||||
@@ -70,7 +70,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
|||||||
chat_id = self.ongoing_messages[message_id].chat_id
|
chat_id = self.ongoing_messages[message_id].chat_id
|
||||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||||
if current_willing < 1:
|
if current_willing < 1:
|
||||||
self.chat_reply_willing[chat_id] = min(1, current_willing + 0.4)
|
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4)
|
||||||
|
|
||||||
async def bombing_buffer_message_handle(self, message_id):
|
async def bombing_buffer_message_handle(self, message_id):
|
||||||
return await super().bombing_buffer_message_handle(message_id)
|
return await super().bombing_buffer_message_handle(message_id)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ Mxp 模式:梦溪畔独家赞助
|
|||||||
下下策是询问一个菜鸟(@梦溪畔)
|
下下策是询问一个菜鸟(@梦溪畔)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
from .willing_manager import BaseWillingManager
|
from .willing_manager import BaseWillingManager
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -50,8 +51,6 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
|
|
||||||
self.mention_willing_gain = 0.6 # 提及意愿增益
|
self.mention_willing_gain = 0.6 # 提及意愿增益
|
||||||
self.interest_willing_gain = 0.3 # 兴趣意愿增益
|
self.interest_willing_gain = 0.3 # 兴趣意愿增益
|
||||||
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
|
|
||||||
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
|
|
||||||
self.single_chat_gain = 0.12 # 单聊增益
|
self.single_chat_gain = 0.12 # 单聊增益
|
||||||
|
|
||||||
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
|
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
|
||||||
@@ -179,10 +178,10 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
probability = self._willing_to_probability(current_willing)
|
probability = self._willing_to_probability(current_willing)
|
||||||
|
|
||||||
if w_info.is_emoji:
|
if w_info.is_emoji:
|
||||||
probability *= self.emoji_response_penalty
|
probability *= global_config.normal_chat.emoji_response_penalty
|
||||||
|
|
||||||
if w_info.group_info and w_info.group_info.group_id in self.global_config.talk_frequency_down_groups:
|
if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups:
|
||||||
probability /= self.down_frequency_rate
|
probability /= global_config.normal_chat.down_frequency_rate
|
||||||
|
|
||||||
self.temporary_willing = current_willing
|
self.temporary_willing = current_willing
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
|
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from src.config.config import global_config, BotConfig
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
|
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
@@ -93,7 +93,6 @@ class BaseWillingManager(ABC):
|
|||||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
|
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
|
||||||
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
|
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self.global_config: BotConfig = global_config
|
|
||||||
self.logger: LoguruLogger = logger
|
self.logger: LoguruLogger = logger
|
||||||
|
|
||||||
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
||||||
@@ -173,7 +172,7 @@ def init_willing_manager() -> BaseWillingManager:
|
|||||||
Returns:
|
Returns:
|
||||||
对应mode的WillingManager实例
|
对应mode的WillingManager实例
|
||||||
"""
|
"""
|
||||||
mode = global_config.willing_mode.lower()
|
mode = global_config.normal_chat.willing_mode.lower()
|
||||||
return BaseWillingManager.create(mode)
|
return BaseWillingManager.create(mode)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from ...common.database import db
|
from ...common.database.database import db
|
||||||
|
from ...common.database.database_model import PersonInfo # 新增导入
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict
|
||||||
@@ -16,7 +17,7 @@ matplotlib.use("Agg")
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import json
|
import json # 新增导入
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
@@ -38,47 +39,49 @@ logger = get_logger("person_info")
|
|||||||
|
|
||||||
person_info_default = {
|
person_info_default = {
|
||||||
"person_id": None,
|
"person_id": None,
|
||||||
"person_name": None,
|
"person_name": None, # 模型中已设为 null=True,此默认值OK
|
||||||
"name_reason": None,
|
"name_reason": None,
|
||||||
"platform": None,
|
"platform": "unknown", # 提供非None的默认值
|
||||||
"user_id": None,
|
"user_id": "unknown", # 提供非None的默认值
|
||||||
"nickname": None,
|
"nickname": "Unknown", # 提供非None的默认值
|
||||||
# "age" : 0,
|
|
||||||
"relationship_value": 0,
|
"relationship_value": 0,
|
||||||
# "saved" : True,
|
"know_time": 0, # 修正拼写:konw_time -> know_time
|
||||||
# "impression" : None,
|
|
||||||
# "gender" : Unkown,
|
|
||||||
"konw_time": 0,
|
|
||||||
"msg_interval": 2000,
|
"msg_interval": 2000,
|
||||||
"msg_interval_list": [],
|
"msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
|
||||||
"user_cardname": None, # 添加群名片
|
"user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
|
||||||
"user_avatar": None, # 添加头像信息(例如URL或标识符)
|
"user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
|
||||||
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
|
}
|
||||||
|
|
||||||
|
|
||||||
class PersonInfoManager:
|
class PersonInfoManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.person_name_list = {}
|
self.person_name_list = {}
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.qv_name_llm = LLMRequest(
|
self.qv_name_llm = LLMRequest(
|
||||||
model=global_config.llm_normal,
|
model=global_config.model.normal,
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
request_type="qv_name",
|
request_type="qv_name",
|
||||||
)
|
)
|
||||||
if "person_info" not in db.list_collection_names():
|
try:
|
||||||
db.create_collection("person_info")
|
db.connect(reuse_if_open=True)
|
||||||
db.person_info.create_index("person_id", unique=True)
|
db.create_tables([PersonInfo], safe=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
||||||
|
|
||||||
# 初始化时读取所有person_name
|
# 初始化时读取所有person_name
|
||||||
cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0})
|
try:
|
||||||
for doc in cursor:
|
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||||
if doc.get("person_name"):
|
PersonInfo.person_name.is_null(False)
|
||||||
self.person_name_list[doc["person_id"]] = doc["person_name"]
|
):
|
||||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
|
if record.person_name:
|
||||||
|
self.person_name_list[record.person_id] = record.person_name
|
||||||
|
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_person_id(platform: str, user_id: int):
|
def get_person_id(platform: str, user_id: int):
|
||||||
"""获取唯一id"""
|
"""获取唯一id"""
|
||||||
# 如果platform中存在-,就截取-后面的部分
|
|
||||||
if "-" in platform:
|
if "-" in platform:
|
||||||
platform = platform.split("-")[1]
|
platform = platform.split("-")[1]
|
||||||
|
|
||||||
@@ -86,13 +89,17 @@ class PersonInfoManager:
|
|||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
def is_person_known(self, platform: str, user_id: int):
|
async def is_person_known(self, platform: str, user_id: int):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
person_id = self.get_person_id(platform, user_id)
|
person_id = self.get_person_id(platform, user_id)
|
||||||
document = db.person_info.find_one({"person_id": person_id})
|
|
||||||
if document:
|
def _db_check_known_sync(p_id: str):
|
||||||
return True
|
return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
|
||||||
else:
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_db_check_known_sync, person_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_person_id_by_person_name(self, person_name: str):
|
def get_person_id_by_person_name(self, person_name: str):
|
||||||
@@ -111,73 +118,111 @@ class PersonInfoManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
_person_info_default = copy.deepcopy(person_info_default)
|
_person_info_default = copy.deepcopy(person_info_default)
|
||||||
_person_info_default["person_id"] = person_id
|
model_fields = PersonInfo._meta.fields.keys()
|
||||||
|
|
||||||
|
final_data = {"person_id": person_id}
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
for key in _person_info_default:
|
for key, value in data.items():
|
||||||
if key != "person_id" and key in data:
|
if key in model_fields:
|
||||||
_person_info_default[key] = data[key]
|
final_data[key] = value
|
||||||
|
|
||||||
db.person_info.insert_one(_person_info_default)
|
for key, default_value in _person_info_default.items():
|
||||||
|
if key in model_fields and key not in final_data:
|
||||||
|
final_data[key] = default_value
|
||||||
|
|
||||||
|
if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
|
||||||
|
final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
|
||||||
|
elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
|
||||||
|
final_data["msg_interval_list"] = json.dumps([])
|
||||||
|
|
||||||
|
def _db_create_sync(p_data: dict):
|
||||||
|
try:
|
||||||
|
PersonInfo.create(**p_data)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
await asyncio.to_thread(_db_create_sync, final_data)
|
||||||
|
|
||||||
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
|
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
|
||||||
"""更新某一个字段,会补全"""
|
"""更新某一个字段,会补全"""
|
||||||
if field_name not in person_info_default.keys():
|
if field_name not in PersonInfo._meta.fields:
|
||||||
logger.debug(f"更新'{field_name}'失败,未定义的字段")
|
if field_name in person_info_default:
|
||||||
|
logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
|
||||||
|
return
|
||||||
|
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
|
||||||
return
|
return
|
||||||
|
|
||||||
document = db.person_info.find_one({"person_id": person_id})
|
def _db_update_sync(p_id: str, f_name: str, val):
|
||||||
|
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||||
if document:
|
if record:
|
||||||
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
|
if f_name == "msg_interval_list" and isinstance(val, list):
|
||||||
|
setattr(record, f_name, json.dumps(val))
|
||||||
else:
|
else:
|
||||||
data[field_name] = value
|
setattr(record, f_name, val)
|
||||||
logger.debug(f"更新时{person_id}不存在,已新建")
|
record.save()
|
||||||
await self.create_person_info(person_id, data)
|
return True, False
|
||||||
|
return False, True
|
||||||
|
|
||||||
|
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
|
||||||
|
|
||||||
|
if needs_creation:
|
||||||
|
logger.debug(f"更新时 {person_id} 不存在,将新建。")
|
||||||
|
creation_data = data if data is not None else {}
|
||||||
|
creation_data[field_name] = value
|
||||||
|
if "platform" not in creation_data or "user_id" not in creation_data:
|
||||||
|
logger.warning(f"为 {person_id} 创建记录时,platform/user_id 可能缺失。")
|
||||||
|
|
||||||
|
await self.create_person_info(person_id, creation_data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def has_one_field(person_id: str, field_name: str):
|
async def has_one_field(person_id: str, field_name: str):
|
||||||
"""判断是否存在某一个字段"""
|
"""判断是否存在某一个字段"""
|
||||||
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
|
if field_name not in PersonInfo._meta.fields:
|
||||||
if document:
|
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _db_has_field_sync(p_id: str, f_name: str):
|
||||||
|
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||||
|
if record:
|
||||||
return True
|
return True
|
||||||
else:
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_json_from_text(text: str) -> dict:
|
def _extract_json_from_text(text: str) -> dict:
|
||||||
"""从文本中提取JSON数据的高容错方法"""
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
try:
|
try:
|
||||||
# 尝试直接解析
|
|
||||||
parsed_json = json.loads(text)
|
parsed_json = json.loads(text)
|
||||||
# 如果解析结果是列表,尝试取第一个元素
|
|
||||||
if isinstance(parsed_json, list):
|
if isinstance(parsed_json, list):
|
||||||
if parsed_json: # 检查列表是否为空
|
if parsed_json:
|
||||||
parsed_json = parsed_json[0]
|
parsed_json = parsed_json[0]
|
||||||
else: # 如果列表为空,重置为 None,走后续逻辑
|
else:
|
||||||
parsed_json = None
|
parsed_json = None
|
||||||
# 确保解析结果是字典
|
|
||||||
if isinstance(parsed_json, dict):
|
if isinstance(parsed_json, dict):
|
||||||
return parsed_json
|
return parsed_json
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# 解析失败,继续尝试其他方法
|
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
|
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
|
||||||
pass # 继续尝试其他方法
|
pass
|
||||||
|
|
||||||
# 如果直接解析失败或结果不是字典
|
|
||||||
try:
|
try:
|
||||||
# 尝试找到JSON对象格式的部分
|
|
||||||
json_pattern = r"\{[^{}]*\}"
|
json_pattern = r"\{[^{}]*\}"
|
||||||
matches = re.findall(json_pattern, text)
|
matches = re.findall(json_pattern, text)
|
||||||
if matches:
|
if matches:
|
||||||
parsed_obj = json.loads(matches[0])
|
parsed_obj = json.loads(matches[0])
|
||||||
if isinstance(parsed_obj, dict): # 确保是字典
|
if isinstance(parsed_obj, dict):
|
||||||
return parsed_obj
|
return parsed_obj
|
||||||
|
|
||||||
# 如果上面都失败了,尝试提取键值对
|
|
||||||
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
|
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
|
||||||
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
|
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
|
||||||
|
|
||||||
@@ -192,7 +237,6 @@ class PersonInfoManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"后备JSON提取失败: {str(e)}")
|
logger.error(f"后备JSON提取失败: {str(e)}")
|
||||||
|
|
||||||
# 如果所有方法都失败了,返回默认字典
|
|
||||||
logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
|
logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
|
||||||
return {"nickname": "", "reason": ""}
|
return {"nickname": "", "reason": ""}
|
||||||
|
|
||||||
@@ -207,9 +251,11 @@ class PersonInfoManager:
|
|||||||
old_name = await self.get_value(person_id, "person_name")
|
old_name = await self.get_value(person_id, "person_name")
|
||||||
old_reason = await self.get_value(person_id, "name_reason")
|
old_reason = await self.get_value(person_id, "name_reason")
|
||||||
|
|
||||||
max_retries = 5 # 最大重试次数
|
max_retries = 5
|
||||||
current_try = 0
|
current_try = 0
|
||||||
existing_names = ""
|
existing_names_str = ""
|
||||||
|
current_name_set = set(self.person_name_list.values())
|
||||||
|
|
||||||
while current_try < max_retries:
|
while current_try < max_retries:
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
prompt_personality = individuality.get_prompt(x_person=2, level=1)
|
prompt_personality = individuality.get_prompt(x_person=2, level=1)
|
||||||
@@ -224,45 +270,58 @@ class PersonInfoManager:
|
|||||||
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason},"
|
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason},"
|
||||||
|
|
||||||
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
|
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
|
||||||
|
|
||||||
qv_name_prompt += (
|
qv_name_prompt += (
|
||||||
"\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改"
|
"\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改"
|
||||||
)
|
)
|
||||||
if existing_names:
|
|
||||||
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n"
|
if existing_names_str:
|
||||||
|
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n"
|
||||||
|
|
||||||
|
if len(current_name_set) < 50 and current_name_set:
|
||||||
|
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
|
||||||
|
|
||||||
qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:"
|
qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:"
|
||||||
qv_name_prompt += """{
|
qv_name_prompt += """{
|
||||||
"nickname": "昵称",
|
"nickname": "昵称",
|
||||||
"reason": "理由"
|
"reason": "理由"
|
||||||
}"""
|
}"""
|
||||||
# logger.debug(f"取名提示词:{qv_name_prompt}")
|
|
||||||
response = await self.qv_name_llm.generate_response(qv_name_prompt)
|
response = await self.qv_name_llm.generate_response(qv_name_prompt)
|
||||||
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
|
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
|
||||||
result = self._extract_json_from_text(response[0])
|
result = self._extract_json_from_text(response[0])
|
||||||
|
|
||||||
if not result["nickname"]:
|
if not result or not result.get("nickname"):
|
||||||
logger.error("生成的昵称为空,重试中...")
|
logger.error("生成的昵称为空或结果格式不正确,重试中...")
|
||||||
current_try += 1
|
current_try += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查生成的昵称是否已存在
|
generated_nickname = result["nickname"]
|
||||||
if result["nickname"] not in self.person_name_list.values():
|
|
||||||
# 更新数据库和内存中的列表
|
|
||||||
await self.update_one_field(person_id, "person_name", result["nickname"])
|
|
||||||
# await self.update_one_field(person_id, "nickname", user_nickname)
|
|
||||||
# await self.update_one_field(person_id, "avatar", user_avatar)
|
|
||||||
await self.update_one_field(person_id, "name_reason", result["reason"])
|
|
||||||
|
|
||||||
self.person_name_list[person_id] = result["nickname"]
|
is_duplicate = False
|
||||||
# logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}")
|
if generated_nickname in current_name_set:
|
||||||
|
is_duplicate = True
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _db_check_name_exists_sync(name_to_check):
|
||||||
|
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
|
||||||
|
|
||||||
|
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||||
|
is_duplicate = True
|
||||||
|
current_name_set.add(generated_nickname)
|
||||||
|
|
||||||
|
if not is_duplicate:
|
||||||
|
await self.update_one_field(person_id, "person_name", generated_nickname)
|
||||||
|
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
|
||||||
|
|
||||||
|
self.person_name_list[person_id] = generated_nickname
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
existing_names += f"{result['nickname']}、"
|
if existing_names_str:
|
||||||
|
existing_names_str += "、"
|
||||||
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...")
|
existing_names_str += generated_nickname
|
||||||
|
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
|
||||||
current_try += 1
|
current_try += 1
|
||||||
|
|
||||||
logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称")
|
logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -272,30 +331,56 @@ class PersonInfoManager:
|
|||||||
logger.debug("删除失败:person_id 不能为空")
|
logger.debug("删除失败:person_id 不能为空")
|
||||||
return
|
return
|
||||||
|
|
||||||
result = db.person_info.delete_one({"person_id": person_id})
|
def _db_delete_sync(p_id: str):
|
||||||
if result.deleted_count > 0:
|
try:
|
||||||
logger.debug(f"删除成功:person_id={person_id}")
|
query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
|
||||||
|
deleted_count = query.execute()
|
||||||
|
return deleted_count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.debug(f"删除成功:person_id={person_id} (Peewee)")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"删除失败:未找到 person_id={person_id}")
|
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_value(person_id: str, field_name: str):
|
async def get_value(person_id: str, field_name: str):
|
||||||
"""获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值"""
|
"""获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||||
if not person_id:
|
if not person_id:
|
||||||
logger.debug("get_value获取失败:person_id不能为空")
|
logger.debug("get_value获取失败:person_id不能为空")
|
||||||
|
return person_info_default.get(field_name)
|
||||||
|
|
||||||
|
if field_name not in PersonInfo._meta.fields:
|
||||||
|
if field_name in person_info_default:
|
||||||
|
logger.trace(f"字段'{field_name}'不在Peewee模型中,但存在于默认配置中。返回配置默认值。")
|
||||||
|
return copy.deepcopy(person_info_default[field_name])
|
||||||
|
logger.debug(f"get_value获取失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if field_name not in person_info_default:
|
def _db_get_value_sync(p_id: str, f_name: str):
|
||||||
logger.debug(f"get_value获取失败:字段'{field_name}'未定义")
|
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||||
|
if record:
|
||||||
|
val = getattr(record, f_name)
|
||||||
|
if f_name == "msg_interval_list" and isinstance(val, str):
|
||||||
|
try:
|
||||||
|
return json.loads(val)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
|
||||||
|
return copy.deepcopy(person_info_default.get(f_name, []))
|
||||||
|
return val
|
||||||
return None
|
return None
|
||||||
|
|
||||||
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
|
value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
|
||||||
|
|
||||||
if document and field_name in document:
|
if value is not None:
|
||||||
return document[field_name]
|
return value
|
||||||
else:
|
else:
|
||||||
default_value = copy.deepcopy(person_info_default[field_name])
|
default_value = copy.deepcopy(person_info_default.get(field_name))
|
||||||
logger.trace(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}")
|
logger.trace(f"获取{person_id}的{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -305,93 +390,84 @@ class PersonInfoManager:
|
|||||||
logger.debug("get_values获取失败:person_id不能为空")
|
logger.debug("get_values获取失败:person_id不能为空")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# 检查所有字段是否有效
|
|
||||||
for field in field_names:
|
|
||||||
if field not in person_info_default:
|
|
||||||
logger.debug(f"get_values获取失败:字段'{field}'未定义")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# 构建查询投影(所有字段都有效才会执行到这里)
|
|
||||||
projection = {field: 1 for field in field_names}
|
|
||||||
|
|
||||||
document = db.person_info.find_one({"person_id": person_id}, projection)
|
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
for field in field_names:
|
|
||||||
result[field] = copy.deepcopy(
|
def _db_get_record_sync(p_id: str):
|
||||||
document.get(field, person_info_default[field]) if document else person_info_default[field]
|
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||||
)
|
|
||||||
|
record = await asyncio.to_thread(_db_get_record_sync, person_id)
|
||||||
|
|
||||||
|
for field_name in field_names:
|
||||||
|
if field_name not in PersonInfo._meta.fields:
|
||||||
|
if field_name in person_info_default:
|
||||||
|
result[field_name] = copy.deepcopy(person_info_default[field_name])
|
||||||
|
logger.trace(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。")
|
||||||
|
else:
|
||||||
|
logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
|
||||||
|
result[field_name] = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if record:
|
||||||
|
value = getattr(record, field_name)
|
||||||
|
if field_name == "msg_interval_list" and isinstance(value, str):
|
||||||
|
try:
|
||||||
|
result[field_name] = json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
|
||||||
|
result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
|
||||||
|
elif value is not None:
|
||||||
|
result[field_name] = value
|
||||||
|
else:
|
||||||
|
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||||
|
else:
|
||||||
|
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def del_all_undefined_field():
|
async def del_all_undefined_field():
|
||||||
"""删除所有项里的未定义字段"""
|
"""删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
|
||||||
# 获取所有已定义的字段名
|
logger.info(
|
||||||
defined_fields = set(person_info_default.keys())
|
"del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。"
|
||||||
|
|
||||||
try:
|
|
||||||
# 遍历集合中的所有文档
|
|
||||||
for document in db.person_info.find({}):
|
|
||||||
# 找出文档中未定义的字段
|
|
||||||
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
|
|
||||||
|
|
||||||
if undefined_fields:
|
|
||||||
# 构建更新操作,使用$unset删除未定义字段
|
|
||||||
update_result = db.person_info.update_one(
|
|
||||||
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if update_result.modified_count > 0:
|
|
||||||
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清理未定义字段时出错: {e}")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_specific_value_list(
|
async def get_specific_value_list(
|
||||||
field_name: str,
|
field_name: str,
|
||||||
way: Callable[[Any], bool], # 接受任意类型值
|
way: Callable[[Any], bool],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取满足条件的字段值字典
|
获取满足条件的字段值字典
|
||||||
|
|
||||||
Args:
|
|
||||||
field_name: 目标字段名
|
|
||||||
way: 判断函数 (value: Any) -> bool
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
{person_id: value} | {}
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# 查找所有nickname包含"admin"的用户
|
|
||||||
result = manager.specific_value_list(
|
|
||||||
"nickname",
|
|
||||||
lambda x: "admin" in x.lower()
|
|
||||||
)
|
|
||||||
"""
|
"""
|
||||||
if field_name not in person_info_default:
|
if field_name not in PersonInfo._meta.fields:
|
||||||
logger.error(f"字段检查失败:'{field_name}'未定义")
|
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def _db_get_specific_sync(f_name: str):
|
||||||
|
found_results = {}
|
||||||
try:
|
try:
|
||||||
result = {}
|
for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
|
||||||
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
|
value = getattr(record, f_name)
|
||||||
|
if f_name == "msg_interval_list" and isinstance(value, str):
|
||||||
try:
|
try:
|
||||||
value = doc[field_name]
|
processed_value = json.loads(value)
|
||||||
if way(value):
|
except json.JSONDecodeError:
|
||||||
result[doc["person_id"]] = value
|
logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
|
||||||
except (KeyError, TypeError, ValueError) as e:
|
|
||||||
logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}")
|
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
processed_value = value
|
||||||
|
|
||||||
return result
|
if way(processed_value):
|
||||||
|
found_results[record.person_id] = processed_value
|
||||||
|
except Exception as e_query:
|
||||||
|
logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
|
||||||
|
return found_results
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(_db_get_specific_sync, field_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
|
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def personal_habit_deduction(self):
|
async def personal_habit_deduction(self):
|
||||||
@@ -399,35 +475,31 @@ class PersonInfoManager:
|
|||||||
try:
|
try:
|
||||||
while 1:
|
while 1:
|
||||||
await asyncio.sleep(600)
|
await asyncio.sleep(600)
|
||||||
current_time = datetime.datetime.now()
|
current_time_dt = datetime.datetime.now()
|
||||||
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
# "msg_interval"推断
|
msg_interval_map_generated = False
|
||||||
msg_interval_map = False
|
msg_interval_lists_map = await self.get_specific_value_list(
|
||||||
msg_interval_lists = await self.get_specific_value_list(
|
|
||||||
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
|
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
|
||||||
)
|
)
|
||||||
for person_id, msg_interval_list_ in msg_interval_lists.items():
|
|
||||||
|
for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
try:
|
try:
|
||||||
time_interval = []
|
time_interval = []
|
||||||
for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]):
|
for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
|
||||||
delta = t2 - t1
|
delta = t2 - t1
|
||||||
if delta > 0:
|
if delta > 0:
|
||||||
time_interval.append(delta)
|
time_interval.append(delta)
|
||||||
|
|
||||||
time_interval = [t for t in time_interval if 200 <= t <= 8000]
|
time_interval = [t for t in time_interval if 200 <= t <= 8000]
|
||||||
# --- 修改后的逻辑 ---
|
|
||||||
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
|
|
||||||
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
|
|
||||||
time_interval.sort()
|
|
||||||
|
|
||||||
# 画图(log) - 这部分保留
|
if len(time_interval) >= 30 + 10:
|
||||||
msg_interval_map = True
|
time_interval.sort()
|
||||||
|
msg_interval_map_generated = True
|
||||||
log_dir = Path("logs/person_info")
|
log_dir = Path("logs/person_info")
|
||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
# 使用截断前的数据画图,更能反映原始分布
|
|
||||||
time_series_original = pd.Series(time_interval)
|
time_series_original = pd.Series(time_interval)
|
||||||
plt.hist(
|
plt.hist(
|
||||||
time_series_original,
|
time_series_original,
|
||||||
@@ -449,34 +521,29 @@ class PersonInfoManager:
|
|||||||
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
|
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
|
||||||
plt.savefig(img_path)
|
plt.savefig(img_path)
|
||||||
plt.close()
|
plt.close()
|
||||||
# 画图结束
|
|
||||||
|
|
||||||
# 去掉头尾各 5 个数据点
|
|
||||||
trimmed_interval = time_interval[5:-5]
|
trimmed_interval = time_interval[5:-5]
|
||||||
|
if trimmed_interval:
|
||||||
# 计算截断后数据的 37% 分位数
|
msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
|
||||||
if trimmed_interval: # 确保截断后列表不为空
|
await self.update_one_field(person_id, "msg_interval", msg_interval_val)
|
||||||
msg_interval = int(round(np.percentile(trimmed_interval, 37)))
|
logger.trace(
|
||||||
# 更新数据库
|
f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
|
||||||
await self.update_one_field(person_id, "msg_interval", msg_interval)
|
)
|
||||||
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
|
|
||||||
else:
|
else:
|
||||||
logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval")
|
logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval")
|
||||||
else:
|
else:
|
||||||
logger.trace(
|
logger.trace(
|
||||||
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
|
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
|
||||||
)
|
)
|
||||||
# --- 修改结束 ---
|
except Exception as e_inner:
|
||||||
except Exception as e:
|
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
|
||||||
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 其他...
|
if msg_interval_map_generated:
|
||||||
|
|
||||||
if msg_interval_map:
|
|
||||||
logger.trace("已保存分布图到: logs/person_info")
|
logger.trace("已保存分布图到: logs/person_info")
|
||||||
current_time = datetime.datetime.now()
|
|
||||||
logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
current_time_dt_end = datetime.datetime.now()
|
||||||
|
logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
await asyncio.sleep(86400)
|
await asyncio.sleep(86400)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -489,41 +556,27 @@ class PersonInfoManager:
|
|||||||
"""
|
"""
|
||||||
根据 platform 和 user_id 获取 person_id。
|
根据 platform 和 user_id 获取 person_id。
|
||||||
如果对应的用户不存在,则使用提供的可选信息创建新用户。
|
如果对应的用户不存在,则使用提供的可选信息创建新用户。
|
||||||
|
|
||||||
Args:
|
|
||||||
platform: 平台标识
|
|
||||||
user_id: 用户在该平台上的ID
|
|
||||||
nickname: 用户的昵称 (可选,用于创建新用户)
|
|
||||||
user_cardname: 用户的群名片 (可选,用于创建新用户)
|
|
||||||
user_avatar: 用户的头像信息 (可选,用于创建新用户)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对应的 person_id。
|
|
||||||
"""
|
"""
|
||||||
person_id = self.get_person_id(platform, user_id)
|
person_id = self.get_person_id(platform, user_id)
|
||||||
|
|
||||||
# 检查用户是否已存在
|
def _db_check_exists_sync(p_id: str):
|
||||||
# 使用静态方法 get_person_id,因此可以直接调用 db
|
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||||
document = db.person_info.find_one({"person_id": person_id})
|
|
||||||
|
|
||||||
if document is None:
|
record = await asyncio.to_thread(_db_check_exists_sync, person_id)
|
||||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
|
|
||||||
|
if record is None:
|
||||||
|
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
|
||||||
initial_data = {
|
initial_data = {
|
||||||
"platform": platform,
|
"platform": platform,
|
||||||
"user_id": user_id,
|
"user_id": str(user_id),
|
||||||
"nickname": nickname,
|
"nickname": nickname,
|
||||||
"konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间
|
"know_time": int(datetime.datetime.now().timestamp()), # 修正拼写:konw_time -> know_time
|
||||||
# 注意:这里没有添加 user_cardname 和 user_avatar,因为它们不在 person_info_default 中
|
|
||||||
# 如果需要存储它们,需要先在 person_info_default 中定义
|
|
||||||
}
|
}
|
||||||
# 过滤掉值为 None 的初始数据
|
model_fields = PersonInfo._meta.fields.keys()
|
||||||
initial_data = {k: v for k, v in initial_data.items() if v is not None}
|
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||||
|
|
||||||
# 注意:create_person_info 是静态方法
|
await self.create_person_info(person_id, data=filtered_initial_data)
|
||||||
await PersonInfoManager.create_person_info(person_id, data=initial_data)
|
logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
|
||||||
# 创建后,可以考虑立即为其取名,但这可能会增加延迟
|
|
||||||
# await self.qv_person_name(person_id, nickname, user_cardname, user_avatar)
|
|
||||||
logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}")
|
|
||||||
|
|
||||||
return person_id
|
return person_id
|
||||||
|
|
||||||
@@ -533,34 +586,54 @@ class PersonInfoManager:
|
|||||||
logger.debug("get_person_info_by_name 获取失败:person_name 不能为空")
|
logger.debug("get_person_info_by_name 获取失败:person_name 不能为空")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 优先从内存缓存查找 person_id
|
|
||||||
found_person_id = None
|
found_person_id = None
|
||||||
for pid, name in self.person_name_list.items():
|
for pid, name_in_cache in self.person_name_list.items():
|
||||||
if name == person_name:
|
if name_in_cache == person_name:
|
||||||
found_person_id = pid
|
found_person_id = pid
|
||||||
break # 找到第一个匹配就停止
|
break
|
||||||
|
|
||||||
if not found_person_id:
|
if not found_person_id:
|
||||||
# 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载)
|
|
||||||
document = db.person_info.find_one({"person_name": person_name})
|
|
||||||
if document:
|
|
||||||
found_person_id = document.get("person_id")
|
|
||||||
else:
|
|
||||||
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户")
|
|
||||||
return None # 数据库也找不到
|
|
||||||
|
|
||||||
# 根据找到的 person_id 获取所需信息
|
def _db_find_by_name_sync(p_name_to_find: str):
|
||||||
if found_person_id:
|
return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
|
||||||
required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"]
|
|
||||||
person_data = await self.get_values(found_person_id, required_fields)
|
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
|
||||||
if person_data: # 确保 get_values 成功返回
|
if record:
|
||||||
return person_data
|
found_person_id = record.person_id
|
||||||
|
if (
|
||||||
|
found_person_id not in self.person_name_list
|
||||||
|
or self.person_name_list[found_person_id] != person_name
|
||||||
|
):
|
||||||
|
self.person_name_list[found_person_id] = person_name
|
||||||
else:
|
else:
|
||||||
logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败")
|
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if found_person_id:
|
||||||
|
required_fields = [
|
||||||
|
"person_id",
|
||||||
|
"platform",
|
||||||
|
"user_id",
|
||||||
|
"nickname",
|
||||||
|
"user_cardname",
|
||||||
|
"user_avatar",
|
||||||
|
"person_name",
|
||||||
|
"name_reason",
|
||||||
|
]
|
||||||
|
valid_fields_to_get = [
|
||||||
|
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
|
||||||
|
]
|
||||||
|
|
||||||
|
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
||||||
|
|
||||||
|
if person_data:
|
||||||
|
final_result = {key: person_data.get(key) for key in required_fields}
|
||||||
|
return final_result
|
||||||
else:
|
else:
|
||||||
# 这理论上不应该发生,因为上面已经处理了找不到的情况
|
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
|
||||||
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id")
|
return None
|
||||||
|
|
||||||
|
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -190,8 +190,8 @@ async def _build_readable_messages_internal(
|
|||||||
|
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||||
if replace_bot_name and user_id == global_config.BOT_QQ:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
person_name = f"{global_config.BOT_NICKNAME}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
|
|
||||||
@@ -427,7 +427,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
output_lines = []
|
output_lines = []
|
||||||
|
|
||||||
def get_anon_name(platform, user_id):
|
def get_anon_name(platform, user_id):
|
||||||
if user_id == global_config.BOT_QQ:
|
if user_id == global_config.bot.qq_account:
|
||||||
return "SELF"
|
return "SELF"
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||||
if person_id not in person_map:
|
if person_id not in person_map:
|
||||||
@@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
def reply_replacer(match, platform=platform):
|
def reply_replacer(match, platform=platform):
|
||||||
# aaa = match.group(1)
|
# aaa = match.group(1)
|
||||||
bbb = match.group(2)
|
bbb = match.group(2)
|
||||||
anon_reply = get_anon_name(platform, bbb)
|
anon_reply = get_anon_name(platform, bbb) # noqa
|
||||||
return f"回复 {anon_reply}"
|
return f"回复 {anon_reply}"
|
||||||
|
|
||||||
content = re.sub(reply_pattern, reply_replacer, content, count=1)
|
content = re.sub(reply_pattern, reply_replacer, content, count=1)
|
||||||
@@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
def at_replacer(match, platform=platform):
|
def at_replacer(match, platform=platform):
|
||||||
# aaa = match.group(1)
|
# aaa = match.group(1)
|
||||||
bbb = match.group(2)
|
bbb = match.group(2)
|
||||||
anon_at = get_anon_name(platform, bbb)
|
anon_at = get_anon_name(platform, bbb) # noqa
|
||||||
return f"@{anon_at}"
|
return f"@{anon_at}"
|
||||||
|
|
||||||
content = re.sub(at_pattern, at_replacer, content)
|
content = re.sub(at_pattern, at_replacer, content)
|
||||||
@@ -501,7 +501,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
user_id = user_info.get("user_id")
|
user_id = user_info.get("user_id")
|
||||||
|
|
||||||
# 检查必要信息是否存在 且 不是机器人自己
|
# 检查必要信息是否存在 且 不是机器人自己
|
||||||
if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
|
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
|
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
|
||||||
from src.common.database import db
|
from src.common.database.database_model import Messages, ThinkingLog
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class InfoCatcher:
|
class InfoCatcher:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
|
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
|
||||||
self.context_length = global_config.observation_context_size
|
|
||||||
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
|
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
|
||||||
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
|
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
|
||||||
|
|
||||||
@@ -60,8 +60,6 @@ class InfoCatcher:
|
|||||||
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
|
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
|
||||||
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
||||||
|
|
||||||
# def catch_shf
|
|
||||||
|
|
||||||
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
|
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
|
||||||
self.timing_results["sub_heartflow_step_time"] = step_duration
|
self.timing_results["sub_heartflow_step_time"] = step_duration
|
||||||
if len(past_mind) > 1:
|
if len(past_mind) > 1:
|
||||||
@@ -72,25 +70,10 @@ class InfoCatcher:
|
|||||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||||
|
|
||||||
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
|
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
|
||||||
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
|
|
||||||
# self.heartflow_data["prompt"] = prompt
|
|
||||||
# self.heartflow_data["response"] = response
|
|
||||||
# self.heartflow_data["model"] = model_name
|
|
||||||
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
|
|
||||||
# self.reasoning_data["thinking_log"] = reasoning_content
|
|
||||||
# self.reasoning_data["prompt"] = prompt
|
|
||||||
# self.reasoning_data["response"] = response
|
|
||||||
# self.reasoning_data["model"] = model_name
|
|
||||||
|
|
||||||
# 直接记录信息喵~
|
|
||||||
self.reasoning_data["thinking_log"] = reasoning_content
|
self.reasoning_data["thinking_log"] = reasoning_content
|
||||||
self.reasoning_data["prompt"] = prompt
|
self.reasoning_data["prompt"] = prompt
|
||||||
self.reasoning_data["response"] = response
|
self.reasoning_data["response"] = response
|
||||||
self.reasoning_data["model"] = model_name
|
self.reasoning_data["model"] = model_name
|
||||||
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
|
|
||||||
# self.heartflow_data["prompt"] = prompt
|
|
||||||
# self.heartflow_data["response"] = response
|
|
||||||
# self.heartflow_data["model"] = model_name
|
|
||||||
|
|
||||||
self.response_text = response
|
self.response_text = response
|
||||||
|
|
||||||
@@ -102,6 +85,7 @@ class InfoCatcher:
|
|||||||
):
|
):
|
||||||
self.timing_results["make_response_time"] = response_duration
|
self.timing_results["make_response_time"] = response_duration
|
||||||
self.response_time = time.time()
|
self.response_time = time.time()
|
||||||
|
self.response_messages = []
|
||||||
for msg in response_message:
|
for msg in response_message:
|
||||||
self.response_messages.append(msg)
|
self.response_messages.append(msg)
|
||||||
|
|
||||||
@@ -112,107 +96,112 @@ class InfoCatcher:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
|
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
|
||||||
try:
|
try:
|
||||||
# 从数据库中获取消息的时间戳
|
|
||||||
time_start = message_start.message_info.time
|
time_start = message_start.message_info.time
|
||||||
time_end = message_end.message_info.time
|
time_end = message_end.message_info.time
|
||||||
chat_id = message_start.chat_stream.stream_id
|
chat_id = message_start.chat_stream.stream_id
|
||||||
|
|
||||||
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
||||||
|
|
||||||
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
|
messages_between_query = (
|
||||||
messages_between = db.messages.find(
|
Messages.select()
|
||||||
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
|
.where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
|
||||||
).sort("time", -1)
|
.order_by(Messages.time.desc())
|
||||||
|
)
|
||||||
|
|
||||||
result = list(messages_between)
|
result = list(messages_between_query)
|
||||||
print(f"查询结果数量: {len(result)}")
|
print(f"查询结果数量: {len(result)}")
|
||||||
if result:
|
if result:
|
||||||
print(f"第一条消息时间: {result[0]['time']}")
|
print(f"第一条消息时间: {result[0].time}")
|
||||||
print(f"最后一条消息时间: {result[-1]['time']}")
|
print(f"最后一条消息时间: {result[-1].time}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取消息时出错: {str(e)}")
|
print(f"获取消息时出错: {str(e)}")
|
||||||
|
print(traceback.format_exc())
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_message_from_db_before_msg(self, message: MessageRecv):
|
def get_message_from_db_before_msg(self, message: MessageRecv):
|
||||||
# 从数据库中获取消息
|
message_id_val = message.message_info.message_id
|
||||||
message_id = message.message_info.message_id
|
chat_id_val = message.chat_stream.stream_id
|
||||||
chat_id = message.chat_stream.stream_id
|
|
||||||
|
|
||||||
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
|
messages_before_query = (
|
||||||
messages_before = (
|
Messages.select()
|
||||||
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
|
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
|
||||||
.sort("time", -1)
|
.order_by(Messages.time.desc())
|
||||||
.limit(self.context_length * 3)
|
.limit(global_config.chat.observation_context_size * 3)
|
||||||
) # 获取更多历史信息
|
)
|
||||||
|
|
||||||
return list(messages_before)
|
return list(messages_before_query)
|
||||||
|
|
||||||
def message_list_to_dict(self, message_list):
|
def message_list_to_dict(self, message_list):
|
||||||
# 存储简化的聊天记录
|
|
||||||
result = []
|
result = []
|
||||||
for message in message_list:
|
for msg_item in message_list:
|
||||||
if not isinstance(message, dict):
|
processed_msg_item = msg_item
|
||||||
message = self.message_to_dict(message)
|
if not isinstance(msg_item, dict):
|
||||||
# print(message)
|
processed_msg_item = self.message_to_dict(msg_item)
|
||||||
|
|
||||||
|
if not processed_msg_item:
|
||||||
|
continue
|
||||||
|
|
||||||
lite_message = {
|
lite_message = {
|
||||||
"time": message["time"],
|
"time": processed_msg_item.get("time"),
|
||||||
"user_nickname": message["user_info"]["user_nickname"],
|
"user_nickname": processed_msg_item.get("user_nickname"),
|
||||||
"processed_plain_text": message["processed_plain_text"],
|
"processed_plain_text": processed_msg_item.get("processed_plain_text"),
|
||||||
}
|
}
|
||||||
result.append(lite_message)
|
result.append(lite_message)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def message_to_dict(message):
|
def message_to_dict(msg_obj):
|
||||||
if not message:
|
if not msg_obj:
|
||||||
return None
|
return None
|
||||||
if isinstance(message, dict):
|
if isinstance(msg_obj, dict):
|
||||||
return message
|
return msg_obj
|
||||||
|
|
||||||
|
if isinstance(msg_obj, Messages):
|
||||||
return {
|
return {
|
||||||
# "message_id": message.message_info.message_id,
|
"time": msg_obj.time,
|
||||||
"time": message.message_info.time,
|
"user_id": msg_obj.user_id,
|
||||||
"user_id": message.message_info.user_info.user_id,
|
"user_nickname": msg_obj.user_nickname,
|
||||||
"user_nickname": message.message_info.user_info.user_nickname,
|
"processed_plain_text": msg_obj.processed_plain_text,
|
||||||
"processed_plain_text": message.processed_plain_text,
|
|
||||||
# "detailed_plain_text": message.detailed_plain_text
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
|
||||||
|
return {
|
||||||
|
"time": msg_obj.message_info.time,
|
||||||
|
"user_id": msg_obj.message_info.user_info.user_id,
|
||||||
|
"user_nickname": msg_obj.message_info.user_info.user_nickname,
|
||||||
|
"processed_plain_text": msg_obj.processed_plain_text,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
|
||||||
|
return {}
|
||||||
|
|
||||||
def done_catch(self):
|
def done_catch(self):
|
||||||
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
|
"""将收集到的信息存储到数据库的 thinking_log 表中喵~"""
|
||||||
try:
|
try:
|
||||||
# 将消息对象转换为可序列化的字典喵~
|
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
|
||||||
|
response_info_dict = {
|
||||||
thinking_log_data = {
|
|
||||||
"chat_id": self.chat_id,
|
|
||||||
"trigger_text": self.trigger_response_text,
|
|
||||||
"response_text": self.response_text,
|
|
||||||
"trigger_info": {
|
|
||||||
"time": self.trigger_response_time,
|
|
||||||
"message": self.message_to_dict(self.trigger_response_message),
|
|
||||||
},
|
|
||||||
"response_info": {
|
|
||||||
"time": self.response_time,
|
"time": self.response_time,
|
||||||
"message": self.response_messages,
|
"message": self.response_messages,
|
||||||
},
|
|
||||||
"timing_results": self.timing_results,
|
|
||||||
"chat_history": self.message_list_to_dict(self.chat_history),
|
|
||||||
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
|
|
||||||
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
|
|
||||||
"heartflow_data": self.heartflow_data,
|
|
||||||
"reasoning_data": self.reasoning_data,
|
|
||||||
}
|
}
|
||||||
|
chat_history_list = self.message_list_to_dict(self.chat_history)
|
||||||
|
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
|
||||||
|
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
|
||||||
|
|
||||||
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
|
log_entry = ThinkingLog(
|
||||||
# if self.response_mode == "heart_flow":
|
chat_id=self.chat_id,
|
||||||
# thinking_log_data["mode_specific_data"] = self.heartflow_data
|
trigger_text=self.trigger_response_text,
|
||||||
# elif self.response_mode == "reasoning":
|
response_text=self.response_text,
|
||||||
# thinking_log_data["mode_specific_data"] = self.reasoning_data
|
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
|
||||||
|
response_info_json=json.dumps(response_info_dict),
|
||||||
# 将数据插入到 thinking_log 集合中喵~
|
timing_results_json=json.dumps(self.timing_results),
|
||||||
db.thinking_log.insert_one(thinking_log_data)
|
chat_history_json=json.dumps(chat_history_list),
|
||||||
|
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
|
||||||
|
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
|
||||||
|
heartflow_data_json=json.dumps(self.heartflow_data),
|
||||||
|
reasoning_data_json=json.dumps(self.reasoning_data),
|
||||||
|
)
|
||||||
|
log_entry.save()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ from collections import defaultdict
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Dict, Tuple, List
|
from typing import Any, Dict, Tuple, List
|
||||||
|
|
||||||
|
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.manager.async_task_manager import AsyncTask
|
from src.manager.async_task_manager import AsyncTask
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database.database import db # This db is the Peewee database instance
|
||||||
|
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
logger = get_module_logger("maibot_statistic")
|
logger = get_module_logger("maibot_statistic")
|
||||||
@@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||||
|
|
||||||
self.record_id: str | None = None
|
self.record_id: int | None = None # Changed to int for Peewee's default ID
|
||||||
"""记录ID"""
|
"""记录ID"""
|
||||||
|
|
||||||
self._init_database() # 初始化数据库
|
self._init_database() # 初始化数据库
|
||||||
@@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _init_database():
|
def _init_database():
|
||||||
"""初始化数据库"""
|
"""初始化数据库"""
|
||||||
if "online_time" not in db.list_collection_names():
|
with db.atomic(): # Use atomic operations for schema changes
|
||||||
# 初始化数据库(在线时长)
|
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||||
db.create_collection("online_time")
|
|
||||||
# 创建索引
|
|
||||||
if ("end_timestamp", 1) not in db.online_time.list_indexes():
|
|
||||||
db.online_time.create_index([("end_timestamp", 1)])
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
|
current_time = datetime.now()
|
||||||
|
extended_end_time = current_time + timedelta(minutes=1)
|
||||||
|
|
||||||
if self.record_id:
|
if self.record_id:
|
||||||
# 如果有记录,则更新结束时间
|
# 如果有记录,则更新结束时间
|
||||||
db.online_time.update_one(
|
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
|
||||||
{"_id": self.record_id},
|
updated_rows = query.execute()
|
||||||
{
|
if updated_rows == 0:
|
||||||
"$set": {
|
# Record might have been deleted or ID is stale, try to find/create
|
||||||
"end_timestamp": datetime.now() + timedelta(minutes=1),
|
self.record_id = None # Reset record_id to trigger find/create logic below
|
||||||
}
|
|
||||||
},
|
if not self.record_id: # Check again if record_id was reset or initially None
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果没有记录,检查一分钟以内是否已有记录
|
# 如果没有记录,检查一分钟以内是否已有记录
|
||||||
current_time = datetime.now()
|
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||||
if recent_record := db.online_time.find_one(
|
recent_record = (
|
||||||
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
|
OnlineTime.select()
|
||||||
):
|
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
|
||||||
# 如果有记录,则更新结束时间
|
.order_by(OnlineTime.end_timestamp.desc())
|
||||||
self.record_id = recent_record["_id"]
|
.first()
|
||||||
db.online_time.update_one(
|
|
||||||
{"_id": self.record_id},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"end_timestamp": current_time + timedelta(minutes=1),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if recent_record:
|
||||||
|
# 如果有记录,则更新结束时间
|
||||||
|
self.record_id = recent_record.id
|
||||||
|
recent_record.end_timestamp = extended_end_time
|
||||||
|
recent_record.save()
|
||||||
else:
|
else:
|
||||||
# 若没有记录,则插入新的在线时间记录
|
# 若没有记录,则插入新的在线时间记录
|
||||||
self.record_id = db.online_time.insert_one(
|
new_record = OnlineTime.create(
|
||||||
{
|
timestamp=current_time.timestamp(), # 添加此行
|
||||||
"start_timestamp": current_time,
|
start_timestamp=current_time,
|
||||||
"end_timestamp": current_time + timedelta(minutes=1),
|
end_timestamp=extended_end_time,
|
||||||
}
|
duration=5, # 初始时长为5分钟
|
||||||
).inserted_id
|
)
|
||||||
|
self.record_id = new_record.id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||||
|
|
||||||
@@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
:param collect_period: 统计时间段
|
:param collect_period: 统计时间段
|
||||||
"""
|
"""
|
||||||
if len(collect_period) <= 0:
|
if not collect_period:
|
||||||
return {}
|
return {}
|
||||||
else:
|
|
||||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
||||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
period_key: {
|
period_key: {
|
||||||
# 总LLM请求数
|
|
||||||
TOTAL_REQ_CNT: 0,
|
TOTAL_REQ_CNT: 0,
|
||||||
# 请求次数统计
|
|
||||||
REQ_CNT_BY_TYPE: defaultdict(int),
|
REQ_CNT_BY_TYPE: defaultdict(int),
|
||||||
REQ_CNT_BY_USER: defaultdict(int),
|
REQ_CNT_BY_USER: defaultdict(int),
|
||||||
REQ_CNT_BY_MODEL: defaultdict(int),
|
REQ_CNT_BY_MODEL: defaultdict(int),
|
||||||
# 输入Token数
|
|
||||||
IN_TOK_BY_TYPE: defaultdict(int),
|
IN_TOK_BY_TYPE: defaultdict(int),
|
||||||
IN_TOK_BY_USER: defaultdict(int),
|
IN_TOK_BY_USER: defaultdict(int),
|
||||||
IN_TOK_BY_MODEL: defaultdict(int),
|
IN_TOK_BY_MODEL: defaultdict(int),
|
||||||
# 输出Token数
|
|
||||||
OUT_TOK_BY_TYPE: defaultdict(int),
|
OUT_TOK_BY_TYPE: defaultdict(int),
|
||||||
OUT_TOK_BY_USER: defaultdict(int),
|
OUT_TOK_BY_USER: defaultdict(int),
|
||||||
OUT_TOK_BY_MODEL: defaultdict(int),
|
OUT_TOK_BY_MODEL: defaultdict(int),
|
||||||
# 总Token数
|
|
||||||
TOTAL_TOK_BY_TYPE: defaultdict(int),
|
TOTAL_TOK_BY_TYPE: defaultdict(int),
|
||||||
TOTAL_TOK_BY_USER: defaultdict(int),
|
TOTAL_TOK_BY_USER: defaultdict(int),
|
||||||
TOTAL_TOK_BY_MODEL: defaultdict(int),
|
TOTAL_TOK_BY_MODEL: defaultdict(int),
|
||||||
# 总开销
|
|
||||||
TOTAL_COST: 0.0,
|
TOTAL_COST: 0.0,
|
||||||
# 请求开销统计
|
|
||||||
COST_BY_TYPE: defaultdict(float),
|
COST_BY_TYPE: defaultdict(float),
|
||||||
COST_BY_USER: defaultdict(float),
|
COST_BY_USER: defaultdict(float),
|
||||||
COST_BY_MODEL: defaultdict(float),
|
COST_BY_MODEL: defaultdict(float),
|
||||||
@@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 以最早的时间戳为起始时间获取记录
|
# 以最早的时间戳为起始时间获取记录
|
||||||
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
|
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||||
record_timestamp = record.get("timestamp")
|
query_start_time = collect_period[-1][1]
|
||||||
|
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
|
||||||
|
record_timestamp = record.timestamp # This is already a datetime object
|
||||||
for idx, (_, period_start) in enumerate(collect_period):
|
for idx, (_, period_start) in enumerate(collect_period):
|
||||||
if record_timestamp >= period_start:
|
if record_timestamp >= period_start:
|
||||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
|
||||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
|
||||||
for period_key, _ in collect_period[idx:]:
|
for period_key, _ in collect_period[idx:]:
|
||||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||||
|
|
||||||
request_type = record.get("request_type", "unknown") # 请求类型
|
request_type = record.request_type or "unknown"
|
||||||
user_id = str(record.get("user_id", "unknown")) # 用户ID
|
user_id = record.user_id or "unknown" # user_id is TextField, already string
|
||||||
model_name = record.get("model_name", "unknown") # 模型名称
|
model_name = record.model_name or "unknown"
|
||||||
|
|
||||||
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
|
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
|
||||||
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
|
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
|
||||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||||
|
|
||||||
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
|
prompt_tokens = record.prompt_tokens or 0
|
||||||
completion_tokens = record.get("completion_tokens", 0) # 输出Token数
|
completion_tokens = record.completion_tokens or 0
|
||||||
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||||
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
|
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
|
||||||
@@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
|
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
|
||||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||||
|
|
||||||
cost = record.get("cost", 0.0)
|
cost = record.cost or 0.0
|
||||||
stats[period_key][TOTAL_COST] += cost
|
stats[period_key][TOTAL_COST] += cost
|
||||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||||
stats[period_key][COST_BY_USER][user_id] += cost
|
stats[period_key][COST_BY_USER][user_id] += cost
|
||||||
stats[period_key][COST_BY_MODEL][model_name] += cost
|
stats[period_key][COST_BY_MODEL][model_name] += cost
|
||||||
break # 取消更早时间段的判断
|
break
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
:param collect_period: 统计时间段
|
:param collect_period: 统计时间段
|
||||||
"""
|
"""
|
||||||
if len(collect_period) <= 0:
|
if not collect_period:
|
||||||
return {}
|
return {}
|
||||||
else:
|
|
||||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
|
||||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
period_key: {
|
period_key: {
|
||||||
# 在线时间统计
|
|
||||||
ONLINE_TIME: 0.0,
|
ONLINE_TIME: 0.0,
|
||||||
}
|
}
|
||||||
for period_key, _ in collect_period
|
for period_key, _ in collect_period
|
||||||
}
|
}
|
||||||
|
|
||||||
# 统计在线时间
|
query_start_time = collect_period[-1][1]
|
||||||
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
|
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||||
end_timestamp: datetime = record.get("end_timestamp")
|
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
|
||||||
for idx, (_, period_start) in enumerate(collect_period):
|
# record.end_timestamp and record.start_timestamp are datetime objects
|
||||||
if end_timestamp >= period_start:
|
record_end_timestamp = record.end_timestamp
|
||||||
# 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间
|
record_start_timestamp = record.start_timestamp
|
||||||
end_timestamp = min(end_timestamp, now)
|
|
||||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
|
||||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
|
||||||
for period_key, _period_start in collect_period[idx:]:
|
|
||||||
start_timestamp: datetime = record.get("start_timestamp")
|
|
||||||
if start_timestamp < _period_start:
|
|
||||||
# 如果开始时间在查询边界之前,则使用开始时间
|
|
||||||
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
|
|
||||||
else:
|
|
||||||
# 否则,使用开始时间
|
|
||||||
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
|
|
||||||
break # 取消更早时间段的判断
|
|
||||||
|
|
||||||
|
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||||
|
if record_end_timestamp >= period_boundary_start:
|
||||||
|
# Calculate effective end time for this record in relation to 'now'
|
||||||
|
effective_end_time = min(record_end_timestamp, now)
|
||||||
|
|
||||||
|
for period_key, current_period_start_time in collect_period[idx:]:
|
||||||
|
# Determine the portion of the record that falls within this specific statistical period
|
||||||
|
overlap_start = max(record_start_timestamp, current_period_start_time)
|
||||||
|
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
||||||
|
|
||||||
|
if overlap_end > overlap_start:
|
||||||
|
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
||||||
|
break
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||||
@@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
:param collect_period: 统计时间段
|
:param collect_period: 统计时间段
|
||||||
"""
|
"""
|
||||||
if len(collect_period) <= 0:
|
if not collect_period:
|
||||||
return {}
|
return {}
|
||||||
else:
|
|
||||||
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
|
|
||||||
collect_period.sort(key=lambda x: x[1], reverse=True)
|
collect_period.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
period_key: {
|
period_key: {
|
||||||
# 消息统计
|
|
||||||
TOTAL_MSG_CNT: 0,
|
TOTAL_MSG_CNT: 0,
|
||||||
MSG_CNT_BY_CHAT: defaultdict(int),
|
MSG_CNT_BY_CHAT: defaultdict(int),
|
||||||
}
|
}
|
||||||
for period_key, _ in collect_period
|
for period_key, _ in collect_period
|
||||||
}
|
}
|
||||||
|
|
||||||
# 统计消息量
|
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||||
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
|
for message in Messages.select().where(Messages.time >= query_start_timestamp):
|
||||||
chat_info = message.get("chat_info", None) # 聊天信息
|
message_time_ts = message.time # This is a float timestamp
|
||||||
user_info = message.get("user_info", None) # 用户信息(消息发送人)
|
|
||||||
message_time = message.get("time", 0) # 消息时间
|
|
||||||
|
|
||||||
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
|
chat_id = None
|
||||||
if group_info is not None:
|
chat_name = None
|
||||||
# 若有群聊信息
|
|
||||||
chat_id = f"g{group_info.get('group_id')}"
|
# Logic based on Peewee model structure, aiming to replicate original intent
|
||||||
chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}")
|
if message.chat_info_group_id:
|
||||||
elif user_info:
|
chat_id = f"g{message.chat_info_group_id}"
|
||||||
# 若没有群聊信息,则尝试获取用户信息
|
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||||
chat_id = f"u{user_info['user_id']}"
|
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
|
||||||
chat_name = user_info["user_nickname"]
|
# This uses the message SENDER's ID as per original logic's fallback
|
||||||
|
chat_id = f"u{message.user_id}" # SENDER's user_id
|
||||||
|
chat_name = message.user_nickname # SENDER's nickname
|
||||||
else:
|
else:
|
||||||
continue # 如果没有群组信息也没有用户信息,则跳过
|
# If neither group_id nor sender_id is available for chat identification
|
||||||
|
logger.warning(
|
||||||
|
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not chat_id: # Should not happen if above logic is correct
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update name_mapping
|
||||||
if chat_id in self.name_mapping:
|
if chat_id in self.name_mapping:
|
||||||
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
|
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||||
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
|
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||||
self.name_mapping[chat_id] = (chat_name, message_time)
|
|
||||||
else:
|
else:
|
||||||
self.name_mapping[chat_id] = (chat_name, message_time)
|
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||||
|
|
||||||
for idx, (_, period_start) in enumerate(collect_period):
|
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||||
if message_time >= period_start.timestamp():
|
if message_time_ts >= period_start_dt.timestamp():
|
||||||
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
|
|
||||||
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
|
|
||||||
for period_key, _ in collect_period[idx:]:
|
for period_key, _ in collect_period[idx:]:
|
||||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager
|
|||||||
from ..message_receive.message import MessageRecv
|
from ..message_receive.message import MessageRecv
|
||||||
from ..models.utils_model import LLMRequest
|
from ..models.utils_model import LLMRequest
|
||||||
from .typo_generator import ChineseTypoGenerator
|
from .typo_generator import ChineseTypoGenerator
|
||||||
from ...common.database import db
|
from ...common.database.database import db
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
|
|
||||||
logger = get_module_logger("chat_utils")
|
logger = get_module_logger("chat_utils")
|
||||||
@@ -43,8 +43,8 @@ def db_message_to_str(message_dict: dict) -> str:
|
|||||||
|
|
||||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||||
"""检查消息是否提到了机器人"""
|
"""检查消息是否提到了机器人"""
|
||||||
keywords = [global_config.BOT_NICKNAME]
|
keywords = [global_config.bot.nickname]
|
||||||
nicknames = global_config.BOT_ALIAS_NAMES
|
nicknames = global_config.bot.alias_names
|
||||||
reply_probability = 0.0
|
reply_probability = 0.0
|
||||||
is_at = False
|
is_at = False
|
||||||
is_mentioned = False
|
is_mentioned = False
|
||||||
@@ -64,18 +64,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 判断是否被@
|
# 判断是否被@
|
||||||
if re.search(f"@[\s\S]*?(id:{global_config.BOT_QQ})", message.processed_plain_text):
|
if re.search(f"@[\s\S]*?(id:{global_config.bot.qq_account})", message.processed_plain_text):
|
||||||
is_at = True
|
is_at = True
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
|
|
||||||
if is_at and global_config.at_bot_inevitable_reply:
|
if is_at and global_config.normal_chat.at_bot_inevitable_reply:
|
||||||
reply_probability = 1.0
|
reply_probability = 1.0
|
||||||
logger.info("被@,回复概率设置为100%")
|
logger.info("被@,回复概率设置为100%")
|
||||||
else:
|
else:
|
||||||
if not is_mentioned:
|
if not is_mentioned:
|
||||||
# 判断是否被回复
|
# 判断是否被回复
|
||||||
if re.match(
|
if re.match(
|
||||||
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\):[\s\S]*?],说:", message.processed_plain_text
|
f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\):[\s\S]*?],说:", message.processed_plain_text
|
||||||
):
|
):
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
else:
|
else:
|
||||||
@@ -88,7 +88,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||||||
for nickname in nicknames:
|
for nickname in nicknames:
|
||||||
if nickname in message_content:
|
if nickname in message_content:
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
if is_mentioned and global_config.mentioned_bot_inevitable_reply:
|
if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
|
||||||
reply_probability = 1.0
|
reply_probability = 1.0
|
||||||
logger.info("被提及,回复概率设置为100%")
|
logger.info("被提及,回复概率设置为100%")
|
||||||
return is_mentioned, reply_probability
|
return is_mentioned, reply_probability
|
||||||
@@ -96,7 +96,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||||||
|
|
||||||
async def get_embedding(text, request_type="embedding"):
|
async def get_embedding(text, request_type="embedding"):
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的embedding向量"""
|
||||||
llm = LLMRequest(model=global_config.embedding, request_type=request_type)
|
# TODO: API-Adapter修改标记
|
||||||
|
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
|
||||||
# return llm.get_embedding_sync(text)
|
# return llm.get_embedding_sync(text)
|
||||||
try:
|
try:
|
||||||
embedding = await llm.get_embedding(text)
|
embedding = await llm.get_embedding(text)
|
||||||
@@ -163,7 +164,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
|
|||||||
user_info = UserInfo.from_dict(msg_db_data["user_info"])
|
user_info = UserInfo.from_dict(msg_db_data["user_info"])
|
||||||
if (
|
if (
|
||||||
(user_info.platform, user_info.user_id) != sender
|
(user_info.platform, user_info.user_id) != sender
|
||||||
and user_info.user_id != global_config.BOT_QQ
|
and user_info.user_id != global_config.bot.qq_account
|
||||||
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
|
||||||
and len(who_chat_in_group) < 5
|
and len(who_chat_in_group) < 5
|
||||||
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目
|
||||||
@@ -321,7 +322,7 @@ def random_remove_punctuation(text: str) -> str:
|
|||||||
|
|
||||||
def process_llm_response(text: str) -> list[str]:
|
def process_llm_response(text: str) -> list[str]:
|
||||||
# 先保护颜文字
|
# 先保护颜文字
|
||||||
if global_config.enable_kaomoji_protection:
|
if global_config.response_splitter.enable_kaomoji_protection:
|
||||||
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
||||||
logger.trace(f"保护颜文字后的文本: {protected_text}")
|
logger.trace(f"保护颜文字后的文本: {protected_text}")
|
||||||
else:
|
else:
|
||||||
@@ -340,8 +341,8 @@ def process_llm_response(text: str) -> list[str]:
|
|||||||
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
|
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
|
||||||
|
|
||||||
# 对清理后的文本进行进一步处理
|
# 对清理后的文本进行进一步处理
|
||||||
max_length = global_config.response_max_length * 2
|
max_length = global_config.response_splitter.max_length * 2
|
||||||
max_sentence_num = global_config.response_max_sentence_num
|
max_sentence_num = global_config.response_splitter.max_sentence_num
|
||||||
# 如果基本上是中文,则进行长度过滤
|
# 如果基本上是中文,则进行长度过滤
|
||||||
if get_western_ratio(cleaned_text) < 0.1:
|
if get_western_ratio(cleaned_text) < 0.1:
|
||||||
if len(cleaned_text) > max_length:
|
if len(cleaned_text) > max_length:
|
||||||
@@ -349,20 +350,20 @@ def process_llm_response(text: str) -> list[str]:
|
|||||||
return ["懒得说"]
|
return ["懒得说"]
|
||||||
|
|
||||||
typo_generator = ChineseTypoGenerator(
|
typo_generator = ChineseTypoGenerator(
|
||||||
error_rate=global_config.chinese_typo_error_rate,
|
error_rate=global_config.chinese_typo.error_rate,
|
||||||
min_freq=global_config.chinese_typo_min_freq,
|
min_freq=global_config.chinese_typo.min_freq,
|
||||||
tone_error_rate=global_config.chinese_typo_tone_error_rate,
|
tone_error_rate=global_config.chinese_typo.tone_error_rate,
|
||||||
word_replace_rate=global_config.chinese_typo_word_replace_rate,
|
word_replace_rate=global_config.chinese_typo.word_replace_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_config.enable_response_splitter:
|
if global_config.response_splitter.enable:
|
||||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||||
else:
|
else:
|
||||||
split_sentences = [cleaned_text]
|
split_sentences = [cleaned_text]
|
||||||
|
|
||||||
sentences = []
|
sentences = []
|
||||||
for sentence in split_sentences:
|
for sentence in split_sentences:
|
||||||
if global_config.chinese_typo_enable:
|
if global_config.chinese_typo.enable:
|
||||||
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
||||||
sentences.append(typoed_text)
|
sentences.append(typoed_text)
|
||||||
if typo_corrections:
|
if typo_corrections:
|
||||||
@@ -372,7 +373,7 @@ def process_llm_response(text: str) -> list[str]:
|
|||||||
|
|
||||||
if len(sentences) > max_sentence_num:
|
if len(sentences) > max_sentence_num:
|
||||||
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||||
return [f"{global_config.BOT_NICKNAME}不知道哦"]
|
return [f"{global_config.bot.nickname}不知道哦"]
|
||||||
|
|
||||||
# if extracted_contents:
|
# if extracted_contents:
|
||||||
# for content in extracted_contents:
|
# for content in extracted_contents:
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ import io
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database.database import db
|
||||||
|
from ...common.database.database_model import Images, ImageDescriptions
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from ..models.utils_model import LLMRequest
|
from ..models.utils_model import LLMRequest
|
||||||
|
|
||||||
@@ -32,40 +33,23 @@ class ImageManager:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self._ensure_image_collection()
|
|
||||||
self._ensure_description_collection()
|
|
||||||
self._ensure_image_dir()
|
self._ensure_image_dir()
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.connect(reuse_if_open=True)
|
||||||
|
db.create_tables([Images, ImageDescriptions], safe=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库连接或表创建失败: {e}")
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
|
||||||
|
|
||||||
def _ensure_image_dir(self):
|
def _ensure_image_dir(self):
|
||||||
"""确保图像存储目录存在"""
|
"""确保图像存储目录存在"""
|
||||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _ensure_image_collection():
|
|
||||||
"""确保images集合存在并创建索引"""
|
|
||||||
if "images" not in db.list_collection_names():
|
|
||||||
db.create_collection("images")
|
|
||||||
|
|
||||||
# 删除旧索引
|
|
||||||
db.images.drop_indexes()
|
|
||||||
# 创建新的复合索引
|
|
||||||
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
|
|
||||||
db.images.create_index([("url", 1)])
|
|
||||||
db.images.create_index([("path", 1)])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _ensure_description_collection():
|
|
||||||
"""确保image_descriptions集合存在并创建索引"""
|
|
||||||
if "image_descriptions" not in db.list_collection_names():
|
|
||||||
db.create_collection("image_descriptions")
|
|
||||||
|
|
||||||
# 删除旧索引
|
|
||||||
db.image_descriptions.drop_indexes()
|
|
||||||
# 创建新的复合索引
|
|
||||||
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||||
"""从数据库获取图片描述
|
"""从数据库获取图片描述
|
||||||
@@ -77,8 +61,14 @@ class ImageManager:
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 描述文本,如果不存在则返回None
|
Optional[str]: 描述文本,如果不存在则返回None
|
||||||
"""
|
"""
|
||||||
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
|
try:
|
||||||
return result["description"] if result else None
|
record = ImageDescriptions.get_or_none(
|
||||||
|
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
|
||||||
|
)
|
||||||
|
return record.description if record else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||||
@@ -90,20 +80,17 @@ class ImageManager:
|
|||||||
description_type: 描述类型 ('emoji' 或 'image')
|
description_type: 描述类型 ('emoji' 或 'image')
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db.image_descriptions.update_one(
|
current_timestamp = time.time()
|
||||||
{"hash": image_hash, "type": description_type},
|
defaults = {"description": description, "timestamp": current_timestamp}
|
||||||
{
|
desc_obj, created = ImageDescriptions.get_or_create(
|
||||||
"$set": {
|
hash=image_hash, type=description_type, defaults=defaults
|
||||||
"description": description,
|
|
||||||
"timestamp": int(time.time()),
|
|
||||||
"hash": image_hash, # 确保hash字段存在
|
|
||||||
"type": description_type, # 确保type字段存在
|
|
||||||
}
|
|
||||||
},
|
|
||||||
upsert=True,
|
|
||||||
)
|
)
|
||||||
|
if not created: # 如果记录已存在,则更新
|
||||||
|
desc_obj.description = description
|
||||||
|
desc_obj.timestamp = current_timestamp
|
||||||
|
desc_obj.save()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存描述到数据库失败: {str(e)}")
|
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||||
|
|
||||||
async def get_emoji_description(self, image_base64: str) -> str:
|
async def get_emoji_description(self, image_base64: str) -> str:
|
||||||
"""获取表情包描述,带查重和保存功能"""
|
"""获取表情包描述,带查重和保存功能"""
|
||||||
@@ -116,51 +103,64 @@ class ImageManager:
|
|||||||
# 查询缓存的描述
|
# 查询缓存的描述
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||||
if cached_description:
|
if cached_description:
|
||||||
# logger.debug(f"缓存表情包描述: {cached_description}")
|
|
||||||
return f"[表情包,含义看起来是:{cached_description}]"
|
return f"[表情包,含义看起来是:{cached_description}]"
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
if image_format == "gif" or image_format == "GIF":
|
if image_format == "gif" or image_format == "GIF":
|
||||||
image_base64 = self.transform_gif(image_base64)
|
image_base64_processed = self.transform_gif(image_base64)
|
||||||
|
if image_base64_processed is None:
|
||||||
|
logger.warning("GIF转换失败,无法获取描述")
|
||||||
|
return "[表情包(GIF处理失败)]"
|
||||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些"
|
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些"
|
||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
|
||||||
else:
|
else:
|
||||||
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
|
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
|
||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||||
|
|
||||||
|
if description is None:
|
||||||
|
logger.warning("AI未能生成表情包描述")
|
||||||
|
return "[表情包(描述生成失败)]"
|
||||||
|
|
||||||
|
# 再次检查缓存,防止并发写入时重复生成
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||||
if cached_description:
|
if cached_description:
|
||||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||||
return f"[表情包,含义看起来是:{cached_description}]"
|
return f"[表情包,含义看起来是:{cached_description}]"
|
||||||
|
|
||||||
# 根据配置决定是否保存图片
|
# 根据配置决定是否保存图片
|
||||||
if global_config.save_emoji:
|
if global_config.emoji.save_emoji:
|
||||||
# 生成文件名和路径
|
# 生成文件名和路径
|
||||||
timestamp = int(time.time())
|
current_timestamp = time.time()
|
||||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
|
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
|
||||||
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
|
os.makedirs(emoji_dir, exist_ok=True)
|
||||||
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
|
file_path = os.path.join(emoji_dir, filename)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 保存文件
|
# 保存文件
|
||||||
with open(file_path, "wb") as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库 (Images表)
|
||||||
image_doc = {
|
try:
|
||||||
"hash": image_hash,
|
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||||
"path": file_path,
|
img_obj.path = file_path
|
||||||
"type": "emoji",
|
img_obj.description = description
|
||||||
"description": description,
|
img_obj.timestamp = current_timestamp
|
||||||
"timestamp": timestamp,
|
img_obj.save()
|
||||||
}
|
except Images.DoesNotExist:
|
||||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
Images.create(
|
||||||
logger.trace(f"保存表情包: {file_path}")
|
hash=image_hash,
|
||||||
|
path=file_path,
|
||||||
|
type="emoji",
|
||||||
|
description=description,
|
||||||
|
timestamp=current_timestamp,
|
||||||
|
)
|
||||||
|
logger.trace(f"保存表情包元数据: {file_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存表情包文件失败: {str(e)}")
|
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||||
|
|
||||||
# 保存描述到数据库
|
# 保存描述到数据库 (ImageDescriptions表)
|
||||||
self._save_description_to_db(image_hash, description, "emoji")
|
self._save_description_to_db(image_hash, description, "emoji")
|
||||||
|
|
||||||
return f"[表情包:{description}]"
|
return f"[表情包:{description}]"
|
||||||
@@ -188,6 +188,11 @@ class ImageManager:
|
|||||||
)
|
)
|
||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||||
|
|
||||||
|
if description is None:
|
||||||
|
logger.warning("AI未能生成图片描述")
|
||||||
|
return "[图片(描述生成失败)]"
|
||||||
|
|
||||||
|
# 再次检查缓存
|
||||||
cached_description = self._get_description_from_db(image_hash, "image")
|
cached_description = self._get_description_from_db(image_hash, "image")
|
||||||
if cached_description:
|
if cached_description:
|
||||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
|
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
|
||||||
@@ -195,38 +200,40 @@ class ImageManager:
|
|||||||
|
|
||||||
logger.debug(f"描述是{description}")
|
logger.debug(f"描述是{description}")
|
||||||
|
|
||||||
if description is None:
|
|
||||||
logger.warning("AI未能生成图片描述")
|
|
||||||
return "[图片]"
|
|
||||||
|
|
||||||
# 根据配置决定是否保存图片
|
# 根据配置决定是否保存图片
|
||||||
if global_config.save_pic:
|
if global_config.emoji.save_pic:
|
||||||
# 生成文件名和路径
|
# 生成文件名和路径
|
||||||
timestamp = int(time.time())
|
current_timestamp = time.time()
|
||||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||||
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
|
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
||||||
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
|
os.makedirs(image_dir, exist_ok=True)
|
||||||
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
|
file_path = os.path.join(image_dir, filename)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 保存文件
|
# 保存文件
|
||||||
with open(file_path, "wb") as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库 (Images表)
|
||||||
image_doc = {
|
try:
|
||||||
"hash": image_hash,
|
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image"))
|
||||||
"path": file_path,
|
img_obj.path = file_path
|
||||||
"type": "image",
|
img_obj.description = description
|
||||||
"description": description,
|
img_obj.timestamp = current_timestamp
|
||||||
"timestamp": timestamp,
|
img_obj.save()
|
||||||
}
|
except Images.DoesNotExist:
|
||||||
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
|
Images.create(
|
||||||
logger.trace(f"保存图片: {file_path}")
|
hash=image_hash,
|
||||||
|
path=file_path,
|
||||||
|
type="image",
|
||||||
|
description=description,
|
||||||
|
timestamp=current_timestamp,
|
||||||
|
)
|
||||||
|
logger.trace(f"保存图片元数据: {file_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存图片文件失败: {str(e)}")
|
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||||
|
|
||||||
# 保存描述到数据库
|
# 保存描述到数据库 (ImageDescriptions表)
|
||||||
self._save_description_to_db(image_hash, description, "image")
|
self._save_description_to_db(image_hash, description, "image")
|
||||||
|
|
||||||
return f"[图片:{description}]"
|
return f"[图片:{description}]"
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
|||||||
sys.path.append(root_path)
|
sys.path.append(root_path)
|
||||||
|
|
||||||
# 现在可以导入src模块
|
# 现在可以导入src模块
|
||||||
from src.common.database import db # noqa E402
|
from common.database.database import db # noqa E402
|
||||||
|
|
||||||
|
|
||||||
# 加载根目录下的env.edv文件
|
# 加载根目录下的env.edv文件
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
from peewee import SqliteDatabase
|
||||||
from pymongo.database import Database
|
from pymongo.database import Database
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
@@ -57,4 +58,15 @@ class DBWrapper:
|
|||||||
|
|
||||||
|
|
||||||
# 全局数据库访问点
|
# 全局数据库访问点
|
||||||
db: Database = DBWrapper()
|
memory_db: Database = DBWrapper()
|
||||||
|
|
||||||
|
# 定义数据库文件路径
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
_DB_DIR = os.path.join(ROOT_PATH, "data")
|
||||||
|
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
|
||||||
|
|
||||||
|
# 确保数据库目录存在
|
||||||
|
os.makedirs(_DB_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 全局 Peewee SQLite 数据库访问点
|
||||||
|
db = SqliteDatabase(_DB_FILE)
|
||||||
358
src/common/database/database_model.py
Normal file
358
src/common/database/database_model.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
|
||||||
|
from .database import db
|
||||||
|
import datetime
|
||||||
|
from ..logger_manager import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("database_model")
|
||||||
|
# 请在此处定义您的数据库实例。
|
||||||
|
# 您需要取消注释并配置适合您的数据库的部分。
|
||||||
|
# 例如,对于 SQLite:
|
||||||
|
# db = SqliteDatabase('MaiBot.db')
|
||||||
|
#
|
||||||
|
# 对于 PostgreSQL:
|
||||||
|
# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
|
||||||
|
# host='localhost', port=5432)
|
||||||
|
#
|
||||||
|
# 对于 MySQL:
|
||||||
|
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
|
||||||
|
# host='localhost', port=3306)
|
||||||
|
|
||||||
|
|
||||||
|
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
|
||||||
|
# 这允许您在一个地方为所有模型指定数据库。
|
||||||
|
class BaseModel(Model):
|
||||||
|
class Meta:
|
||||||
|
# 将下面的 'db' 替换为您实际的数据库实例变量名。
|
||||||
|
database = db # 例如: database = my_actual_db_instance
|
||||||
|
pass # 在用户定义数据库实例之前,此处为占位符
|
||||||
|
|
||||||
|
|
||||||
|
class ChatStreams(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
|
||||||
|
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
|
||||||
|
stream_id = TextField(unique=True, index=True)
|
||||||
|
|
||||||
|
# create_time: 1746096761.4490178 (时间戳,精确到小数点后7位)
|
||||||
|
# DoubleField 用于存储浮点数,适合此类时间戳。
|
||||||
|
create_time = DoubleField()
|
||||||
|
|
||||||
|
# group_info 字段:
|
||||||
|
# platform: "qq"
|
||||||
|
# group_id: "941657197"
|
||||||
|
# group_name: "测试"
|
||||||
|
group_platform = TextField()
|
||||||
|
group_id = TextField()
|
||||||
|
group_name = TextField()
|
||||||
|
|
||||||
|
# last_active_time: 1746623771.4825106 (时间戳,精确到小数点后7位)
|
||||||
|
last_active_time = DoubleField()
|
||||||
|
|
||||||
|
# platform: "qq" (顶层平台字段)
|
||||||
|
platform = TextField()
|
||||||
|
|
||||||
|
# user_info 字段:
|
||||||
|
# platform: "qq"
|
||||||
|
# user_id: "1787882683"
|
||||||
|
# user_nickname: "墨梓柒(IceSakurary)"
|
||||||
|
# user_cardname: ""
|
||||||
|
user_platform = TextField()
|
||||||
|
user_id = TextField()
|
||||||
|
user_nickname = TextField()
|
||||||
|
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
|
||||||
|
user_cardname = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||||
|
# 如果不使用带有数据库实例的 BaseModel,或者想覆盖它,
|
||||||
|
# 请取消注释并在下面设置数据库实例:
|
||||||
|
# database = db
|
||||||
|
table_name = "chat_streams" # 可选:明确指定数据库中的表名
|
||||||
|
|
||||||
|
|
||||||
|
class LLMUsage(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储 API 使用日志数据的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name = TextField(index=True) # 添加索引
|
||||||
|
user_id = TextField(index=True) # 添加索引
|
||||||
|
request_type = TextField(index=True) # 添加索引
|
||||||
|
endpoint = TextField()
|
||||||
|
prompt_tokens = IntegerField()
|
||||||
|
completion_tokens = IntegerField()
|
||||||
|
total_tokens = IntegerField()
|
||||||
|
cost = DoubleField()
|
||||||
|
status = TextField()
|
||||||
|
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||||
|
# database = db
|
||||||
|
table_name = "llm_usage"
|
||||||
|
|
||||||
|
|
||||||
|
class Emoji(BaseModel):
|
||||||
|
"""表情包"""
|
||||||
|
|
||||||
|
full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名)
|
||||||
|
format = TextField() # 图片格式
|
||||||
|
emoji_hash = TextField(index=True) # 表情包的哈希值
|
||||||
|
description = TextField() # 表情包的描述
|
||||||
|
query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数)
|
||||||
|
is_registered = BooleanField(default=False) # 是否已注册
|
||||||
|
is_banned = BooleanField(default=False) # 是否被禁止注册
|
||||||
|
# emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化
|
||||||
|
emotion = TextField(null=True)
|
||||||
|
record_time = FloatField() # 记录时间(被创建的时间)
|
||||||
|
register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间)
|
||||||
|
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
|
||||||
|
last_used_time = FloatField(null=True) # 上次使用时间
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# database = db # 继承自 BaseModel
|
||||||
|
table_name = "emoji"
|
||||||
|
|
||||||
|
|
||||||
|
class Messages(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储消息数据的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
message_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
|
||||||
|
time = DoubleField() # 消息时间戳
|
||||||
|
|
||||||
|
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
|
||||||
|
|
||||||
|
# 从 chat_info 扁平化而来的字段
|
||||||
|
chat_info_stream_id = TextField()
|
||||||
|
chat_info_platform = TextField()
|
||||||
|
chat_info_user_platform = TextField()
|
||||||
|
chat_info_user_id = TextField()
|
||||||
|
chat_info_user_nickname = TextField()
|
||||||
|
chat_info_user_cardname = TextField(null=True)
|
||||||
|
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
|
||||||
|
chat_info_group_id = TextField(null=True)
|
||||||
|
chat_info_group_name = TextField(null=True)
|
||||||
|
chat_info_create_time = DoubleField()
|
||||||
|
chat_info_last_active_time = DoubleField()
|
||||||
|
|
||||||
|
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
|
||||||
|
user_platform = TextField()
|
||||||
|
user_id = TextField()
|
||||||
|
user_nickname = TextField()
|
||||||
|
user_cardname = TextField(null=True)
|
||||||
|
|
||||||
|
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
|
||||||
|
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
|
||||||
|
memorized_times = IntegerField(default=0) # 被记忆的次数
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# database = db # 继承自 BaseModel
|
||||||
|
table_name = "messages"
|
||||||
|
|
||||||
|
|
||||||
|
class Images(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储图像信息的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
emoji_hash = TextField(index=True) # 图像的哈希值
|
||||||
|
description = TextField(null=True) # 图像的描述
|
||||||
|
path = TextField(unique=True) # 图像文件的路径
|
||||||
|
timestamp = FloatField() # 时间戳
|
||||||
|
type = TextField() # 图像类型,例如 "emoji"
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# database = db # 继承自 BaseModel
|
||||||
|
table_name = "images"
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDescriptions(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储图像描述信息的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
type = TextField() # 类型,例如 "emoji"
|
||||||
|
image_description_hash = TextField(index=True) # 图像的哈希值
|
||||||
|
description = TextField() # 图像的描述
|
||||||
|
timestamp = FloatField() # 时间戳
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# database = db # 继承自 BaseModel
|
||||||
|
table_name = "image_descriptions"
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineTime(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储在线时长记录的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
|
||||||
|
timestamp = TextField(default=datetime.datetime.now) # 时间戳
|
||||||
|
duration = IntegerField() # 时长,单位分钟
|
||||||
|
start_timestamp = DateTimeField(default=datetime.datetime.now)
|
||||||
|
end_timestamp = DateTimeField(index=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# database = db # 继承自 BaseModel
|
||||||
|
table_name = "online_time"
|
||||||
|
|
||||||
|
|
||||||
|
class PersonInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储个人信息数据的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
person_id = TextField(unique=True, index=True) # 个人唯一ID
|
||||||
|
person_name = TextField(null=True) # 个人名称 (允许为空)
|
||||||
|
name_reason = TextField(null=True) # 名称设定的原因
|
||||||
|
platform = TextField() # 平台
|
||||||
|
user_id = TextField(index=True) # 用户ID
|
||||||
|
nickname = TextField() # 用户昵称
|
||||||
|
relationship_value = IntegerField(default=0) # 关系值
|
||||||
|
know_time = FloatField() # 认识时间 (时间戳)
|
||||||
|
msg_interval = IntegerField() # 消息间隔
|
||||||
|
# msg_interval_list: 存储为 JSON 字符串的列表
|
||||||
|
msg_interval_list = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# database = db # 继承自 BaseModel
|
||||||
|
table_name = "person_info"
|
||||||
|
|
||||||
|
|
||||||
|
class Knowledges(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储知识库条目的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
content = TextField() # 知识内容的文本
|
||||||
|
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
||||||
|
# 可以添加其他元数据字段,如 source, create_time 等
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
# database = db # 继承自 BaseModel
|
||||||
|
table_name = "knowledges"
|
||||||
|
|
||||||
|
|
||||||
|
class ThinkingLog(BaseModel):
|
||||||
|
chat_id = TextField(index=True)
|
||||||
|
trigger_text = TextField(null=True)
|
||||||
|
response_text = TextField(null=True)
|
||||||
|
|
||||||
|
# Store complex dicts/lists as JSON strings
|
||||||
|
trigger_info_json = TextField(null=True)
|
||||||
|
response_info_json = TextField(null=True)
|
||||||
|
timing_results_json = TextField(null=True)
|
||||||
|
chat_history_json = TextField(null=True)
|
||||||
|
chat_history_in_thinking_json = TextField(null=True)
|
||||||
|
chat_history_after_response_json = TextField(null=True)
|
||||||
|
heartflow_data_json = TextField(null=True)
|
||||||
|
reasoning_data_json = TextField(null=True)
|
||||||
|
|
||||||
|
# Add a timestamp for the log entry itself
|
||||||
|
# Ensure you have: from peewee import DateTimeField
|
||||||
|
# And: import datetime
|
||||||
|
created_at = DateTimeField(default=datetime.datetime.now)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "thinking_logs"
|
||||||
|
|
||||||
|
|
||||||
|
class RecalledMessages(BaseModel):
|
||||||
|
"""
|
||||||
|
用于存储撤回消息记录的模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
message_id = TextField(index=True) # 被撤回的消息 ID
|
||||||
|
time = DoubleField() # 撤回操作发生的时间戳
|
||||||
|
stream_id = TextField() # 对应的 ChatStreams stream_id
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "recalled_messages"
|
||||||
|
|
||||||
|
|
||||||
|
def create_tables():
|
||||||
|
"""
|
||||||
|
创建所有在模型中定义的数据库表。
|
||||||
|
"""
|
||||||
|
with db:
|
||||||
|
db.create_tables(
|
||||||
|
[
|
||||||
|
ChatStreams,
|
||||||
|
LLMUsage,
|
||||||
|
Emoji,
|
||||||
|
Messages,
|
||||||
|
Images,
|
||||||
|
ImageDescriptions,
|
||||||
|
OnlineTime,
|
||||||
|
PersonInfo,
|
||||||
|
Knowledges,
|
||||||
|
ThinkingLog,
|
||||||
|
RecalledMessages, # 添加新模型
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_database():
|
||||||
|
"""
|
||||||
|
检查所有定义的表是否存在,如果不存在则创建它们。
|
||||||
|
检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
models = [
|
||||||
|
ChatStreams,
|
||||||
|
LLMUsage,
|
||||||
|
Emoji,
|
||||||
|
Messages,
|
||||||
|
Images,
|
||||||
|
ImageDescriptions,
|
||||||
|
OnlineTime,
|
||||||
|
PersonInfo,
|
||||||
|
Knowledges,
|
||||||
|
ThinkingLog,
|
||||||
|
RecalledMessages, # 添加新模型
|
||||||
|
]
|
||||||
|
|
||||||
|
needs_creation = False
|
||||||
|
try:
|
||||||
|
with db: # 管理 table_exists 检查的连接
|
||||||
|
for model in models:
|
||||||
|
table_name = model._meta.table_name
|
||||||
|
if not db.table_exists(model):
|
||||||
|
logger.warning(f"表 '{table_name}' 未找到。")
|
||||||
|
needs_creation = True
|
||||||
|
break # 一个表丢失,无需进一步检查。
|
||||||
|
if not needs_creation:
|
||||||
|
# 检查字段
|
||||||
|
for model in models:
|
||||||
|
table_name = model._meta.table_name
|
||||||
|
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
||||||
|
existing_columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
model_fields = model._meta.fields
|
||||||
|
for field_name in model_fields:
|
||||||
|
if field_name not in existing_columns:
|
||||||
|
logger.error(f"表 '{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"检查表或字段是否存在时出错: {e}")
|
||||||
|
# 如果检查失败(例如数据库不可用),则退出
|
||||||
|
return
|
||||||
|
|
||||||
|
if needs_creation:
|
||||||
|
logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...")
|
||||||
|
try:
|
||||||
|
create_tables() # 此函数有其自己的 'with db:' 上下文管理。
|
||||||
|
logger.info("数据库表创建过程完成。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"创建表期间出错: {e}")
|
||||||
|
else:
|
||||||
|
logger.info("所有数据库表及字段均已存在。")
|
||||||
|
|
||||||
|
|
||||||
|
# 模块加载时调用初始化函数
|
||||||
|
initialize_database()
|
||||||
@@ -1,11 +1,19 @@
|
|||||||
from src.common.database import db
|
from src.common.database.database_model import Messages # 更改导入
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List, Any, Optional
|
from typing import List, Any, Optional
|
||||||
|
from peewee import Model # 添加 Peewee Model 导入
|
||||||
|
|
||||||
logger = get_module_logger(__name__)
|
logger = get_module_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
将 Peewee 模型实例转换为字典。
|
||||||
|
"""
|
||||||
|
return model_instance.__data__
|
||||||
|
|
||||||
|
|
||||||
def find_messages(
|
def find_messages(
|
||||||
message_filter: dict[str, Any],
|
message_filter: dict[str, Any],
|
||||||
sort: Optional[List[tuple[str, int]]] = None,
|
sort: Optional[List[tuple[str, int]]] = None,
|
||||||
@@ -16,39 +24,84 @@ def find_messages(
|
|||||||
根据提供的过滤器、排序和限制条件查找消息。
|
根据提供的过滤器、排序和限制条件查找消息。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_filter: MongoDB 查询过滤器。
|
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
||||||
sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。
|
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
|
||||||
limit: 返回的最大文档数,0表示不限制。
|
limit: 返回的最大文档数,0表示不限制。
|
||||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
消息文档列表,如果出错则返回空列表。
|
消息字典列表,如果出错则返回空列表。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
query = db.messages.find(message_filter)
|
query = Messages.select()
|
||||||
|
|
||||||
|
# 应用过滤器
|
||||||
|
if message_filter:
|
||||||
|
conditions = []
|
||||||
|
for key, value in message_filter.items():
|
||||||
|
if hasattr(Messages, key):
|
||||||
|
field = getattr(Messages, key)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# 处理 MongoDB 风格的操作符
|
||||||
|
for op, op_value in value.items():
|
||||||
|
if op == "$gt":
|
||||||
|
conditions.append(field > op_value)
|
||||||
|
elif op == "$lt":
|
||||||
|
conditions.append(field < op_value)
|
||||||
|
elif op == "$gte":
|
||||||
|
conditions.append(field >= op_value)
|
||||||
|
elif op == "$lte":
|
||||||
|
conditions.append(field <= op_value)
|
||||||
|
elif op == "$ne":
|
||||||
|
conditions.append(field != op_value)
|
||||||
|
elif op == "$in":
|
||||||
|
conditions.append(field.in_(op_value))
|
||||||
|
elif op == "$nin":
|
||||||
|
conditions.append(field.not_in(op_value))
|
||||||
|
else:
|
||||||
|
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
||||||
|
else:
|
||||||
|
# 直接相等比较
|
||||||
|
conditions.append(field == value)
|
||||||
|
else:
|
||||||
|
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||||
|
if conditions:
|
||||||
|
query = query.where(*conditions)
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
if limit_mode == "earliest":
|
if limit_mode == "earliest":
|
||||||
# 获取时间最早的 limit 条记录,已经是正序
|
# 获取时间最早的 limit 条记录,已经是正序
|
||||||
query = query.sort([("time", 1)]).limit(limit)
|
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||||
results = list(query)
|
peewee_results = list(query)
|
||||||
else: # 默认为 'latest'
|
else: # 默认为 'latest'
|
||||||
# 获取时间最晚的 limit 条记录
|
# 获取时间最晚的 limit 条记录
|
||||||
query = query.sort([("time", -1)]).limit(limit)
|
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||||
latest_results = list(query)
|
latest_results_peewee = list(query)
|
||||||
# 将结果按时间正序排列
|
# 将结果按时间正序排列
|
||||||
# 假设消息文档中总是有 'time' 字段且可排序
|
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
|
||||||
results = sorted(latest_results, key=lambda msg: msg.get("time"))
|
|
||||||
else:
|
else:
|
||||||
# limit 为 0 时,应用传入的 sort 参数
|
# limit 为 0 时,应用传入的 sort 参数
|
||||||
if sort:
|
if sort:
|
||||||
query = query.sort(sort)
|
peewee_sort_terms = []
|
||||||
results = list(query)
|
for field_name, direction in sort:
|
||||||
|
if hasattr(Messages, field_name):
|
||||||
|
field = getattr(Messages, field_name)
|
||||||
|
if direction == 1: # ASC
|
||||||
|
peewee_sort_terms.append(field.asc())
|
||||||
|
elif direction == -1: # DESC
|
||||||
|
peewee_sort_terms.append(field.desc())
|
||||||
|
else:
|
||||||
|
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
||||||
|
else:
|
||||||
|
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||||
|
if peewee_sort_terms:
|
||||||
|
query = query.order_by(*peewee_sort_terms)
|
||||||
|
peewee_results = list(query)
|
||||||
|
|
||||||
return results
|
return [_model_to_dict(msg) for msg in peewee_results]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = (
|
log_message = (
|
||||||
f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||||
+ traceback.format_exc()
|
+ traceback.format_exc()
|
||||||
)
|
)
|
||||||
logger.error(log_message)
|
logger.error(log_message)
|
||||||
@@ -60,18 +113,57 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
|||||||
根据提供的过滤器计算消息数量。
|
根据提供的过滤器计算消息数量。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_filter: MongoDB 查询过滤器。
|
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
符合条件的消息数量,如果出错则返回 0。
|
符合条件的消息数量,如果出错则返回 0。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
count = db.messages.count_documents(message_filter)
|
query = Messages.select()
|
||||||
|
|
||||||
|
# 应用过滤器
|
||||||
|
if message_filter:
|
||||||
|
conditions = []
|
||||||
|
for key, value in message_filter.items():
|
||||||
|
if hasattr(Messages, key):
|
||||||
|
field = getattr(Messages, key)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# 处理 MongoDB 风格的操作符
|
||||||
|
for op, op_value in value.items():
|
||||||
|
if op == "$gt":
|
||||||
|
conditions.append(field > op_value)
|
||||||
|
elif op == "$lt":
|
||||||
|
conditions.append(field < op_value)
|
||||||
|
elif op == "$gte":
|
||||||
|
conditions.append(field >= op_value)
|
||||||
|
elif op == "$lte":
|
||||||
|
conditions.append(field <= op_value)
|
||||||
|
elif op == "$ne":
|
||||||
|
conditions.append(field != op_value)
|
||||||
|
elif op == "$in":
|
||||||
|
conditions.append(field.in_(op_value))
|
||||||
|
elif op == "$nin":
|
||||||
|
conditions.append(field.not_in(op_value))
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 直接相等比较
|
||||||
|
conditions.append(field == value)
|
||||||
|
else:
|
||||||
|
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||||
|
if conditions:
|
||||||
|
query = query.where(*conditions)
|
||||||
|
|
||||||
|
count = query.count()
|
||||||
return count
|
return count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc()
|
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||||
logger.error(log_message)
|
logger.error(log_message)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||||
|
# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。
|
||||||
|
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
info_dict = {
|
info_dict = {
|
||||||
"os_type": "Unknown",
|
"os_type": "Unknown",
|
||||||
"py_version": platform.python_version(),
|
"py_version": platform.python_version(),
|
||||||
"mmc_version": global_config.MAI_VERSION,
|
"mmc_version": global_config.MMC_VERSION,
|
||||||
}
|
}
|
||||||
|
|
||||||
match platform.system():
|
match platform.system():
|
||||||
@@ -133,9 +133,8 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
# 发送心跳
|
# 发送心跳
|
||||||
if global_config.remote_enable:
|
if global_config.telemetry.enable:
|
||||||
if self.client_uuid is None:
|
if self.client_uuid is None and not await self._req_uuid():
|
||||||
if not await self._req_uuid():
|
|
||||||
logger.error("获取UUID失败,跳过此次心跳")
|
logger.error("获取UUID失败,跳过此次心跳")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -1,64 +1,68 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
from dataclasses import field, dataclass
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import tomli
|
|
||||||
import tomlkit
|
import tomlkit
|
||||||
import shutil
|
import shutil
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
from packaging import version
|
from tomlkit import TOMLDocument
|
||||||
from packaging.version import Version, InvalidVersion
|
from tomlkit.items import Table
|
||||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from src.config.config_base import ConfigBase
|
||||||
|
from src.config.official_configs import (
|
||||||
|
BotConfig,
|
||||||
|
ChatTargetConfig,
|
||||||
|
PersonalityConfig,
|
||||||
|
IdentityConfig,
|
||||||
|
PlatformsConfig,
|
||||||
|
ChatConfig,
|
||||||
|
NormalChatConfig,
|
||||||
|
FocusChatConfig,
|
||||||
|
EmojiConfig,
|
||||||
|
MemoryConfig,
|
||||||
|
MoodConfig,
|
||||||
|
KeywordReactionConfig,
|
||||||
|
ChineseTypoConfig,
|
||||||
|
ResponseSplitterConfig,
|
||||||
|
TelemetryConfig,
|
||||||
|
ExperimentalConfig,
|
||||||
|
ModelConfig,
|
||||||
|
)
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
# 配置主程序日志格式
|
# 配置主程序日志格式
|
||||||
logger = get_logger("config")
|
logger = get_logger("config")
|
||||||
|
|
||||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
CONFIG_DIR = "config"
|
||||||
is_test = True
|
TEMPLATE_DIR = "template"
|
||||||
mai_version_main = "0.6.4"
|
|
||||||
mai_version_fix = "snapshot-1"
|
|
||||||
|
|
||||||
if mai_version_fix:
|
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||||
if is_test:
|
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||||
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
|
MMC_VERSION = "0.7.0-snapshot.1"
|
||||||
else:
|
|
||||||
mai_version = f"{mai_version_main}-{mai_version_fix}"
|
|
||||||
else:
|
|
||||||
if is_test:
|
|
||||||
mai_version = f"test-{mai_version_main}"
|
|
||||||
else:
|
|
||||||
mai_version = mai_version_main
|
|
||||||
|
|
||||||
|
|
||||||
def update_config():
|
def update_config():
|
||||||
# 获取根目录路径
|
# 获取根目录路径
|
||||||
root_dir = Path(__file__).parent.parent.parent
|
old_config_dir = f"{CONFIG_DIR}/old"
|
||||||
template_dir = root_dir / "template"
|
|
||||||
config_dir = root_dir / "config"
|
|
||||||
old_config_dir = config_dir / "old"
|
|
||||||
|
|
||||||
# 定义文件路径
|
# 定义文件路径
|
||||||
template_path = template_dir / "bot_config_template.toml"
|
template_path = f"{TEMPLATE_DIR}/bot_config_template.toml"
|
||||||
old_config_path = config_dir / "bot_config.toml"
|
old_config_path = f"{CONFIG_DIR}/bot_config.toml"
|
||||||
new_config_path = config_dir / "bot_config.toml"
|
new_config_path = f"{CONFIG_DIR}/bot_config.toml"
|
||||||
|
|
||||||
# 检查配置文件是否存在
|
# 检查配置文件是否存在
|
||||||
if not old_config_path.exists():
|
if not os.path.exists(old_config_path):
|
||||||
logger.info("配置文件不存在,从模板创建新配置")
|
logger.info("配置文件不存在,从模板创建新配置")
|
||||||
# 创建文件夹
|
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
|
||||||
old_config_dir.mkdir(parents=True, exist_ok=True)
|
shutil.copy2(template_path, old_config_path) # 复制模板文件
|
||||||
shutil.copy2(template_path, old_config_path)
|
|
||||||
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
|
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
|
||||||
# 如果是新创建的配置文件,直接返回
|
# 如果是新创建的配置文件,直接返回
|
||||||
return quit()
|
quit()
|
||||||
|
|
||||||
# 读取旧配置文件和模板文件
|
# 读取旧配置文件和模板文件
|
||||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||||
@@ -75,13 +79,15 @@ def update_config():
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||||
|
else:
|
||||||
|
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||||
|
|
||||||
# 创建old目录(如果不存在)
|
# 创建old目录(如果不存在)
|
||||||
old_config_dir.mkdir(exist_ok=True)
|
os.makedirs(old_config_dir, exist_ok=True)
|
||||||
|
|
||||||
# 生成带时间戳的新文件名
|
# 生成带时间戳的新文件名
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
|
old_backup_path = f"{old_config_dir}/bot_config_{timestamp}.toml"
|
||||||
|
|
||||||
# 移动旧配置文件到old目录
|
# 移动旧配置文件到old目录
|
||||||
shutil.move(old_config_path, old_backup_path)
|
shutil.move(old_config_path, old_backup_path)
|
||||||
@@ -91,24 +97,23 @@ def update_config():
|
|||||||
shutil.copy2(template_path, new_config_path)
|
shutil.copy2(template_path, new_config_path)
|
||||||
logger.info(f"已创建新配置文件: {new_config_path}")
|
logger.info(f"已创建新配置文件: {new_config_path}")
|
||||||
|
|
||||||
# 递归更新配置
|
def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict):
|
||||||
def update_dict(target, source):
|
"""
|
||||||
|
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||||
|
"""
|
||||||
for key, value in source.items():
|
for key, value in source.items():
|
||||||
# 跳过version字段的更新
|
# 跳过version字段的更新
|
||||||
if key == "version":
|
if key == "version":
|
||||||
continue
|
continue
|
||||||
if key in target:
|
if key in target:
|
||||||
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
|
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
|
||||||
update_dict(target[key], value)
|
update_dict(target[key], value)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# 对数组类型进行特殊处理
|
# 对数组类型进行特殊处理
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
# 如果是空数组,确保它保持为空数组
|
# 如果是空数组,确保它保持为空数组
|
||||||
if not value:
|
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||||
target[key] = tomlkit.array()
|
|
||||||
else:
|
|
||||||
target[key] = tomlkit.array(value)
|
|
||||||
else:
|
else:
|
||||||
# 其他类型使用item方法创建新值
|
# 其他类型使用item方法创建新值
|
||||||
target[key] = tomlkit.item(value)
|
target[key] = tomlkit.item(value)
|
||||||
@@ -123,619 +128,57 @@ def update_config():
|
|||||||
# 保存更新后的配置(保留注释和格式)
|
# 保存更新后的配置(保留注释和格式)
|
||||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||||
f.write(tomlkit.dumps(new_config))
|
f.write(tomlkit.dumps(new_config))
|
||||||
logger.info("配置文件更新完成")
|
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||||
|
quit()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BotConfig:
|
class Config(ConfigBase):
|
||||||
"""机器人配置类"""
|
"""总配置类"""
|
||||||
|
|
||||||
INNER_VERSION: Version = None
|
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
|
||||||
MAI_VERSION: str = mai_version # 硬编码的版本信息
|
|
||||||
|
|
||||||
# bot
|
bot: BotConfig
|
||||||
BOT_QQ: Optional[str] = "114514"
|
chat_target: ChatTargetConfig
|
||||||
BOT_NICKNAME: Optional[str] = None
|
personality: PersonalityConfig
|
||||||
BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它
|
identity: IdentityConfig
|
||||||
|
platforms: PlatformsConfig
|
||||||
|
chat: ChatConfig
|
||||||
|
normal_chat: NormalChatConfig
|
||||||
|
focus_chat: FocusChatConfig
|
||||||
|
emoji: EmojiConfig
|
||||||
|
memory: MemoryConfig
|
||||||
|
mood: MoodConfig
|
||||||
|
keyword_reaction: KeywordReactionConfig
|
||||||
|
chinese_typo: ChineseTypoConfig
|
||||||
|
response_splitter: ResponseSplitterConfig
|
||||||
|
telemetry: TelemetryConfig
|
||||||
|
experimental: ExperimentalConfig
|
||||||
|
model: ModelConfig
|
||||||
|
|
||||||
# group
|
|
||||||
talk_allowed_groups = set()
|
|
||||||
talk_frequency_down_groups = set()
|
|
||||||
ban_user_id = set()
|
|
||||||
|
|
||||||
# personality
|
def load_config(config_path: str) -> Config:
|
||||||
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
|
|
||||||
personality_sides: List[str] = field(
|
|
||||||
default_factory=lambda: [
|
|
||||||
"用一句话或几句话描述人格的一些侧面",
|
|
||||||
"用一句话或几句话描述人格的一些侧面",
|
|
||||||
"用一句话或几句话描述人格的一些侧面",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
expression_style = "描述麦麦说话的表达风格,表达习惯"
|
|
||||||
# identity
|
|
||||||
identity_detail: List[str] = field(
|
|
||||||
default_factory=lambda: [
|
|
||||||
"身份特点",
|
|
||||||
"身份特点",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
height: int = 170 # 身高 单位厘米
|
|
||||||
weight: int = 50 # 体重 单位千克
|
|
||||||
age: int = 20 # 年龄 单位岁
|
|
||||||
gender: str = "男" # 性别
|
|
||||||
appearance: str = "用几句话描述外貌特征" # 外貌特征
|
|
||||||
|
|
||||||
# chat
|
|
||||||
allow_focus_mode: bool = True # 是否允许专注聊天状态
|
|
||||||
|
|
||||||
base_normal_chat_num: int = 3 # 最多允许多少个群进行普通聊天
|
|
||||||
base_focused_chat_num: int = 2 # 最多允许多少个群进行专注聊天
|
|
||||||
|
|
||||||
observation_context_size: int = 12 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
|
|
||||||
|
|
||||||
message_buffer: bool = True # 消息缓冲器
|
|
||||||
|
|
||||||
ban_words = set()
|
|
||||||
ban_msgs_regex = set()
|
|
||||||
|
|
||||||
# focus_chat
|
|
||||||
reply_trigger_threshold: float = 3.0 # 心流聊天触发阈值,越低越容易触发
|
|
||||||
default_decay_rate_per_second: float = 0.98 # 默认衰减率,越大衰减越慢
|
|
||||||
consecutive_no_reply_threshold = 3
|
|
||||||
|
|
||||||
compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5
|
|
||||||
compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除
|
|
||||||
|
|
||||||
# normal_chat
|
|
||||||
model_reasoning_probability: float = 0.7 # 麦麦回答时选择推理模型(主要)模型概率
|
|
||||||
model_normal_probability: float = 0.3 # 麦麦回答时选择一般模型(次要)模型概率
|
|
||||||
|
|
||||||
emoji_chance: float = 0.2 # 发送表情包的基础概率
|
|
||||||
thinking_timeout: int = 120 # 思考时间
|
|
||||||
|
|
||||||
willing_mode: str = "classical" # 意愿模式
|
|
||||||
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
|
|
||||||
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
|
|
||||||
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
|
|
||||||
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
|
|
||||||
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
|
|
||||||
at_bot_inevitable_reply: bool = False # @bot 必然回复
|
|
||||||
|
|
||||||
# emoji
|
|
||||||
max_emoji_num: int = 200 # 表情包最大数量
|
|
||||||
max_reach_deletion: bool = True # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包
|
|
||||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
|
||||||
|
|
||||||
save_pic: bool = False # 是否保存图片
|
|
||||||
save_emoji: bool = False # 是否保存表情包
|
|
||||||
steal_emoji: bool = True # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
|
|
||||||
|
|
||||||
EMOJI_CHECK: bool = False # 是否开启过滤
|
|
||||||
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
|
|
||||||
|
|
||||||
# memory
|
|
||||||
build_memory_interval: int = 600 # 记忆构建间隔(秒)
|
|
||||||
memory_build_distribution: list = field(
|
|
||||||
default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
|
|
||||||
) # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
|
||||||
build_memory_sample_num: int = 10 # 记忆构建采样数量
|
|
||||||
build_memory_sample_length: int = 20 # 记忆构建采样长度
|
|
||||||
memory_compress_rate: float = 0.1 # 记忆压缩率
|
|
||||||
|
|
||||||
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
|
|
||||||
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
|
|
||||||
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
|
|
||||||
|
|
||||||
consolidate_memory_interval: int = 1000 # 记忆整合间隔(秒)
|
|
||||||
consolidation_similarity_threshold: float = 0.7 # 相似度阈值
|
|
||||||
consolidate_memory_percentage: float = 0.01 # 检查节点比例
|
|
||||||
|
|
||||||
memory_ban_words: list = field(
|
|
||||||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
|
|
||||||
) # 添加新的配置项默认值
|
|
||||||
|
|
||||||
# mood
|
|
||||||
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
|
|
||||||
mood_decay_rate: float = 0.95 # 情绪衰减率
|
|
||||||
mood_intensity_factor: float = 0.7 # 情绪强度因子
|
|
||||||
|
|
||||||
# keywords
|
|
||||||
keywords_reaction_rules = [] # 关键词回复规则
|
|
||||||
|
|
||||||
# chinese_typo
|
|
||||||
chinese_typo_enable = True # 是否启用中文错别字生成器
|
|
||||||
chinese_typo_error_rate = 0.03 # 单字替换概率
|
|
||||||
chinese_typo_min_freq = 7 # 最小字频阈值
|
|
||||||
chinese_typo_tone_error_rate = 0.2 # 声调错误概率
|
|
||||||
chinese_typo_word_replace_rate = 0.02 # 整词替换概率
|
|
||||||
|
|
||||||
# response_splitter
|
|
||||||
enable_kaomoji_protection = False # 是否启用颜文字保护
|
|
||||||
enable_response_splitter = True # 是否启用回复分割器
|
|
||||||
response_max_length = 100 # 回复允许的最大长度
|
|
||||||
response_max_sentence_num = 3 # 回复允许的最大句子数
|
|
||||||
|
|
||||||
model_max_output_length: int = 800 # 最大回复长度
|
|
||||||
|
|
||||||
# remote
|
|
||||||
remote_enable: bool = True # 是否启用远程控制
|
|
||||||
|
|
||||||
# experimental
|
|
||||||
enable_friend_chat: bool = False # 是否启用好友聊天
|
|
||||||
# enable_think_flow: bool = False # 是否启用思考流程
|
|
||||||
talk_allowed_private = set()
|
|
||||||
enable_pfc_chatting: bool = False # 是否启用PFC聊天
|
|
||||||
|
|
||||||
# 模型配置
|
|
||||||
llm_reasoning: dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
# llm_reasoning_minor: dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
llm_summary: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
embedding: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
moderation: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
llm_observation: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
llm_heartflow: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
llm_tool_use: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
llm_plan: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
api_urls: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_config_dir() -> str:
|
|
||||||
"""获取配置文件目录"""
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
|
||||||
config_dir = os.path.join(root_dir, "config")
|
|
||||||
if not os.path.exists(config_dir):
|
|
||||||
os.makedirs(config_dir)
|
|
||||||
return config_dir
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def convert_to_specifierset(cls, value: str) -> SpecifierSet:
|
|
||||||
"""将 字符串 版本表达式转换成 SpecifierSet
|
|
||||||
Args:
|
|
||||||
value[str]: 版本表达式(字符串)
|
|
||||||
Returns:
|
|
||||||
SpecifierSet
|
|
||||||
"""
|
"""
|
||||||
|
加载配置文件
|
||||||
try:
|
:param config_path: 配置文件路径
|
||||||
converted = SpecifierSet(value)
|
:return: Config对象
|
||||||
except InvalidSpecifier:
|
|
||||||
logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
return converted
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config_version(cls, toml: dict) -> Version:
|
|
||||||
"""提取配置文件的 SpecifierSet 版本数据
|
|
||||||
Args:
|
|
||||||
toml[dict]: 输入的配置文件字典
|
|
||||||
Returns:
|
|
||||||
Version
|
|
||||||
"""
|
"""
|
||||||
|
# 读取配置文件
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config_data = tomlkit.load(f)
|
||||||
|
|
||||||
if "inner" in toml:
|
# 创建Config对象
|
||||||
try:
|
try:
|
||||||
config_version: str = toml["inner"]["version"]
|
return Config.from_dict(config_data)
|
||||||
except KeyError as e:
|
except Exception as e:
|
||||||
logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
|
logger.critical("配置文件解析失败")
|
||||||
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
|
raise e
|
||||||
else:
|
|
||||||
toml["inner"] = {"version": "0.0.0"}
|
|
||||||
config_version = toml["inner"]["version"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
ver = version.parse(config_version)
|
|
||||||
except InvalidVersion as e:
|
|
||||||
logger.error(
|
|
||||||
"配置文件中 inner段 的 version 键是错误的版本描述\n"
|
|
||||||
"请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
|
|
||||||
"本项目在不同的版本下有不同的模板,请注意识别"
|
|
||||||
)
|
|
||||||
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
|
|
||||||
|
|
||||||
return ver
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_config(cls, config_path: str = None) -> "BotConfig":
|
|
||||||
"""从TOML配置文件加载配置"""
|
|
||||||
config = cls()
|
|
||||||
|
|
||||||
def personality(parent: dict):
|
|
||||||
personality_config = parent["personality"]
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
|
|
||||||
config.personality_core = personality_config.get("personality_core", config.personality_core)
|
|
||||||
config.personality_sides = personality_config.get("personality_sides", config.personality_sides)
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.7.0"):
|
|
||||||
config.expression_style = personality_config.get("expression_style", config.expression_style)
|
|
||||||
|
|
||||||
def identity(parent: dict):
|
|
||||||
identity_config = parent["identity"]
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
|
|
||||||
config.identity_detail = identity_config.get("identity_detail", config.identity_detail)
|
|
||||||
config.height = identity_config.get("height", config.height)
|
|
||||||
config.weight = identity_config.get("weight", config.weight)
|
|
||||||
config.age = identity_config.get("age", config.age)
|
|
||||||
config.gender = identity_config.get("gender", config.gender)
|
|
||||||
config.appearance = identity_config.get("appearance", config.appearance)
|
|
||||||
|
|
||||||
def emoji(parent: dict):
|
|
||||||
emoji_config = parent["emoji"]
|
|
||||||
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
|
|
||||||
config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
|
|
||||||
config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.1.1"):
|
|
||||||
config.max_emoji_num = emoji_config.get("max_emoji_num", config.max_emoji_num)
|
|
||||||
config.max_reach_deletion = emoji_config.get("max_reach_deletion", config.max_reach_deletion)
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
|
|
||||||
config.save_pic = emoji_config.get("save_pic", config.save_pic)
|
|
||||||
config.save_emoji = emoji_config.get("save_emoji", config.save_emoji)
|
|
||||||
config.steal_emoji = emoji_config.get("steal_emoji", config.steal_emoji)
|
|
||||||
|
|
||||||
def bot(parent: dict):
|
|
||||||
# 机器人基础配置
|
|
||||||
bot_config = parent["bot"]
|
|
||||||
bot_qq = bot_config.get("qq")
|
|
||||||
config.BOT_QQ = str(bot_qq)
|
|
||||||
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
|
|
||||||
config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
|
|
||||||
|
|
||||||
def chat(parent: dict):
|
|
||||||
chat_config = parent["chat"]
|
|
||||||
config.allow_focus_mode = chat_config.get("allow_focus_mode", config.allow_focus_mode)
|
|
||||||
config.base_normal_chat_num = chat_config.get("base_normal_chat_num", config.base_normal_chat_num)
|
|
||||||
config.base_focused_chat_num = chat_config.get("base_focused_chat_num", config.base_focused_chat_num)
|
|
||||||
config.observation_context_size = chat_config.get(
|
|
||||||
"observation_context_size", config.observation_context_size
|
|
||||||
)
|
|
||||||
config.message_buffer = chat_config.get("message_buffer", config.message_buffer)
|
|
||||||
config.ban_words = chat_config.get("ban_words", config.ban_words)
|
|
||||||
for r in chat_config.get("ban_msgs_regex", config.ban_msgs_regex):
|
|
||||||
config.ban_msgs_regex.add(re.compile(r))
|
|
||||||
|
|
||||||
def normal_chat(parent: dict):
|
|
||||||
normal_chat_config = parent["normal_chat"]
|
|
||||||
config.model_reasoning_probability = normal_chat_config.get(
|
|
||||||
"model_reasoning_probability", config.model_reasoning_probability
|
|
||||||
)
|
|
||||||
config.model_normal_probability = normal_chat_config.get(
|
|
||||||
"model_normal_probability", config.model_normal_probability
|
|
||||||
)
|
|
||||||
config.emoji_chance = normal_chat_config.get("emoji_chance", config.emoji_chance)
|
|
||||||
config.thinking_timeout = normal_chat_config.get("thinking_timeout", config.thinking_timeout)
|
|
||||||
|
|
||||||
config.willing_mode = normal_chat_config.get("willing_mode", config.willing_mode)
|
|
||||||
config.response_willing_amplifier = normal_chat_config.get(
|
|
||||||
"response_willing_amplifier", config.response_willing_amplifier
|
|
||||||
)
|
|
||||||
config.response_interested_rate_amplifier = normal_chat_config.get(
|
|
||||||
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
|
|
||||||
)
|
|
||||||
config.down_frequency_rate = normal_chat_config.get("down_frequency_rate", config.down_frequency_rate)
|
|
||||||
config.emoji_response_penalty = normal_chat_config.get(
|
|
||||||
"emoji_response_penalty", config.emoji_response_penalty
|
|
||||||
)
|
|
||||||
|
|
||||||
config.mentioned_bot_inevitable_reply = normal_chat_config.get(
|
|
||||||
"mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
|
|
||||||
)
|
|
||||||
config.at_bot_inevitable_reply = normal_chat_config.get(
|
|
||||||
"at_bot_inevitable_reply", config.at_bot_inevitable_reply
|
|
||||||
)
|
|
||||||
|
|
||||||
def focus_chat(parent: dict):
|
|
||||||
focus_chat_config = parent["focus_chat"]
|
|
||||||
config.compressed_length = focus_chat_config.get("compressed_length", config.compressed_length)
|
|
||||||
config.compress_length_limit = focus_chat_config.get("compress_length_limit", config.compress_length_limit)
|
|
||||||
config.reply_trigger_threshold = focus_chat_config.get(
|
|
||||||
"reply_trigger_threshold", config.reply_trigger_threshold
|
|
||||||
)
|
|
||||||
config.default_decay_rate_per_second = focus_chat_config.get(
|
|
||||||
"default_decay_rate_per_second", config.default_decay_rate_per_second
|
|
||||||
)
|
|
||||||
config.consecutive_no_reply_threshold = focus_chat_config.get(
|
|
||||||
"consecutive_no_reply_threshold", config.consecutive_no_reply_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
def model(parent: dict):
|
|
||||||
# 加载模型配置
|
|
||||||
model_config: dict = parent["model"]
|
|
||||||
|
|
||||||
config_list = [
|
|
||||||
"llm_reasoning",
|
|
||||||
# "llm_reasoning_minor",
|
|
||||||
"llm_normal",
|
|
||||||
"llm_topic_judge",
|
|
||||||
"llm_summary",
|
|
||||||
"vlm",
|
|
||||||
"embedding",
|
|
||||||
"llm_tool_use",
|
|
||||||
"llm_observation",
|
|
||||||
"llm_sub_heartflow",
|
|
||||||
"llm_plan",
|
|
||||||
"llm_heartflow",
|
|
||||||
"llm_PFC_action_planner",
|
|
||||||
"llm_PFC_chat",
|
|
||||||
"llm_PFC_reply_checker",
|
|
||||||
]
|
|
||||||
|
|
||||||
for item in config_list:
|
|
||||||
if item in model_config:
|
|
||||||
cfg_item: dict = model_config[item]
|
|
||||||
|
|
||||||
# base_url 的例子: SILICONFLOW_BASE_URL
|
|
||||||
# key 的例子: SILICONFLOW_KEY
|
|
||||||
cfg_target = {
|
|
||||||
"name": "",
|
|
||||||
"base_url": "",
|
|
||||||
"key": "",
|
|
||||||
"stream": False,
|
|
||||||
"pri_in": 0,
|
|
||||||
"pri_out": 0,
|
|
||||||
"temp": 0.7,
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
|
|
||||||
cfg_target = cfg_item
|
|
||||||
|
|
||||||
elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
|
|
||||||
stable_item = ["name", "pri_in", "pri_out"]
|
|
||||||
|
|
||||||
stream_item = ["stream"]
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.0.1"):
|
|
||||||
stable_item.append("stream")
|
|
||||||
|
|
||||||
pricing_item = ["pri_in", "pri_out"]
|
|
||||||
|
|
||||||
# 从配置中原始拷贝稳定字段
|
|
||||||
for i in stable_item:
|
|
||||||
# 如果 字段 属于计费项 且获取不到,那默认值是 0
|
|
||||||
if i in pricing_item and i not in cfg_item:
|
|
||||||
cfg_target[i] = 0
|
|
||||||
|
|
||||||
if i in stream_item and i not in cfg_item:
|
|
||||||
cfg_target[i] = False
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 没有特殊情况则原样复制
|
|
||||||
try:
|
|
||||||
cfg_target[i] = cfg_item[i]
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error(f"{item} 中的必要字段不存在,请检查")
|
|
||||||
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
|
|
||||||
|
|
||||||
# 如果配置中有temp参数,就使用配置中的值
|
|
||||||
if "temp" in cfg_item:
|
|
||||||
cfg_target["temp"] = cfg_item["temp"]
|
|
||||||
else:
|
|
||||||
# 如果没有temp参数,就删除默认值
|
|
||||||
cfg_target.pop("temp", None)
|
|
||||||
|
|
||||||
provider = cfg_item.get("provider")
|
|
||||||
if provider is None:
|
|
||||||
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
|
||||||
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
|
||||||
|
|
||||||
cfg_target["base_url"] = f"{provider}_BASE_URL"
|
|
||||||
cfg_target["key"] = f"{provider}_KEY"
|
|
||||||
|
|
||||||
# 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
|
|
||||||
setattr(config, item, cfg_target)
|
|
||||||
else:
|
|
||||||
logger.error(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件")
|
|
||||||
raise KeyError(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件")
|
|
||||||
|
|
||||||
def memory(parent: dict):
|
|
||||||
memory_config = parent["memory"]
|
|
||||||
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
|
|
||||||
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
|
|
||||||
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
|
|
||||||
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
|
|
||||||
config.memory_forget_percentage = memory_config.get(
|
|
||||||
"memory_forget_percentage", config.memory_forget_percentage
|
|
||||||
)
|
|
||||||
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
|
|
||||||
config.memory_build_distribution = memory_config.get(
|
|
||||||
"memory_build_distribution", config.memory_build_distribution
|
|
||||||
)
|
|
||||||
config.build_memory_sample_num = memory_config.get(
|
|
||||||
"build_memory_sample_num", config.build_memory_sample_num
|
|
||||||
)
|
|
||||||
config.build_memory_sample_length = memory_config.get(
|
|
||||||
"build_memory_sample_length", config.build_memory_sample_length
|
|
||||||
)
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.5.1"):
|
|
||||||
config.consolidate_memory_interval = memory_config.get(
|
|
||||||
"consolidate_memory_interval", config.consolidate_memory_interval
|
|
||||||
)
|
|
||||||
config.consolidation_similarity_threshold = memory_config.get(
|
|
||||||
"consolidation_similarity_threshold", config.consolidation_similarity_threshold
|
|
||||||
)
|
|
||||||
config.consolidate_memory_percentage = memory_config.get(
|
|
||||||
"consolidate_memory_percentage", config.consolidate_memory_percentage
|
|
||||||
)
|
|
||||||
|
|
||||||
def remote(parent: dict):
|
|
||||||
remote_config = parent["remote"]
|
|
||||||
config.remote_enable = remote_config.get("enable", config.remote_enable)
|
|
||||||
|
|
||||||
def mood(parent: dict):
|
|
||||||
mood_config = parent["mood"]
|
|
||||||
config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
|
|
||||||
config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate)
|
|
||||||
config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor)
|
|
||||||
|
|
||||||
def keywords_reaction(parent: dict):
|
|
||||||
keywords_reaction_config = parent["keywords_reaction"]
|
|
||||||
if keywords_reaction_config.get("enable", False):
|
|
||||||
config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules)
|
|
||||||
for rule in config.keywords_reaction_rules:
|
|
||||||
if rule.get("enable", False) and "regex" in rule:
|
|
||||||
rule["regex"] = [re.compile(r) for r in rule.get("regex", [])]
|
|
||||||
|
|
||||||
def chinese_typo(parent: dict):
|
|
||||||
chinese_typo_config = parent["chinese_typo"]
|
|
||||||
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
|
|
||||||
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
|
|
||||||
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
|
|
||||||
config.chinese_typo_tone_error_rate = chinese_typo_config.get(
|
|
||||||
"tone_error_rate", config.chinese_typo_tone_error_rate
|
|
||||||
)
|
|
||||||
config.chinese_typo_word_replace_rate = chinese_typo_config.get(
|
|
||||||
"word_replace_rate", config.chinese_typo_word_replace_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
def response_splitter(parent: dict):
|
|
||||||
response_splitter_config = parent["response_splitter"]
|
|
||||||
config.enable_response_splitter = response_splitter_config.get(
|
|
||||||
"enable_response_splitter", config.enable_response_splitter
|
|
||||||
)
|
|
||||||
config.response_max_length = response_splitter_config.get("response_max_length", config.response_max_length)
|
|
||||||
config.response_max_sentence_num = response_splitter_config.get(
|
|
||||||
"response_max_sentence_num", config.response_max_sentence_num
|
|
||||||
)
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.4.2"):
|
|
||||||
config.enable_kaomoji_protection = response_splitter_config.get(
|
|
||||||
"enable_kaomoji_protection", config.enable_kaomoji_protection
|
|
||||||
)
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.6.0"):
|
|
||||||
config.model_max_output_length = response_splitter_config.get(
|
|
||||||
"model_max_output_length", config.model_max_output_length
|
|
||||||
)
|
|
||||||
|
|
||||||
def groups(parent: dict):
|
|
||||||
groups_config = parent["groups"]
|
|
||||||
# config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
|
|
||||||
config.talk_allowed_groups = set(str(group) for group in groups_config.get("talk_allowed", []))
|
|
||||||
# config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
|
|
||||||
config.talk_frequency_down_groups = set(
|
|
||||||
str(group) for group in groups_config.get("talk_frequency_down", [])
|
|
||||||
)
|
|
||||||
# config.ban_user_id = set(groups_config.get("ban_user_id", []))
|
|
||||||
config.ban_user_id = set(str(user) for user in groups_config.get("ban_user_id", []))
|
|
||||||
|
|
||||||
def experimental(parent: dict):
|
|
||||||
experimental_config = parent["experimental"]
|
|
||||||
config.enable_friend_chat = experimental_config.get("enable_friend_chat", config.enable_friend_chat)
|
|
||||||
# config.enable_think_flow = experimental_config.get("enable_think_flow", config.enable_think_flow)
|
|
||||||
config.talk_allowed_private = set(str(user) for user in experimental_config.get("talk_allowed_private", []))
|
|
||||||
if config.INNER_VERSION in SpecifierSet(">=1.1.0"):
|
|
||||||
config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting)
|
|
||||||
|
|
||||||
# 版本表达式:>=1.0.0,<2.0.0
|
|
||||||
# 允许字段:func: method, support: str, notice: str, necessary: bool
|
|
||||||
# 如果使用 notice 字段,在该组配置加载时,会展示该字段对用户的警示
|
|
||||||
# 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
|
|
||||||
# 正常执行程序,但是会看到这条自定义提示
|
|
||||||
|
|
||||||
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
|
|
||||||
# 主版本号:当你做了不兼容的 API 修改,
|
|
||||||
# 次版本号:当你做了向下兼容的功能性新增,
|
|
||||||
# 修订号:当你做了向下兼容的问题修正。
|
|
||||||
# 先行版本号及版本编译信息可以加到"主版本号.次版本号.修订号"的后面,作为延伸。
|
|
||||||
|
|
||||||
# 如果你做了break的修改,就应该改动主版本号
|
|
||||||
# 如果做了一个兼容修改,就不应该要求这个选项是必须的!
|
|
||||||
include_configs = {
|
|
||||||
"bot": {"func": bot, "support": ">=0.0.0"},
|
|
||||||
"groups": {"func": groups, "support": ">=0.0.0"},
|
|
||||||
"personality": {"func": personality, "support": ">=0.0.0"},
|
|
||||||
"identity": {"func": identity, "support": ">=1.2.4"},
|
|
||||||
"emoji": {"func": emoji, "support": ">=0.0.0"},
|
|
||||||
"model": {"func": model, "support": ">=0.0.0"},
|
|
||||||
"memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
|
|
||||||
"mood": {"func": mood, "support": ">=0.0.0"},
|
|
||||||
"remote": {"func": remote, "support": ">=0.0.10", "necessary": False},
|
|
||||||
"keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
|
|
||||||
"chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
|
|
||||||
"response_splitter": {"func": response_splitter, "support": ">=0.0.11", "necessary": False},
|
|
||||||
"experimental": {"func": experimental, "support": ">=0.0.11", "necessary": False},
|
|
||||||
"chat": {"func": chat, "support": ">=1.6.0", "necessary": False},
|
|
||||||
"normal_chat": {"func": normal_chat, "support": ">=1.6.0", "necessary": False},
|
|
||||||
"focus_chat": {"func": focus_chat, "support": ">=1.6.0", "necessary": False},
|
|
||||||
}
|
|
||||||
|
|
||||||
# 原地修改,将 字符串版本表达式 转换成 版本对象
|
|
||||||
for key in include_configs:
|
|
||||||
item_support = include_configs[key]["support"]
|
|
||||||
include_configs[key]["support"] = cls.convert_to_specifierset(item_support)
|
|
||||||
|
|
||||||
if os.path.exists(config_path):
|
|
||||||
with open(config_path, "rb") as f:
|
|
||||||
try:
|
|
||||||
toml_dict = tomli.load(f)
|
|
||||||
except tomli.TOMLDecodeError as e:
|
|
||||||
logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
# 获取配置文件版本
|
|
||||||
config.INNER_VERSION = cls.get_config_version(toml_dict)
|
|
||||||
|
|
||||||
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
|
|
||||||
for key in include_configs:
|
|
||||||
if key in toml_dict:
|
|
||||||
group_specifierset: SpecifierSet = include_configs[key]["support"]
|
|
||||||
|
|
||||||
# 检查配置文件版本是否在支持范围内
|
|
||||||
if config.INNER_VERSION in group_specifierset:
|
|
||||||
# 如果版本在支持范围内,检查是否存在通知
|
|
||||||
if "notice" in include_configs[key]:
|
|
||||||
logger.warning(include_configs[key]["notice"])
|
|
||||||
|
|
||||||
include_configs[key]["func"](toml_dict)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 如果版本不在支持范围内,崩溃并提示用户
|
|
||||||
logger.error(
|
|
||||||
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
|
|
||||||
f"当前程序仅支持以下版本范围: {group_specifierset}"
|
|
||||||
)
|
|
||||||
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
|
|
||||||
|
|
||||||
# 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理
|
|
||||||
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
|
|
||||||
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
|
|
||||||
if key == "keywords_reaction":
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 如果用户根本没有需要的配置项,提示缺少配置
|
|
||||||
logger.error(f"配置文件中缺少必需的字段: '{key}'")
|
|
||||||
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
|
|
||||||
|
|
||||||
# identity_detail字段非空检查
|
|
||||||
if not config.identity_detail:
|
|
||||||
logger.error("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
|
|
||||||
raise ValueError("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
|
|
||||||
|
|
||||||
logger.success(f"成功加载配置文件: {config_path}")
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
# 获取配置文件路径
|
# 获取配置文件路径
|
||||||
logger.info(f"MaiCore当前版本: {mai_version}")
|
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
||||||
update_config()
|
update_config()
|
||||||
|
|
||||||
bot_config_floder_path = BotConfig.get_config_dir()
|
logger.info("正在品鉴配置文件...")
|
||||||
logger.info(f"正在品鉴配置文件目录: {bot_config_floder_path}")
|
global_config = load_config(config_path=f"{CONFIG_DIR}/bot_config.toml")
|
||||||
|
logger.info("非常的新鲜,非常的美味!")
|
||||||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
|
||||||
|
|
||||||
if os.path.exists(bot_config_path):
|
|
||||||
# 如果开发环境配置文件不存在,则使用默认配置文件
|
|
||||||
logger.info(f"异常的新鲜,异常的美味: {bot_config_path}")
|
|
||||||
else:
|
|
||||||
# 配置文件不存在
|
|
||||||
logger.error("配置文件不存在,请检查路径: {bot_config_path}")
|
|
||||||
raise FileNotFoundError(f"配置文件不存在: {bot_config_path}")
|
|
||||||
|
|
||||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
|
||||||
|
|||||||
116
src/config/config_base.py
Normal file
116
src/config/config_base.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from dataclasses import dataclass, fields, MISSING
|
||||||
|
from typing import TypeVar, Type, Any, get_origin, get_args
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="ConfigBase")
|
||||||
|
|
||||||
|
TOML_DICT_TYPE = {
|
||||||
|
int,
|
||||||
|
float,
|
||||||
|
str,
|
||||||
|
bool,
|
||||||
|
list,
|
||||||
|
dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfigBase:
|
||||||
|
"""配置类的基类"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
|
||||||
|
"""从字典加载配置字段"""
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||||
|
|
||||||
|
init_args: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for f in fields(cls):
|
||||||
|
field_name = f.name
|
||||||
|
|
||||||
|
if field_name.startswith("_"):
|
||||||
|
# 跳过以 _ 开头的字段
|
||||||
|
continue
|
||||||
|
|
||||||
|
if field_name not in data:
|
||||||
|
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||||
|
# 跳过未提供且有默认值/默认构造方法的字段
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Missing required field: '{field_name}'")
|
||||||
|
|
||||||
|
value = data[field_name]
|
||||||
|
field_type = f.type
|
||||||
|
|
||||||
|
try:
|
||||||
|
init_args[field_name] = cls._convert_field(value, field_type)
|
||||||
|
except TypeError as e:
|
||||||
|
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
|
||||||
|
|
||||||
|
return cls(**init_args)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||||
|
"""
|
||||||
|
转换字段值为指定类型
|
||||||
|
|
||||||
|
1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法
|
||||||
|
2. 对于泛型集合类型(list, set, tuple),递归转换每个元素
|
||||||
|
3. 对于基础类型(int, str, float, bool),直接转换
|
||||||
|
4. 对于其他类型,尝试直接转换,如果失败则抛出异常
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||||
|
if isinstance(field_type, type) and issubclass(field_type, ConfigBase):
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||||
|
return field_type.from_dict(value)
|
||||||
|
|
||||||
|
# 处理泛型集合类型(list, set, tuple)
|
||||||
|
field_origin_type = get_origin(field_type)
|
||||||
|
field_type_args = get_args(field_type)
|
||||||
|
|
||||||
|
if field_origin_type in {list, set, tuple}:
|
||||||
|
# 检查提供的value是否为list
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||||
|
|
||||||
|
if field_origin_type is list:
|
||||||
|
return [cls._convert_field(item, field_type_args[0]) for item in value]
|
||||||
|
elif field_origin_type is set:
|
||||||
|
return {cls._convert_field(item, field_type_args[0]) for item in value}
|
||||||
|
elif field_origin_type is tuple:
|
||||||
|
# 检查提供的value长度是否与类型参数一致
|
||||||
|
if len(value) != len(field_type_args):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||||||
|
)
|
||||||
|
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args))
|
||||||
|
|
||||||
|
if field_origin_type is dict:
|
||||||
|
# 检查提供的value是否为dict
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||||
|
|
||||||
|
# 检查字典的键值类型
|
||||||
|
if len(field_type_args) != 2:
|
||||||
|
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||||
|
key_type, value_type = field_type_args
|
||||||
|
|
||||||
|
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||||
|
|
||||||
|
# 处理基础类型,例如 int, str 等
|
||||||
|
if field_type is Any or isinstance(value, field_type):
|
||||||
|
return value
|
||||||
|
|
||||||
|
# 其他类型,尝试直接转换
|
||||||
|
try:
|
||||||
|
return field_type(value)
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""返回配置类的字符串表示"""
|
||||||
|
return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})"
|
||||||
399
src/config/official_configs.py
Normal file
399
src/config/official_configs.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.config.config_base import ConfigBase
|
||||||
|
|
||||||
|
"""
|
||||||
|
须知:
|
||||||
|
1. 本文件中记录了所有的配置项
|
||||||
|
2. 所有新增的class都需要继承自ConfigBase
|
||||||
|
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||||
|
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BotConfig(ConfigBase):
|
||||||
|
"""QQ机器人配置类"""
|
||||||
|
|
||||||
|
qq_account: str
|
||||||
|
"""QQ账号"""
|
||||||
|
|
||||||
|
nickname: str
|
||||||
|
"""昵称"""
|
||||||
|
|
||||||
|
alias_names: list[str] = field(default_factory=lambda: [])
|
||||||
|
"""别名列表"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatTargetConfig(ConfigBase):
|
||||||
|
"""
|
||||||
|
聊天目标配置类
|
||||||
|
此类中有聊天的群组和用户配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
talk_allowed_groups: set[str] = field(default_factory=lambda: set())
|
||||||
|
"""允许聊天的群组列表"""
|
||||||
|
|
||||||
|
talk_frequency_down_groups: set[str] = field(default_factory=lambda: set())
|
||||||
|
"""降低聊天频率的群组列表"""
|
||||||
|
|
||||||
|
ban_user_id: set[str] = field(default_factory=lambda: set())
|
||||||
|
"""禁止聊天的用户列表"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PersonalityConfig(ConfigBase):
|
||||||
|
"""人格配置类"""
|
||||||
|
|
||||||
|
personality_core: str
|
||||||
|
"""核心人格"""
|
||||||
|
|
||||||
|
expression_style: str
|
||||||
|
"""表达风格"""
|
||||||
|
|
||||||
|
personality_sides: list[str] = field(default_factory=lambda: [])
|
||||||
|
"""人格侧写"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IdentityConfig(ConfigBase):
|
||||||
|
"""个体特征配置类"""
|
||||||
|
|
||||||
|
height: int = 170
|
||||||
|
"""身高(单位:厘米)"""
|
||||||
|
|
||||||
|
weight: float = 50
|
||||||
|
"""体重(单位:千克)"""
|
||||||
|
|
||||||
|
age: int = 18
|
||||||
|
"""年龄(单位:岁)"""
|
||||||
|
|
||||||
|
gender: str = "女"
|
||||||
|
"""性别(男/女)"""
|
||||||
|
|
||||||
|
appearance: str = "可爱"
|
||||||
|
"""外貌描述"""
|
||||||
|
|
||||||
|
identity_detail: list[str] = field(default_factory=lambda: [])
|
||||||
|
"""身份特征"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlatformsConfig(ConfigBase):
|
||||||
|
"""平台配置类"""
|
||||||
|
|
||||||
|
qq: str
|
||||||
|
"""QQ适配器连接URL配置"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatConfig(ConfigBase):
|
||||||
|
"""聊天配置类"""
|
||||||
|
|
||||||
|
allow_focus_mode: bool = True
|
||||||
|
"""是否允许专注聊天状态"""
|
||||||
|
|
||||||
|
base_normal_chat_num: int = 3
|
||||||
|
"""最多允许多少个群进行普通聊天"""
|
||||||
|
|
||||||
|
base_focused_chat_num: int = 2
|
||||||
|
"""最多允许多少个群进行专注聊天"""
|
||||||
|
|
||||||
|
observation_context_size: int = 12
|
||||||
|
"""可观察到的最长上下文大小,超过这个值的上下文会被压缩"""
|
||||||
|
|
||||||
|
message_buffer: bool = True
|
||||||
|
"""消息缓冲器"""
|
||||||
|
|
||||||
|
ban_words: set[str] = field(default_factory=lambda: set())
|
||||||
|
"""过滤词列表"""
|
||||||
|
|
||||||
|
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||||
|
"""过滤正则表达式列表"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NormalChatConfig(ConfigBase):
|
||||||
|
"""普通聊天配置类"""
|
||||||
|
|
||||||
|
reasoning_model_probability: float = 0.3
|
||||||
|
"""
|
||||||
|
发言时选择推理模型的概率(0-1之间)
|
||||||
|
选择普通模型的概率为 1 - reasoning_normal_model_probability
|
||||||
|
"""
|
||||||
|
|
||||||
|
emoji_chance: float = 0.2
|
||||||
|
"""发送表情包的基础概率"""
|
||||||
|
|
||||||
|
thinking_timeout: int = 120
|
||||||
|
"""最长思考时间"""
|
||||||
|
|
||||||
|
willing_mode: str = "classical"
|
||||||
|
"""意愿模式"""
|
||||||
|
|
||||||
|
response_willing_amplifier: float = 1.0
|
||||||
|
"""回复意愿放大系数"""
|
||||||
|
|
||||||
|
response_interested_rate_amplifier: float = 1.0
|
||||||
|
"""回复兴趣度放大系数"""
|
||||||
|
|
||||||
|
down_frequency_rate: float = 3.0
|
||||||
|
"""降低回复频率的群组回复意愿降低系数"""
|
||||||
|
|
||||||
|
emoji_response_penalty: float = 0.0
|
||||||
|
"""表情包回复惩罚系数"""
|
||||||
|
|
||||||
|
mentioned_bot_inevitable_reply: bool = False
|
||||||
|
"""提及 bot 必然回复"""
|
||||||
|
|
||||||
|
at_bot_inevitable_reply: bool = False
|
||||||
|
"""@bot 必然回复"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FocusChatConfig(ConfigBase):
|
||||||
|
"""专注聊天配置类"""
|
||||||
|
|
||||||
|
reply_trigger_threshold: float = 3.0
|
||||||
|
"""心流聊天触发阈值,越低越容易触发"""
|
||||||
|
|
||||||
|
default_decay_rate_per_second: float = 0.98
|
||||||
|
"""默认衰减率,越大衰减越快"""
|
||||||
|
|
||||||
|
consecutive_no_reply_threshold: int = 3
|
||||||
|
"""连续不回复的次数阈值"""
|
||||||
|
|
||||||
|
compressed_length: int = 5
|
||||||
|
"""心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5"""
|
||||||
|
|
||||||
|
compress_length_limit: int = 5
|
||||||
|
"""最多压缩份数,超过该数值的压缩上下文会被删除"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmojiConfig(ConfigBase):
|
||||||
|
"""表情包配置类"""
|
||||||
|
|
||||||
|
max_reg_num: int = 200
|
||||||
|
"""表情包最大注册数量"""
|
||||||
|
|
||||||
|
do_replace: bool = True
|
||||||
|
"""达到最大注册数量时替换旧表情包"""
|
||||||
|
|
||||||
|
check_interval: int = 120
|
||||||
|
"""表情包检查间隔(分钟)"""
|
||||||
|
|
||||||
|
save_pic: bool = False
|
||||||
|
"""是否保存图片"""
|
||||||
|
|
||||||
|
cache_emoji: bool = True
|
||||||
|
"""是否缓存表情包"""
|
||||||
|
|
||||||
|
steal_emoji: bool = True
|
||||||
|
"""是否偷取表情包,让麦麦可以发送她保存的这些表情包"""
|
||||||
|
|
||||||
|
content_filtration: bool = False
|
||||||
|
"""是否开启表情包过滤"""
|
||||||
|
|
||||||
|
filtration_prompt: str = "符合公序良俗"
|
||||||
|
"""表情包过滤要求"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryConfig(ConfigBase):
|
||||||
|
"""记忆配置类"""
|
||||||
|
|
||||||
|
memory_build_interval: int = 600
|
||||||
|
"""记忆构建间隔(秒)"""
|
||||||
|
|
||||||
|
memory_build_distribution: tuple[
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4))
|
||||||
|
"""记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重"""
|
||||||
|
|
||||||
|
memory_build_sample_num: int = 8
|
||||||
|
"""记忆构建采样数量"""
|
||||||
|
|
||||||
|
memory_build_sample_length: int = 40
|
||||||
|
"""记忆构建采样长度"""
|
||||||
|
|
||||||
|
memory_compress_rate: float = 0.1
|
||||||
|
"""记忆压缩率"""
|
||||||
|
|
||||||
|
forget_memory_interval: int = 1000
|
||||||
|
"""记忆遗忘间隔(秒)"""
|
||||||
|
|
||||||
|
memory_forget_time: int = 24
|
||||||
|
"""记忆遗忘时间(小时)"""
|
||||||
|
|
||||||
|
memory_forget_percentage: float = 0.01
|
||||||
|
"""记忆遗忘比例"""
|
||||||
|
|
||||||
|
consolidate_memory_interval: int = 1000
|
||||||
|
"""记忆整合间隔(秒)"""
|
||||||
|
|
||||||
|
consolidation_similarity_threshold: float = 0.7
|
||||||
|
"""整合相似度阈值"""
|
||||||
|
|
||||||
|
consolidate_memory_percentage: float = 0.01
|
||||||
|
"""整合检查节点比例"""
|
||||||
|
|
||||||
|
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
||||||
|
"""不允许记忆的词列表"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MoodConfig(ConfigBase):
|
||||||
|
"""情绪配置类"""
|
||||||
|
|
||||||
|
mood_update_interval: int = 1
|
||||||
|
"""情绪更新间隔(秒)"""
|
||||||
|
|
||||||
|
mood_decay_rate: float = 0.95
|
||||||
|
"""情绪衰减率"""
|
||||||
|
|
||||||
|
mood_intensity_factor: float = 0.7
|
||||||
|
"""情绪强度因子"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KeywordRuleConfig(ConfigBase):
|
||||||
|
"""关键词规则配置类"""
|
||||||
|
|
||||||
|
enable: bool = True
|
||||||
|
"""是否启用关键词规则"""
|
||||||
|
|
||||||
|
keywords: list[str] = field(default_factory=lambda: [])
|
||||||
|
"""关键词列表"""
|
||||||
|
|
||||||
|
regex: list[str] = field(default_factory=lambda: [])
|
||||||
|
"""正则表达式列表"""
|
||||||
|
|
||||||
|
reaction: str = ""
|
||||||
|
"""关键词触发的反应"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KeywordReactionConfig(ConfigBase):
|
||||||
|
"""关键词配置类"""
|
||||||
|
|
||||||
|
enable: bool = True
|
||||||
|
"""是否启用关键词反应"""
|
||||||
|
|
||||||
|
rules: list[KeywordRuleConfig] = field(default_factory=lambda: [])
|
||||||
|
"""关键词反应规则列表"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChineseTypoConfig(ConfigBase):
|
||||||
|
"""中文错别字配置类"""
|
||||||
|
|
||||||
|
enable: bool = True
|
||||||
|
"""是否启用中文错别字生成器"""
|
||||||
|
|
||||||
|
error_rate: float = 0.01
|
||||||
|
"""单字替换概率"""
|
||||||
|
|
||||||
|
min_freq: int = 9
|
||||||
|
"""最小字频阈值"""
|
||||||
|
|
||||||
|
tone_error_rate: float = 0.1
|
||||||
|
"""声调错误概率"""
|
||||||
|
|
||||||
|
word_replace_rate: float = 0.006
|
||||||
|
"""整词替换概率"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResponseSplitterConfig(ConfigBase):
|
||||||
|
"""回复分割器配置类"""
|
||||||
|
|
||||||
|
enable: bool = True
|
||||||
|
"""是否启用回复分割器"""
|
||||||
|
|
||||||
|
max_length: int = 256
|
||||||
|
"""回复允许的最大长度"""
|
||||||
|
|
||||||
|
max_sentence_num: int = 3
|
||||||
|
"""回复允许的最大句子数"""
|
||||||
|
|
||||||
|
enable_kaomoji_protection: bool = False
|
||||||
|
"""是否启用颜文字保护"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TelemetryConfig(ConfigBase):
|
||||||
|
"""遥测配置类"""
|
||||||
|
|
||||||
|
enable: bool = True
|
||||||
|
"""是否启用遥测"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExperimentalConfig(ConfigBase):
|
||||||
|
"""实验功能配置类"""
|
||||||
|
|
||||||
|
enable_friend_chat: bool = False
|
||||||
|
"""是否启用好友聊天"""
|
||||||
|
|
||||||
|
talk_allowed_private: set[str] = field(default_factory=lambda: set())
|
||||||
|
"""允许聊天的私聊列表"""
|
||||||
|
|
||||||
|
pfc_chatting: bool = False
|
||||||
|
"""是否启用PFC"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig(ConfigBase):
|
||||||
|
"""模型配置类"""
|
||||||
|
|
||||||
|
model_max_output_length: int = 800 # 最大回复长度
|
||||||
|
|
||||||
|
reasoning: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""推理模型配置"""
|
||||||
|
|
||||||
|
normal: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""普通模型配置"""
|
||||||
|
|
||||||
|
topic_judge: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""主题判断模型配置"""
|
||||||
|
|
||||||
|
summary: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""摘要模型配置"""
|
||||||
|
|
||||||
|
vlm: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""视觉语言模型配置"""
|
||||||
|
|
||||||
|
heartflow: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""心流模型配置"""
|
||||||
|
|
||||||
|
observation: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""观察模型配置"""
|
||||||
|
|
||||||
|
sub_heartflow: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""子心流模型配置"""
|
||||||
|
|
||||||
|
plan: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""计划模型配置"""
|
||||||
|
|
||||||
|
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""嵌入模型配置"""
|
||||||
|
|
||||||
|
pfc_action_planner: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""PFC动作规划模型配置"""
|
||||||
|
|
||||||
|
pfc_chat: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""PFC聊天模型配置"""
|
||||||
|
|
||||||
|
pfc_reply_checker: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""PFC回复检查模型配置"""
|
||||||
|
|
||||||
|
tool_use: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""工具使用模型配置"""
|
||||||
@@ -114,7 +114,7 @@ class ActionPlanner:
|
|||||||
request_type="action_planning",
|
request_type="action_planning",
|
||||||
)
|
)
|
||||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.bot.nickname
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||||
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
|
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
|
||||||
@@ -140,7 +140,7 @@ class ActionPlanner:
|
|||||||
# (这部分逻辑不变)
|
# (这部分逻辑不变)
|
||||||
time_since_last_bot_message_info = ""
|
time_since_last_bot_message_info = ""
|
||||||
try:
|
try:
|
||||||
bot_id = str(global_config.BOT_QQ)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||||
for i in range(len(observation_info.chat_history) - 1, -1, -1):
|
for i in range(len(observation_info.chat_history) - 1, -1, -1):
|
||||||
msg = observation_info.chat_history[i]
|
msg = observation_info.chat_history[i]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import (
|
|||||||
create_new_message_notification,
|
create_new_message_notification,
|
||||||
create_cold_chat_notification,
|
create_cold_chat_notification,
|
||||||
)
|
)
|
||||||
from src.experimental.PFC.message_storage import MongoDBMessageStorage
|
from src.experimental.PFC.message_storage import PeeweeMessageStorage
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -53,7 +53,7 @@ class ChatObserver:
|
|||||||
|
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
self.message_storage = MongoDBMessageStorage()
|
self.message_storage = PeeweeMessageStorage()
|
||||||
|
|
||||||
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||||
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||||
@@ -323,7 +323,7 @@ class ChatObserver:
|
|||||||
for msg in messages:
|
for msg in messages:
|
||||||
try:
|
try:
|
||||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||||
if user_info.user_id == global_config.BOT_QQ:
|
if user_info.user_id == global_config.bot.qq_account:
|
||||||
self.update_bot_speak_time(msg["time"])
|
self.update_bot_speak_time(msg["time"])
|
||||||
else:
|
else:
|
||||||
self.update_user_speak_time(msg["time"])
|
self.update_user_speak_time(msg["time"])
|
||||||
|
|||||||
@@ -42,8 +42,8 @@ class DirectMessageSender:
|
|||||||
|
|
||||||
# 获取麦麦的信息
|
# 获取麦麦的信息
|
||||||
bot_user_info = UserInfo(
|
bot_user_info = UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.bot.qq_account,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.bot.nickname,
|
||||||
platform=chat_stream.platform,
|
platform=chat_stream.platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from src.common.database import db
|
|
||||||
|
# from src.common.database.database import db # Peewee db 导入
|
||||||
|
from src.common.database.database_model import Messages # Peewee Messages 模型导入
|
||||||
|
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
|
||||||
|
|
||||||
|
|
||||||
class MessageStorage(ABC):
|
class MessageStorage(ABC):
|
||||||
@@ -47,28 +50,35 @@ class MessageStorage(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MongoDBMessageStorage(MessageStorage):
|
class PeeweeMessageStorage(MessageStorage):
|
||||||
"""MongoDB消息存储实现"""
|
"""Peewee消息存储实现"""
|
||||||
|
|
||||||
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
|
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
|
||||||
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
|
query = (
|
||||||
# print(f"storage_check_message: {message_time}")
|
Messages.select()
|
||||||
|
.where((Messages.chat_id == chat_id) & (Messages.time > message_time))
|
||||||
|
.order_by(Messages.time.asc())
|
||||||
|
)
|
||||||
|
|
||||||
return list(db.messages.find(query).sort("time", 1))
|
# print(f"storage_check_message: {message_time}")
|
||||||
|
messages_models = list(query)
|
||||||
|
return [model_to_dict(msg) for msg in messages_models]
|
||||||
|
|
||||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||||
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
|
query = (
|
||||||
|
Messages.select()
|
||||||
messages = list(db.messages.find(query).sort("time", -1).limit(limit))
|
.where((Messages.chat_id == chat_id) & (Messages.time < time_point))
|
||||||
|
.order_by(Messages.time.desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
messages_models = list(query)
|
||||||
# 将消息按时间正序排列
|
# 将消息按时间正序排列
|
||||||
messages.reverse()
|
messages_models.reverse()
|
||||||
return messages
|
return [model_to_dict(msg) for msg in messages_models]
|
||||||
|
|
||||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||||
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
|
return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists()
|
||||||
|
|
||||||
return db.messages.find_one(query) is not None
|
|
||||||
|
|
||||||
|
|
||||||
# # 创建一个内存消息存储实现,用于测试
|
# # 创建一个内存消息存储实现,用于测试
|
||||||
|
|||||||
@@ -42,13 +42,14 @@ class GoalAnalyzer:
|
|||||||
"""对话目标分析器"""
|
"""对话目标分析器"""
|
||||||
|
|
||||||
def __init__(self, stream_id: str, private_name: str):
|
def __init__(self, stream_id: str, private_name: str):
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.llm = LLMRequest(
|
self.llm = LLMRequest(
|
||||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
|
model=global_config.model.normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.bot.nickname
|
||||||
self.nick_name = global_config.BOT_ALIAS_NAMES
|
self.nick_name = global_config.bot.alias_names
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,10 @@ class KnowledgeFetcher:
|
|||||||
"""知识调取器"""
|
"""知识调取器"""
|
||||||
|
|
||||||
def __init__(self, private_name: str):
|
def __init__(self, private_name: str):
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
self.llm = LLMRequest(
|
self.llm = LLMRequest(
|
||||||
model=global_config.llm_normal,
|
model=global_config.model.normal,
|
||||||
temperature=global_config.llm_normal["temp"],
|
temperature=global_config.model.normal["temp"],
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
request_type="knowledge_fetch",
|
request_type="knowledge_fetch",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class ReplyChecker:
|
|||||||
self.llm = LLMRequest(
|
self.llm = LLMRequest(
|
||||||
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
|
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
|
||||||
)
|
)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.bot.nickname
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||||
self.max_retries = 3 # 最大重试次数
|
self.max_retries = 3 # 最大重试次数
|
||||||
@@ -43,7 +43,7 @@ class ReplyChecker:
|
|||||||
bot_messages = []
|
bot_messages = []
|
||||||
for msg in reversed(chat_history):
|
for msg in reversed(chat_history):
|
||||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||||
if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串
|
if str(user_info.user_id) == str(global_config.bot.qq_account): # 确保比较的是字符串
|
||||||
bot_messages.append(msg.get("processed_plain_text", ""))
|
bot_messages.append(msg.get("processed_plain_text", ""))
|
||||||
if len(bot_messages) >= 2: # 只和最近的两条比较
|
if len(bot_messages) >= 2: # 只和最近的两条比较
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ class ReplyGenerator:
|
|||||||
request_type="reply_generation",
|
request_type="reply_generation",
|
||||||
)
|
)
|
||||||
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.bot.nickname
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||||
self.reply_checker = ReplyChecker(stream_id, private_name)
|
self.reply_checker = ReplyChecker(stream_id, private_name)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class Waiter:
|
|||||||
|
|
||||||
def __init__(self, stream_id: str, private_name: str):
|
def __init__(self, stream_id: str, private_name: str):
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.bot.nickname
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
# self.wait_accumulated_time = 0 # 不再需要累加计时
|
# self.wait_accumulated_time = 0 # 不再需要累加计时
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class MessageProcessor:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_ban_words(text: str, chat, userinfo) -> bool:
|
def _check_ban_words(text: str, chat, userinfo) -> bool:
|
||||||
"""检查消息中是否包含过滤词"""
|
"""检查消息中是否包含过滤词"""
|
||||||
for word in global_config.ban_words:
|
for word in global_config.chat.ban_words:
|
||||||
if word in text:
|
if word in text:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
|
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
|
||||||
@@ -28,7 +28,7 @@ class MessageProcessor:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_ban_regex(text: str, chat, userinfo) -> bool:
|
def _check_ban_regex(text: str, chat, userinfo) -> bool:
|
||||||
"""检查消息是否匹配过滤正则表达式"""
|
"""检查消息是否匹配过滤正则表达式"""
|
||||||
for pattern in global_config.ban_msgs_regex:
|
for pattern in global_config.chat.ban_msgs_regex:
|
||||||
if pattern.search(text):
|
if pattern.search(text):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
|
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
|
||||||
|
|||||||
32
src/main.py
32
src/main.py
@@ -40,7 +40,7 @@ class MainSystem:
|
|||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
|
logger.debug(f"正在唤醒{global_config.bot.nickname}......")
|
||||||
|
|
||||||
# 其他初始化任务
|
# 其他初始化任务
|
||||||
await asyncio.gather(self._init_components())
|
await asyncio.gather(self._init_components())
|
||||||
@@ -84,7 +84,7 @@ class MainSystem:
|
|||||||
asyncio.create_task(chat_manager._auto_save_task())
|
asyncio.create_task(chat_manager._auto_save_task())
|
||||||
|
|
||||||
# 使用HippocampusManager初始化海马体
|
# 使用HippocampusManager初始化海马体
|
||||||
self.hippocampus_manager.initialize(global_config=global_config)
|
self.hippocampus_manager.initialize()
|
||||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||||
|
|
||||||
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
|
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
|
||||||
@@ -92,15 +92,15 @@ class MainSystem:
|
|||||||
|
|
||||||
# 初始化个体特征
|
# 初始化个体特征
|
||||||
self.individuality.initialize(
|
self.individuality.initialize(
|
||||||
bot_nickname=global_config.BOT_NICKNAME,
|
bot_nickname=global_config.bot.nickname,
|
||||||
personality_core=global_config.personality_core,
|
personality_core=global_config.personality.personality_core,
|
||||||
personality_sides=global_config.personality_sides,
|
personality_sides=global_config.personality.personality_sides,
|
||||||
identity_detail=global_config.identity_detail,
|
identity_detail=global_config.identity.identity_detail,
|
||||||
height=global_config.height,
|
height=global_config.identity.height,
|
||||||
weight=global_config.weight,
|
weight=global_config.identity.weight,
|
||||||
age=global_config.age,
|
age=global_config.identity.age,
|
||||||
gender=global_config.gender,
|
gender=global_config.identity.gender,
|
||||||
appearance=global_config.appearance,
|
appearance=global_config.identity.appearance,
|
||||||
)
|
)
|
||||||
logger.success("个体特征初始化成功")
|
logger.success("个体特征初始化成功")
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ class MainSystem:
|
|||||||
async def build_memory_task():
|
async def build_memory_task():
|
||||||
"""记忆构建任务"""
|
"""记忆构建任务"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(global_config.build_memory_interval)
|
await asyncio.sleep(global_config.memory.memory_build_interval)
|
||||||
logger.info("正在进行记忆构建")
|
logger.info("正在进行记忆构建")
|
||||||
await HippocampusManager.get_instance().build_memory()
|
await HippocampusManager.get_instance().build_memory()
|
||||||
|
|
||||||
@@ -149,16 +149,18 @@ class MainSystem:
|
|||||||
async def forget_memory_task():
|
async def forget_memory_task():
|
||||||
"""记忆遗忘任务"""
|
"""记忆遗忘任务"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(global_config.forget_memory_interval)
|
await asyncio.sleep(global_config.memory.forget_memory_interval)
|
||||||
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
|
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
|
||||||
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
|
await HippocampusManager.get_instance().forget_memory(
|
||||||
|
percentage=global_config.memory.memory_forget_percentage
|
||||||
|
)
|
||||||
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
|
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def consolidate_memory_task():
|
async def consolidate_memory_task():
|
||||||
"""记忆整合任务"""
|
"""记忆整合任务"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(global_config.consolidate_memory_interval)
|
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
|
||||||
print("\033[1;32m[记忆整合]\033[0m 开始整合记忆...")
|
print("\033[1;32m[记忆整合]\033[0m 开始整合记忆...")
|
||||||
await HippocampusManager.get_instance().consolidate_memory()
|
await HippocampusManager.get_instance().consolidate_memory()
|
||||||
print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
|
print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
|
||||||
|
|||||||
@@ -34,14 +34,14 @@ class MoodUpdateTask(AsyncTask):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task_name="Mood Update Task",
|
task_name="Mood Update Task",
|
||||||
wait_before_start=global_config.mood_update_interval,
|
wait_before_start=global_config.mood.mood_update_interval,
|
||||||
run_interval=global_config.mood_update_interval,
|
run_interval=global_config.mood.mood_update_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 从配置文件获取衰减率
|
# 从配置文件获取衰减率
|
||||||
self.decay_rate_valence: float = 1 - global_config.mood_decay_rate
|
self.decay_rate_valence: float = 1 - global_config.mood.mood_decay_rate
|
||||||
"""愉悦度衰减率"""
|
"""愉悦度衰减率"""
|
||||||
self.decay_rate_arousal: float = 1 - global_config.mood_decay_rate
|
self.decay_rate_arousal: float = 1 - global_config.mood.mood_decay_rate
|
||||||
"""唤醒度衰减率"""
|
"""唤醒度衰减率"""
|
||||||
|
|
||||||
self.last_update = time.time()
|
self.last_update = time.time()
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class ChangeMoodTool(BaseTool):
|
|||||||
_ori_response = ",".join(response_set)
|
_ori_response = ",".join(response_set)
|
||||||
# _stance, emotion = await gpt._get_emotion_tags(ori_response, message_processed_plain_text)
|
# _stance, emotion = await gpt._get_emotion_tags(ori_response, message_processed_plain_text)
|
||||||
emotion = "平静"
|
emotion = "平静"
|
||||||
mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
|
mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
|
||||||
return {"name": "change_mood", "content": f"你的心情刚刚变化了,现在的心情是: {emotion}"}
|
return {"name": "change_mood", "content": f"你的心情刚刚变化了,现在的心情是: {emotion}"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"心情改变工具执行失败: {str(e)}")
|
logger.error(f"心情改变工具执行失败: {str(e)}")
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from src.tools.tool_can_use.base_tool import BaseTool
|
from src.tools.tool_can_use.base_tool import BaseTool
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding
|
||||||
from src.common.database import db
|
from src.common.database.database_model import Knowledges # Updated import
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from typing import Any, Union
|
from typing import Any, Union, List # Added List
|
||||||
|
import json # Added for parsing embedding
|
||||||
|
import math # Added for cosine similarity
|
||||||
|
|
||||||
logger = get_logger("get_knowledge_tool")
|
logger = get_logger("get_knowledge_tool")
|
||||||
|
|
||||||
@@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: 工具执行结果
|
dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
|
query = "" # Initialize query to ensure it's defined in except block
|
||||||
try:
|
try:
|
||||||
query = function_args.get("query")
|
query = function_args.get("query")
|
||||||
threshold = function_args.get("threshold", 0.4)
|
threshold = function_args.get("threshold", 0.4)
|
||||||
@@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool):
|
|||||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||||
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
|
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||||
|
"""计算两个向量之间的余弦相似度"""
|
||||||
|
dot_product = sum(p * q for p, q in zip(vec1, vec2))
|
||||||
|
magnitude1 = math.sqrt(sum(p * p for p in vec1))
|
||||||
|
magnitude2 = math.sqrt(sum(q * q for q in vec2))
|
||||||
|
if magnitude1 == 0 or magnitude2 == 0:
|
||||||
|
return 0.0
|
||||||
|
return dot_product / (magnitude1 * magnitude2)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_info_from_db(
|
def get_info_from_db(
|
||||||
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||||
) -> Union[str, list]:
|
) -> Union[str, list]:
|
||||||
"""从数据库中获取相关信息
|
"""从数据库中获取相关信息
|
||||||
|
|
||||||
@@ -66,66 +79,51 @@ class SearchKnowledgeTool(BaseTool):
|
|||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return "" if not return_raw else []
|
return "" if not return_raw else []
|
||||||
|
|
||||||
# 使用余弦相似度计算
|
similar_items = []
|
||||||
pipeline = [
|
try:
|
||||||
{
|
all_knowledges = Knowledges.select()
|
||||||
"$addFields": {
|
for item in all_knowledges:
|
||||||
"dotProduct": {
|
try:
|
||||||
"$reduce": {
|
item_embedding_str = item.embedding
|
||||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
if not item_embedding_str:
|
||||||
"initialValue": 0,
|
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
|
||||||
"in": {
|
continue
|
||||||
"$add": [
|
item_embedding = json.loads(item_embedding_str)
|
||||||
"$$value",
|
if not isinstance(item_embedding, list) or not all(
|
||||||
{
|
isinstance(x, (int, float)) for x in item_embedding
|
||||||
"$multiply": [
|
):
|
||||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
|
||||||
{"$arrayElemAt": [query_embedding, "$$this"]},
|
continue
|
||||||
]
|
except json.JSONDecodeError:
|
||||||
},
|
logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
|
||||||
]
|
continue
|
||||||
},
|
except AttributeError:
|
||||||
}
|
logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
|
||||||
},
|
continue
|
||||||
"magnitude1": {
|
|
||||||
"$sqrt": {
|
|
||||||
"$reduce": {
|
|
||||||
"input": "$embedding",
|
|
||||||
"initialValue": 0,
|
|
||||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"magnitude2": {
|
|
||||||
"$sqrt": {
|
|
||||||
"$reduce": {
|
|
||||||
"input": query_embedding,
|
|
||||||
"initialValue": 0,
|
|
||||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
|
||||||
{
|
|
||||||
"$match": {
|
|
||||||
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$sort": {"similarity": -1}},
|
|
||||||
{"$limit": limit},
|
|
||||||
{"$project": {"content": 1, "similarity": 1}},
|
|
||||||
]
|
|
||||||
|
|
||||||
results = list(db.knowledges.aggregate(pipeline))
|
similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
|
||||||
logger.debug(f"知识库查询结果数量: {len(results)}")
|
|
||||||
|
if similarity >= threshold:
|
||||||
|
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
|
||||||
|
|
||||||
|
# 按相似度降序排序
|
||||||
|
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
|
||||||
|
|
||||||
|
# 应用限制
|
||||||
|
results = similar_items[:limit]
|
||||||
|
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
|
||||||
|
return "" if not return_raw else []
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return "" if not return_raw else []
|
return "" if not return_raw else []
|
||||||
|
|
||||||
if return_raw:
|
if return_raw:
|
||||||
return results
|
# Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
|
||||||
|
# 这里返回包含内容和相似度的字典列表
|
||||||
|
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
|
||||||
else:
|
else:
|
||||||
# 返回所有找到的内容,用换行分隔
|
# 返回所有找到的内容,用换行分隔
|
||||||
return "\n".join(str(result["content"]) for result in results)
|
return "\n".join(str(result["content"]) for result in results)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ logger = get_logger("tool_use")
|
|||||||
class ToolUser:
|
class ToolUser:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm_model_tool = LLMRequest(
|
self.llm_model_tool = LLMRequest(
|
||||||
model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
|
model=global_config.model.tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -37,7 +37,7 @@ class ToolUser:
|
|||||||
# print(f"intol111111111111111111111111111111111222222222222mid_memory_info:{mid_memory_info}")
|
# print(f"intol111111111111111111111111111111111222222222222mid_memory_info:{mid_memory_info}")
|
||||||
|
|
||||||
# 这些信息应该从调用者传入,而不是从self获取
|
# 这些信息应该从调用者传入,而不是从self获取
|
||||||
bot_name = global_config.BOT_NICKNAME
|
bot_name = global_config.bot.nickname
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt += mid_memory_info
|
prompt += mid_memory_info
|
||||||
prompt += "你正在思考如何回复群里的消息。\n"
|
prompt += "你正在思考如何回复群里的消息。\n"
|
||||||
|
|||||||
@@ -1,104 +0,0 @@
|
|||||||
[inner.version]
|
|
||||||
describe = "版本号"
|
|
||||||
important = true
|
|
||||||
can_edit = false
|
|
||||||
|
|
||||||
[bot.qq]
|
|
||||||
describe = "机器人的QQ号"
|
|
||||||
important = true
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[bot.nickname]
|
|
||||||
describe = "机器人的昵称"
|
|
||||||
important = true
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[bot.alias_names]
|
|
||||||
describe = "机器人的别名列表,该选项还在调试中,暂时未生效"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[groups.talk_allowed]
|
|
||||||
describe = "可以回复消息的群号码列表"
|
|
||||||
important = true
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[groups.talk_frequency_down]
|
|
||||||
describe = "降低回复频率的群号码列表"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[groups.ban_user_id]
|
|
||||||
describe = "禁止回复和读取消息的QQ号列表"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[personality.personality_core]
|
|
||||||
describe = "用一句话或几句话描述人格的核心特点,建议20字以内"
|
|
||||||
important = true
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[personality.personality_sides]
|
|
||||||
describe = "用一句话或几句话描述人格的一些细节,条数任意,不能为0,该选项还在调试中"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[identity.identity_detail]
|
|
||||||
describe = "身份特点列表,条数任意,不能为0,该选项还在调试中"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[identity.age]
|
|
||||||
describe = "年龄,单位岁"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[identity.gender]
|
|
||||||
describe = "性别"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[identity.appearance]
|
|
||||||
describe = "外貌特征描述,该选项还在调试中,暂时未生效"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[platforms.nonebot-qq]
|
|
||||||
describe = "nonebot-qq适配器提供的链接"
|
|
||||||
important = true
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[chat.allow_focus_mode]
|
|
||||||
describe = "是否允许专注聊天状态"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[chat.base_normal_chat_num]
|
|
||||||
describe = "最多允许多少个群进行普通聊天"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[chat.base_focused_chat_num]
|
|
||||||
describe = "最多允许多少个群进行专注聊天"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[chat.observation_context_size]
|
|
||||||
describe = "观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[chat.message_buffer]
|
|
||||||
describe = "启用消息缓冲器,启用此项以解决消息的拆分问题,但会使麦麦的回复延迟"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[chat.ban_words]
|
|
||||||
describe = "需要过滤的消息列表"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
|
|
||||||
[chat.ban_msgs_regex]
|
|
||||||
describe = "需要过滤的消息(原始消息)匹配的正则表达式,匹配到的消息将被过滤(支持CQ码)"
|
|
||||||
important = false
|
|
||||||
can_edit = true
|
|
||||||
@@ -1,18 +1,10 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "1.7.0"
|
version = "2.0.0"
|
||||||
|
|
||||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||||
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||||
#如果新增项目,请在BotConfig类下新增相应的变量
|
#如果新增项目,请阅读src/config/official_configs.py中的说明
|
||||||
#1.如果你修改的是[]层级项目,例如你新增了 [memory],那么请在config.py的 load_config函数中的include_configs字典中新增"内容":{
|
#
|
||||||
#"func":memory,
|
|
||||||
#"support":">=0.0.0", #新的版本号
|
|
||||||
#"necessary":False #是否必须
|
|
||||||
#}
|
|
||||||
#2.如果你修改的是[]下的项目,例如你新增了[memory]下的 memory_ban_words ,那么请在config.py的 load_config函数中的 memory函数下新增版本判断:
|
|
||||||
# if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
|
|
||||||
# config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
|
|
||||||
|
|
||||||
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
|
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
|
||||||
# 主版本号:当你做了不兼容的 API 修改,
|
# 主版本号:当你做了不兼容的 API 修改,
|
||||||
# 次版本号:当你做了向下兼容的功能性新增,
|
# 次版本号:当你做了向下兼容的功能性新增,
|
||||||
@@ -21,11 +13,11 @@ version = "1.7.0"
|
|||||||
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||||
|
|
||||||
[bot]
|
[bot]
|
||||||
qq = 1145141919810
|
qq_account = 1145141919810
|
||||||
nickname = "麦麦"
|
nickname = "麦麦"
|
||||||
alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效
|
alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效
|
||||||
|
|
||||||
[groups]
|
[chat_target]
|
||||||
talk_allowed = [
|
talk_allowed = [
|
||||||
123,
|
123,
|
||||||
123,
|
123,
|
||||||
@@ -53,10 +45,13 @@ identity_detail = [
|
|||||||
"身份特点",
|
"身份特点",
|
||||||
"身份特点",
|
"身份特点",
|
||||||
]# 条数任意,不能为0, 该选项还在调试中
|
]# 条数任意,不能为0, 该选项还在调试中
|
||||||
|
|
||||||
#外貌特征
|
#外貌特征
|
||||||
age = 20 # 年龄 单位岁
|
age = 18 # 年龄 单位岁
|
||||||
gender = "男" # 性别
|
gender = "女" # 性别
|
||||||
appearance = "用几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效
|
height = "170" # 身高(单位cm)
|
||||||
|
weight = "50" # 体重(单位kg)
|
||||||
|
appearance = "用一句或几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效
|
||||||
|
|
||||||
[platforms] # 必填项目,填写每个平台适配器提供的链接
|
[platforms] # 必填项目,填写每个平台适配器提供的链接
|
||||||
qq="http://127.0.0.1:18002/api/message"
|
qq="http://127.0.0.1:18002/api/message"
|
||||||
@@ -85,11 +80,10 @@ ban_msgs_regex = [
|
|||||||
|
|
||||||
[normal_chat] #普通聊天
|
[normal_chat] #普通聊天
|
||||||
#一般回复参数
|
#一般回复参数
|
||||||
model_reasoning_probability = 0.7 # 麦麦回答时选择推理模型 模型的概率
|
reasoning_model_probability = 0.3 # 麦麦回答时选择推理模型的概率(与之相对的,普通模型的概率为1 - reasoning_model_probability)
|
||||||
model_normal_probability = 0.3 # 麦麦回答时选择一般模型 模型的概率
|
|
||||||
|
|
||||||
emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率,设置为1让麦麦自己决定发不发
|
emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率,设置为1让麦麦自己决定发不发
|
||||||
thinking_timeout = 100 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢)
|
thinking_timeout = 120 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢)
|
||||||
|
|
||||||
willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现)
|
willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现)
|
||||||
response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1
|
response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1
|
||||||
@@ -100,8 +94,8 @@ mentioned_bot_inevitable_reply = false # 提及 bot 必然回复
|
|||||||
at_bot_inevitable_reply = false # @bot 必然回复
|
at_bot_inevitable_reply = false # @bot 必然回复
|
||||||
|
|
||||||
[focus_chat] #专注聊天
|
[focus_chat] #专注聊天
|
||||||
reply_trigger_threshold = 3.6 # 专注聊天触发阈值,越低越容易进入专注聊天
|
reply_trigger_threshold = 3.0 # 专注聊天触发阈值,越低越容易进入专注聊天
|
||||||
default_decay_rate_per_second = 0.95 # 默认衰减率,越大衰减越快,越高越难进入专注聊天
|
default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快,越高越难进入专注聊天
|
||||||
consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天
|
consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天
|
||||||
|
|
||||||
# 以下选项暂时无效
|
# 以下选项暂时无效
|
||||||
@@ -110,20 +104,20 @@ compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下
|
|||||||
|
|
||||||
|
|
||||||
[emoji]
|
[emoji]
|
||||||
max_emoji_num = 40 # 表情包最大数量
|
max_reg_num = 40 # 表情包最大注册数量
|
||||||
max_reach_deletion = true # 开启则在达到最大数量时删除表情包,关闭则达到最大数量时不删除,只是不会继续收集表情包
|
do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包
|
||||||
check_interval = 10 # 检查表情包(注册,破损,删除)的时间间隔(分钟)
|
check_interval = 120 # 检查表情包(注册,破损,删除)的时间间隔(分钟)
|
||||||
save_pic = false # 是否保存图片
|
save_pic = false # 是否保存图片
|
||||||
save_emoji = false # 是否保存表情包
|
cache_emoji = true # 是否缓存表情包
|
||||||
steal_emoji = true # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
|
steal_emoji = true # 是否偷取表情包,让麦麦可以发送她保存的这些表情包
|
||||||
enable_check = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存
|
content_filtration = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存
|
||||||
check_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存
|
filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存
|
||||||
|
|
||||||
[memory]
|
[memory]
|
||||||
build_memory_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
|
memory_build_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
|
||||||
build_memory_distribution = [6.0,3.0,0.6,32.0,12.0,0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
||||||
build_memory_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
|
memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
|
||||||
build_memory_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富
|
memory_build_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富
|
||||||
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
|
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
|
||||||
|
|
||||||
forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习
|
forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习
|
||||||
@@ -135,49 +129,45 @@ consolidation_similarity_threshold = 0.7 # 相似度阈值
|
|||||||
consolidation_check_percentage = 0.01 # 检查节点比例
|
consolidation_check_percentage = 0.01 # 检查节点比例
|
||||||
|
|
||||||
#不希望记忆的词,已经记忆的不会受到影响
|
#不希望记忆的词,已经记忆的不会受到影响
|
||||||
memory_ban_words = [
|
memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
|
||||||
# "403","张三"
|
|
||||||
]
|
|
||||||
|
|
||||||
[mood]
|
[mood]
|
||||||
mood_update_interval = 1.0 # 情绪更新间隔 单位秒
|
mood_update_interval = 1.0 # 情绪更新间隔 单位秒
|
||||||
mood_decay_rate = 0.95 # 情绪衰减率
|
mood_decay_rate = 0.95 # 情绪衰减率
|
||||||
mood_intensity_factor = 1.0 # 情绪强度因子
|
mood_intensity_factor = 1.0 # 情绪强度因子
|
||||||
|
|
||||||
[keywords_reaction] # 针对某个关键词作出反应
|
[keyword_reaction] # 针对某个关键词作出反应
|
||||||
enable = true # 关键词反应功能的总开关
|
enable = true # 关键词反应功能的总开关
|
||||||
|
|
||||||
[[keywords_reaction.rules]] # 如果想要新增多个关键词,直接复制本条,修改keywords和reaction即可
|
[[keyword_reaction.rules]] # 如果想要新增多个关键词,直接复制本条,修改keywords和reaction即可
|
||||||
enable = true # 是否启用此条(为了人类在未来AI战争能更好地识别AI(bushi),默认开启)
|
enable = true # 是否启用此条(为了人类在未来AI战争能更好地识别AI(bushi),默认开启)
|
||||||
keywords = ["人机", "bot", "机器", "入机", "robot", "机器人","ai","AI"] # 会触发反应的关键词
|
keywords = ["人机", "bot", "机器", "入机", "robot", "机器人","ai","AI"] # 会触发反应的关键词
|
||||||
reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" # 触发之后添加的提示词
|
reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" # 触发之后添加的提示词
|
||||||
|
|
||||||
[[keywords_reaction.rules]] # 就像这样复制
|
[[keyword_reaction.rules]] # 就像这样复制
|
||||||
enable = false # 仅作示例,不会触发
|
enable = false # 仅作示例,不会触发
|
||||||
keywords = ["测试关键词回复","test",""]
|
keywords = ["测试关键词回复","test",""]
|
||||||
reaction = "回答“测试成功”" # 修复错误的引号
|
reaction = "回答“测试成功”" # 修复错误的引号
|
||||||
|
|
||||||
[[keywords_reaction.rules]] # 使用正则表达式匹配句式
|
[[keyword_reaction.rules]] # 使用正则表达式匹配句式
|
||||||
enable = false # 仅作示例,不会触发
|
enable = false # 仅作示例,不会触发
|
||||||
regex = ["^(?P<n>\\S{1,20})是这样的$"] # 将匹配到的词汇命名为n,反应中对应的[n]会被替换为匹配到的内容,若不了解正则表达式请勿编写
|
regex = ["^(?P<n>\\S{1,20})是这样的$"] # 将匹配到的词汇命名为n,反应中对应的[n]会被替换为匹配到的内容,若不了解正则表达式请勿编写
|
||||||
reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)"
|
reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)"
|
||||||
|
|
||||||
[chinese_typo]
|
[chinese_typo]
|
||||||
enable = true # 是否启用中文错别字生成器
|
enable = true # 是否启用中文错别字生成器
|
||||||
error_rate=0.001 # 单字替换概率
|
error_rate=0.01 # 单字替换概率
|
||||||
min_freq=9 # 最小字频阈值
|
min_freq=9 # 最小字频阈值
|
||||||
tone_error_rate=0.1 # 声调错误概率
|
tone_error_rate=0.1 # 声调错误概率
|
||||||
word_replace_rate=0.006 # 整词替换概率
|
word_replace_rate=0.006 # 整词替换概率
|
||||||
|
|
||||||
[response_splitter]
|
[response_splitter]
|
||||||
enable_response_splitter = true # 是否启用回复分割器
|
enable = true # 是否启用回复分割器
|
||||||
response_max_length = 256 # 回复允许的最大长度
|
max_length = 256 # 回复允许的最大长度
|
||||||
response_max_sentence_num = 4 # 回复允许的最大句子数
|
max_sentence_num = 4 # 回复允许的最大句子数
|
||||||
enable_kaomoji_protection = false # 是否启用颜文字保护
|
enable_kaomoji_protection = false # 是否启用颜文字保护
|
||||||
|
|
||||||
model_max_output_length = 256 # 模型单次返回的最大token数
|
[telemetry] #发送统计信息,主要是看全球有多少只麦麦
|
||||||
|
|
||||||
[remote] #发送统计信息,主要是看全球有多少只麦麦
|
|
||||||
enable = true
|
enable = true
|
||||||
|
|
||||||
[experimental] #实验性功能
|
[experimental] #实验性功能
|
||||||
@@ -194,14 +184,17 @@ pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与
|
|||||||
# stream = <true|false> : 用于指定模型是否是使用流式输出
|
# stream = <true|false> : 用于指定模型是否是使用流式输出
|
||||||
# 如果不指定,则该项是 False
|
# 如果不指定,则该项是 False
|
||||||
|
|
||||||
|
[model]
|
||||||
|
model_max_output_length = 800 # 模型单次返回的最大token数
|
||||||
|
|
||||||
#这个模型必须是推理模型
|
#这个模型必须是推理模型
|
||||||
[model.llm_reasoning] # 一般聊天模式的推理回复模型
|
[model.reasoning] # 一般聊天模式的推理回复模型
|
||||||
name = "Pro/deepseek-ai/DeepSeek-R1"
|
name = "Pro/deepseek-ai/DeepSeek-R1"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 1.0 #模型的输入价格(非必填,可以记录消耗)
|
pri_in = 1.0 #模型的输入价格(非必填,可以记录消耗)
|
||||||
pri_out = 4.0 #模型的输出价格(非必填,可以记录消耗)
|
pri_out = 4.0 #模型的输出价格(非必填,可以记录消耗)
|
||||||
|
|
||||||
[model.llm_normal] #V3 回复模型 专注和一般聊天模式共用的回复模型
|
[model.normal] #V3 回复模型 专注和一般聊天模式共用的回复模型
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
|
pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
|
||||||
@@ -209,13 +202,13 @@ pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
|
|||||||
#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数
|
#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数
|
||||||
temp = 0.2 #模型的温度,新V3建议0.1-0.3
|
temp = 0.2 #模型的温度,新V3建议0.1-0.3
|
||||||
|
|
||||||
[model.llm_topic_judge] #主题判断模型:建议使用qwen2.5 7b
|
[model.topic_judge] #主题判断模型:建议使用qwen2.5 7b
|
||||||
name = "Pro/Qwen/Qwen2.5-7B-Instruct"
|
name = "Pro/Qwen/Qwen2.5-7B-Instruct"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 0.35
|
pri_in = 0.35
|
||||||
pri_out = 0.35
|
pri_out = 0.35
|
||||||
|
|
||||||
[model.llm_summary] #概括模型,建议使用qwen2.5 32b 及以上
|
[model.summary] #概括模型,建议使用qwen2.5 32b 及以上
|
||||||
name = "Qwen/Qwen2.5-32B-Instruct"
|
name = "Qwen/Qwen2.5-32B-Instruct"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 1.26
|
pri_in = 1.26
|
||||||
@@ -227,27 +220,27 @@ provider = "SILICONFLOW"
|
|||||||
pri_in = 0.35
|
pri_in = 0.35
|
||||||
pri_out = 0.35
|
pri_out = 0.35
|
||||||
|
|
||||||
[model.llm_heartflow] # 用于控制麦麦是否参与聊天的模型
|
[model.heartflow] # 用于控制麦麦是否参与聊天的模型
|
||||||
name = "Qwen/Qwen2.5-32B-Instruct"
|
name = "Qwen/Qwen2.5-32B-Instruct"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 1.26
|
pri_in = 1.26
|
||||||
pri_out = 1.26
|
pri_out = 1.26
|
||||||
|
|
||||||
[model.llm_observation] #观察模型,压缩聊天内容,建议用免费的
|
[model.observation] #观察模型,压缩聊天内容,建议用免费的
|
||||||
# name = "Pro/Qwen/Qwen2.5-7B-Instruct"
|
# name = "Pro/Qwen/Qwen2.5-7B-Instruct"
|
||||||
name = "Qwen/Qwen2.5-7B-Instruct"
|
name = "Qwen/Qwen2.5-7B-Instruct"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 0
|
pri_in = 0
|
||||||
pri_out = 0
|
pri_out = 0
|
||||||
|
|
||||||
[model.llm_sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型
|
[model.sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 2
|
pri_in = 2
|
||||||
pri_out = 8
|
pri_out = 8
|
||||||
temp = 0.3 #模型的温度,新V3建议0.1-0.3
|
temp = 0.3 #模型的温度,新V3建议0.1-0.3
|
||||||
|
|
||||||
[model.llm_plan] #决策:认真水群时,负责决定麦麦该做什么
|
[model.plan] #决策:认真水群时,负责决定麦麦该做什么
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 2
|
pri_in = 2
|
||||||
@@ -265,7 +258,7 @@ pri_out = 0
|
|||||||
#私聊PFC:需要开启PFC功能,默认三个模型均为硅基流动v3,如果需要支持多人同时私聊或频繁调用,建议把其中的一个或两个换成官方v3或其它模型,以免撞到429
|
#私聊PFC:需要开启PFC功能,默认三个模型均为硅基流动v3,如果需要支持多人同时私聊或频繁调用,建议把其中的一个或两个换成官方v3或其它模型,以免撞到429
|
||||||
|
|
||||||
#PFC决策模型
|
#PFC决策模型
|
||||||
[model.llm_PFC_action_planner]
|
[model.pfc_action_planner]
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
temp = 0.3
|
temp = 0.3
|
||||||
@@ -273,7 +266,7 @@ pri_in = 2
|
|||||||
pri_out = 8
|
pri_out = 8
|
||||||
|
|
||||||
#PFC聊天模型
|
#PFC聊天模型
|
||||||
[model.llm_PFC_chat]
|
[model.pfc_chat]
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
temp = 0.3
|
temp = 0.3
|
||||||
@@ -281,7 +274,7 @@ pri_in = 2
|
|||||||
pri_out = 8
|
pri_out = 8
|
||||||
|
|
||||||
#PFC检查模型
|
#PFC检查模型
|
||||||
[model.llm_PFC_reply_checker]
|
[model.pfc_reply_checker]
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 2
|
pri_in = 2
|
||||||
@@ -294,7 +287,7 @@ pri_out = 8
|
|||||||
#以下模型暂时没有使用!!
|
#以下模型暂时没有使用!!
|
||||||
#以下模型暂时没有使用!!
|
#以下模型暂时没有使用!!
|
||||||
|
|
||||||
[model.llm_tool_use] #工具调用模型,需要使用支持工具调用的模型,建议使用qwen2.5 32b
|
[model.tool_use] #工具调用模型,需要使用支持工具调用的模型,建议使用qwen2.5 32b
|
||||||
name = "Qwen/Qwen2.5-32B-Instruct"
|
name = "Qwen/Qwen2.5-32B-Instruct"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
pri_in = 1.26
|
pri_in = 1.26
|
||||||
|
|||||||
7
tests/test_config.py
Normal file
7
tests/test_config.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig:
|
||||||
|
def test_load(self):
|
||||||
|
config = global_config
|
||||||
|
print(config)
|
||||||
Reference in New Issue
Block a user