Merge branch 'new-storage' into plugin

This commit is contained in:
SengokuCola
2025-05-16 21:14:16 +08:00
63 changed files with 2397 additions and 2008 deletions

View File

@@ -5,12 +5,15 @@ import os
import random
import time
import traceback
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Any
from PIL import Image
import io
import re
from ...common.database import db
# from gradio_client import file
from ...common.database.database_model import Emoji
from ...common.database.database import db as peewee_db
from ...config.config import global_config
from ..utils.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest
@@ -51,7 +54,7 @@ class MaiEmoji:
self.is_deleted = False # 标记是否已被删除
self.format = ""
async def initialize_hash_format(self):
async def initialize_hash_format(self) -> Optional[bool]:
"""从文件创建表情包实例, 计算哈希值和格式"""
try:
# 使用 full_path 检查文件是否存在
@@ -104,7 +107,7 @@ class MaiEmoji:
self.is_deleted = True
return None
async def register_to_db(self):
async def register_to_db(self) -> bool:
"""
注册表情包
将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下
@@ -143,22 +146,22 @@ class MaiEmoji:
# --- 数据库操作 ---
try:
# 准备数据库记录 for emoji collection
emoji_record = {
"filename": self.filename,
"path": self.path, # 存储目录路径
"full_path": self.full_path, # 存储完整文件路径
"embedding": self.embedding,
"description": self.description,
"emotion": self.emotion,
"hash": self.hash,
"format": self.format,
"timestamp": int(self.register_time),
"usage_count": self.usage_count,
"last_used_time": self.last_used_time,
}
emotion_str = ",".join(self.emotion) if self.emotion else ""
# 使用upsert确保记录存在或被更新
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True)
Emoji.create(
hash=self.hash,
full_path=self.full_path,
format=self.format,
description=self.description,
emotion=emotion_str, # Store as comma-separated string
query_count=0, # Default value
is_registered=True,
is_banned=False, # Default value
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
register_time=self.register_time,
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -166,14 +169,6 @@ class MaiEmoji:
except Exception as db_error:
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
# 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误
# 可以考虑在这里尝试删除已移动的文件,避免残留
try:
if os.path.exists(self.full_path): # full_path 此时是目标路径
os.remove(self.full_path)
logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}")
except Exception as remove_error:
logger.error(f"[错误] 回滚删除文件失败: {remove_error}")
return False
except Exception as e:
@@ -181,7 +176,7 @@ class MaiEmoji:
logger.error(traceback.format_exc())
return False
async def delete(self):
async def delete(self) -> bool:
"""删除表情包
删除表情包的文件和数据库记录
@@ -201,10 +196,14 @@ class MaiEmoji:
# 文件删除失败,但仍然尝试删除数据库记录
# 2. 删除数据库记录
result = db.emoji.delete_one({"hash": self.hash})
deleted_in_db = result.deleted_count > 0
try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
if deleted_in_db:
if result > 0:
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
# 3. 标记对象已被删除
self.is_deleted = True
@@ -224,7 +223,7 @@ class MaiEmoji:
return False
def _emoji_objects_to_readable_list(emoji_objects):
def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
"""将表情包对象列表转换为可读的字符串列表
参数:
@@ -243,47 +242,48 @@ def _emoji_objects_to_readable_list(emoji_objects):
return emoji_info_list
def _to_emoji_objects(data):
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = []
load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data)
for emoji_data in emoji_data_list:
full_path = emoji_data.get("full_path")
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
full_path = emoji_data.full_path
if not full_path:
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}")
logger.warning(
f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}"
)
load_errors += 1
continue # 跳过缺少 full_path 的记录
continue
try:
# 使用 full_path 初始化 MaiEmoji 对象
emoji = MaiEmoji(full_path=full_path)
# 设置从数据库加载的属性
emoji.hash = emoji_data.get("hash", "")
# 如果 hash 为空,也跳过?取决于业务逻辑
emoji.hash = emoji_data.emoji_hash
if not emoji.hash:
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
load_errors += 1
continue
emoji.description = emoji_data.get("description", "")
emoji.emotion = emoji_data.get("emotion", [])
emoji.usage_count = emoji_data.get("usage_count", 0)
# 优先使用 last_used_time否则用 timestamp最后用当前时间
last_used = emoji_data.get("last_used_time")
timestamp = emoji_data.get("timestamp")
emoji.last_used_time = (
last_used if last_used is not None else (timestamp if timestamp is not None else time.time())
)
emoji.register_time = timestamp if timestamp is not None else time.time()
emoji.format = emoji_data.get("format", "") # 加载格式
emoji.description = emoji_data.description
# Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count
# 不需要再手动设置 path 和 filename__init__ 会自动处理
db_last_used_time = emoji_data.last_used_time
db_register_time = emoji_data.register_time
# If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time
emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
# If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
emoji.format = emoji_data.format
emoji_objects.append(emoji)
except ValueError as ve: # 捕获 __init__ 可能的错误
except ValueError as ve:
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
load_errors += 1
except Exception as e:
@@ -292,13 +292,13 @@ def _to_emoji_objects(data):
return emoji_objects, load_errors
def _ensure_emoji_dir():
def _ensure_emoji_dir() -> None:
"""确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
async def clear_temp_emoji():
async def clear_temp_emoji() -> None:
"""清理临时表情包
清理/data/emoji和/data/image目录下的所有文件
当目录中文件数超过100时会全部删除
@@ -320,7 +320,7 @@ async def clear_temp_emoji():
logger.success("[清理] 完成")
async def clean_unused_emojis(emoji_dir, emoji_objects):
async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None:
"""清理指定目录中未被 emoji_objects 追踪的表情包文件"""
if not os.path.exists(emoji_dir):
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
@@ -360,74 +360,52 @@ async def clean_unused_emojis(emoji_dir, emoji_objects):
class EmojiManager:
_instance = None
def __new__(cls):
def __new__(cls) -> "EmojiManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
def __init__(self) -> None:
self._initialized = None
self._scan_task = None
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.llm_emotion_judge = LLMRequest(
model=global_config.llm_normal, max_tokens=600, request_type="emoji"
model=global_config.model.normal, max_tokens=600, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
self.emoji_num_max = global_config.emoji.max_reg_num
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表使用类型注解明确列表元素类型
logger.info("启动表情包管理器")
def initialize(self):
def initialize(self) -> None:
"""初始化数据库连接和表情目录"""
if not self._initialized:
try:
self._ensure_emoji_collection()
_ensure_emoji_dir()
self._initialized = True
# 更新表情包数量
# 启动时执行一次完整性检查
# await self.check_emoji_file_integrity()
except Exception as e:
logger.exception(f"初始化表情管理器失败: {e}")
peewee_db.connect(reuse_if_open=True)
if peewee_db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
Emoji.create_table(safe=True) # Ensures table exists
def _ensure_db(self):
def _ensure_db(self) -> None:
"""确保数据库已初始化"""
if not self._initialized:
self.initialize()
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
@staticmethod
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
索引的作用是加快数据库查询速度:
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
- tags字段的普通索引: 加快按标签搜索表情包的速度
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_hash: str):
def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}})
for emoji in self.emoji_objects:
if emoji.hash == emoji_hash:
emoji.usage_count += 1
break
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
except Emoji.DoesNotExist:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -447,7 +425,6 @@ class EmojiManager:
if not all_emojis:
logger.warning("内存中没有任何表情包对象")
# 可以考虑再查一次数据库?或者依赖定期任务更新
return None
# 计算每个表情包与输入文本的最大情感相似度
@@ -463,40 +440,38 @@ class EmojiManager:
# 计算与每个emotion标签的相似度取最大值
max_similarity = 0
best_matching_emotion = "" # 记录最匹配的 emotion 喵~
best_matching_emotion = ""
for emotion in emotions:
# 使用编辑距离计算相似度
distance = self._levenshtein_distance(text_emotion, emotion)
max_len = max(len(text_emotion), len(emotion))
similarity = 1 - (distance / max_len if max_len > 0 else 0)
if similarity > max_similarity: # 如果找到更相似的喵~
if similarity > max_similarity:
max_similarity = similarity
best_matching_emotion = emotion # 就记下这个 emotion 喵~
best_matching_emotion = emotion
if best_matching_emotion: # 确保有匹配的情感才添加喵~
emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~
if best_matching_emotion:
emoji_similarities.append((emoji, max_similarity, best_matching_emotion))
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前10个最相似的表情包
top_emojis = (
emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
) # 改个名字,更清晰喵~
top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
if not top_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前几个中随机选择一个
selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~
selected_emoji, similarity, matched_emotion = random.choice(top_emojis)
# 更新使用次数
self.record_usage(selected_emoji.hash)
self.record_usage(selected_emoji.emoji_hash)
_time_end = time.time()
logger.info( # 使用匹配到的 emotion 记录日志喵~
logger.info(
f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}"
)
# 返回完整文件路径和描述
@@ -534,7 +509,7 @@ class EmojiManager:
return previous_row[-1]
async def check_emoji_file_integrity(self):
async def check_emoji_file_integrity(self) -> None:
"""检查表情包文件完整性
遍历self.emoji_objects中的所有对象检查文件是否存在
如果文件已被删除,则执行对象的删除方法并从列表中移除
@@ -599,7 +574,7 @@ class EmojiManager:
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc())
async def start_periodic_check_register(self):
async def start_periodic_check_register(self) -> None:
"""定期检查表情包完整性和数量"""
await self.get_all_emoji_from_db()
while True:
@@ -613,18 +588,18 @@ class EmojiManager:
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
os.makedirs(EMOJI_DIR, exist_ok=True)
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
continue
# 检查目录是否为空
files = os.listdir(EMOJI_DIR)
if not files:
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
continue
# 检查是否需要处理表情包(数量超过最大值或不足)
if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or (
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
self.emoji_num < self.emoji_num_max
):
try:
@@ -651,15 +626,16 @@ class EmojiManager:
except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
async def get_all_emoji_from_db(self):
async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
self._ensure_db()
logger.info("[数据库] 开始加载所有表情包记录...")
logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...")
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find())
emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
# 更新内存中的列表和数量
self.emoji_objects = emoji_objects
@@ -674,7 +650,7 @@ class EmojiManager:
self.emoji_objects = [] # 加载失败则清空列表
self.emoji_num = 0
async def get_emoji_from_db(self, emoji_hash=None):
async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
参数:
@@ -686,15 +662,16 @@ class EmojiManager:
try:
self._ensure_db()
query = {}
if emoji_hash:
query = {"hash": emoji_hash}
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query))
emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -705,7 +682,7 @@ class EmojiManager:
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
return []
async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]:
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
"""从内存中的 emoji_objects 列表获取表情包
参数:
@@ -758,7 +735,7 @@ class EmojiManager:
logger.error(traceback.format_exc())
return False
async def replace_a_emoji(self, new_emoji: MaiEmoji):
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
"""替换一个表情包
Args:
@@ -788,7 +765,7 @@ class EmojiManager:
# 构建提示词
prompt = (
f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
f"新表情包信息:\n"
f"描述: {new_emoji.description}\n\n"
@@ -819,7 +796,7 @@ class EmojiManager:
# 删除选定的表情包
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
delete_success = await self.delete_emoji(emoji_to_delete.hash)
delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash)
if delete_success:
# 修复:等待异步注册完成
@@ -847,7 +824,7 @@ class EmojiManager:
logger.error(traceback.format_exc())
return False
async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]:
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
"""获取表情包描述和情感列表
Args:
@@ -871,10 +848,10 @@ class EmojiManager:
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
# 审核表情包
if global_config.EMOJI_CHECK:
if global_config.emoji.content_filtration:
prompt = f'''
这是一个表情包,请对这个表情包进行审核,标准如下:
1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
1. 必须符合"{global_config.emoji.filtration_prompt}"的要求
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
3. 不能是任何形式的截图,聊天记录或视频截图
4. 不要出现5个以上文字

View File

@@ -76,9 +76,10 @@ def init_prompt():
class DefaultExpressor:
def __init__(self, chat_id: str):
self.log_prefix = "expressor"
# TODO: API-Adapter修改标记
self.express_model = LLMRequest(
model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"],
model=global_config.model.normal,
temperature=global_config.model.normal["temp"],
max_tokens=256,
request_type="response_heartflow",
)
@@ -102,8 +103,8 @@ class DefaultExpressor:
messageinfo = anchor_message.message_info
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
# logger.debug(f"创建思考消息:{anchor_message}")
@@ -192,7 +193,7 @@ class DefaultExpressor:
try:
# 1. 获取情绪影响因子并调整模型温度
arousal_multiplier = mood_manager.get_arousal_multiplier()
current_temp = float(global_config.llm_normal["temp"]) * arousal_multiplier
current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
self.express_model.params["temperature"] = current_temp # 动态调整温度
# 2. 获取信息捕捉器
@@ -231,6 +232,7 @@ class DefaultExpressor:
try:
with Timer("LLM生成", {}): # 内部计时器,可选保留
# TODO: API-Adapter修改标记
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
content, reasoning_content, model_name = await self.express_model.generate_response(prompt)
@@ -482,8 +484,8 @@ class DefaultExpressor:
"""构建单个发送消息"""
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
)

View File

@@ -77,8 +77,9 @@ def init_prompt() -> None:
class ExpressionLearner:
def __init__(self) -> None:
# TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest(
model=global_config.llm_normal,
model=global_config.model.normal,
temperature=0.1,
max_tokens=256,
request_type="response_heartflow",
@@ -289,7 +290,7 @@ class ExpressionLearner:
# 构建prompt
prompt = await global_prompt_manager.format_prompt(
"personality_expression_prompt",
personality=global_config.expression_style,
personality=global_config.personality.expression_style,
)
# logger.info(f"个性表达方式提取prompt: {prompt}")

View File

@@ -112,7 +112,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否包含过滤词
"""
for word in global_config.ban_words:
for word in global_config.chat.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
@@ -132,7 +132,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否匹配过滤正则
"""
for pattern in global_config.ban_msgs_regex:
for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")

View File

@@ -13,6 +13,9 @@ from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
import random
import json
import math
from src.common.database.database_model import Knowledges
logger = get_logger("prompt")
@@ -45,7 +48,7 @@ def init_prompt():
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt}{reply_style1}
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)只输出回复内容。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情at或 @等 )。只输出回复内容。
{moderation_prompt}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""",
"reasoning_prompt_main",
@@ -110,7 +113,7 @@ class PromptBuilder:
who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
limit=global_config.observation_context_size,
limit=global_config.chat.observation_context_size,
)
elif chat_stream.user_info:
who_chat_in_group.append(
@@ -158,7 +161,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.observation_context_size,
limit=global_config.chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
@@ -170,18 +173,15 @@ class PromptBuilder:
# 关键词检测与反应
keywords_reaction_prompt = ""
for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
logger.info(
f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}"
)
keywords_reaction_prompt += rule.get("reaction", "") + ""
for rule in global_config.keyword_reaction.rules:
if rule.enable:
if any(keyword in message_txt for keyword in rule.keywords):
logger.info(f"检测到以下关键词之一:{rule.keywords},触发反应:{rule.reaction}")
keywords_reaction_prompt += f"{rule.reaction}"
else:
for pattern in rule.get("regex", []):
result = pattern.search(message_txt)
if result:
reaction = rule.get("reaction", "")
for pattern in rule.regex:
if result := pattern.search(message_txt):
reaction = rule.reaction
for name, content in result.groupdict().items():
reaction = reaction.replace(f"[{name}]", content)
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
@@ -227,8 +227,8 @@ class PromptBuilder:
chat_target_2=chat_target_2,
chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME,
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
bot_name=global_config.bot.nickname,
bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen,
@@ -249,8 +249,8 @@ class PromptBuilder:
prompt_info=prompt_info,
chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME,
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
bot_name=global_config.bot.nickname,
bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen,
@@ -269,30 +269,6 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
@@ -402,8 +378,6 @@ class PromptBuilder:
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
@@ -432,14 +406,14 @@ class PromptBuilder:
return related_info
else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
try:
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38)
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
@@ -455,70 +429,70 @@ class PromptBuilder:
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
logger.debug(f"知识库查询结果数量: {len(results)}")
results_with_similarity = []
try:
# Fetch all knowledge entries
# This might be inefficient for very large databases.
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
all_knowledges = Knowledges.select()
if not results:
if not all_knowledges:
return [] if return_raw else ""
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
if query_embedding_magnitude == 0: # Avoid division by zero
return "" if not return_raw else []
for knowledge_item in all_knowledges:
try:
db_embedding_str = knowledge_item.embedding
db_embedding = json.loads(db_embedding_str)
if len(db_embedding) != len(query_embedding):
logger.warning(
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
)
continue
# Calculate Cosine Similarity
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
if db_embedding_magnitude == 0: # Avoid division by zero
similarity = 0.0
else:
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
if similarity >= threshold:
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
except json.JSONDecodeError:
logger.error(
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
)
except Exception as e:
logger.error(f"Error processing knowledge item: {e}")
# Sort by similarity in descending order
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
# Limit results
limited_results = results_with_similarity[:limit]
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
if not limited_results:
return "" if not return_raw else []
if return_raw:
return limited_results
else:
return "\n".join(str(result["content"]) for result in limited_results)
except Exception as e:
logger.error(f"Error querying Knowledges with Peewee: {e}")
return "" if not return_raw else []
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
init_prompt()
prompt_builder = PromptBuilder()

View File

@@ -26,8 +26,9 @@ class ChattingInfoProcessor(BaseProcessor):
def __init__(self):
"""初始化观察处理器"""
super().__init__()
# TODO: API-Adapter修改标记
self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def process_info(
@@ -110,12 +111,12 @@ class ChattingInfoProcessor(BaseProcessor):
"created_at": datetime.now().timestamp(),
}
obs.mid_memorys.append(mid_memory)
if len(obs.mid_memorys) > obs.max_mid_memory_len:
obs.mid_memorys.pop(0) # 移除最旧的
obs.mid_memories.append(mid_memory)
if len(obs.mid_memories) > obs.max_mid_memory_len:
obs.mid_memories.pop(0) # 移除最旧的
mid_memory_str = "之前聊天的内容概述是:\n"
for mid_memory_item in obs.mid_memorys: # 重命名循环变量以示区分
for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分
time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60)
mid_memory_str += (
f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}){mid_memory_item['theme']}\n"

View File

@@ -71,8 +71,8 @@ class MindProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
model=global_config.llm_sub_heartflow,
temperature=global_config.llm_sub_heartflow["temp"],
model=global_config.model.sub_heartflow,
temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800,
request_type="sub_heart_flow",
)

View File

@@ -49,7 +49,7 @@ class ToolProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id
self.log_prefix = f"[{subheartflow_id}:ToolExecutor] "
self.llm_model = LLMRequest(
model=global_config.llm_tool_use,
model=global_config.model.tool_use,
max_tokens=500,
request_type="tool_execution",
)

View File

@@ -34,8 +34,9 @@ def init_prompt():
class MemoryActivator:
def __init__(self):
# TODO: API-Adapter修改标记
self.summary_model = LLMRequest(
model=global_config.llm_summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
model=global_config.model.summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
)
self.running_memory = []

View File

@@ -35,8 +35,9 @@ class Heartflow:
self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state)
# LLM模型配置
# TODO: API-Adapter修改标记
self.llm_model = LLMRequest(
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
model=global_config.model.heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
)
# 外部依赖模块

View File

@@ -20,9 +20,9 @@ MAX_REPLY_PROBABILITY = 1
class InterestChatting:
def __init__(
self,
decay_rate=global_config.default_decay_rate_per_second,
decay_rate=global_config.focus_chat.default_decay_rate_per_second,
max_interest=MAX_INTEREST,
trigger_threshold=global_config.reply_trigger_threshold,
trigger_threshold=global_config.focus_chat.reply_trigger_threshold,
max_probability=MAX_REPLY_PROBABILITY,
):
# 基础属性初始化

View File

@@ -18,19 +18,14 @@ enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
prevent_offline_state = True
# 目前默认不启用OFFLINE状态
# 不同状态下普通聊天的最大消息数
base_normal_chat_num = global_config.base_normal_chat_num
base_focused_chat_num = global_config.base_focused_chat_num
MAX_NORMAL_CHAT_NUM_PEEKING = int(base_normal_chat_num / 2)
MAX_NORMAL_CHAT_NUM_NORMAL = base_normal_chat_num
MAX_NORMAL_CHAT_NUM_FOCUSED = base_normal_chat_num + 1
MAX_NORMAL_CHAT_NUM_PEEKING = int(global_config.chat.base_normal_chat_num / 2)
MAX_NORMAL_CHAT_NUM_NORMAL = global_config.chat.base_normal_chat_num
MAX_NORMAL_CHAT_NUM_FOCUSED = global_config.chat.base_normal_chat_num + 1
# 不同状态下专注聊天的最大消息数
MAX_FOCUSED_CHAT_NUM_PEEKING = int(base_focused_chat_num / 2)
MAX_FOCUSED_CHAT_NUM_NORMAL = base_focused_chat_num
MAX_FOCUSED_CHAT_NUM_FOCUSED = base_focused_chat_num + 2
MAX_FOCUSED_CHAT_NUM_PEEKING = int(global_config.chat.base_focused_chat_num / 2)
MAX_FOCUSED_CHAT_NUM_NORMAL = global_config.chat.base_focused_chat_num
MAX_FOCUSED_CHAT_NUM_FOCUSED = global_config.chat.base_focused_chat_num + 2
# -- 状态定义 --

View File

@@ -55,19 +55,20 @@ class ChattingObservation(Observation):
self.talking_message = []
self.talking_message_str = ""
self.talking_message_str_truncate = ""
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.max_now_obs_len = global_config.observation_context_size
self.overlap_len = global_config.compressed_length
self.mid_memorys = []
self.max_mid_memory_len = global_config.compress_length_limit
self.name = global_config.bot.nickname
self.nick_name = global_config.bot.alias_names
self.max_now_obs_len = global_config.chat.observation_context_size
self.overlap_len = global_config.focus_chat.compressed_length
self.mid_memories = []
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
self.mid_memory_info = ""
self.person_list = []
self.oldest_messages = []
self.oldest_messages_str = ""
self.compressor_prompt = ""
# TODO: API-Adapter修改标记
self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def initialize(self):
@@ -85,7 +86,7 @@ class ChattingObservation(Observation):
for id in ids:
print(f"id{id}")
try:
for mid_memory in self.mid_memorys:
for mid_memory in self.mid_memories:
if mid_memory["id"] == id:
mid_memory_by_id = mid_memory
msg_str = ""
@@ -103,7 +104,7 @@ class ChattingObservation(Observation):
else:
mid_memory_str = "之前的聊天内容:\n"
for mid_memory in self.mid_memorys:
for mid_memory in self.mid_memories:
mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str

View File

@@ -76,8 +76,9 @@ class SubHeartflowManager:
# 为 LLM 状态评估创建一个 LLMRequest 实例
# 使用与 Heartflow 相同的模型和参数
# TODO: API-Adapter修改标记
self.llm_state_evaluator = LLMRequest(
model=global_config.llm_heartflow, # 与 Heartflow 一致
model=global_config.model.heartflow, # 与 Heartflow 一致
temperature=0.6, # 与 Heartflow 一致
max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多)
request_type="subheartflow_state_eval", # 保留特定的请求类型
@@ -278,7 +279,7 @@ class SubHeartflowManager:
focused_limit = current_state.get_focused_chat_max_num()
# --- 新增:检查是否允许进入 FOCUS 模式 --- #
if not global_config.allow_focus_mode:
if not global_config.chat.allow_focus_mode:
if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏
logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)")
return # 如果不允许,直接返回
@@ -766,7 +767,7 @@ class SubHeartflowManager:
focused_limit = current_mai_state.get_focused_chat_max_num()
# --- 检查是否允许 FOCUS 模式 --- #
if not global_config.allow_focus_mode:
if not global_config.chat.allow_focus_mode:
# Log less frequently to avoid spam
# if int(time.time()) % 60 == 0:
# logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态")

View File

@@ -10,7 +10,7 @@ import jieba
import networkx as nx
import numpy as np
from collections import Counter
from ...common.database import db
from ...common.database.database import memory_db as db
from ...chat.models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
@@ -19,9 +19,10 @@ from ..utils.chat_message_builder import (
build_readable_messages,
) # 导入 build_readable_messages
from ..utils.utils import translate_timestamp_to_human_readable
from .memory_config import MemoryConfig
from rich.traceback import install
from ...config.config import global_config
install(extra_lines=3)
@@ -195,18 +196,16 @@ class Hippocampus:
self.llm_summary = None
self.entorhinal_cortex = None
self.parahippocampal_gyrus = None
self.config = None
def initialize(self, global_config):
# 使用导入的 MemoryConfig dataclass 和其 from_global_config 方法
self.config = MemoryConfig.from_global_config(global_config)
def initialize(self):
# 初始化子组件
self.entorhinal_cortex = EntorhinalCortex(self)
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory")
self.llm_summary = LLMRequest(self.config.llm_summary, request_type="memory")
# TODO: API-Adapter修改标记
self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory")
self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory")
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
@@ -792,7 +791,6 @@ class EntorhinalCortex:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.config = hippocampus.config
def get_memory_sample(self):
"""从数据库获取记忆样本"""
@@ -801,13 +799,13 @@ class EntorhinalCortex:
# 创建双峰分布的记忆调度器
sample_scheduler = MemoryBuildScheduler(
n_hours1=self.config.memory_build_distribution[0],
std_hours1=self.config.memory_build_distribution[1],
weight1=self.config.memory_build_distribution[2],
n_hours2=self.config.memory_build_distribution[3],
std_hours2=self.config.memory_build_distribution[4],
weight2=self.config.memory_build_distribution[5],
total_samples=self.config.build_memory_sample_num,
n_hours1=global_config.memory.memory_build_distribution[0],
std_hours1=global_config.memory.memory_build_distribution[1],
weight1=global_config.memory.memory_build_distribution[2],
n_hours2=global_config.memory.memory_build_distribution[3],
std_hours2=global_config.memory.memory_build_distribution[4],
weight2=global_config.memory.memory_build_distribution[5],
total_samples=global_config.memory.memory_build_sample_num,
)
timestamps = sample_scheduler.get_timestamp_array()
@@ -818,7 +816,7 @@ class EntorhinalCortex:
for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet
messages = self.random_get_msg_snippet(
timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
)
if messages:
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
@@ -1099,7 +1097,6 @@ class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.config = hippocampus.config
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。
@@ -1159,7 +1156,7 @@ class ParahippocampalGyrus:
# 3. 过滤掉包含禁用关键词的topic
filtered_topics = [
topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words)
topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
]
logger.debug(f"过滤后话题: {filtered_topics}")
@@ -1222,7 +1219,7 @@ class ParahippocampalGyrus:
bar = "" * filled_length + "-" * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
compress_rate = self.config.memory_compress_rate
compress_rate = global_config.memory.memory_compress_rate
try:
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
except Exception as e:
@@ -1322,7 +1319,7 @@ class ParahippocampalGyrus:
edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get("last_modified")
if current_time - last_modified > 3600 * self.config.memory_forget_time:
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
current_strength = edge_data.get("strength", 1)
new_strength = current_strength - 1
@@ -1430,8 +1427,8 @@ class ParahippocampalGyrus:
async def operation_consolidate_memory(self):
"""整合记忆:合并节点内相似的记忆项"""
start_time = time.time()
percentage = self.config.consolidate_memory_percentage
similarity_threshold = self.config.consolidation_similarity_threshold
percentage = global_config.memory.consolidate_memory_percentage
similarity_threshold = global_config.memory.consolidation_similarity_threshold
logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}")
# 获取所有至少有2条记忆项的节点
@@ -1544,7 +1541,6 @@ class ParahippocampalGyrus:
class HippocampusManager:
_instance = None
_hippocampus = None
_global_config = None
_initialized = False
@classmethod
@@ -1559,19 +1555,15 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return cls._hippocampus
def initialize(self, global_config):
def initialize(self):
"""初始化海马体实例"""
if self._initialized:
return self._hippocampus
self._global_config = global_config
self._hippocampus = Hippocampus()
self._hippocampus.initialize(global_config)
self._hippocampus.initialize()
self._initialized = True
# 输出记忆系统参数信息
config = self._hippocampus.config
# 输出记忆图统计信息
memory_graph = self._hippocampus.memory_graph.G
node_count = len(memory_graph.nodes())
@@ -1579,9 +1571,9 @@ class HippocampusManager:
logger.success(f"""--------------------------------
记忆系统参数配置:
构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate}
记忆构建分布: {config.memory_build_distribution}
遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后
构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
记忆构建分布: {global_config.memory.memory_build_distribution}
遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
--------------------------------""") # noqa: E501

View File

@@ -7,7 +7,6 @@ import os
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.config.config import global_config
from rich.traceback import install
install(extra_lines=3)
@@ -19,7 +18,7 @@ async def test_memory_system():
# 初始化记忆系统
print("开始初始化记忆系统...")
hippocampus_manager = HippocampusManager.get_instance()
hippocampus_manager.initialize(global_config=global_config)
hippocampus_manager.initialize()
print("记忆系统初始化完成")
# 测试记忆构建

View File

@@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.logger import get_module_logger # noqa E402
from src.common.database import db # noqa E402
from common.database.database import db # noqa E402
logger = get_module_logger("mem_alter")
console = Console()

View File

@@ -1,48 +0,0 @@
from dataclasses import dataclass
from typing import List
@dataclass
class MemoryConfig:
"""记忆系统配置类"""
# 记忆构建相关配置
memory_build_distribution: List[float] # 记忆构建的时间分布参数
build_memory_sample_num: int # 每次构建记忆的样本数量
build_memory_sample_length: int # 每个样本的消息长度
memory_compress_rate: float # 记忆压缩率
# 记忆遗忘相关配置
memory_forget_time: int # 记忆遗忘时间(小时)
# 记忆过滤相关配置
memory_ban_words: List[str] # 记忆过滤词列表
# 新增:记忆整合相关配置
consolidation_similarity_threshold: float # 相似度阈值
consolidate_memory_percentage: float # 检查节点比例
consolidate_memory_interval: int # 记忆整合间隔
llm_topic_judge: str # 话题判断模型
llm_summary: str # 话题总结模型
@classmethod
def from_global_config(cls, global_config):
"""从全局配置创建记忆系统配置"""
# 使用 getattr 提供默认值,防止全局配置缺少这些项
return cls(
memory_build_distribution=getattr(
global_config, "memory_build_distribution", (24, 12, 0.5, 168, 72, 0.5)
), # 添加默认值
build_memory_sample_num=getattr(global_config, "build_memory_sample_num", 5),
build_memory_sample_length=getattr(global_config, "build_memory_sample_length", 30),
memory_compress_rate=getattr(global_config, "memory_compress_rate", 0.1),
memory_forget_time=getattr(global_config, "memory_forget_time", 24 * 7),
memory_ban_words=getattr(global_config, "memory_ban_words", []),
# 新增加载整合配置,并提供默认值
consolidation_similarity_threshold=getattr(global_config, "consolidation_similarity_threshold", 0.7),
consolidate_memory_percentage=getattr(global_config, "consolidate_memory_percentage", 0.01),
consolidate_memory_interval=getattr(global_config, "consolidate_memory_interval", 1000),
llm_topic_judge=getattr(global_config, "llm_topic_judge", "default_judge_model"), # 添加默认模型名
llm_summary=getattr(global_config, "llm_summary", "default_summary_model"), # 添加默认模型名
)

View File

@@ -41,7 +41,7 @@ class ChatBot:
chat_id = str(message.chat_stream.stream_id)
private_name = str(message.message_info.user_info.user_nickname)
if global_config.enable_pfc_chatting:
if global_config.experimental.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
except Exception as e:
@@ -78,19 +78,19 @@ class ChatBot:
userinfo = message.message_info.user_info
# 用户黑名单拦截
if userinfo.user_id in global_config.ban_user_id:
if userinfo.user_id in global_config.chat_target.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复")
return
if groupinfo is None:
logger.trace("检测到私聊消息,检查")
# 好友黑名单拦截
if userinfo.user_id not in global_config.talk_allowed_private:
if userinfo.user_id not in global_config.experimental.talk_allowed_private:
logger.debug(f"用户{userinfo.user_id}没有私聊权限")
return
# 群聊黑名单拦截
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups:
logger.trace(f"{groupinfo.group_id}被禁止回复")
return
@@ -112,7 +112,7 @@ class ChatBot:
if groupinfo is None:
logger.trace("检测到私聊消息")
# 是否在配置信息中开启私聊模式
if global_config.enable_friend_chat:
if global_config.experimental.enable_friend_chat:
logger.trace("私聊模式已启用")
# 是否进入PFC
if global_config.enable_pfc_chatting:

View File

@@ -5,7 +5,8 @@ import copy
from typing import Dict, Optional
from ...common.database import db
from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger
@@ -82,7 +83,13 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self._ensure_collection()
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
@@ -107,15 +114,6 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
@staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
@@ -151,16 +149,43 @@ class ChatManager:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream = copy.deepcopy(stream)
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info
if group_info:
stream.group_info = group_info
return stream
# 检查数据库中是否存在
data = db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
def _db_find_stream_sync(s_id: str):
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
@@ -175,7 +200,7 @@ class ChatManager:
group_info=group_info,
)
except Exception as e:
logger.error(f"创建聊天流失败: {e}")
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e
# 保存到内存和数据库
@@ -205,15 +230,38 @@ class ChatManager:
elif stream.user_info and stream.user_info.user_nickname:
return f"{stream.user_info.user_nickname}的私聊"
else:
# 如果没有群名或用户昵称,返回 None 或其他默认值
return None
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)
stream.saved = True
stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
async def _save_all_streams(self):
"""保存所有聊天流"""
@@ -222,10 +270,44 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
# 创建全局单例

View File

@@ -38,7 +38,7 @@ class MessageBuffer:
async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲"""
if not global_config.message_buffer:
if not global_config.chat.message_buffer:
person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id
)
@@ -107,7 +107,7 @@ class MessageBuffer:
async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理"""
if not global_config.message_buffer:
if not global_config.chat.message_buffer:
return True
person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info

View File

@@ -279,7 +279,7 @@ class MessageManager:
)
# 检查是否超时
if thinking_time > global_config.thinking_timeout:
if thinking_time > global_config.normal_chat.thinking_timeout:
logger.warning(
f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}"
)

View File

@@ -1,9 +1,10 @@
import re
from typing import Union
from ...common.database import db
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models
from src.common.logger import get_module_logger
logger = get_module_logger("message_storage")
@@ -29,42 +30,66 @@ class MessageStorage:
else:
filtered_detailed_plain_text = ""
message_data = {
"message_id": message.message_info.message_id,
"time": message.message_info.time,
"chat_id": chat_stream.stream_id,
"chat_info": chat_stream.to_dict(),
"user_info": message.message_info.user_info.to_dict(),
# 使用过滤后的文本
"processed_plain_text": filtered_processed_plain_text,
"detailed_plain_text": filtered_detailed_plain_text,
"memorized_times": message.memorized_times,
}
db.messages.insert_one(message_data)
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict()
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
Messages.create(
message_id=msg_id,
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
# Flattened chat_info
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
processed_plain_text=filtered_processed_plain_text,
detailed_plain_text=filtered_detailed_plain_text,
memorized_times=message.memorized_times,
)
except Exception:
logger.exception("存储消息失败")
@staticmethod
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
"""存储撤回消息到数据库"""
if "recalled_messages" not in db.list_collection_names():
db.create_collection("recalled_messages")
else:
try:
message_data = {
"message_id": message_id,
"time": time,
"stream_id": chat_stream.stream_id,
}
db.recalled_messages.insert_one(message_data)
except Exception:
logger.exception("存储撤回消息失败")
# Table creation is handled by initialize_database in database_model.py
try:
RecalledMessages.create(
message_id=message_id,
time=float(time), # Assuming time is a string representing a float timestamp
stream_id=chat_stream.stream_id,
)
except Exception:
logger.exception("存储撤回消息失败")
@staticmethod
async def remove_recalled_message(time: str) -> None:
"""删除撤回消息"""
try:
db.recalled_messages.delete_many({"time": {"$lt": time - 300}})
# Assuming input 'time' is a string timestamp that can be converted to float
current_time_float = float(time)
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
except Exception:
logger.exception("删除撤回消息失败")

View File

@@ -12,7 +12,8 @@ import base64
from PIL import Image
import io
import os
from ...common.database import db
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from ...config.config import global_config
from rich.traceback import install
@@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload
@@ -111,8 +110,8 @@ class LLMRequest:
def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = os.environ[model["key"]]
self.base_url = os.environ[model["base_url"]]
self.api_key = os.environ[f"{model['provider']}_KEY"]
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
@@ -134,13 +133,11 @@ class LLMRequest:
def _init_database():
"""初始化数据库集合"""
try:
# 创建llm_usage集合的索引
db.llm_usage.create_index([("timestamp", 1)])
db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建数据库索引失败: {str(e)}")
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def _record_usage(
self,
@@ -165,19 +162,19 @@ class LLMRequest:
request_type = self.request_type
try:
usage_data = {
"model_name": self.model_name,
"user_id": user_id,
"request_type": request_type,
"endpoint": endpoint,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"cost": self._calculate_cost(prompt_tokens, completion_tokens),
"status": "success",
"timestamp": datetime.now(),
}
db.llm_usage.insert_one(usage_data)
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=self.model_name,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(prompt_tokens, completion_tokens),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.trace(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
@@ -500,11 +497,11 @@ class LLMRequest:
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
global_config.llm_normal["name"] = self.model_name
if global_config.model.normal.get("name") == old_model_name:
global_config.model.normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
if global_config.model.reasoning.get("name") == old_model_name:
global_config.model.reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
if payload and "model" in payload:
@@ -636,7 +633,7 @@ class LLMRequest:
**params_copy,
}
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
payload["max_tokens"] = global_config.model_max_output_length
payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")

View File

@@ -73,8 +73,8 @@ class NormalChat:
messageinfo = message.message_info
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
@@ -121,8 +121,8 @@ class NormalChat:
message_id=thinking_id,
chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=message.message_info.platform,
),
sender_info=message.message_info.user_info,
@@ -147,7 +147,7 @@ class NormalChat:
# 改为实例方法
async def _handle_emoji(self, message: MessageRecv, response: str):
"""处理表情包"""
if random() < global_config.emoji_chance:
if random() < global_config.normal_chat.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response)
if emoji_raw:
emoji_path, description = emoji_raw
@@ -160,8 +160,8 @@ class NormalChat:
message_id="mt" + str(thinking_time_point),
chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=message.message_info.platform,
),
sender_info=message.message_info.user_info,
@@ -186,7 +186,7 @@ class NormalChat:
label=emotion,
stance=stance, # 使用 self.chat_stream
)
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
async def _reply_interested_message(self) -> None:
"""
@@ -430,7 +430,7 @@ class NormalChat:
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息中是否包含过滤词"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
for word in global_config.ban_words:
for word in global_config.chat.ban_words:
if word in text:
logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
@@ -445,7 +445,7 @@ class NormalChat:
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
for pattern in global_config.ban_msgs_regex:
for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"

View File

@@ -15,21 +15,22 @@ logger = get_logger("llm")
class NormalChatGenerator:
def __init__(self):
# TODO: API-Adapter修改标记
self.model_reasoning = LLMRequest(
model=global_config.llm_reasoning,
model=global_config.model.reasoning,
temperature=0.7,
max_tokens=3000,
request_type="response_reasoning",
)
self.model_normal = LLMRequest(
model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"],
model=global_config.model.normal,
temperature=global_config.model.normal["temp"],
max_tokens=256,
request_type="response_reasoning",
)
self.model_sum = LLMRequest(
model=global_config.llm_summary, temperature=0.7, max_tokens=3000, request_type="relation"
model=global_config.model.summary, temperature=0.7, max_tokens=3000, request_type="relation"
)
self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model"
@@ -37,7 +38,7 @@ class NormalChatGenerator:
async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
if random.random() < global_config.model_reasoning_probability:
if random.random() < global_config.normal_chat.reasoning_model_probability:
self.current_model_type = "深深地"
current_model = self.model_reasoning
else:
@@ -51,7 +52,7 @@ class NormalChatGenerator:
model_response = await self._generate_response_with_model(message, current_model, thinking_id)
if model_response:
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
logger.info(f"{global_config.bot.nickname}的回复是:{model_response}")
model_response = await self._process_response(model_response)
return model_response
@@ -113,7 +114,7 @@ class NormalChatGenerator:
- "中立":不表达明确立场或无关回应
2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签
3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒"
4. 考虑回复者的人格设定为{global_config.personality_core}
4. 考虑回复者的人格设定为{global_config.personality.personality_core}
对话示例:
被回复「A就是笨」

View File

@@ -1,18 +1,20 @@
import asyncio
from src.config.config import global_config
from .willing_manager import BaseWillingManager
class ClassicalWillingManager(BaseWillingManager):
def __init__(self):
super().__init__()
self._decay_task: asyncio.Task = None
self._decay_task: asyncio.Task | None = None
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(1)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.9)
self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9)
async def async_task_starter(self):
if self._decay_task is None:
@@ -23,35 +25,33 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = willing_info.chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
interested_rate = willing_info.interested_rate * self.global_config.response_interested_rate_amplifier
interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
if interested_rate > 0.4:
current_willing += interested_rate - 0.3
if willing_info.is_mentioned_bot and current_willing < 1.0:
current_willing += 1
elif willing_info.is_mentioned_bot:
current_willing += 0.05
if willing_info.is_mentioned_bot:
current_willing += 1 if current_willing < 1.0 else 0.05
is_emoji_not_reply = False
if willing_info.is_emoji:
if self.global_config.emoji_response_penalty != 0:
current_willing *= self.global_config.emoji_response_penalty
if global_config.normal_chat.emoji_response_penalty != 0:
current_willing *= global_config.normal_chat.emoji_response_penalty
else:
is_emoji_not_reply = True
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(
max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1
max((current_willing - 0.5), 0.01) * global_config.normal_chat.response_willing_amplifier * 2, 1
)
# 检查群组权限(如果是群聊)
if (
willing_info.group_info
and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups
and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups
):
reply_probability = reply_probability / self.global_config.down_frequency_rate
reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate
if is_emoji_not_reply:
reply_probability = 0
@@ -61,7 +61,7 @@ class ClassicalWillingManager(BaseWillingManager):
async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
self.chat_reply_willing[chat_id] = max(0, current_willing - 1.8)
self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8)
async def after_generate_reply_handle(self, message_id):
if message_id not in self.ongoing_messages:
@@ -70,7 +70,7 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
if current_willing < 1:
self.chat_reply_willing[chat_id] = min(1, current_willing + 0.4)
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4)
async def bombing_buffer_message_handle(self, message_id):
return await super().bombing_buffer_message_handle(message_id)

View File

@@ -19,6 +19,7 @@ Mxp 模式:梦溪畔独家赞助
下下策是询问一个菜鸟(@梦溪畔)
"""
from src.config.config import global_config
from .willing_manager import BaseWillingManager
from typing import Dict
import asyncio
@@ -50,8 +51,6 @@ class MxpWillingManager(BaseWillingManager):
self.mention_willing_gain = 0.6 # 提及意愿增益
self.interest_willing_gain = 0.3 # 兴趣意愿增益
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
self.single_chat_gain = 0.12 # 单聊增益
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
@@ -179,10 +178,10 @@ class MxpWillingManager(BaseWillingManager):
probability = self._willing_to_probability(current_willing)
if w_info.is_emoji:
probability *= self.emoji_response_penalty
probability *= global_config.normal_chat.emoji_response_penalty
if w_info.group_info and w_info.group_info.group_id in self.global_config.talk_frequency_down_groups:
probability /= self.down_frequency_rate
if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups:
probability /= global_config.normal_chat.down_frequency_rate
self.temporary_willing = current_willing

View File

@@ -1,6 +1,6 @@
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
from dataclasses import dataclass
from src.config.config import global_config, BotConfig
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
@@ -93,7 +93,6 @@ class BaseWillingManager(ABC):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
self.lock = asyncio.Lock()
self.global_config: BotConfig = global_config
self.logger: LoguruLogger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
@@ -173,7 +172,7 @@ def init_willing_manager() -> BaseWillingManager:
Returns:
对应mode的WillingManager实例
"""
mode = global_config.willing_mode.lower()
mode = global_config.normal_chat.willing_mode.lower()
return BaseWillingManager.create(mode)

View File

@@ -1,5 +1,6 @@
from src.common.logger_manager import get_logger
from ...common.database import db
from ...common.database.database import db
from ...common.database.database_model import PersonInfo # 新增导入
import copy
import hashlib
from typing import Any, Callable, Dict
@@ -16,7 +17,7 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import json
import json # 新增导入
import re
@@ -38,47 +39,49 @@ logger = get_logger("person_info")
person_info_default = {
"person_id": None,
"person_name": None,
"person_name": None, # 模型中已设为 null=True此默认值OK
"name_reason": None,
"platform": None,
"user_id": None,
"nickname": None,
# "age" : 0,
"platform": "unknown", # 提供非None的默认值
"user_id": "unknown", # 提供非None的默认值
"nickname": "Unknown", # 提供非None的默认值
"relationship_value": 0,
# "saved" : True,
# "impression" : None,
# "gender" : Unkown,
"konw_time": 0,
"know_time": 0, # 修正拼写konw_time -> know_time
"msg_interval": 2000,
"msg_interval_list": [],
"user_cardname": None, # 添加群名片
"user_avatar": None, # 添加头像信息例如URL或标识符
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
"msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
"user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
"user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
}
class PersonInfoManager:
def __init__(self):
self.person_name_list = {}
# TODO: API-Adapter修改标记
self.qv_name_llm = LLMRequest(
model=global_config.llm_normal,
model=global_config.model.normal,
max_tokens=256,
request_type="qv_name",
)
if "person_info" not in db.list_collection_names():
db.create_collection("person_info")
db.person_info.create_index("person_id", unique=True)
try:
db.connect(reuse_if_open=True)
db.create_tables([PersonInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name
cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0})
for doc in cursor:
if doc.get("person_name"):
self.person_name_list[doc["person_id"]] = doc["person_name"]
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
try:
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_null(False)
):
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod
def get_person_id(platform: str, user_id: int):
"""获取唯一id"""
# 如果platform中存在-,就截取-后面的部分
if "-" in platform:
platform = platform.split("-")[1]
@@ -86,13 +89,17 @@ class PersonInfoManager:
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
def is_person_known(self, platform: str, user_id: int):
async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人"""
person_id = self.get_person_id(platform, user_id)
document = db.person_info.find_one({"person_id": person_id})
if document:
return True
else:
def _db_check_known_sync(p_id: str):
return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False
def get_person_id_by_person_name(self, person_name: str):
@@ -111,73 +118,111 @@ class PersonInfoManager:
return
_person_info_default = copy.deepcopy(person_info_default)
_person_info_default["person_id"] = person_id
model_fields = PersonInfo._meta.fields.keys()
final_data = {"person_id": person_id}
if data:
for key in _person_info_default:
if key != "person_id" and key in data:
_person_info_default[key] = data[key]
for key, value in data.items():
if key in model_fields:
final_data[key] = value
db.person_info.insert_one(_person_info_default)
for key, default_value in _person_info_default.items():
if key in model_fields and key not in final_data:
final_data[key] = default_value
if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
final_data["msg_interval_list"] = json.dumps([])
def _db_create_sync(p_data: dict):
try:
PersonInfo.create(**p_data)
return True
except Exception as e:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
"""更新某一个字段,会补全"""
if field_name not in person_info_default.keys():
logger.debug(f"更新'{field_name}'失败,未定义的字段")
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
return
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return
document = db.person_info.find_one({"person_id": person_id})
def _db_update_sync(p_id: str, f_name: str, val):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
if f_name == "msg_interval_list" and isinstance(val, list):
setattr(record, f_name, json.dumps(val))
else:
setattr(record, f_name, val)
record.save()
return True, False
return False, True
if document:
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
else:
data[field_name] = value
logger.debug(f"更新时{person_id}不存在,已新建")
await self.create_person_info(person_id, data)
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
if needs_creation:
logger.debug(f"更新时 {person_id} 不存在,将新建。")
creation_data = data if data is not None else {}
creation_data[field_name] = value
if "platform" not in creation_data or "user_id" not in creation_data:
logger.warning(f"{person_id} 创建记录时platform/user_id 可能缺失。")
await self.create_person_info(person_id, creation_data)
@staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段"""
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
if document:
return True
else:
if field_name not in PersonInfo._meta.fields:
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return True
return False
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
return False
@staticmethod
def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法"""
try:
# 尝试直接解析
parsed_json = json.loads(text)
# 如果解析结果是列表,尝试取第一个元素
if isinstance(parsed_json, list):
if parsed_json: # 检查列表是否为空
if parsed_json:
parsed_json = parsed_json[0]
else: # 如果列表为空,重置为 None走后续逻辑
else:
parsed_json = None
# 确保解析结果是字典
if isinstance(parsed_json, dict):
return parsed_json
except json.JSONDecodeError:
# 解析失败,继续尝试其他方法
pass
except Exception as e:
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
pass # 继续尝试其他方法
pass
# 如果直接解析失败或结果不是字典
try:
# 尝试找到JSON对象格式的部分
json_pattern = r"\{[^{}]*\}"
matches = re.findall(json_pattern, text)
if matches:
parsed_obj = json.loads(matches[0])
if isinstance(parsed_obj, dict): # 确保是字典
if isinstance(parsed_obj, dict):
return parsed_obj
# 如果上面都失败了,尝试提取键值对
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
@@ -192,7 +237,6 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"后备JSON提取失败: {str(e)}")
# 如果所有方法都失败了,返回默认字典
logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
return {"nickname": "", "reason": ""}
@@ -207,9 +251,11 @@ class PersonInfoManager:
old_name = await self.get_value(person_id, "person_name")
old_reason = await self.get_value(person_id, "name_reason")
max_retries = 5 # 最大重试次数
max_retries = 5
current_try = 0
existing_names = ""
existing_names_str = ""
current_name_set = set(self.person_name_list.values())
while current_try < max_retries:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=1)
@@ -224,45 +270,58 @@ class PersonInfoManager:
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}"
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
qv_name_prompt += (
"\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改"
)
if existing_names:
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}\n"
if existing_names_str:
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}\n"
if len(current_name_set) < 50 and current_name_set:
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
qv_name_prompt += "请用json给出你的想法并给出理由示例如下"
qv_name_prompt += """{
"nickname": "昵称",
"reason": "理由"
}"""
# logger.debug(f"取名提示词:{qv_name_prompt}")
response = await self.qv_name_llm.generate_response(qv_name_prompt)
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
result = self._extract_json_from_text(response[0])
if not result["nickname"]:
logger.error("生成的昵称为空,重试中...")
if not result or not result.get("nickname"):
logger.error("生成的昵称为空或结果格式不正确,重试中...")
current_try += 1
continue
# 检查生成的昵称是否已存在
if result["nickname"] not in self.person_name_list.values():
# 更新数据库和内存中的列表
await self.update_one_field(person_id, "person_name", result["nickname"])
# await self.update_one_field(person_id, "nickname", user_nickname)
# await self.update_one_field(person_id, "avatar", user_avatar)
await self.update_one_field(person_id, "name_reason", result["reason"])
generated_nickname = result["nickname"]
self.person_name_list[person_id] = result["nickname"]
# logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}")
is_duplicate = False
if generated_nickname in current_name_set:
is_duplicate = True
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
if not is_duplicate:
await self.update_one_field(person_id, "person_name", generated_nickname)
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
self.person_name_list[person_id] = generated_nickname
return result
else:
existing_names += f"{result['nickname']}"
if existing_names_str:
existing_names_str += ""
existing_names_str += generated_nickname
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
current_try += 1
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...")
current_try += 1
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称")
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
return None
@staticmethod
@@ -272,30 +331,56 @@ class PersonInfoManager:
logger.debug("删除失败person_id 不能为空")
return
result = db.person_info.delete_one({"person_id": person_id})
if result.deleted_count > 0:
logger.debug(f"删除成功person_id={person_id}")
def _db_delete_sync(p_id: str):
try:
query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
deleted_count = query.execute()
return deleted_count
except Exception as e:
logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
else:
logger.debug(f"删除失败:未找到 person_id={person_id}")
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
@staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id:
logger.debug("get_value获取失败person_id不能为空")
return person_info_default.get(field_name)
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
logger.trace(f"字段'{field_name}'不在Peewee模型中但存在于默认配置中。返回配置默认值。")
return copy.deepcopy(person_info_default[field_name])
logger.debug(f"get_value获取失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
return None
if field_name not in person_info_default:
logger.debug(f"get_value获取失败字段'{field_name}'未定义")
def _db_get_value_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
val = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
return copy.deepcopy(person_info_default.get(f_name, []))
return val
return None
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if document and field_name in document:
return document[field_name]
if value is not None:
return value
else:
default_value = copy.deepcopy(person_info_default[field_name])
logger.trace(f"获取{person_id}{field_name}失败,已返回默认值{default_value}")
default_value = copy.deepcopy(person_info_default.get(field_name))
logger.trace(f"获取{person_id}{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
return default_value
@staticmethod
@@ -305,93 +390,84 @@ class PersonInfoManager:
logger.debug("get_values获取失败person_id不能为空")
return {}
# 检查所有字段是否有效
for field in field_names:
if field not in person_info_default:
logger.debug(f"get_values获取失败字段'{field}'未定义")
return {}
# 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names}
document = db.person_info.find_one({"person_id": person_id}, projection)
result = {}
for field in field_names:
result[field] = copy.deepcopy(
document.get(field, person_info_default[field]) if document else person_info_default[field]
)
def _db_get_record_sync(p_id: str):
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names:
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.trace(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if field_name == "msg_interval_list" and isinstance(value, str):
try:
result[field_name] = json.loads(value)
except json.JSONDecodeError:
logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
elif value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
return result
@staticmethod
async def del_all_undefined_field():
"""删除所有项里的未定义字段"""
# 获取所有已定义的字段名
defined_fields = set(person_info_default.keys())
try:
# 遍历集合中的所有文档
for document in db.person_info.find({}):
# 找出文档中未定义的字段
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one(
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
)
if update_result.modified_count > 0:
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
return
except Exception as e:
logger.error(f"清理未定义字段时出错: {e}")
return
"""删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
logger.info(
"del_all_undefined_field: 对于使用Peewee的SQL数据库此操作通常不适用或不需要因为表结构是预定义的。"
)
return
@staticmethod
async def get_specific_value_list(
field_name: str,
way: Callable[[Any], bool], # 接受任意类型值
way: Callable[[Any], bool],
) -> Dict[str, Any]:
"""
获取满足条件的字段值字典
Args:
field_name: 目标字段名
way: 判断函数 (value: Any) -> bool
Returns:
{person_id: value} | {}
Example:
# 查找所有nickname包含"admin"的用户
result = manager.specific_value_list(
"nickname",
lambda x: "admin" in x.lower()
)
"""
if field_name not in person_info_default:
logger.error(f"字段检查失败:'{field_name}'未定义")
if field_name not in PersonInfo._meta.fields:
logger.error(f"字段检查失败:'{field_name}'在 PersonInfo Peewee 模型中定义")
return {}
def _db_get_specific_sync(f_name: str):
found_results = {}
try:
for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
value = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(value, str):
try:
processed_value = json.loads(value)
except json.JSONDecodeError:
logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
continue
else:
processed_value = value
if way(processed_value):
found_results[record.person_id] = processed_value
except Exception as e_query:
logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results
try:
result = {}
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
try:
value = doc[field_name]
if way(value):
result[doc["person_id"]] = value
except (KeyError, TypeError, ValueError) as e:
logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}")
continue
return result
return await asyncio.to_thread(_db_get_specific_sync, field_name)
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
return {}
async def personal_habit_deduction(self):
@@ -399,35 +475,31 @@ class PersonInfoManager:
try:
while 1:
await asyncio.sleep(600)
current_time = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
current_time_dt = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
# "msg_interval"推断
msg_interval_map = False
msg_interval_lists = await self.get_specific_value_list(
msg_interval_map_generated = False
msg_interval_lists_map = await self.get_specific_value_list(
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
)
for person_id, msg_interval_list_ in msg_interval_lists.items():
for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
await asyncio.sleep(0.3)
try:
time_interval = []
for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]):
for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
delta = t2 - t1
if delta > 0:
time_interval.append(delta)
time_interval = [t for t in time_interval if 200 <= t <= 8000]
# --- 修改后的逻辑 ---
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
time_interval.sort()
# 画图(log) - 这部分保留
msg_interval_map = True
if len(time_interval) >= 30 + 10:
time_interval.sort()
msg_interval_map_generated = True
log_dir = Path("logs/person_info")
log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6))
# 使用截断前的数据画图,更能反映原始分布
time_series_original = pd.Series(time_interval)
plt.hist(
time_series_original,
@@ -449,34 +521,29 @@ class PersonInfoManager:
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path)
plt.close()
# 画图结束
# 去掉头尾各 5 个数据点
trimmed_interval = time_interval[5:-5]
# 计算截断后数据的 37% 分位数
if trimmed_interval: # 确保截断后列表不为空
msg_interval = int(round(np.percentile(trimmed_interval, 37)))
# 更新数据库
await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
if trimmed_interval:
msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
await self.update_one_field(person_id, "msg_interval", msg_interval_val)
logger.trace(
f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
)
else:
logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval")
else:
logger.trace(
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
)
# --- 修改结束 ---
except Exception as e:
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
except Exception as e_inner:
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
continue
# 其他...
if msg_interval_map:
if msg_interval_map_generated:
logger.trace("已保存分布图到: logs/person_info")
current_time = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
current_time_dt_end = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
await asyncio.sleep(86400)
except Exception as e:
@@ -489,41 +556,27 @@ class PersonInfoManager:
"""
根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。
Args:
platform: 平台标识
user_id: 用户在该平台上的ID
nickname: 用户的昵称 (可选,用于创建新用户)
user_cardname: 用户的群名片 (可选,用于创建新用户)
user_avatar: 用户的头像信息 (可选,用于创建新用户)
Returns:
对应的 person_id。
"""
person_id = self.get_person_id(platform, user_id)
# 检查用户是否已存在
# 使用静态方法 get_person_id因此可以直接调用 db
document = db.person_info.find_one({"person_id": person_id})
def _db_check_exists_sync(p_id: str):
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if document is None:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
record = await asyncio.to_thread(_db_check_exists_sync, person_id)
if record is None:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
initial_data = {
"platform": platform,
"user_id": user_id,
"user_id": str(user_id),
"nickname": nickname,
"konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间
# 注意:这里没有添加 user_cardname 和 user_avatar因为它们不在 person_info_default 中
# 如果需要存储它们,需要先在 person_info_default 中定义
"know_time": int(datetime.datetime.now().timestamp()), # 修正拼写konw_time -> know_time
}
# 过滤掉值为 None 的初始数据
initial_data = {k: v for k, v in initial_data.items() if v is not None}
model_fields = PersonInfo._meta.fields.keys()
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
# 注意create_person_info 是静态方法
await PersonInfoManager.create_person_info(person_id, data=initial_data)
# 创建后,可以考虑立即为其取名,但这可能会增加延迟
# await self.qv_person_name(person_id, nickname, user_cardname, user_avatar)
logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}")
await self.create_person_info(person_id, data=filtered_initial_data)
logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
return person_id
@@ -533,35 +586,55 @@ class PersonInfoManager:
logger.debug("get_person_info_by_name 获取失败person_name 不能为空")
return None
# 优先从内存缓存查找 person_id
found_person_id = None
for pid, name in self.person_name_list.items():
if name == person_name:
for pid, name_in_cache in self.person_name_list.items():
if name_in_cache == person_name:
found_person_id = pid
break # 找到第一个匹配就停止
break
if not found_person_id:
# 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载)
document = db.person_info.find_one({"person_name": person_name})
if document:
found_person_id = document.get("person_id")
else:
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户")
return None # 数据库也找不到
# 根据找到的 person_id 获取所需信息
if found_person_id:
required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"]
person_data = await self.get_values(found_person_id, required_fields)
if person_data: # 确保 get_values 成功返回
return person_data
def _db_find_by_name_sync(p_name_to_find: str):
return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if record:
found_person_id = record.person_id
if (
found_person_id not in self.person_name_list
or self.person_name_list[found_person_id] != person_name
):
self.person_name_list[found_person_id] = person_name
else:
logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败")
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
return None
else:
# 这理论上不应该发生,因为上面已经处理了找不到的情况
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id")
return None
if found_person_id:
required_fields = [
"person_id",
"platform",
"user_id",
"nickname",
"user_cardname",
"user_avatar",
"person_name",
"name_reason",
]
valid_fields_to_get = [
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)
if person_data:
final_result = {key: person_data.get(key) for key in required_fields}
return final_result
else:
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
return None
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
return None
person_info_manager = PersonInfoManager()

View File

@@ -190,8 +190,8 @@ async def _build_readable_messages_internal(
person_id = person_info_manager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
if replace_bot_name and user_id == global_config.BOT_QQ:
person_name = f"{global_config.BOT_NICKNAME}(你)"
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -427,7 +427,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
output_lines = []
def get_anon_name(platform, user_id):
if user_id == global_config.BOT_QQ:
if user_id == global_config.bot.qq_account:
return "SELF"
person_id = person_info_manager.get_person_id(platform, user_id)
if person_id not in person_map:
@@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
def reply_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
anon_reply = get_anon_name(platform, bbb)
anon_reply = get_anon_name(platform, bbb) # noqa
return f"回复 {anon_reply}"
content = re.sub(reply_pattern, reply_replacer, content, count=1)
@@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
def at_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
anon_at = get_anon_name(platform, bbb)
anon_at = get_anon_name(platform, bbb) # noqa
return f"@{anon_at}"
content = re.sub(at_pattern, at_replacer, content)
@@ -501,7 +501,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
user_id = user_info.get("user_id")
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
person_id = person_info_manager.get_person_id(platform, user_id)

View File

@@ -1,15 +1,15 @@
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv, MessageSending, Message
from src.common.database import db
from src.common.database.database_model import Messages, ThinkingLog
import time
import traceback
from typing import List
import json
class InfoCatcher:
def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
self.context_length = global_config.observation_context_size
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
@@ -60,8 +60,6 @@ class InfoCatcher:
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1:
@@ -72,25 +70,10 @@ class InfoCatcher:
self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
# self.reasoning_data["thinking_log"] = reasoning_content
# self.reasoning_data["prompt"] = prompt
# self.reasoning_data["response"] = response
# self.reasoning_data["model"] = model_name
# 直接记录信息喵~
self.reasoning_data["thinking_log"] = reasoning_content
self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
self.response_text = response
@@ -102,6 +85,7 @@ class InfoCatcher:
):
self.timing_results["make_response_time"] = response_duration
self.response_time = time.time()
self.response_messages = []
for msg in response_message:
self.response_messages.append(msg)
@@ -112,107 +96,112 @@ class InfoCatcher:
@staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try:
# 从数据库中获取消息的时间戳
time_start = message_start.message_info.time
time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
messages_between = db.messages.find(
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
).sort("time", -1)
messages_between_query = (
Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
.order_by(Messages.time.desc())
)
result = list(messages_between)
result = list(messages_between_query)
print(f"查询结果数量: {len(result)}")
if result:
print(f"第一条消息时间: {result[0]['time']}")
print(f"最后一条消息时间: {result[-1]['time']}")
print(f"第一条消息时间: {result[0].time}")
print(f"最后一条消息时间: {result[-1].time}")
return result
except Exception as e:
print(f"获取消息时出错: {str(e)}")
print(traceback.format_exc())
return []
def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息
message_id = message.message_info.message_id
chat_id = message.chat_stream.stream_id
message_id_val = message.message_info.message_id
chat_id_val = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
messages_before = (
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
.sort("time", -1)
.limit(self.context_length * 3)
) # 获取更多历史信息
messages_before_query = (
Messages.select()
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
.order_by(Messages.time.desc())
.limit(global_config.chat.observation_context_size * 3)
)
return list(messages_before)
return list(messages_before_query)
def message_list_to_dict(self, message_list):
# 存储简化的聊天记录
result = []
for message in message_list:
if not isinstance(message, dict):
message = self.message_to_dict(message)
# print(message)
for msg_item in message_list:
processed_msg_item = msg_item
if not isinstance(msg_item, dict):
processed_msg_item = self.message_to_dict(msg_item)
if not processed_msg_item:
continue
lite_message = {
"time": message["time"],
"user_nickname": message["user_info"]["user_nickname"],
"processed_plain_text": message["processed_plain_text"],
"time": processed_msg_item.get("time"),
"user_nickname": processed_msg_item.get("user_nickname"),
"processed_plain_text": processed_msg_item.get("processed_plain_text"),
}
result.append(lite_message)
return result
@staticmethod
def message_to_dict(message):
if not message:
def message_to_dict(msg_obj):
if not msg_obj:
return None
if isinstance(message, dict):
return message
return {
# "message_id": message.message_info.message_id,
"time": message.message_info.time,
"user_id": message.message_info.user_info.user_id,
"user_nickname": message.message_info.user_info.user_nickname,
"processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text
}
if isinstance(msg_obj, dict):
return msg_obj
def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~"""
try:
# 将消息对象转换为可序列化的字典喵~
thinking_log_data = {
"chat_id": self.chat_id,
"trigger_text": self.trigger_response_text,
"response_text": self.response_text,
"trigger_info": {
"time": self.trigger_response_time,
"message": self.message_to_dict(self.trigger_response_message),
},
"response_info": {
"time": self.response_time,
"message": self.response_messages,
},
"timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
"heartflow_data": self.heartflow_data,
"reasoning_data": self.reasoning_data,
if isinstance(msg_obj, Messages):
return {
"time": msg_obj.time,
"user_id": msg_obj.user_id,
"user_nickname": msg_obj.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~
# if self.response_mode == "heart_flow":
# thinking_log_data["mode_specific_data"] = self.heartflow_data
# elif self.response_mode == "reasoning":
# thinking_log_data["mode_specific_data"] = self.reasoning_data
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
return {
"time": msg_obj.message_info.time,
"user_id": msg_obj.message_info.user_info.user_id,
"user_nickname": msg_obj.message_info.user_info.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
# 将数据插入到 thinking_log 集合中喵~
db.thinking_log.insert_one(thinking_log_data)
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
return {}
def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 表中喵~"""
try:
trigger_info_dict = self.message_to_dict(self.trigger_response_message)
response_info_dict = {
"time": self.response_time,
"message": self.response_messages,
}
chat_history_list = self.message_list_to_dict(self.chat_history)
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
log_entry = ThinkingLog(
chat_id=self.chat_id,
trigger_text=self.trigger_response_text,
response_text=self.response_text,
trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
response_info_json=json.dumps(response_info_dict),
timing_results_json=json.dumps(self.timing_results),
chat_history_json=json.dumps(chat_history_list),
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
heartflow_data_json=json.dumps(self.heartflow_data),
reasoning_data_json=json.dumps(self.reasoning_data),
)
log_entry.save()
return True
except Exception as e:

View File

@@ -2,10 +2,12 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask
from ...common.database import db
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic")
@@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: str | None = None
self.record_id: int | None = None # Changed to int for Peewee's default ID
"""记录ID"""
self._init_database() # 初始化数据库
@@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask):
@staticmethod
def _init_database():
"""初始化数据库"""
if "online_time" not in db.list_collection_names():
# 初始化数据库(在线时长)
db.create_collection("online_time")
# 创建索引
if ("end_timestamp", 1) not in db.online_time.list_indexes():
db.online_time.create_index([("end_timestamp", 1)])
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self):
try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id:
# 如果有记录,则更新结束时间
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": datetime.now() + timedelta(minutes=1),
}
},
)
else:
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
updated_rows = query.execute()
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
self.record_id = None # Reset record_id to trigger find/create logic below
if not self.record_id: # Check again if record_id was reset or initially None
# 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now()
if recent_record := db.online_time.find_one(
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}}
):
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
.order_by(OnlineTime.end_timestamp.desc())
.first()
)
if recent_record:
# 如果有记录,则更新结束时间
self.record_id = recent_record["_id"]
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": current_time + timedelta(minutes=1),
}
},
)
self.record_id = recent_record.id
recent_record.end_timestamp = extended_end_time
recent_record.save()
else:
# 若没有记录,则插入新的在线时间记录
self.record_id = db.online_time.insert_one(
{
"start_timestamp": current_time,
"end_timestamp": current_time + timedelta(minutes=1),
}
).inserted_id
new_record = OnlineTime.create(
timestamp=current_time.timestamp(), # 添加此行
start_timestamp=current_time,
end_timestamp=extended_end_time,
duration=5, # 初始时长为5分钟
)
self.record_id = new_record.id
except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 总LLM请求数
TOTAL_REQ_CNT: 0,
# 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
# 输入Token数
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
# 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
# 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
# 总开销
TOTAL_COST: 0.0,
# 请求开销统计
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
@@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask):
}
# 以最早的时间戳为起始时间获取记录
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}):
record_timestamp = record.get("timestamp")
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type", "unknown") # 请求类型
user_id = str(record.get("user_id", "unknown")) # 用户ID
model_name = record.get("model_name", "unknown") # 模型名称
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数
completion_tokens = record.get("completion_tokens", 0) # 输出Token数
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数
prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.completion_tokens or 0
total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
@@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0)
cost = record.cost or 0.0
stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost
break # 取消更早时间段的判断
break
return stats
@staticmethod
@@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 在线时间统计
ONLINE_TIME: 0.0,
}
for period_key, _ in collect_period
}
# 统计在线时间
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}):
end_timestamp: datetime = record.get("end_timestamp")
for idx, (_, period_start) in enumerate(collect_period):
if end_timestamp >= period_start:
# 由于end_timestamp会超前标记时间所以我们需要判断是否晚于当前时间如果是则使用当前时间作为结束时间
end_timestamp = min(end_timestamp, now)
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _period_start in collect_period[idx:]:
start_timestamp: datetime = record.get("start_timestamp")
if start_timestamp < _period_start:
# 如果开始时间在查询边界之前,则使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
else:
# 否则,使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
break # 取消更早时间段的判断
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[idx:]:
# Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end
if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break
return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
@@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段
"""
if len(collect_period) <= 0:
if not collect_period:
return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
period_key: {
# 消息统计
TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int),
}
for period_key, _ in collect_period
}
# 统计消息量
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}):
chat_info = message.get("chat_info", None) # 聊天信息
user_info = message.get("user_info", None) # 用户信息(消息发送人)
message_time = message.get("time", 0) # 消息时间
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp):
message_time_ts = message.time # This is a float timestamp
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息
if group_info is not None:
# 若有群聊信息
chat_id = f"g{group_info.get('group_id')}"
chat_name = group_info.get("group_name", f"{group_info.get('group_id')}")
elif user_info:
# 若没有群聊信息,则尝试获取用户信息
chat_id = f"u{user_info['user_id']}"
chat_name = user_info["user_nickname"]
chat_id = None
chat_name = None
# Logic based on Peewee model structure, aiming to replicate original intent
if message.chat_info_group_id:
chat_id = f"g{message.chat_info_group_id}"
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
else:
continue # 如果没有群组信息也没有用户信息,则跳过
# If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
)
continue
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]:
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称
self.name_mapping[chat_id] = (chat_name, message_time)
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
self.name_mapping[chat_id] = (chat_name, message_time)
self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start) in enumerate(collect_period):
if message_time >= period_start.timestamp():
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:

View File

@@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...common.database import db
from ...common.database.database import db
from ...config.config import global_config
logger = get_module_logger("chat_utils")
@@ -43,8 +43,8 @@ def db_message_to_str(message_dict: dict) -> str:
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
@@ -64,18 +64,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
)
# 判断是否被@
if re.search(f"@[\s\S]*?id:{global_config.BOT_QQ}", message.processed_plain_text):
if re.search(f"@[\s\S]*?id:{global_config.bot.qq_account}", message.processed_plain_text):
is_at = True
is_mentioned = True
if is_at and global_config.at_bot_inevitable_reply:
if is_at and global_config.normal_chat.at_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\)[\s\S]*?],说:", message.processed_plain_text
f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\)[\s\S]*?],说:", message.processed_plain_text
):
is_mentioned = True
else:
@@ -88,7 +88,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
for nickname in nicknames:
if nickname in message_content:
is_mentioned = True
if is_mentioned and global_config.mentioned_bot_inevitable_reply:
if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被提及回复概率设置为100%")
return is_mentioned, reply_probability
@@ -96,7 +96,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding"):
"""获取文本的embedding向量"""
llm = LLMRequest(model=global_config.embedding, request_type=request_type)
# TODO: API-Adapter修改标记
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
try:
embedding = await llm.get_embedding(text)
@@ -163,7 +164,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
user_info = UserInfo.from_dict(msg_db_data["user_info"])
if (
(user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.BOT_QQ
and user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
@@ -321,7 +322,7 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> list[str]:
# 先保护颜文字
if global_config.enable_kaomoji_protection:
if global_config.response_splitter.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}")
else:
@@ -340,8 +341,8 @@ def process_llm_response(text: str) -> list[str]:
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
# 对清理后的文本进行进一步处理
max_length = global_config.response_max_length * 2
max_sentence_num = global_config.response_max_sentence_num
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
@@ -349,20 +350,20 @@ def process_llm_response(text: str) -> list[str]:
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo_error_rate,
min_freq=global_config.chinese_typo_min_freq,
tone_error_rate=global_config.chinese_typo_tone_error_rate,
word_replace_rate=global_config.chinese_typo_word_replace_rate,
error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo.tone_error_rate,
word_replace_rate=global_config.chinese_typo.word_replace_rate,
)
if global_config.enable_response_splitter:
if global_config.response_splitter.enable:
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else:
split_sentences = [cleaned_text]
sentences = []
for sentence in split_sentences:
if global_config.chinese_typo_enable:
if global_config.chinese_typo.enable:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text)
if typo_corrections:
@@ -372,7 +373,7 @@ def process_llm_response(text: str) -> list[str]:
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.BOT_NICKNAME}不知道哦"]
return [f"{global_config.bot.nickname}不知道哦"]
# if extracted_contents:
# for content in extracted_contents:

View File

@@ -8,7 +8,8 @@ import io
import numpy as np
from ...common.database import db
from ...common.database.database import db
from ...common.database.database_model import Images, ImageDescriptions
from ...config.config import global_config
from ..models.utils_model import LLMRequest
@@ -32,40 +33,23 @@ class ImageManager:
def __init__(self):
if not self._initialized:
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
self._initialized = True
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
# 删除旧索引
db.images.drop_indexes()
# 创建新的复合索引
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 删除旧索引
db.image_descriptions.drop_indexes()
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -77,8 +61,14 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
@@ -90,20 +80,17 @@ class ImageManager:
description_type: 描述类型 ('emoji''image')
"""
try:
db.image_descriptions.update_one(
{"hash": image_hash, "type": description_type},
{
"$set": {
"description": description,
"timestamp": int(time.time()),
"hash": image_hash, # 确保hash字段存在
"type": description_type, # 确保type字段存在
}
},
upsert=True,
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
except Exception as e:
logger.error(f"保存描述到数据库失败: {str(e)}")
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
@@ -116,51 +103,64 @@ class ImageManager:
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
# logger.debug(f"缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
image_base64 = self.transform_gif(image_base64)
image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]"
prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg")
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成表情包描述")
return "[表情包(描述生成失败)]"
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]"
# 根据配置决定是否保存图片
if global_config.save_emoji:
if global_config.emoji.save_emoji:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
os.makedirs(emoji_dir, exist_ok=True)
file_path = os.path.join(emoji_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
"hash": image_hash,
"path": file_path,
"type": "emoji",
"description": description,
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.trace(f"保存表情包: {file_path}")
# 保存到数据库 (Images表)
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
img_obj.path = file_path
img_obj.description = description
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist:
Images.create(
hash=image_hash,
path=file_path,
type="emoji",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存表情包元数据: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
# 保存描述到数据库
# 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]"
@@ -188,6 +188,11 @@ class ImageManager:
)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
# 再次检查缓存
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
@@ -195,38 +200,40 @@ class ImageManager:
logger.debug(f"描述是{description}")
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片
if global_config.save_pic:
if global_config.emoji.save_pic:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
current_timestamp = time.time()
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
"hash": image_hash,
"path": file_path,
"type": "image",
"description": description,
"timestamp": timestamp,
}
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.trace(f"保存图片: {file_path}")
# 保存到数据库 (Images表)
try:
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image"))
img_obj.path = file_path
img_obj.description = description
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist:
Images.create(
hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存图片元数据: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到数据库
# 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]"

View File

@@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import db # noqa E402
from common.database.database import db # noqa E402
# 加载根目录下的env.edv文件