Merge pull request #956 from MaiM-with-u/plugin

Plugin插件和工作记忆
This commit is contained in:
SengokuCola
2025-05-16 23:47:18 +08:00
committed by GitHub
62 changed files with 4048 additions and 1596 deletions

2
.gitignore vendored
View File

@@ -301,3 +301,5 @@ $RECYCLE.BIN/
# Windows shortcuts # Windows shortcuts
*.lnk *.lnk
src/chat/focus_chat/working_memory/test/test1.txt
src/chat/focus_chat/working_memory/test/test4.txt

Binary file not shown.

View File

@@ -41,7 +41,7 @@ class APIBotConfig:
allow_focus_mode: bool # 是否允许专注聊天状态 allow_focus_mode: bool # 是否允许专注聊天状态
base_normal_chat_num: int # 最多允许多少个群进行普通聊天 base_normal_chat_num: int # 最多允许多少个群进行普通聊天
base_focused_chat_num: int # 最多允许多少个群进行专注聊天 base_focused_chat_num: int # 最多允许多少个群进行专注聊天
observation_context_size: int # 观察到的最长上下文大小 chat.observation_context_size: int # 观察到的最长上下文大小
message_buffer: bool # 是否启用消息缓冲 message_buffer: bool # 是否启用消息缓冲
ban_words: List[str] # 禁止词列表 ban_words: List[str] # 禁止词列表
ban_msgs_regex: List[str] # 禁止消息的正则表达式列表 ban_msgs_regex: List[str] # 禁止消息的正则表达式列表
@@ -128,7 +128,7 @@ class APIBotConfig:
llm_reasoning: Dict[str, Any] # 推理模型配置 llm_reasoning: Dict[str, Any] # 推理模型配置
llm_normal: Dict[str, Any] # 普通模型配置 llm_normal: Dict[str, Any] # 普通模型配置
llm_topic_judge: Dict[str, Any] # 主题判断模型配置 llm_topic_judge: Dict[str, Any] # 主题判断模型配置
llm_summary: Dict[str, Any] # 总结模型配置 model.summary: Dict[str, Any] # 总结模型配置
vlm: Dict[str, Any] # VLM模型配置 vlm: Dict[str, Any] # VLM模型配置
llm_heartflow: Dict[str, Any] # 心流模型配置 llm_heartflow: Dict[str, Any] # 心流模型配置
llm_observation: Dict[str, Any] # 观察模型配置 llm_observation: Dict[str, Any] # 观察模型配置
@@ -203,7 +203,7 @@ class APIBotConfig:
"llm_reasoning", "llm_reasoning",
"llm_normal", "llm_normal",
"llm_topic_judge", "llm_topic_judge",
"llm_summary", "model.summary",
"vlm", "vlm",
"llm_heartflow", "llm_heartflow",
"llm_observation", "llm_observation",

View File

@@ -5,12 +5,15 @@ import os
import random import random
import time import time
import traceback import traceback
from typing import Optional, Tuple from typing import Optional, Tuple, List, Any
from PIL import Image from PIL import Image
import io import io
import re import re
from ...common.database import db # from gradio_client import file
from ...common.database.database_model import Emoji
from ...common.database.database import db as peewee_db
from ...config.config import global_config from ...config.config import global_config
from ..utils.utils_image import image_path_to_base64, image_manager from ..utils.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
@@ -51,7 +54,7 @@ class MaiEmoji:
self.is_deleted = False # 标记是否已被删除 self.is_deleted = False # 标记是否已被删除
self.format = "" self.format = ""
async def initialize_hash_format(self): async def initialize_hash_format(self) -> Optional[bool]:
"""从文件创建表情包实例, 计算哈希值和格式""" """从文件创建表情包实例, 计算哈希值和格式"""
try: try:
# 使用 full_path 检查文件是否存在 # 使用 full_path 检查文件是否存在
@@ -104,7 +107,7 @@ class MaiEmoji:
self.is_deleted = True self.is_deleted = True
return None return None
async def register_to_db(self): async def register_to_db(self) -> bool:
""" """
注册表情包 注册表情包
将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下 将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下
@@ -143,22 +146,22 @@ class MaiEmoji:
# --- 数据库操作 --- # --- 数据库操作 ---
try: try:
# 准备数据库记录 for emoji collection # 准备数据库记录 for emoji collection
emoji_record = { emotion_str = ",".join(self.emotion) if self.emotion else ""
"filename": self.filename,
"path": self.path, # 存储目录路径
"full_path": self.full_path, # 存储完整文件路径
"embedding": self.embedding,
"description": self.description,
"emotion": self.emotion,
"hash": self.hash,
"format": self.format,
"timestamp": int(self.register_time),
"usage_count": self.usage_count,
"last_used_time": self.last_used_time,
}
# 使用upsert确保记录存在或被更新 Emoji.create(
db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True) hash=self.hash,
full_path=self.full_path,
format=self.format,
description=self.description,
emotion=emotion_str, # Store as comma-separated string
query_count=0, # Default value
is_registered=True,
is_banned=False, # Default value
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
register_time=self.register_time,
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -166,14 +169,6 @@ class MaiEmoji:
except Exception as db_error: except Exception as db_error:
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
# 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误
# 可以考虑在这里尝试删除已移动的文件,避免残留
try:
if os.path.exists(self.full_path): # full_path 此时是目标路径
os.remove(self.full_path)
logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}")
except Exception as remove_error:
logger.error(f"[错误] 回滚删除文件失败: {remove_error}")
return False return False
except Exception as e: except Exception as e:
@@ -181,7 +176,7 @@ class MaiEmoji:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
async def delete(self): async def delete(self) -> bool:
"""删除表情包 """删除表情包
删除表情包的文件和数据库记录 删除表情包的文件和数据库记录
@@ -201,10 +196,14 @@ class MaiEmoji:
# 文件删除失败,但仍然尝试删除数据库记录 # 文件删除失败,但仍然尝试删除数据库记录
# 2. 删除数据库记录 # 2. 删除数据库记录
result = db.emoji.delete_one({"hash": self.hash}) try:
deleted_in_db = result.deleted_count > 0 will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
if deleted_in_db: if result > 0:
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
# 3. 标记对象已被删除 # 3. 标记对象已被删除
self.is_deleted = True self.is_deleted = True
@@ -224,7 +223,7 @@ class MaiEmoji:
return False return False
def _emoji_objects_to_readable_list(emoji_objects): def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
"""将表情包对象列表转换为可读的字符串列表 """将表情包对象列表转换为可读的字符串列表
参数: 参数:
@@ -243,47 +242,48 @@ def _emoji_objects_to_readable_list(emoji_objects):
return emoji_info_list return emoji_info_list
def _to_emoji_objects(data): def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = [] emoji_objects = []
load_errors = 0 load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data) emoji_data_list = list(data)
for emoji_data in emoji_data_list: for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
full_path = emoji_data.get("full_path") full_path = emoji_data.full_path
if not full_path: if not full_path:
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}") logger.warning(
f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}"
)
load_errors += 1 load_errors += 1
continue # 跳过缺少 full_path 的记录 continue
try: try:
# 使用 full_path 初始化 MaiEmoji 对象
emoji = MaiEmoji(full_path=full_path) emoji = MaiEmoji(full_path=full_path)
# 设置从数据库加载的属性 emoji.hash = emoji_data.emoji_hash
emoji.hash = emoji_data.get("hash", "")
# 如果 hash 为空,也跳过?取决于业务逻辑
if not emoji.hash: if not emoji.hash:
logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}")
load_errors += 1 load_errors += 1
continue continue
emoji.description = emoji_data.get("description", "") emoji.description = emoji_data.description
emoji.emotion = emoji_data.get("emotion", []) # Deserialize emotion string from DB to list
emoji.usage_count = emoji_data.get("usage_count", 0) emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
# 优先使用 last_used_time否则用 timestamp最后用当前时间 emoji.usage_count = emoji_data.usage_count
last_used = emoji_data.get("last_used_time")
timestamp = emoji_data.get("timestamp")
emoji.last_used_time = (
last_used if last_used is not None else (timestamp if timestamp is not None else time.time())
)
emoji.register_time = timestamp if timestamp is not None else time.time()
emoji.format = emoji_data.get("format", "") # 加载格式
# 不需要再手动设置 path 和 filename__init__ 会自动处理 db_last_used_time = emoji_data.last_used_time
db_register_time = emoji_data.register_time
# If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time
emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
# If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
emoji.format = emoji_data.format
emoji_objects.append(emoji) emoji_objects.append(emoji)
except ValueError as ve: # 捕获 __init__ 可能的错误 except ValueError as ve:
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
load_errors += 1 load_errors += 1
except Exception as e: except Exception as e:
@@ -292,13 +292,13 @@ def _to_emoji_objects(data):
return emoji_objects, load_errors return emoji_objects, load_errors
def _ensure_emoji_dir(): def _ensure_emoji_dir() -> None:
"""确保表情存储目录存在""" """确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True) os.makedirs(EMOJI_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True) os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
async def clear_temp_emoji(): async def clear_temp_emoji() -> None:
"""清理临时表情包 """清理临时表情包
清理/data/emoji和/data/image目录下的所有文件 清理/data/emoji和/data/image目录下的所有文件
当目录中文件数超过100时会全部删除 当目录中文件数超过100时会全部删除
@@ -320,7 +320,7 @@ async def clear_temp_emoji():
logger.success("[清理] 完成") logger.success("[清理] 完成")
async def clean_unused_emojis(emoji_dir, emoji_objects): async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None:
"""清理指定目录中未被 emoji_objects 追踪的表情包文件""" """清理指定目录中未被 emoji_objects 追踪的表情包文件"""
if not os.path.exists(emoji_dir): if not os.path.exists(emoji_dir):
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
@@ -360,13 +360,13 @@ async def clean_unused_emojis(emoji_dir, emoji_objects):
class EmojiManager: class EmojiManager:
_instance = None _instance = None
def __new__(cls): def __new__(cls) -> "EmojiManager":
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self): def __init__(self) -> None:
self._initialized = None self._initialized = None
self._scan_task = None self._scan_task = None
@@ -382,53 +382,30 @@ class EmojiManager:
logger.info("启动表情包管理器") logger.info("启动表情包管理器")
def initialize(self): def initialize(self) -> None:
"""初始化数据库连接和表情目录""" """初始化数据库连接和表情目录"""
if not self._initialized: peewee_db.connect(reuse_if_open=True)
try: if peewee_db.is_closed():
self._ensure_emoji_collection() raise RuntimeError("数据库连接失败")
_ensure_emoji_dir() _ensure_emoji_dir()
self._initialized = True Emoji.create_table(safe=True) # Ensures table exists
# 更新表情包数量
# 启动时执行一次完整性检查
# await self.check_emoji_file_integrity()
except Exception as e:
logger.exception(f"初始化表情管理器失败: {e}")
def _ensure_db(self): def _ensure_db(self) -> None:
"""确保数据库已初始化""" """确保数据库已初始化"""
if not self._initialized: if not self._initialized:
self.initialize() self.initialize()
if not self._initialized: if not self._initialized:
raise RuntimeError("EmojiManager not initialized") raise RuntimeError("EmojiManager not initialized")
@staticmethod def record_usage(self, emoji_hash: str) -> None:
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
索引的作用是加快数据库查询速度:
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
- tags字段的普通索引: 加快按标签搜索表情包的速度
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_hash: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}}) emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
for emoji in self.emoji_objects: emoji_update.usage_count += 1
if emoji.hash == emoji_hash: emoji_update.last_used_time = time.time() # Update last used time
emoji.usage_count += 1 emoji_update.save() # Persist changes to DB
break except Emoji.DoesNotExist:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.error(f"记录表情使用失败: {str(e)}")
@@ -448,7 +425,6 @@ class EmojiManager:
if not all_emojis: if not all_emojis:
logger.warning("内存中没有任何表情包对象") logger.warning("内存中没有任何表情包对象")
# 可以考虑再查一次数据库?或者依赖定期任务更新
return None return None
# 计算每个表情包与输入文本的最大情感相似度 # 计算每个表情包与输入文本的最大情感相似度
@@ -464,40 +440,38 @@ class EmojiManager:
# 计算与每个emotion标签的相似度取最大值 # 计算与每个emotion标签的相似度取最大值
max_similarity = 0 max_similarity = 0
best_matching_emotion = "" # 记录最匹配的 emotion 喵~ best_matching_emotion = ""
for emotion in emotions: for emotion in emotions:
# 使用编辑距离计算相似度 # 使用编辑距离计算相似度
distance = self._levenshtein_distance(text_emotion, emotion) distance = self._levenshtein_distance(text_emotion, emotion)
max_len = max(len(text_emotion), len(emotion)) max_len = max(len(text_emotion), len(emotion))
similarity = 1 - (distance / max_len if max_len > 0 else 0) similarity = 1 - (distance / max_len if max_len > 0 else 0)
if similarity > max_similarity: # 如果找到更相似的喵~ if similarity > max_similarity:
max_similarity = similarity max_similarity = similarity
best_matching_emotion = emotion # 就记下这个 emotion 喵~ best_matching_emotion = emotion
if best_matching_emotion: # 确保有匹配的情感才添加喵~ if best_matching_emotion:
emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~ emoji_similarities.append((emoji, max_similarity, best_matching_emotion))
# 按相似度降序排序 # 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True) emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前10个最相似的表情包 # 获取前10个最相似的表情包
top_emojis = ( top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
) # 改个名字,更清晰喵~
if not top_emojis: if not top_emojis:
logger.warning("未找到匹配的表情包") logger.warning("未找到匹配的表情包")
return None return None
# 从前几个中随机选择一个 # 从前几个中随机选择一个
selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~ selected_emoji, similarity, matched_emotion = random.choice(top_emojis)
# 更新使用次数 # 更新使用次数
self.record_usage(selected_emoji.hash) self.record_usage(selected_emoji.emoji_hash)
_time_end = time.time() _time_end = time.time()
logger.info( # 使用匹配到的 emotion 记录日志喵~ logger.info(
f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}" f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}"
) )
# 返回完整文件路径和描述 # 返回完整文件路径和描述
@@ -535,7 +509,7 @@ class EmojiManager:
return previous_row[-1] return previous_row[-1]
async def check_emoji_file_integrity(self): async def check_emoji_file_integrity(self) -> None:
"""检查表情包文件完整性 """检查表情包文件完整性
遍历self.emoji_objects中的所有对象检查文件是否存在 遍历self.emoji_objects中的所有对象检查文件是否存在
如果文件已被删除,则执行对象的删除方法并从列表中移除 如果文件已被删除,则执行对象的删除方法并从列表中移除
@@ -600,7 +574,7 @@ class EmojiManager:
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}") logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def start_periodic_check_register(self): async def start_periodic_check_register(self) -> None:
"""定期检查表情包完整性和数量""" """定期检查表情包完整性和数量"""
await self.get_all_emoji_from_db() await self.get_all_emoji_from_db()
while True: while True:
@@ -654,13 +628,14 @@ class EmojiManager:
await asyncio.sleep(global_config.emoji.check_interval * 60) await asyncio.sleep(global_config.emoji.check_interval * 60)
async def get_all_emoji_from_db(self): async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects""" """获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try: try:
self._ensure_db() self._ensure_db()
logger.info("[数据库] 开始加载所有表情包记录...") logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...")
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find()) emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
# 更新内存中的列表和数量 # 更新内存中的列表和数量
self.emoji_objects = emoji_objects self.emoji_objects = emoji_objects
@@ -675,7 +650,7 @@ class EmojiManager:
self.emoji_objects = [] # 加载失败则清空列表 self.emoji_objects = [] # 加载失败则清空列表
self.emoji_num = 0 self.emoji_num = 0
async def get_emoji_from_db(self, emoji_hash=None): async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
参数: 参数:
@@ -687,15 +662,16 @@ class EmojiManager:
try: try:
self._ensure_db() self._ensure_db()
query = {}
if emoji_hash: if emoji_hash:
query = {"hash": emoji_hash} query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
else: else:
logger.warning( logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。" "[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
) )
query = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query)) emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
if load_errors > 0: if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -706,7 +682,7 @@ class EmojiManager:
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
return [] return []
async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]: async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
"""从内存中的 emoji_objects 列表获取表情包 """从内存中的 emoji_objects 列表获取表情包
参数: 参数:
@@ -759,7 +735,7 @@ class EmojiManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
async def replace_a_emoji(self, new_emoji: MaiEmoji): async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
"""替换一个表情包 """替换一个表情包
Args: Args:
@@ -820,7 +796,7 @@ class EmojiManager:
# 删除选定的表情包 # 删除选定的表情包
logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}") logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}")
delete_success = await self.delete_emoji(emoji_to_delete.hash) delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash)
if delete_success: if delete_success:
# 修复:等待异步注册完成 # 修复:等待异步注册完成
@@ -848,7 +824,7 @@ class EmojiManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]: async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
"""获取表情包描述和情感列表 """获取表情包描述和情感列表
Args: Args:

View File

@@ -10,7 +10,6 @@ from src.config.config import global_config
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
from src.chat.focus_chat.heartFC_sender import HeartFCSender from src.chat.focus_chat.heartFC_sender import HeartFCSender
from src.chat.utils.utils import process_llm_response from src.chat.utils.utils import process_llm_response
from src.chat.utils.info_catcher import info_catcher_manager from src.chat.utils.info_catcher import info_catcher_manager
@@ -18,10 +17,62 @@ from src.manager.mood_manager import mood_manager
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
from src.individuality.individuality import Individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
import random
logger = get_logger("expressor") logger = get_logger("expressor")
def init_prompt():
Prompt(
"""
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
{style_habbits}
你现在正在群里聊天,以下是群里正在进行的聊天内容:
{chat_info}
以上是聊天内容,你需要了解聊天记录中的内容
{chat_target}
你的名字是{bot_name}{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
请你根据情景使用以下句法:
{grammar_habbits}
回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 ),只输出一条回复就好。
现在,你说:
""",
"default_expressor_prompt",
)
Prompt(
"""
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
{style_habbits}
你现在正在群里聊天,以下是群里正在进行的聊天内容:
{chat_info}
以上是聊天内容,你需要了解聊天记录中的内容
{chat_target}
你的名字是{bot_name}{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
请你根据情景使用以下句法:
{grammar_habbits}
回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 ),只输出一条回复就好。
现在,你说:
""",
"default_expressor_private_prompt", # New template for private FOCUSED chat
)
class DefaultExpressor: class DefaultExpressor:
def __init__(self, chat_id: str): def __init__(self, chat_id: str):
self.log_prefix = "expressor" self.log_prefix = "expressor"
@@ -67,7 +118,7 @@ class DefaultExpressor:
reply=anchor_message, # 回复的是锚点消息 reply=anchor_message, # 回复的是锚点消息
thinking_start_time=thinking_time_point, thinking_start_time=thinking_time_point,
) )
logger.debug(f"创建思考消息thinking_message{thinking_message}") # logger.debug(f"创建思考消息thinking_message{thinking_message}")
await self.heart_fc_sender.register_thinking(thinking_message) await self.heart_fc_sender.register_thinking(thinking_message)
@@ -107,7 +158,7 @@ class DefaultExpressor:
if reply: if reply:
with Timer("发送消息", cycle_timers): with Timer("发送消息", cycle_timers):
sent_msg_list = await self._send_response_messages( sent_msg_list = await self.send_response_messages(
anchor_message=anchor_message, anchor_message=anchor_message,
thinking_id=thinking_id, thinking_id=thinking_id,
response_set=reply, response_set=reply,
@@ -163,13 +214,10 @@ class DefaultExpressor:
# 3. 构建 Prompt # 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await prompt_builder.build_prompt( prompt = await self.build_prompt_focus(
build_mode="focus",
chat_stream=self.chat_stream, # Pass the stream object chat_stream=self.chat_stream, # Pass the stream object
in_mind_reply=in_mind_reply, in_mind_reply=in_mind_reply,
reason=reason, reason=reason,
current_mind_info="",
structured_info="",
sender_name=sender_name_for_prompt, # Pass determined name sender_name=sender_name_for_prompt, # Pass determined name
target_message=target_message, target_message=target_message,
) )
@@ -188,7 +236,7 @@ class DefaultExpressor:
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n") # logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
content, reasoning_content, model_name = await self.express_model.generate_response(prompt) content, reasoning_content, model_name = await self.express_model.generate_response(prompt)
logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n") # logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n")
logger.info(f"想要表达:{in_mind_reply}") logger.info(f"想要表达:{in_mind_reply}")
logger.info(f"理由:{reason}") logger.info(f"理由:{reason}")
@@ -225,10 +273,108 @@ class DefaultExpressor:
traceback.print_exc() traceback.print_exc()
return None return None
async def build_prompt_focus(
self,
reason,
chat_stream,
sender_name,
in_mind_reply,
target_message,
) -> str:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=0, level=2)
# Determine if it's a group chat
is_group_chat = bool(chat_stream.group_info)
# Use sender_name passed from caller for private chat, otherwise use a default for group
# Default sender_name for group chat isn't used in the group prompt template, but set for consistency
effective_sender_name = sender_name if not is_group_chat else "某人"
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=True,
timestamp_mode="relative",
read_mark=0.0,
truncate=True,
)
(
learnt_style_expressions,
learnt_grammar_expressions,
personality_expressions,
) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
style_habbits = []
grammar_habbits = []
# 1. learnt_expressions加权随机选3条
if learnt_style_expressions:
weights = [expr["count"] for expr in learnt_style_expressions]
selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3)
for expr in selected_learnt:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
# 2. learnt_grammar_expressions加权随机选3条
if learnt_grammar_expressions:
weights = [expr["count"] for expr in learnt_grammar_expressions]
selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3)
for expr in selected_learnt:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
grammar_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
# 3. personality_expressions随机选1条
if personality_expressions:
expr = random.choice(personality_expressions)
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
style_habbits_str = "\n".join(style_habbits)
grammar_habbits_str = "\n".join(grammar_habbits)
logger.debug("开始构建 focus prompt")
# --- Choose template based on chat type ---
if is_group_chat:
template_name = "default_expressor_prompt"
# Group specific formatting variables (already fetched or default)
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
prompt = await global_prompt_manager.format_prompt(
template_name,
style_habbits=style_habbits_str,
grammar_habbits=grammar_habbits_str,
chat_target=chat_target_1,
chat_info=chat_talking_prompt,
bot_name=global_config.bot.nickname,
prompt_personality="",
reason=reason,
in_mind_reply=in_mind_reply,
target_message=target_message,
)
else: # Private chat
template_name = "default_expressor_private_prompt"
prompt = await global_prompt_manager.format_prompt(
template_name,
sender_name=effective_sender_name, # Used in private template
chat_talking_prompt=chat_talking_prompt,
bot_name=global_config.bot.nickname,
prompt_personality=prompt_personality,
reason=reason,
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
)
return prompt
# --- 发送器 (Sender) --- # # --- 发送器 (Sender) --- #
async def _send_response_messages( async def send_response_messages(
self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str = ""
) -> Optional[MessageSending]: ) -> Optional[MessageSending]:
"""发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender"""
chat = self.chat_stream chat = self.chat_stream
@@ -243,7 +389,11 @@ class DefaultExpressor:
stream_name = chat_manager.get_stream_name(chat_id) or chat_id # 获取流名称用于日志 stream_name = chat_manager.get_stream_name(chat_id) or chat_id # 获取流名称用于日志
# 检查思考过程是否仍在进行,并获取开始时间 # 检查思考过程是否仍在进行,并获取开始时间
if thinking_id:
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id) thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
else:
thinking_id = "ds" + str(round(time.time(), 2))
thinking_start_time = time.time()
if thinking_start_time is None: if thinking_start_time is None:
logger.error(f"[{stream_name}]思考过程未找到或已结束,无法发送回复。") logger.error(f"[{stream_name}]思考过程未找到或已结束,无法发送回复。")
@@ -276,6 +426,7 @@ class DefaultExpressor:
reply_to=reply_to, reply_to=reply_to,
is_emoji=is_emoji, is_emoji=is_emoji,
thinking_id=thinking_id, thinking_id=thinking_id,
thinking_start_time=thinking_start_time,
) )
try: try:
@@ -297,6 +448,7 @@ class DefaultExpressor:
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}") logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}")
traceback.print_exc()
# 这里可以选择是继续发送下一个片段还是中止 # 这里可以选择是继续发送下一个片段还是中止
# 在尝试发送完所有片段后,完成原始的 thinking_id 状态 # 在尝试发送完所有片段后,完成原始的 thinking_id 状态
@@ -327,10 +479,10 @@ class DefaultExpressor:
reply_to: bool, reply_to: bool,
is_emoji: bool, is_emoji: bool,
thinking_id: str, thinking_id: str,
thinking_start_time: float,
) -> MessageSending: ) -> MessageSending:
"""构建单个发送消息""" """构建单个发送消息"""
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(self.chat_id, thinking_id)
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.bot.qq_account, user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname, user_nickname=global_config.bot.nickname,
@@ -350,3 +502,40 @@ class DefaultExpressor:
) )
return bot_message return bot_message
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素。
参数:
items: 待抽取的元素列表
weights: 每个元素对应的权重与items等长且为正数
k: 需要抽取的元素个数
返回:
selected: 按权重加权且不重复抽取的k个元素组成的列表
如果items中的元素不足k就只会返回所有可用的元素
实现思路:
每次从当前池中按权重加权随机选出一个元素选中后将其从池中移除重复k次。
这样保证了:
1. count越大被选中概率越高
2. 不会重复选中同一个元素
"""
selected = []
pool = list(zip(items, weights))
for _ in range(min(k, len(pool))):
total = sum(w for _, w in pool)
r = random.uniform(0, total)
upto = 0
for idx, (item, weight) in enumerate(pool):
upto += weight
if upto >= r:
selected.append(item)
pool.pop(idx)
break
return selected
init_prompt()

View File

@@ -14,15 +14,17 @@ from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info_processors.chattinginfo_processor import ChattingInfoProcessor from src.chat.focus_chat.info_processors.chattinginfo_processor import ChattingInfoProcessor
from src.chat.focus_chat.info_processors.mind_processor import MindProcessor from src.chat.focus_chat.info_processors.mind_processor import MindProcessor
from src.chat.heart_flow.observation.memory_observation import MemoryObservation from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.heart_flow.observation.working_observation import WorkingObservation from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.focus_chat.memory_activator import MemoryActivator from src.chat.focus_chat.memory_activator import MemoryActivator
from src.chat.focus_chat.info_processors.base_processor import BaseProcessor from src.chat.focus_chat.info_processors.base_processor import BaseProcessor
from src.chat.focus_chat.info_processors.self_processor import SelfProcessor
from src.chat.focus_chat.planners.planner import ActionPlanner from src.chat.focus_chat.planners.planner import ActionPlanner
from src.chat.focus_chat.planners.action_factory import ActionManager from src.chat.focus_chat.planners.action_manager import ActionManager
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
install(extra_lines=3) install(extra_lines=3)
@@ -57,7 +59,7 @@ async def _handle_cycle_delay(action_taken_this_cycle: bool, cycle_start_time: f
class HeartFChatting: class HeartFChatting:
""" """
管理一个连续的Plan-Replier-Sender循环 管理一个连续的Focus Chat循环
用于在特定聊天流中生成回复。 用于在特定聊天流中生成回复。
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
""" """
@@ -79,18 +81,24 @@ class HeartFChatting:
# 基础属性 # 基础属性
self.stream_id: str = chat_id # 聊天流ID self.stream_id: str = chat_id # 聊天流ID
self.chat_stream: Optional[ChatStream] = None # 关联的聊天流 self.chat_stream: Optional[ChatStream] = None # 关联的聊天流
self.observations: List[Observation] = observations # 关联的观察列表,用于监控聊天流状态
self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback
self.log_prefix: str = str(chat_id) # Initial default, will be updated self.log_prefix: str = str(chat_id) # Initial default, will be updated
self.memory_observation = MemoryObservation(observe_id=self.stream_id)
self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id) self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id)
self.working_observation = WorkingObservation(observe_id=self.stream_id) self.chatting_observation = observations[0]
self.memory_activator = MemoryActivator() self.memory_activator = MemoryActivator()
self.working_memory = WorkingMemory(chat_id=self.stream_id)
self.working_observation = WorkingMemoryObservation(
observe_id=self.stream_id, working_memory=self.working_memory
)
self.expressor = DefaultExpressor(chat_id=self.stream_id) self.expressor = DefaultExpressor(chat_id=self.stream_id)
self.action_manager = ActionManager() self.action_manager = ActionManager()
self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager) self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager)
self.hfcloop_observation.set_action_manager(self.action_manager)
self.all_observations = observations
# --- 处理器列表 --- # --- 处理器列表 ---
self.processors: List[BaseProcessor] = [] self.processors: List[BaseProcessor] = []
self._register_default_processors() self._register_default_processors()
@@ -107,9 +115,7 @@ class HeartFChatting:
self._cycle_counter = 0 self._cycle_counter = 0
self._cycle_history: Deque[CycleDetail] = deque(maxlen=10) # 保留最近10个循环的信息 self._cycle_history: Deque[CycleDetail] = deque(maxlen=10) # 保留最近10个循环的信息
self._current_cycle: Optional[CycleDetail] = None self._current_cycle: Optional[CycleDetail] = None
self.total_no_reply_count: int = 0 # 连续不回复计数器
self._shutting_down: bool = False # 关闭标志位 self._shutting_down: bool = False # 关闭标志位
self.total_waiting_time: float = 0.0 # 累计等待时间
async def _initialize(self) -> bool: async def _initialize(self) -> bool:
""" """
@@ -150,6 +156,8 @@ class HeartFChatting:
self.processors.append(ChattingInfoProcessor()) self.processors.append(ChattingInfoProcessor())
self.processors.append(MindProcessor(subheartflow_id=self.stream_id)) self.processors.append(MindProcessor(subheartflow_id=self.stream_id))
self.processors.append(ToolProcessor(subheartflow_id=self.stream_id)) self.processors.append(ToolProcessor(subheartflow_id=self.stream_id))
self.processors.append(WorkingMemoryProcessor(subheartflow_id=self.stream_id))
self.processors.append(SelfProcessor(subheartflow_id=self.stream_id))
logger.info(f"{self.log_prefix} 已注册默认处理器: {[p.__class__.__name__ for p in self.processors]}") logger.info(f"{self.log_prefix} 已注册默认处理器: {[p.__class__.__name__ for p in self.processors]}")
async def start(self): async def start(self):
@@ -327,6 +335,7 @@ class HeartFChatting:
f"{self.log_prefix} 处理器 {processor_name} 执行失败,耗时 (自并行开始): {duration_since_parallel_start:.2f}秒. 错误: {e}", f"{self.log_prefix} 处理器 {processor_name} 执行失败,耗时 (自并行开始): {duration_since_parallel_start:.2f}秒. 错误: {e}",
exc_info=True, exc_info=True,
) )
traceback.print_exc()
# 即使出错,也认为该任务结束了,已从 pending_tasks 中移除 # 即使出错,也认为该任务结束了,已从 pending_tasks 中移除
if pending_tasks: if pending_tasks:
@@ -348,13 +357,12 @@ class HeartFChatting:
async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> tuple[bool, str]: async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> tuple[bool, str]:
try: try:
with Timer("观察", cycle_timers): with Timer("观察", cycle_timers):
await self.observations[0].observe() # await self.observations[0].observe()
await self.memory_observation.observe() await self.chatting_observation.observe()
await self.working_observation.observe() await self.working_observation.observe()
await self.hfcloop_observation.observe() await self.hfcloop_observation.observe()
observations: List[Observation] = [] observations: List[Observation] = []
observations.append(self.observations[0]) observations.append(self.chatting_observation)
observations.append(self.memory_observation)
observations.append(self.working_observation) observations.append(self.working_observation)
observations.append(self.hfcloop_observation) observations.append(self.hfcloop_observation)
@@ -362,6 +370,8 @@ class HeartFChatting:
"observations": observations, "observations": observations,
} }
self.all_observations = observations
with Timer("回忆", cycle_timers): with Timer("回忆", cycle_timers):
running_memorys = await self.memory_activator.activate_memory(observations) running_memorys = await self.memory_activator.activate_memory(observations)
@@ -394,8 +404,7 @@ class HeartFChatting:
elif action_type == "no_reply": elif action_type == "no_reply":
action_str = "不回复" action_str = "不回复"
else: else:
action_type = "unknown" action_str = action_type
action_str = "未知动作"
logger.info(f"{self.log_prefix} 麦麦决定'{action_str}', 原因'{reasoning}'") logger.info(f"{self.log_prefix} 麦麦决定'{action_str}', 原因'{reasoning}'")
@@ -451,14 +460,14 @@ class HeartFChatting:
reasoning=reasoning, reasoning=reasoning,
cycle_timers=cycle_timers, cycle_timers=cycle_timers,
thinking_id=thinking_id, thinking_id=thinking_id,
observations=self.observations, observations=self.all_observations,
expressor=self.expressor, expressor=self.expressor,
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
current_cycle=self._current_cycle, current_cycle=self._current_cycle,
log_prefix=self.log_prefix, log_prefix=self.log_prefix,
on_consecutive_no_reply_callback=self.on_consecutive_no_reply_callback, on_consecutive_no_reply_callback=self.on_consecutive_no_reply_callback,
total_no_reply_count=self.total_no_reply_count, # total_no_reply_count=self.total_no_reply_count,
total_waiting_time=self.total_waiting_time, # total_waiting_time=self.total_waiting_time,
shutting_down=self._shutting_down, shutting_down=self._shutting_down,
) )
@@ -469,14 +478,6 @@ class HeartFChatting:
# 处理动作并获取结果 # 处理动作并获取结果
success, reply_text = await action_handler.handle_action() success, reply_text = await action_handler.handle_action()
# 更新状态计数器
if action == "no_reply":
self.total_no_reply_count = getattr(action_handler, "total_no_reply_count", self.total_no_reply_count)
self.total_waiting_time = getattr(action_handler, "total_waiting_time", self.total_waiting_time)
elif action == "reply":
self.total_no_reply_count = 0
self.total_waiting_time = 0.0
return success, reply_text return success, reply_text
except Exception as e: except Exception as e:

View File

@@ -106,6 +106,7 @@ class HeartFCSender:
and not message.is_private_message() and not message.is_private_message()
and message.reply.processed_plain_text != "[System Trigger Context]" and message.reply.processed_plain_text != "[System Trigger Context]"
): ):
message.set_reply(message.reply)
logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...") logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...")
await message.process() await message.process()

View File

@@ -7,41 +7,20 @@ from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
import time import time
from typing import Union, Optional from typing import Union, Optional
from src.common.database import db
from src.chat.utils.utils import get_recent_group_speaker from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
import random import random
import json
import math
from src.common.database.database_model import Knowledges
logger = get_logger("prompt") logger = get_logger("prompt")
def init_prompt(): def init_prompt():
Prompt(
"""
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
{style_habbits}
你现在正在群里聊天,以下是群里正在进行的聊天内容:
{chat_info}
以上是聊天内容,你需要了解聊天记录中的内容
{chat_target}
你的名字是{bot_name}{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。
请你根据情景使用以下句法:
{grammar_habbits}
回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。
回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 ),只输出一条回复就好。
现在,你说:
""",
"heart_flow_prompt",
)
Prompt( Prompt(
""" """
你有以下信息可供参考: 你有以下信息可供参考:
@@ -68,7 +47,7 @@ def init_prompt():
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt}{reply_style1} 你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt}{reply_style1}
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}{prompt_ger} 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)只输出回复内容。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情at或 @等 )。只输出回复内容。
{moderation_prompt} {moderation_prompt}
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""", 不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""",
"reasoning_prompt_main", "reasoning_prompt_main",
@@ -81,29 +60,6 @@ def init_prompt():
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
# --- Template for HeartFChatting (FOCUSED mode) ---
Prompt(
"""
{info_from_tools}
你正在和 {sender_name} 私聊。
聊天记录如下:
{chat_talking_prompt}
现在你想要回复。
你需要扮演一位网名叫{bot_name}的人进行回复,这个人的特点是:"{prompt_personality}"
你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,然后给出日常且口语化的回复,平淡一些。
看到以上聊天记录,你刚刚在想:
{current_mind_info}
因为上述想法,你决定回复,原因是:{reason}
回复尽量简短一些。请注意把握聊天内容,{reply_style2}{prompt_ger},不要复读自己说的话
{reply_style1},说中文,不要刻意突出自身学科背景,注意只输出回复内容。
{moderation_prompt}。注意:回复不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""",
"heart_flow_private_prompt", # New template for private FOCUSED chat
)
# --- Template for NormalChat (CHAT mode) ---
Prompt( Prompt(
""" """
{memory_prompt} {memory_prompt}
@@ -125,118 +81,6 @@ def init_prompt():
) )
async def _build_prompt_focus(
reason, current_mind_info, structured_info, chat_stream, sender_name, in_mind_reply, target_message
) -> str:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=0, level=2)
# Determine if it's a group chat
is_group_chat = bool(chat_stream.group_info)
# Use sender_name passed from caller for private chat, otherwise use a default for group
# Default sender_name for group chat isn't used in the group prompt template, but set for consistency
effective_sender_name = sender_name if not is_group_chat else "某人"
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=True,
timestamp_mode="relative",
read_mark=0.0,
truncate=True,
)
if structured_info:
structured_info_prompt = await global_prompt_manager.format_prompt(
"info_from_tools", structured_info=structured_info
)
else:
structured_info_prompt = ""
# 从/data/expression/对应chat_id/expressions.json中读取表达方式
(
learnt_style_expressions,
learnt_grammar_expressions,
personality_expressions,
) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id)
style_habbits = []
grammar_habbits = []
# 1. learnt_expressions加权随机选3条
if learnt_style_expressions:
weights = [expr["count"] for expr in learnt_style_expressions]
selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3)
for expr in selected_learnt:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
# 2. learnt_grammar_expressions加权随机选3条
if learnt_grammar_expressions:
weights = [expr["count"] for expr in learnt_grammar_expressions]
selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3)
for expr in selected_learnt:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
grammar_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
# 3. personality_expressions随机选1条
if personality_expressions:
expr = random.choice(personality_expressions)
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
style_habbits_str = "\n".join(style_habbits)
grammar_habbits_str = "\n".join(grammar_habbits)
logger.debug("开始构建 focus prompt")
# --- Choose template based on chat type ---
if is_group_chat:
template_name = "heart_flow_prompt"
# Group specific formatting variables (already fetched or default)
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
prompt = await global_prompt_manager.format_prompt(
template_name,
# info_from_tools=structured_info_prompt,
style_habbits=style_habbits_str,
grammar_habbits=grammar_habbits_str,
chat_target=chat_target_1, # Used in group template
# chat_talking_prompt=chat_talking_prompt,
chat_info=chat_talking_prompt,
bot_name=global_config.bot.nickname,
# prompt_personality=prompt_personality,
prompt_personality="",
reason=reason,
in_mind_reply=in_mind_reply,
target_message=target_message,
# moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
# sender_name is not used in the group template
)
else: # Private chat
template_name = "heart_flow_private_prompt"
prompt = await global_prompt_manager.format_prompt(
template_name,
info_from_tools=structured_info_prompt,
sender_name=effective_sender_name, # Used in private template
chat_talking_prompt=chat_talking_prompt,
bot_name=global_config.bot.nickname,
prompt_personality=prompt_personality,
# chat_target and chat_target_2 are not used in private template
current_mind_info=current_mind_info,
reason=reason,
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
)
# --- End choosing template ---
# logger.debug(f"focus_chat_prompt (is_group={is_group_chat}): \n{prompt}")
return prompt
class PromptBuilder: class PromptBuilder:
def __init__(self): def __init__(self):
self.prompt_built = "" self.prompt_built = ""
@@ -256,17 +100,6 @@ class PromptBuilder:
) -> Optional[str]: ) -> Optional[str]:
if build_mode == "normal": if build_mode == "normal":
return await self._build_prompt_normal(chat_stream, message_txt or "", sender_name) return await self._build_prompt_normal(chat_stream, message_txt or "", sender_name)
elif build_mode == "focus":
return await _build_prompt_focus(
reason,
current_mind_info,
structured_info,
chat_stream,
sender_name,
in_mind_reply,
target_message,
)
return None return None
async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> str: async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> str:
@@ -435,30 +268,6 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法 # 1. 先从LLM获取主题类似于记忆系统的做法
topics = [] topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息 # 如果无法提取到主题,直接使用整个消息
if not topics: if not topics:
@@ -568,8 +377,6 @@ class PromptBuilder:
for _i, result in enumerate(results, 1): for _i, result in enumerate(results, 1):
_similarity = result["similarity"] _similarity = result["similarity"]
content = result["content"].strip() content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n" related_info += f"{content}\n"
related_info += "\n" related_info += "\n"
@@ -598,14 +405,14 @@ class PromptBuilder:
return related_info return related_info
else: else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索") logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info return related_info
except Exception as e: except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}") logger.error(f"获取知识库内容时发生异常: {str(e)}")
try: try:
knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old related_info += knowledge_from_old
logger.debug( logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}" f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
@@ -621,103 +428,69 @@ class PromptBuilder:
) -> Union[str, list]: ) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline)) results_with_similarity = []
logger.debug(f"知识库查询结果数量: {len(results)}") try:
# Fetch all knowledge entries
# This might be inefficient for very large databases.
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
all_knowledges = Knowledges.select()
if not results: if not all_knowledges:
return [] if return_raw else ""
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
if query_embedding_magnitude == 0: # Avoid division by zero
return "" if not return_raw else []
for knowledge_item in all_knowledges:
try:
db_embedding_str = knowledge_item.embedding
db_embedding = json.loads(db_embedding_str)
if len(db_embedding) != len(query_embedding):
logger.warning(
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
)
continue
# Calculate Cosine Similarity
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
if db_embedding_magnitude == 0: # Avoid division by zero
similarity = 0.0
else:
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
if similarity >= threshold:
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
except json.JSONDecodeError:
logger.error(
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
)
except Exception as e:
logger.error(f"Error processing knowledge item: {e}")
# Sort by similarity in descending order
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
# Limit results
limited_results = results_with_similarity[:limit]
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
if not limited_results:
return "" if not return_raw else [] return "" if not return_raw else []
if return_raw: if return_raw:
return results return limited_results
else: else:
# 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in limited_results)
return "\n".join(str(result["content"]) for result in results)
except Exception as e:
def weighted_sample_no_replacement(items, weights, k) -> list: logger.error(f"Error querying Knowledges with Peewee: {e}")
""" return "" if not return_raw else []
加权且不放回地随机抽取k个元素。
参数:
items: 待抽取的元素列表
weights: 每个元素对应的权重与items等长且为正数
k: 需要抽取的元素个数
返回:
selected: 按权重加权且不重复抽取的k个元素组成的列表
如果items中的元素不足k就只会返回所有可用的元素
实现思路:
每次从当前池中按权重加权随机选出一个元素选中后将其从池中移除重复k次。
这样保证了:
1. count越大被选中概率越高
2. 不会重复选中同一个元素
"""
selected = []
pool = list(zip(items, weights))
for _ in range(min(k, len(pool))):
total = sum(w for _, w in pool)
r = random.uniform(0, total)
upto = 0
for idx, (item, weight) in enumerate(pool):
upto += weight
if upto >= r:
selected.append(item)
pool.pop(idx)
break
return selected
init_prompt() init_prompt()

View File

@@ -17,6 +17,7 @@ class InfoBase:
type: str = "base" type: str = "base"
data: Dict[str, Any] = field(default_factory=dict) data: Dict[str, Any] = field(default_factory=dict)
processed_info: str = ""
def get_type(self) -> str: def get_type(self) -> str:
"""获取信息类型 """获取信息类型
@@ -58,3 +59,11 @@ class InfoBase:
if isinstance(value, list): if isinstance(value, list):
return value return value
return [] return []
def get_processed_info(self) -> str:
"""获取处理后的信息
Returns:
str: 处理后的信息字符串
"""
return self.processed_info

View File

@@ -0,0 +1,40 @@
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class SelfInfo(InfoBase):
"""思维信息类
用于存储和管理当前思维状态的信息。
Attributes:
type (str): 信息类型标识符,默认为 "mind"
data (Dict[str, Any]): 包含 current_mind 的数据字典
"""
type: str = "self"
def get_self_info(self) -> str:
"""获取当前思维状态
Returns:
str: 当前思维状态
"""
return self.get_info("self_info") or ""
def set_self_info(self, self_info: str) -> None:
"""设置当前思维状态
Args:
self_info: 要设置的思维状态
"""
self.data["self_info"] = self_info
def get_processed_info(self) -> str:
"""获取处理后的信息
Returns:
str: 处理后的信息
"""
return self.get_self_info()

View File

@@ -0,0 +1,89 @@
from typing import Dict, Optional, List
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class WorkingMemoryInfo(InfoBase):
type: str = "workingmemory"
processed_info: str = ""
def set_talking_message(self, message: str) -> None:
"""设置说话消息
Args:
message (str): 说话消息内容
"""
self.data["talking_message"] = message
def set_working_memory(self, working_memory: List[str]) -> None:
"""设置工作记忆
Args:
working_memory (str): 工作记忆内容
"""
self.data["working_memory"] = working_memory
def add_working_memory(self, working_memory: str) -> None:
"""添加工作记忆
Args:
working_memory (str): 工作记忆内容
"""
working_memory_list = self.data.get("working_memory", [])
# print(f"working_memory_list: {working_memory_list}")
working_memory_list.append(working_memory)
# print(f"working_memory_list: {working_memory_list}")
self.data["working_memory"] = working_memory_list
def get_working_memory(self) -> List[str]:
"""获取工作记忆
Returns:
List[str]: 工作记忆内容
"""
return self.data.get("working_memory", [])
def get_type(self) -> str:
"""获取信息类型
Returns:
str: 当前信息对象的类型标识符
"""
return self.type
def get_data(self) -> Dict[str, str]:
"""获取所有信息数据
Returns:
Dict[str, str]: 包含所有信息数据的字典
"""
return self.data
def get_info(self, key: str) -> Optional[str]:
"""获取特定属性的信息
Args:
key: 要获取的属性键名
Returns:
Optional[str]: 属性值,如果键不存在则返回 None
"""
return self.data.get(key)
def get_processed_info(self) -> Dict[str, str]:
"""获取处理后的信息
Returns:
Dict[str, str]: 处理后的信息数据
"""
all_memory = self.get_working_memory()
# print(f"all_memory: {all_memory}")
memory_str = ""
for memory in all_memory:
memory_str += f"{memory}\n"
self.processed_info = memory_str
return self.processed_info

View File

@@ -27,7 +27,7 @@ class ChattingInfoProcessor(BaseProcessor):
"""初始化观察处理器""" """初始化观察处理器"""
super().__init__() super().__init__()
# TODO: API-Adapter修改标记 # TODO: API-Adapter修改标记
self.llm_summary = LLMRequest( self.model_summary = LLMRequest(
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
) )
@@ -55,6 +55,8 @@ class ChattingInfoProcessor(BaseProcessor):
for obs in observations: for obs in observations:
# print(f"obs: {obs}") # print(f"obs: {obs}")
if isinstance(obs, ChattingObservation): if isinstance(obs, ChattingObservation):
# print("1111111111111111111111读取111111111111111")
obs_info = ObsInfo() obs_info = ObsInfo()
await self.chat_compress(obs) await self.chat_compress(obs)
@@ -92,7 +94,7 @@ class ChattingInfoProcessor(BaseProcessor):
async def chat_compress(self, obs: ChattingObservation): async def chat_compress(self, obs: ChattingObservation):
if obs.compressor_prompt: if obs.compressor_prompt:
try: try:
summary_result, _, _ = await self.llm_summary.generate_response(obs.compressor_prompt) summary_result, _, _ = await self.model_summary.generate_response(obs.compressor_prompt)
summary = "没有主题的闲聊" # 默认值 summary = "没有主题的闲聊" # 默认值
if summary_result: # 确保结果不为空 if summary_result: # 确保结果不为空
summary = summary_result summary = summary_result

View File

@@ -6,21 +6,14 @@ import time
import traceback import traceback
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
import random
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import safe_json_dumps from src.chat.utils.json_utils import safe_json_dumps
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
import difflib
from src.chat.person_info.relationship_manager import relationship_manager from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo from src.chat.focus_chat.info.mind_info import MindInfo
from typing import List, Optional from typing import List, Optional
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.focus_chat.info_processors.processor_utils import (
calculate_similarity,
calculate_replacement_probability,
get_spark,
)
from typing import Dict from typing import Dict
from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.info_base import InfoBase
@@ -28,7 +21,6 @@ logger = get_logger("processor")
def init_prompt(): def init_prompt():
# --- Group Chat Prompt ---
group_prompt = """ group_prompt = """
你的名字是{bot_name} 你的名字是{bot_name}
{memory_str} {memory_str}
@@ -44,31 +36,29 @@ def init_prompt():
现在请你继续输出观察和规划,输出要求: 现在请你继续输出观察和规划,输出要求:
1. 先关注未读新消息的内容和近期回复历史 1. 先关注未读新消息的内容和近期回复历史
2. 根据新信息,修改和删除之前的观察和规划 2. 根据新信息,修改和删除之前的观察和规划
3. 根据聊天内容继续输出观察和规划{hf_do_next} 3. 根据聊天内容继续输出观察和规划
4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。 4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。
6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好""" 6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好"""
Prompt(group_prompt, "sub_heartflow_prompt_before") Prompt(group_prompt, "sub_heartflow_prompt_before")
# --- Private Chat Prompt ---
private_prompt = """ private_prompt = """
你的名字是{bot_name}
{memory_str} {memory_str}
{extra_info} {extra_info}
{relation_prompt} {relation_prompt}
你的名字是{bot_name},{prompt_personality},你现在{mood_info}
{cycle_info_block} {cycle_info_block}
现在是{time_now},你正在上网,和 {chat_target_name} 私聊,以下是你们的聊天内容: 现在是{time_now},你正在上网,和qq群里的网友们聊天以下是正在进行的聊天内容:
{chat_observe_info} {chat_observe_info}
以下是你之前对聊天的观察和规划:
以下是你之前对聊天的观察和规划,你的名字是{bot_name}
{last_mind} {last_mind}
请仔细阅读聊天内容,想想你和 {chat_target_name} 的关系,回顾你们刚刚的交流,你刚刚发言和对方的反应,思考聊天的主题。
请思考你要不要回复以及如何回复对方。 现在请你继续输出观察和规划,输出要求:
思考并输出你的内心想法 1. 先关注未读新消息的内容和近期回复历史
输出要求: 2. 根据新信息,修改和删除之前的观察和规划
1. 根据聊天内容生成你的想法,{hf_do_next} 3. 根据聊天内容继续输出观察和规划
2. 不要分点、不要使用表情符号 4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。
3. 避免多余符号(冒号、引号、括号等) 6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好"""
4. 语言简洁自然,不要浮夸
5. 如果你刚发言,对方没有回复你,请谨慎回复"""
Prompt(private_prompt, "sub_heartflow_prompt_private_before") Prompt(private_prompt, "sub_heartflow_prompt_private_before")
@@ -210,45 +200,26 @@ class MindProcessor(BaseProcessor):
for person in person_list: for person in person_list:
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True) relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
# 构建个性部分
# prompt_personality = individuality.get_prompt(x_person=2, level=2)
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
spark_prompt = get_spark()
# ---------- 5. 构建最终提示词 ----------
template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before" template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before"
logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板") logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板")
prompt = (await global_prompt_manager.get_prompt_async(template_name)).format( prompt = (await global_prompt_manager.get_prompt_async(template_name)).format(
bot_name=individuality.name,
memory_str=memory_str, memory_str=memory_str,
extra_info=self.structured_info_str, extra_info=self.structured_info_str,
# prompt_personality=prompt_personality,
relation_prompt=relation_prompt, relation_prompt=relation_prompt,
bot_name=individuality.name, time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
time_now=time_now,
chat_observe_info=chat_observe_info, chat_observe_info=chat_observe_info,
# mood_info="mood_info",
hf_do_next=spark_prompt,
last_mind=previous_mind, last_mind=previous_mind,
cycle_info_block=hfcloop_observe_info, cycle_info_block=hfcloop_observe_info,
chat_target_name=chat_target_name, chat_target_name=chat_target_name,
) )
# 在构建完提示词后生成最终的prompt字符串 content = "(不知道该想些什么...)"
final_prompt = prompt
content = "" # 初始化内容变量
try: try:
# 调用LLM生成响应 content, _ = await self.llm_model.generate_response_async(prompt=prompt)
response, _ = await self.llm_model.generate_response_async(prompt=final_prompt) if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果思考失败。")
# 直接使用LLM返回的文本响应作为 content
content = response if response else ""
except Exception as e: except Exception as e:
# 处理总体异常 # 处理总体异常
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
@@ -256,16 +227,8 @@ class MindProcessor(BaseProcessor):
content = "思考过程中出现错误" content = "思考过程中出现错误"
# 记录初步思考结果 # 记录初步思考结果
logger.debug(f"{self.log_prefix} 思考prompt: \n{final_prompt}\n") logger.debug(f"{self.log_prefix} 思考prompt: \n{prompt}\n")
# 处理空响应情况
if not content:
content = "(不知道该想些什么...)"
logger.warning(f"{self.log_prefix} LLM返回空结果思考失败。")
# ---------- 8. 更新思考状态并返回结果 ----------
logger.info(f"{self.log_prefix} 思考结果: {content}") logger.info(f"{self.log_prefix} 思考结果: {content}")
# 更新当前思考内容
self.update_current_mind(content) self.update_current_mind(content)
return content return content
@@ -275,138 +238,5 @@ class MindProcessor(BaseProcessor):
self.past_mind.append(self.current_mind) self.past_mind.append(self.current_mind)
self.current_mind = response self.current_mind = response
def de_similar(self, previous_mind, new_content):
try:
similarity = calculate_similarity(previous_mind, new_content)
replacement_prob = calculate_replacement_probability(similarity)
logger.debug(f"{self.log_prefix} 新旧想法相似度: {similarity:.2f}, 替换概率: {replacement_prob:.2f}")
# 定义词语列表 (移到判断之前)
yu_qi_ci_liebiao = ["", "", "", "", "", ""]
zhuan_zhe_liebiao = ["但是", "不过", "然而", "可是", "只是"]
cheng_jie_liebiao = ["然后", "接着", "此外", "而且", "另外"]
zhuan_jie_ci_liebiao = zhuan_zhe_liebiao + cheng_jie_liebiao
if random.random() < replacement_prob:
# 相似度非常高时,尝试去重或特殊处理
if similarity == 1.0:
logger.debug(f"{self.log_prefix} 想法完全重复 (相似度 1.0),执行特殊处理...")
# 随机截取大约一半内容
if len(new_content) > 1: # 避免内容过短无法截取
split_point = max(
1, len(new_content) // 2 + random.randint(-len(new_content) // 4, len(new_content) // 4)
)
truncated_content = new_content[:split_point]
else:
truncated_content = new_content # 如果只有一个字符或者为空,就不截取了
# 添加语气词和转折/承接词
yu_qi_ci = random.choice(yu_qi_ci_liebiao)
zhuan_jie_ci = random.choice(zhuan_jie_ci_liebiao)
content = f"{yu_qi_ci}{zhuan_jie_ci}{truncated_content}"
logger.debug(f"{self.log_prefix} 想法重复,特殊处理后: {content}")
else:
# 相似度较高但非100%,执行标准去重逻辑
logger.debug(f"{self.log_prefix} 执行概率性去重 (概率: {replacement_prob:.2f})...")
logger.debug(
f"{self.log_prefix} previous_mind类型: {type(previous_mind)}, new_content类型: {type(new_content)}"
)
matcher = difflib.SequenceMatcher(None, previous_mind, new_content)
logger.debug(f"{self.log_prefix} matcher类型: {type(matcher)}")
deduplicated_parts = []
last_match_end_in_b = 0
# 获取并记录所有匹配块
matching_blocks = matcher.get_matching_blocks()
logger.debug(f"{self.log_prefix} 匹配块数量: {len(matching_blocks)}")
logger.debug(
f"{self.log_prefix} 匹配块示例(前3个): {matching_blocks[:3] if len(matching_blocks) > 3 else matching_blocks}"
)
# get_matching_blocks()返回形如[(i, j, n), ...]的列表其中i是a中的索引j是b中的索引n是匹配的长度
for idx, match in enumerate(matching_blocks):
if not isinstance(match, tuple):
logger.error(f"{self.log_prefix} 匹配块 {idx} 不是元组类型,而是 {type(match)}: {match}")
continue
try:
_i, j, n = match # 解包元组为三个变量
logger.debug(f"{self.log_prefix} 匹配块 {idx}: i={_i}, j={j}, n={n}")
if last_match_end_in_b < j:
# 确保添加的是字符串,而不是元组
try:
non_matching_part = new_content[last_match_end_in_b:j]
logger.debug(
f"{self.log_prefix} 添加非匹配部分: '{non_matching_part}', 类型: {type(non_matching_part)}"
)
if not isinstance(non_matching_part, str):
logger.warning(
f"{self.log_prefix} 非匹配部分不是字符串类型: {type(non_matching_part)}"
)
non_matching_part = str(non_matching_part)
deduplicated_parts.append(non_matching_part)
except Exception as e:
logger.error(f"{self.log_prefix} 处理非匹配部分时出错: {e}")
logger.error(traceback.format_exc())
last_match_end_in_b = j + n
except Exception as e:
logger.error(f"{self.log_prefix} 处理匹配块时出错: {e}")
logger.error(traceback.format_exc())
logger.debug(f"{self.log_prefix} 去重前部分列表: {deduplicated_parts}")
logger.debug(f"{self.log_prefix} 列表元素类型: {[type(part) for part in deduplicated_parts]}")
# 确保所有元素都是字符串
deduplicated_parts = [str(part) for part in deduplicated_parts]
# 防止列表为空
if not deduplicated_parts:
logger.warning(f"{self.log_prefix} 去重后列表为空,添加空字符串")
deduplicated_parts = [""]
logger.debug(f"{self.log_prefix} 处理后的部分列表: {deduplicated_parts}")
try:
deduplicated_content = "".join(deduplicated_parts).strip()
logger.debug(f"{self.log_prefix} 拼接后的去重内容: '{deduplicated_content}'")
except Exception as e:
logger.error(f"{self.log_prefix} 拼接去重内容时出错: {e}")
logger.error(traceback.format_exc())
deduplicated_content = ""
if deduplicated_content:
# 根据概率决定是否添加词语
prefix_str = ""
if random.random() < 0.3: # 30% 概率添加语气词
prefix_str += random.choice(yu_qi_ci_liebiao)
if random.random() < 0.7: # 70% 概率添加转折/承接词
prefix_str += random.choice(zhuan_jie_ci_liebiao)
# 组合最终结果
if prefix_str:
content = f"{prefix_str}{deduplicated_content}" # 更新 content
logger.debug(f"{self.log_prefix} 去重并添加引导词后: {content}")
else:
content = deduplicated_content # 更新 content
logger.debug(f"{self.log_prefix} 去重后 (未添加引导词): {content}")
else:
logger.warning(f"{self.log_prefix} 去重后内容为空保留原始LLM输出: {new_content}")
content = new_content # 保留原始 content
else:
logger.debug(f"{self.log_prefix} 未执行概率性去重 (概率: {replacement_prob:.2f})")
# content 保持 new_content 不变
except Exception as e:
logger.error(f"{self.log_prefix} 应用概率性去重或特殊处理时出错: {e}")
logger.error(traceback.format_exc())
# 出错时保留原始 content
content = new_content
return content
init_prompt() init_prompt()

View File

@@ -1,56 +0,0 @@
import difflib
import random
import time
def calculate_similarity(text_a: str, text_b: str) -> float:
"""
计算两个文本字符串的相似度。
"""
if not text_a or not text_b:
return 0.0
matcher = difflib.SequenceMatcher(None, text_a, text_b)
return matcher.ratio()
def calculate_replacement_probability(similarity: float) -> float:
"""
根据相似度计算替换的概率。
规则:
- 相似度 <= 0.4: 概率 = 0
- 相似度 >= 0.9: 概率 = 1
- 相似度 == 0.6: 概率 = 0.7
- 0.4 < 相似度 <= 0.6: 线性插值 (0.4, 0) 到 (0.6, 0.7)
- 0.6 < 相似度 < 0.9: 线性插值 (0.6, 0.7) 到 (0.9, 1.0)
"""
if similarity <= 0.4:
return 0.0
elif similarity >= 0.9:
return 1.0
elif 0.4 < similarity <= 0.6:
# p = 3.5 * s - 1.4
probability = 3.5 * similarity - 1.4
return max(0.0, probability)
else: # 0.6 < similarity < 0.9
# p = s + 0.1
probability = similarity + 0.1
return min(1.0, max(0.0, probability))
def get_spark():
local_random = random.Random()
current_minute = int(time.strftime("%M"))
local_random.seed(current_minute)
hf_options = [
("可以参考之前的想法,在原来想法的基础上继续思考", 0.2),
("可以参考之前的想法,在原来的想法上尝试新的话题", 0.4),
("不要太深入", 0.2),
("进行深入思考", 0.2),
]
# 加权随机选择思考指导
hf_do_next = local_random.choices(
[option[0] for option in hf_options], weights=[option[1] for option in hf_options], k=1
)[0]
return hf_do_next

View File

@@ -0,0 +1,161 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.observation import Observation
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import chat_manager
from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from typing import List, Optional
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from typing import Dict
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info.self_info import SelfInfo
logger = get_logger("processor")
def init_prompt():
indentify_prompt = """
你的名字是{bot_name},你是一个AI但是你伪装成了一个人类你的人格是{prompt_personality}
你对外的形象是一只橙色的鱼,头上有绿色的树叶,你用的头像也是这个。
{relation_prompt}
{memory_str}
现在是{time_now}你正在上网和qq群里的网友们聊天以下是正在进行的聊天内容
{chat_observe_info}
现在请你根据现有的信息,思考自我认同
1. 你是一个什么样的人,你和群里的人关系如何
2. 思考有没有人提到你,或者图片与你有关
3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同
4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景
"""
Prompt(indentify_prompt, "indentify_prompt")
class SelfProcessor(BaseProcessor):
log_prefix = "自我认同"
def __init__(self, subheartflow_id: str):
super().__init__()
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
model=global_config.model.sub_heartflow,
temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800,
request_type="self_identify",
)
name = chat_manager.get_stream_name(self.subheartflow_id)
self.log_prefix = f"[{name}] "
async def process_info(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
) -> List[InfoBase]:
"""处理信息对象
Args:
*infos: 可变数量的InfoBase类型的信息对象
Returns:
List[InfoBase]: 处理后的结构化信息列表
"""
self_info_str = await self.self_indentify(observations, running_memorys)
if self_info_str:
self_info = SelfInfo()
self_info.set_self_info(self_info_str)
else:
self_info = None
return None
return [self_info]
async def self_indentify(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None
):
"""
在回复前进行思考,生成内心想法并收集工具调用结果
参数:
observations: 观察信息
返回:
如果return_prompt为False:
tuple: (current_mind, past_mind) 当前想法和过去的想法列表
如果return_prompt为True:
tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt
"""
memory_str = ""
if running_memorys:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memorys:
memory_str += f"{running_memory['topic']}: {running_memory['content']}\n"
if observations is None:
observations = []
for observation in observations:
if isinstance(observation, ChattingObservation):
# 获取聊天元信息
is_group_chat = observation.is_group_chat
chat_target_info = observation.chat_target_info
chat_target_name = "对方" # 私聊默认名称
if not is_group_chat and chat_target_info:
# 优先使用person_name其次user_nickname最后回退到默认值
chat_target_name = (
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or chat_target_name
)
# 获取聊天内容
chat_observe_info = observation.get_observe_info()
person_list = observation.person_list
if isinstance(observation, HFCloopObservation):
# hfcloop_observe_info = observation.get_observe_info()
pass
individuality = Individuality.get_instance()
personality_block = individuality.get_prompt(x_person=2, level=2)
relation_prompt = ""
for person in person_list:
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
bot_name=individuality.name,
prompt_personality=personality_block,
memory_str=memory_str,
relation_prompt=relation_prompt,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_observe_info,
)
content = ""
try:
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果自我识别失败。")
except Exception as e:
# 处理总体异常
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
logger.error(traceback.format_exc())
content = "自我识别过程中出现错误"
if content == "None":
content = ""
# 记录初步思考结果
logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n")
logger.info(f"{self.log_prefix} 自我识别结果: {content}")
return content
init_prompt()

View File

@@ -11,8 +11,8 @@ from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor from .base_processor import BaseProcessor
from typing import List, Optional, Dict from typing import List, Optional, Dict
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation
from src.chat.heart_flow.observation.working_observation import WorkingObservation
from src.chat.focus_chat.info.structured_info import StructuredInfo from src.chat.focus_chat.info.structured_info import StructuredInfo
from src.chat.heart_flow.observation.structure_observation import StructureObservation
logger = get_logger("processor") logger = get_logger("processor")
@@ -24,9 +24,6 @@ def init_prompt():
tool_executor_prompt = """ tool_executor_prompt = """
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now} 你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}
你要在群聊中扮演以下角色:
{prompt_personality}
你当前的额外信息: 你当前的额外信息:
{memory_str} {memory_str}
@@ -70,6 +67,8 @@ class ToolProcessor(BaseProcessor):
list: 处理后的结构化信息列表 list: 处理后的结构化信息列表
""" """
working_infos = []
if observations: if observations:
for observation in observations: for observation in observations:
if isinstance(observation, ChattingObservation): if isinstance(observation, ChattingObservation):
@@ -77,7 +76,7 @@ class ToolProcessor(BaseProcessor):
# 更新WorkingObservation中的结构化信息 # 更新WorkingObservation中的结构化信息
for observation in observations: for observation in observations:
if isinstance(observation, WorkingObservation): if isinstance(observation, StructureObservation):
for structured_info in result: for structured_info in result:
logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}") logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}")
observation.add_structured_info(structured_info) observation.add_structured_info(structured_info)
@@ -86,6 +85,7 @@ class ToolProcessor(BaseProcessor):
logger.debug(f"{self.log_prefix} 获取更新后WorkingObservation中的结构化信息: {working_infos}") logger.debug(f"{self.log_prefix} 获取更新后WorkingObservation中的结构化信息: {working_infos}")
structured_info = StructuredInfo() structured_info = StructuredInfo()
if working_infos:
for working_info in working_infos: for working_info in working_infos:
structured_info.set_info(working_info.get("type"), working_info.get("content")) structured_info.set_info(working_info.get("type"), working_info.get("content"))
@@ -134,7 +134,7 @@ class ToolProcessor(BaseProcessor):
# 获取个性信息 # 获取个性信息
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=2) # prompt_personality = individuality.get_prompt(x_person=2, level=2)
# 获取时间信息 # 获取时间信息
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
@@ -148,14 +148,14 @@ class ToolProcessor(BaseProcessor):
# chat_target_name=chat_target_name, # chat_target_name=chat_target_name,
is_group_chat=is_group_chat, is_group_chat=is_group_chat,
# relation_prompt=relation_prompt, # relation_prompt=relation_prompt,
prompt_personality=prompt_personality, # prompt_personality=prompt_personality,
# mood_info=mood_info, # mood_info=mood_info,
bot_name=individuality.name, bot_name=individuality.name,
time_now=time_now, time_now=time_now,
) )
# 调用LLM专注于工具使用 # 调用LLM专注于工具使用
logger.debug(f"开始执行工具调用{prompt}") # logger.debug(f"开始执行工具调用{prompt}")
response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools) response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
logger.debug(f"获取到工具原始输出:\n{tool_calls}") logger.debug(f"获取到工具原始输出:\n{tool_calls}")

View File

@@ -0,0 +1,236 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.observation import Observation
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
import time
import traceback
from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import chat_manager
from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo
from typing import List, Optional
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from typing import Dict
from src.chat.focus_chat.info.info_base import InfoBase
from json_repair import repair_json
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
import asyncio
import json
logger = get_logger("processor")
def init_prompt():
memory_proces_prompt = """
你的名字是{bot_name}
现在是{time_now}你正在上网和qq群里的网友们聊天以下是正在进行的聊天内容
{chat_observe_info}
以下是你已经总结的记忆摘要你可以调取这些记忆查看内容来帮助你聊天不要一次调取太多记忆最多调取3个左右记忆
{memory_str}
观察聊天内容和已经总结的记忆,思考是否有新内容需要总结成记忆,如果有,就输出 true否则输出 false
如果当前聊天记录的内容已经被总结千万不要总结新记忆输出false
如果已经总结的记忆包含了当前聊天记录的内容千万不要总结新记忆输出false
如果已经总结的记忆摘要,包含了当前聊天记录的内容千万不要总结新记忆输出false
如果有相近的记忆请合并记忆输出merge_memory格式为[["id1", "id2"], ["id3", "id4"],...]你可以进行多组合并但是每组合并只能有两个记忆id不要输出其他内容
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆以JSON格式输出格式如下
```json
{{
"selected_memory_ids": ["id1", "id2", ...],
"new_memory": "true" or "false",
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
}}
```
"""
Prompt(memory_proces_prompt, "prompt_memory_proces")
class WorkingMemoryProcessor(BaseProcessor):
log_prefix = "工作记忆"
def __init__(self, subheartflow_id: str):
super().__init__()
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
model=global_config.model.sub_heartflow,
temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800,
request_type="working_memory",
)
name = chat_manager.get_stream_name(self.subheartflow_id)
self.log_prefix = f"[{name}] "
async def process_info(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
) -> List[InfoBase]:
"""处理信息对象
Args:
*infos: 可变数量的InfoBase类型的信息对象
Returns:
List[InfoBase]: 处理后的结构化信息列表
"""
working_memory = None
chat_info = ""
try:
for observation in observations:
if isinstance(observation, WorkingMemoryObservation):
working_memory = observation.get_observe_info()
# working_memory_obs = observation
if isinstance(observation, ChattingObservation):
chat_info = observation.get_observe_info()
# chat_info_truncate = observation.talking_message_str_truncate
if not working_memory:
logger.warning(f"{self.log_prefix} 没有找到工作记忆对象")
mind_info = MindInfo()
return [mind_info]
except Exception as e:
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
logger.error(traceback.format_exc())
return []
all_memory = working_memory.get_all_memories()
memory_prompts = []
for memory in all_memory:
# memory_content = memory.data
memory_summary = memory.summary
memory_id = memory.id
memory_brief = memory_summary.get("brief")
# memory_detailed = memory_summary.get("detailed")
memory_keypoints = memory_summary.get("keypoints")
memory_events = memory_summary.get("events")
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
memory_prompts.append(memory_single_prompt)
memory_choose_str = "".join(memory_prompts)
# 使用提示模板进行处理
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
bot_name=global_config.bot.nickname,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_info,
memory_str=memory_choose_str,
)
# 调用LLM处理记忆
content = ""
try:
logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}")
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果处理工作记忆失败。")
except Exception as e:
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
logger.error(traceback.format_exc())
# 解析LLM返回的JSON
try:
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict):
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败结果不是字典类型: {type(result)}")
return []
selected_memory_ids = result.get("selected_memory_ids", [])
new_memory = result.get("new_memory", "")
merge_memory = result.get("merge_memory", [])
except Exception as e:
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
logger.error(traceback.format_exc())
return []
logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}")
# 根据selected_memory_ids调取记忆
memory_str = ""
if selected_memory_ids:
for memory_id in selected_memory_ids:
memory = await working_memory.retrieve_memory(memory_id)
if memory:
# memory_content = memory.data
memory_summary = memory.summary
memory_id = memory.id
memory_brief = memory_summary.get("brief")
# memory_detailed = memory_summary.get("detailed")
memory_keypoints = memory_summary.get("keypoints")
memory_events = memory_summary.get("events")
for keypoint in memory_keypoints:
memory_str += f"记忆要点:{keypoint}\n"
for event in memory_events:
memory_str += f"记忆事件:{event}\n"
# memory_str += f"记忆摘要:{memory_detailed}\n"
# memory_str += f"记忆主题:{memory_brief}\n"
working_memory_info = WorkingMemoryInfo()
if memory_str:
working_memory_info.add_working_memory(memory_str)
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
else:
logger.warning(f"{self.log_prefix} 没有找到工作记忆")
# 根据聊天内容添加新记忆
if new_memory:
# 使用异步方式添加新记忆,不阻塞主流程
logger.debug(f"{self.log_prefix} {new_memory}新记忆: ")
asyncio.create_task(self.add_memory_async(working_memory, chat_info))
if merge_memory:
for merge_pairs in merge_memory:
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
if memory1 and memory2:
memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n"
memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n"
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
return [working_memory_info]
async def add_memory_async(self, working_memory: WorkingMemory, content: str):
"""异步添加记忆,不阻塞主流程
Args:
working_memory: 工作记忆对象
content: 记忆内容
"""
try:
await working_memory.add_memory(content=content, from_source="chat_text")
logger.debug(f"{self.log_prefix} 异步添加新记忆成功: {content[:30]}...")
except Exception as e:
logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}")
logger.error(traceback.format_exc())
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
"""异步合并记忆,不阻塞主流程
Args:
working_memory: 工作记忆对象
memory_str: 记忆内容
"""
try:
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
logger.debug(f"{self.log_prefix} 异步合并记忆成功: {memory_id1}{memory_id2}...")
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.summary.get('brief')}")
logger.debug(f"{self.log_prefix} 合并后的记忆详情: {merged_memory.summary.get('detailed')}")
logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('keypoints')}")
logger.debug(f"{self.log_prefix} 合并后的记忆事件: {merged_memory.summary.get('events')}")
except Exception as e:
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
logger.error(traceback.format_exc())
init_prompt()

View File

@@ -1,5 +1,5 @@
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.working_observation import WorkingObservation from src.chat.heart_flow.observation.structure_observation import StructureObservation
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.models.utils_model import LLMRequest from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
@@ -54,7 +54,7 @@ class MemoryActivator:
for observation in observations: for observation in observations:
if isinstance(observation, ChattingObservation): if isinstance(observation, ChattingObservation):
obs_info_text += observation.get_observe_info() obs_info_text += observation.get_observe_info()
elif isinstance(observation, WorkingObservation): elif isinstance(observation, StructureObservation):
working_info = observation.get_observe_info() working_info = observation.get_observe_info()
for working_info_item in working_info: for working_info_item in working_info:
obs_info_text += f"{working_info_item['type']}: {working_info_item['content']}\n" obs_info_text += f"{working_info_item['type']}: {working_info_item['content']}\n"

View File

@@ -5,10 +5,14 @@ from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
import importlib
import pkgutil
import os
# 导入动作类,确保装饰器被执行 # 导入动作类,确保装饰器被执行
import src.chat.focus_chat.planners.actions # noqa
logger = get_logger("action_factory") logger = get_logger("action_manager")
# 定义动作信息类型 # 定义动作信息类型
ActionInfo = Dict[str, Any] ActionInfo = Dict[str, Any]
@@ -34,13 +38,12 @@ class ActionManager:
# 加载所有已注册动作 # 加载所有已注册动作
self._load_registered_actions() self._load_registered_actions()
# 加载插件动作
self._load_plugin_actions()
# 初始化时将默认动作加载到使用中的动作 # 初始化时将默认动作加载到使用中的动作
self._using_actions = self._default_actions.copy() self._using_actions = self._default_actions.copy()
# logger.info(f"当前可用动作: {list(self._using_actions.keys())}")
# for action_name, action_info in self._using_actions.items():
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
def _load_registered_actions(self) -> None: def _load_registered_actions(self) -> None:
""" """
加载所有通过装饰器注册的动作 加载所有通过装饰器注册的动作
@@ -49,6 +52,11 @@ class ActionManager:
# 从_ACTION_REGISTRY获取所有已注册动作 # 从_ACTION_REGISTRY获取所有已注册动作
for action_name, action_class in _ACTION_REGISTRY.items(): for action_name, action_class in _ACTION_REGISTRY.items():
# 获取动作相关信息 # 获取动作相关信息
# 不读取插件动作和基类
if action_name == "base_action" or action_name == "plugin_action":
continue
action_description: str = getattr(action_class, "action_description", "") action_description: str = getattr(action_class, "action_description", "")
action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {}) action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {})
action_require: list[str] = getattr(action_class, "action_require", []) action_require: list[str] = getattr(action_class, "action_require", [])
@@ -62,10 +70,6 @@ class ActionManager:
"require": action_require, "require": action_require,
} }
# 注册2
print("注册2")
print(action_info)
# 添加到所有已注册的动作 # 添加到所有已注册的动作
self._registered_actions[action_name] = action_info self._registered_actions[action_name] = action_info
@@ -73,14 +77,56 @@ class ActionManager:
if is_default: if is_default:
self._default_actions[action_name] = action_info self._default_actions[action_name] = action_info
logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") # logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
logger.info(f"默认动作: {list(self._default_actions.keys())}") # logger.info(f"默认动作: {list(self._default_actions.keys())}")
# for action_name, action_info in self._default_actions.items(): # for action_name, action_info in self._default_actions.items():
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
except Exception as e: except Exception as e:
logger.error(f"加载已注册动作失败: {e}") logger.error(f"加载已注册动作失败: {e}")
def _load_plugin_actions(self) -> None:
"""
加载所有插件目录中的动作
"""
try:
# 检查插件目录是否存在
plugin_path = "src.plugins"
plugin_dir = plugin_path.replace(".", os.path.sep)
if not os.path.exists(plugin_dir):
logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载")
return
# 导入插件包
try:
plugins_package = importlib.import_module(plugin_path)
except ImportError as e:
logger.error(f"导入插件包失败: {e}")
return
# 遍历插件包中的所有子包
for _, plugin_name, is_pkg in pkgutil.iter_modules(
plugins_package.__path__, plugins_package.__name__ + "."
):
if not is_pkg:
continue
# 检查插件是否有actions子包
plugin_actions_path = f"{plugin_name}.actions"
try:
# 尝试导入插件的actions包
importlib.import_module(plugin_actions_path)
logger.info(f"成功加载插件动作模块: {plugin_actions_path}")
except ImportError as e:
logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}")
continue
# 再次从_ACTION_REGISTRY获取所有动作包括刚刚从插件加载的
self._load_registered_actions()
except Exception as e:
logger.error(f"加载插件动作失败: {e}")
def create_action( def create_action(
self, self,
action_name: str, action_name: str,
@@ -94,8 +140,8 @@ class ActionManager:
current_cycle: CycleDetail, current_cycle: CycleDetail,
log_prefix: str, log_prefix: str,
on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
total_no_reply_count: int = 0, # total_no_reply_count: int = 0,
total_waiting_time: float = 0.0, # total_waiting_time: float = 0.0,
shutting_down: bool = False, shutting_down: bool = False,
) -> Optional[BaseAction]: ) -> Optional[BaseAction]:
""" """
@@ -131,7 +177,7 @@ class ActionManager:
return None return None
try: try:
# 创建动作实例并传递所有必要参数 # 创建动作实例
instance = handler_class( instance = handler_class(
action_name=action_name, action_name=action_name,
action_data=action_data, action_data=action_data,
@@ -139,14 +185,14 @@ class ActionManager:
cycle_timers=cycle_timers, cycle_timers=cycle_timers,
thinking_id=thinking_id, thinking_id=thinking_id,
observations=observations, observations=observations,
on_consecutive_no_reply_callback=on_consecutive_no_reply_callback,
current_cycle=current_cycle,
log_prefix=log_prefix,
total_no_reply_count=total_no_reply_count,
total_waiting_time=total_waiting_time,
shutting_down=shutting_down,
expressor=expressor, expressor=expressor,
chat_stream=chat_stream, chat_stream=chat_stream,
current_cycle=current_cycle,
log_prefix=log_prefix,
on_consecutive_no_reply_callback=on_consecutive_no_reply_callback,
# total_no_reply_count=total_no_reply_count,
# total_waiting_time=total_waiting_time,
shutting_down=shutting_down,
) )
return instance return instance
@@ -272,7 +318,3 @@ class ActionManager:
Optional[Type[BaseAction]]: 动作处理器类如果不存在则返回None Optional[Type[BaseAction]]: 动作处理器类如果不存在则返回None
""" """
return _ACTION_REGISTRY.get(action_name) return _ACTION_REGISTRY.get(action_name)
# 创建全局实例
ActionFactory = ActionManager()

View File

@@ -0,0 +1,5 @@
# 导入所有动作模块以确保装饰器被执行
from . import reply_action # noqa
from . import no_reply_action # noqa
# 在此处添加更多动作模块导入

View File

@@ -43,8 +43,8 @@ class NoReplyAction(BaseAction):
on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]],
current_cycle: CycleDetail, current_cycle: CycleDetail,
log_prefix: str, log_prefix: str,
total_no_reply_count: int = 0, # total_no_reply_count: int = 0,
total_waiting_time: float = 0.0, # total_waiting_time: float = 0.0,
shutting_down: bool = False, shutting_down: bool = False,
**kwargs, **kwargs,
): ):
@@ -69,8 +69,8 @@ class NoReplyAction(BaseAction):
self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback
self._current_cycle = current_cycle self._current_cycle = current_cycle
self.log_prefix = log_prefix self.log_prefix = log_prefix
self.total_no_reply_count = total_no_reply_count # self.total_no_reply_count = total_no_reply_count
self.total_waiting_time = total_waiting_time # self.total_waiting_time = total_waiting_time
self._shutting_down = shutting_down self._shutting_down = shutting_down
async def handle_action(self) -> Tuple[bool, str]: async def handle_action(self) -> Tuple[bool, str]:
@@ -94,36 +94,7 @@ class NoReplyAction(BaseAction):
# 等待新消息、超时或关闭信号,并获取结果 # 等待新消息、超时或关闭信号,并获取结果
await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix) await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix)
# 从计时器获取实际等待时间 # 从计时器获取实际等待时间
current_waiting = self.cycle_timers.get("等待新消息", 0.0) _current_waiting = self.cycle_timers.get("等待新消息", 0.0)
if not self._shutting_down:
self.total_no_reply_count += 1
self.total_waiting_time += current_waiting # 累加等待时间
logger.debug(
f"{self.log_prefix} 连续不回复计数增加: {self.total_no_reply_count}/{CONSECUTIVE_NO_REPLY_THRESHOLD}, "
f"本次等待: {current_waiting:.2f}秒, 累计等待: {self.total_waiting_time:.2f}"
)
# 检查是否同时达到次数和时间阈值
time_threshold = 0.66 * WAITING_TIME_THRESHOLD * CONSECUTIVE_NO_REPLY_THRESHOLD
if (
self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD
and self.total_waiting_time >= time_threshold
):
logger.info(
f"{self.log_prefix} 连续不回复达到阈值 ({self.total_no_reply_count}次) "
f"且累计等待时间达到 {self.total_waiting_time:.2f}秒 (阈值 {time_threshold}秒)"
f"调用回调请求状态转换"
)
# 调用回调。注意:这里不重置计数器和时间,依赖回调函数成功改变状态来隐式重置上下文。
await self.on_consecutive_no_reply_callback()
elif self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD:
# 仅次数达到阈值,但时间未达到
logger.debug(
f"{self.log_prefix} 连续不回复次数达到阈值 ({self.total_no_reply_count}次) "
f"但累计等待时间 {self.total_waiting_time:.2f}秒 未达到时间阈值 ({time_threshold}秒),暂不调用回调"
)
# else: 次数和时间都未达到阈值,不做处理
return True, "" # 不回复动作没有回复文本 return True, "" # 不回复动作没有回复文本

View File

@@ -0,0 +1,205 @@
import traceback
from typing import Tuple, Dict, List, Any, Optional
from src.chat.focus_chat.planners.actions.base_action import BaseAction
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
from src.common.logger_manager import get_logger
from src.chat.person_info.person_info import person_info_manager
from abc import abstractmethod
logger = get_logger("plugin_action")
class PluginAction(BaseAction):
"""插件动作基类
封装了主程序内部依赖提供简化的API接口给插件开发者
"""
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs):
"""初始化插件动作基类"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
# 存储内部服务和对象引用
self._services = {}
# 从kwargs提取必要的内部服务
if "observations" in kwargs:
self._services["observations"] = kwargs["observations"]
if "expressor" in kwargs:
self._services["expressor"] = kwargs["expressor"]
if "chat_stream" in kwargs:
self._services["chat_stream"] = kwargs["chat_stream"]
if "current_cycle" in kwargs:
self._services["current_cycle"] = kwargs["current_cycle"]
self.log_prefix = kwargs.get("log_prefix", "")
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
"""根据用户名获取用户ID"""
person_id = person_info_manager.get_person_id_by_person_name(person_name)
user_id = await person_info_manager.get_value(person_id, "user_id")
platform = await person_info_manager.get_value(person_id, "platform")
return platform, user_id
# 提供简化的API方法
async def send_message(self, text: str, target: Optional[str] = None) -> bool:
"""发送消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
expressor = self._services.get("expressor")
chat_stream = self._services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
chatting_observation: ChattingObservation = next(
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
response_set = [
("text", text),
]
# 调用内部方法发送消息
success = await expressor.send_response_messages(
anchor_message=anchor_message,
response_set=response_set,
)
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
traceback.print_exc()
return False
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
"""发送消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
expressor = self._services.get("expressor")
chat_stream = self._services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
chatting_observation: ChattingObservation = next(
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
anchor_message = await create_empty_anchor_message(
chat_stream.platform, chat_stream.group_info, chat_stream
)
else:
anchor_message.update_chat_stream(chat_stream)
# 调用内部方法发送消息
success, _ = await expressor.deal_reply(
cycle_timers=self.cycle_timers,
action_data=reply_data,
anchor_message=anchor_message,
reasoning=self.reasoning,
thinking_id=self.thinking_id,
)
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
return False
def get_chat_type(self) -> str:
"""获取当前聊天类型
Returns:
str: 聊天类型 ("group""private")
"""
chat_stream = self._services.get("chat_stream")
if chat_stream and hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
"""获取最近的消息
Args:
count: 要获取的消息数量
Returns:
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
"""
messages = []
observations = self._services.get("observations", [])
if observations and len(observations) > 0:
obs = observations[0]
if hasattr(obs, "get_talking_message"):
raw_messages = obs.get_talking_message()
# 转换为简化格式
for msg in raw_messages[-count:]:
simple_msg = {
"sender": msg.get("sender", "未知"),
"content": msg.get("content", ""),
"timestamp": msg.get("timestamp", 0),
}
messages.append(simple_msg)
return messages
@abstractmethod
async def process(self) -> Tuple[bool, str]:
"""插件处理逻辑,子类必须实现此方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def handle_action(self) -> Tuple[bool, str]:
"""实现BaseAction的抽象方法调用子类的process方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
return await self.process()

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
from typing import Tuple, List from typing import Tuple, List
@@ -25,19 +24,18 @@ class ReplyAction(BaseAction):
action_description: str = "表达想法,可以只包含文本、表情或两者都有" action_description: str = "表达想法,可以只包含文本、表情或两者都有"
action_parameters: dict[str:str] = { action_parameters: dict[str:str] = {
"text": "你想要表达的内容(可选)", "text": "你想要表达的内容(可选)",
"emojis": "描述当前使用表情包的场景(可选)", "emojis": "描述当前使用表情包的场景,一段话描述(可选)",
"target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)", "target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)",
} }
action_require: list[str] = [ action_require: list[str] = [
"有实质性内容需要表达", "有实质性内容需要表达",
"有人提到你,但你还没有回应他", "有人提到你,但你还没有回应他",
"在合适的时候添加表情(不要总是添加)", "在合适的时候添加表情(不要总是添加),表情描述要详细,描述当前场景,一段话描述",
"如果你要回复特定某人的某句话或者你想回复较早的消息请在target中指定那句话的原始文本", "如果你有明确的,要回复特定某人的某句话或者你想回复较早的消息请在target中指定那句话的原始文本",
"除非有明确的回复目标如果选择了target不用特别提到某个人的人名",
"一次只回复一个人,一次只回复一个话题,突出重点", "一次只回复一个人,一次只回复一个话题,突出重点",
"如果是自己发的消息想继续,需自然衔接", "如果是自己发的消息想继续,需自然衔接",
"避免重复或评价自己的发言,不要和自己聊天", "避免重复或评价自己的发言,不要和自己聊天",
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。", "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要有额外的符号,尽量简单简短",
] ]
default = True default = True
@@ -104,13 +102,15 @@ class ReplyAction(BaseAction):
"emojis": "微笑" # 表情关键词列表(可选) "emojis": "微笑" # 表情关键词列表(可选)
} }
""" """
# 重置连续不回复计数器
self.total_no_reply_count = 0
self.total_waiting_time = 0.0
# 从聊天观察获取锚定消息 # 从聊天观察获取锚定消息
observations: ChattingObservation = self.observations[0] chatting_observation: ChattingObservation = next(
anchor_message = observations.serch_message_by_text(reply_data["target"]) obs for obs in self.observations if isinstance(obs, ChattingObservation)
)
if reply_data.get("target"):
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
else:
anchor_message = None
# 如果没有找到锚点消息,创建一个占位符 # 如果没有找到锚点消息,创建一个占位符
if not anchor_message: if not anchor_message:

View File

@@ -12,8 +12,8 @@ from src.chat.focus_chat.info.structured_info import StructuredInfo
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
from src.chat.focus_chat.planners.action_factory import ActionManager from src.chat.focus_chat.planners.action_manager import ActionManager
from src.chat.focus_chat.planners.action_factory import ActionInfo from src.chat.focus_chat.planners.action_manager import ActionInfo
logger = get_logger("planner") logger = get_logger("planner")
@@ -22,8 +22,12 @@ install(extra_lines=3)
def init_prompt(): def init_prompt():
Prompt( Prompt(
"""你的名字是{bot_name},{prompt_personality}{chat_context_description}。需要基于以下信息决定如何参与对话: """{extra_info_block}
你需要基于以下信息决定如何参与对话
这些信息可能会有冲突请你整合这些信息并选择一个最合适的action
{chat_content_block} {chat_content_block}
{mind_info_block} {mind_info_block}
{cycle_info_block} {cycle_info_block}
@@ -55,8 +59,7 @@ action_name: {action_name}
参数: 参数:
{action_parameters} {action_parameters}
动作要求: 动作要求:
{action_require} {action_require}""",
""",
"action_prompt", "action_prompt",
) )
@@ -66,7 +69,7 @@ class ActionPlanner:
self.log_prefix = log_prefix self.log_prefix = log_prefix
# LLM规划器配置 # LLM规划器配置
self.planner_llm = LLMRequest( self.planner_llm = LLMRequest(
model=global_config.llm_plan, model=global_config.model.plan,
max_tokens=1000, max_tokens=1000,
request_type="action_planning", # 用于动作规划 request_type="action_planning", # 用于动作规划
) )
@@ -87,9 +90,10 @@ class ActionPlanner:
try: try:
# 获取观察信息 # 获取观察信息
extra_info: list[str] = []
for info in all_plan_info: for info in all_plan_info:
if isinstance(info, ObsInfo): if isinstance(info, ObsInfo):
logger.debug(f"{self.log_prefix} 观察信息: {info}") # logger.debug(f"{self.log_prefix} 观察信息: {info}")
observed_messages = info.get_talking_message() observed_messages = info.get_talking_message()
observed_messages_str = info.get_talking_message_str_truncate() observed_messages_str = info.get_talking_message_str_truncate()
chat_type = info.get_chat_type() chat_type = info.get_chat_type()
@@ -98,14 +102,17 @@ class ActionPlanner:
else: else:
is_group_chat = False is_group_chat = False
elif isinstance(info, MindInfo): elif isinstance(info, MindInfo):
logger.debug(f"{self.log_prefix} 思维信息: {info}") # logger.debug(f"{self.log_prefix} 思维信息: {info}")
current_mind = info.get_current_mind() current_mind = info.get_current_mind()
elif isinstance(info, CycleInfo): elif isinstance(info, CycleInfo):
logger.debug(f"{self.log_prefix} 循环信息: {info}") # logger.debug(f"{self.log_prefix} 循环信息: {info}")
cycle_info = info.get_observe_info() cycle_info = info.get_observe_info()
elif isinstance(info, StructuredInfo): elif isinstance(info, StructuredInfo):
logger.debug(f"{self.log_prefix} 结构化信息: {info}") # logger.debug(f"{self.log_prefix} 结构化信息: {info}")
_structured_info = info.get_data() _structured_info = info.get_data()
else:
logger.debug(f"{self.log_prefix} 其他信息: {info}")
extra_info.append(info.get_processed_info())
current_available_actions = self.action_manager.get_using_actions() current_available_actions = self.action_manager.get_using_actions()
@@ -118,6 +125,7 @@ class ActionPlanner:
# structured_info=structured_info, # <-- Pass SubMind info # structured_info=structured_info, # <-- Pass SubMind info
current_available_actions=current_available_actions, # <-- Pass determined actions current_available_actions=current_available_actions, # <-- Pass determined actions
cycle_info=cycle_info, # <-- Pass cycle info cycle_info=cycle_info, # <-- Pass cycle info
extra_info=extra_info,
) )
# --- 调用 LLM (普通文本生成) --- # --- 调用 LLM (普通文本生成) ---
@@ -144,15 +152,13 @@ class ActionPlanner:
extracted_action = parsed_json.get("action", "no_reply") extracted_action = parsed_json.get("action", "no_reply")
extracted_reasoning = parsed_json.get("reasoning", "LLM未提供理由") extracted_reasoning = parsed_json.get("reasoning", "LLM未提供理由")
# 新的reply格式 # 将所有其他属性添加到action_data
if extracted_action == "reply": action_data = {}
action_data = { for key, value in parsed_json.items():
"text": parsed_json.get("text", []), if key not in ["action", "reasoning"]:
"emojis": parsed_json.get("emojis", []), action_data[key] = value
"target": parsed_json.get("target", ""),
} # 对于reply动作不需要额外处理因为相关字段已经在上面的循环中添加到action_data
else:
action_data = {} # 其他动作可能不需要额外数据
if extracted_action not in current_available_actions: if extracted_action not in current_available_actions:
logger.warning( logger.warning(
@@ -207,6 +213,7 @@ class ActionPlanner:
current_mind: Optional[str], current_mind: Optional[str],
current_available_actions: Dict[str, ActionInfo], current_available_actions: Dict[str, ActionInfo],
cycle_info: Optional[str], cycle_info: Optional[str],
extra_info: list[str],
) -> str: ) -> str:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)""" """构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try: try:
@@ -261,15 +268,19 @@ class ActionPlanner:
action_options_block += using_action_prompt action_options_block += using_action_prompt
extra_info_block = "\n".join(extra_info)
extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策"
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format( prompt = planner_prompt_template.format(
bot_name=global_config.BOT_NICKNAME, bot_name=global_config.bot.nickname,
prompt_personality=personality_block, prompt_personality=personality_block,
chat_context_description=chat_context_description, chat_context_description=chat_context_description,
chat_content_block=chat_content_block, chat_content_block=chat_content_block,
mind_info_block=mind_info_block, mind_info_block=mind_info_block,
cycle_info_block=cycle_info, cycle_info_block=cycle_info,
action_options_text=action_options_block, action_options_text=action_options_block,
extra_info_block=extra_info_block,
) )
return prompt return prompt

View File

@@ -0,0 +1,112 @@
from typing import Dict, Any, List, Optional, Set, Tuple
import time
import random
import string
class MemoryItem:
"""记忆项类,用于存储单个记忆的所有相关信息"""
def __init__(self, data: Any, from_source: str = "", tags: Optional[List[str]] = None):
"""
初始化记忆项
Args:
data: 记忆数据
from_source: 数据来源
tags: 数据标签列表
"""
# 生成可读ID时间戳_随机字符串
timestamp = int(time.time())
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
self.id = f"{timestamp}_{random_str}"
self.data = data
self.data_type = type(data)
self.from_source = from_source
self.tags = set(tags) if tags else set()
self.timestamp = time.time()
# 修改summary的结构说明用于存储可能的总结信息
# summary结构{
# "brief": "记忆内容主题",
# "detailed": "记忆内容概括",
# "keypoints": ["关键概念1", "关键概念2"],
# "events": ["事件1", "事件2"]
# }
self.summary = None
# 记忆精简次数
self.compress_count = 0
# 记忆提取次数
self.retrieval_count = 0
# 记忆强度 (初始为10)
self.memory_strength = 10.0
# 记忆操作历史记录
# 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
def add_tag(self, tag: str) -> None:
"""添加标签"""
self.tags.add(tag)
def remove_tag(self, tag: str) -> None:
"""移除标签"""
if tag in self.tags:
self.tags.remove(tag)
def has_tag(self, tag: str) -> bool:
"""检查是否有特定标签"""
return tag in self.tags
def has_all_tags(self, tags: List[str]) -> bool:
"""检查是否有所有指定的标签"""
return all(tag in self.tags for tag in tags)
def matches_source(self, source: str) -> bool:
"""检查来源是否匹配"""
return self.from_source == source
def set_summary(self, summary: Dict[str, Any]) -> None:
"""设置总结信息"""
self.summary = summary
def increase_strength(self, amount: float) -> None:
"""增加记忆强度"""
self.memory_strength = min(10.0, self.memory_strength + amount)
# 记录操作历史
self.record_operation("strengthen")
def decrease_strength(self, amount: float) -> None:
"""减少记忆强度"""
self.memory_strength = max(0.1, self.memory_strength - amount)
# 记录操作历史
self.record_operation("weaken")
def increase_compress_count(self) -> None:
"""增加精简次数并减弱记忆强度"""
self.compress_count += 1
# 记录操作历史
self.record_operation("compress")
def record_retrieval(self) -> None:
"""记录记忆被提取的情况"""
self.retrieval_count += 1
# 提取后强度翻倍
self.memory_strength = min(10.0, self.memory_strength * 2)
# 记录操作历史
self.record_operation("retrieval")
def record_operation(self, operation_type: str) -> None:
"""记录操作历史"""
current_time = time.time()
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
def to_tuple(self) -> Tuple[Any, str, Set[str], float, str]:
"""转换为元组格式(为了兼容性)"""
return (self.data, self.from_source, self.tags, self.timestamp, self.id)
def is_memory_valid(self) -> bool:
"""检查记忆是否有效强度是否大于等于1"""
return self.memory_strength >= 1.0

View File

@@ -0,0 +1,781 @@
from typing import Dict, Any, Type, TypeVar, List, Optional
import traceback
from json_repair import repair_json
from rich.traceback import install
from src.common.logger_manager import get_logger
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
import json # 添加json模块导入
install(extra_lines=3)
logger = get_logger("working_memory")
T = TypeVar("T")
class MemoryManager:
def __init__(self, chat_id: str):
"""
初始化工作记忆
Args:
chat_id: 关联的聊天ID用于标识该工作记忆属于哪个聊天
"""
# 关联的聊天ID
self._chat_id = chat_id
# 主存储: 数据类型 -> 记忆项列表
self._memory: Dict[Type, List[MemoryItem]] = {}
# ID到记忆项的映射
self._id_map: Dict[str, MemoryItem] = {}
self.llm_summarizer = LLMRequest(
model=global_config.model.summary, temperature=0.3, max_tokens=512, request_type="memory_summarization"
)
@property
def chat_id(self) -> str:
"""获取关联的聊天ID"""
return self._chat_id
@chat_id.setter
def chat_id(self, value: str):
"""设置关联的聊天ID"""
self._chat_id = value
def push_item(self, memory_item: MemoryItem) -> str:
"""
推送一个已创建的记忆项到工作记忆中
Args:
memory_item: 要存储的记忆项
Returns:
记忆项的ID
"""
data_type = memory_item.data_type
# 确保存在该类型的存储列表
if data_type not in self._memory:
self._memory[data_type] = []
# 添加到内存和ID映射
self._memory[data_type].append(memory_item)
self._id_map[memory_item.id] = memory_item
return memory_item.id
async def push_with_summary(self, data: T, from_source: str = "", tags: Optional[List[str]] = None) -> MemoryItem:
"""
推送一段有类型的信息到工作记忆中,并自动生成总结
Args:
data: 要存储的数据
from_source: 数据来源
tags: 数据标签列表
Returns:
包含原始数据和总结信息的字典
"""
# 如果数据是字符串类型,则先进行总结
if isinstance(data, str):
# 先生成总结
summary = await self.summarize_memory_item(data)
# 准备标签
memory_tags = list(tags) if tags else []
# 创建记忆项
memory_item = MemoryItem(data, from_source, memory_tags)
# 将总结信息保存到记忆项中
memory_item.set_summary(summary)
# 推送记忆项
self.push_item(memory_item)
return memory_item
else:
# 非字符串类型,直接创建并推送记忆项
memory_item = MemoryItem(data, from_source, tags)
self.push_item(memory_item)
return memory_item
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
"""
通过ID获取记忆项
Args:
memory_id: 记忆项ID
Returns:
找到的记忆项如果不存在则返回None
"""
memory_item = self._id_map.get(memory_id)
if memory_item:
# 检查记忆强度如果小于1则删除
if not memory_item.is_memory_valid():
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
self.delete(memory_id)
return None
return memory_item
def get_all_items(self) -> List[MemoryItem]:
"""获取所有记忆项"""
return list(self._id_map.values())
def find_items(
self,
data_type: Optional[Type] = None,
source: Optional[str] = None,
tags: Optional[List[str]] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
memory_id: Optional[str] = None,
limit: Optional[int] = None,
newest_first: bool = False,
min_strength: float = 0.0,
) -> List[MemoryItem]:
"""
按条件查找记忆项
Args:
data_type: 要查找的数据类型
source: 数据来源
tags: 必须包含的标签列表
start_time: 开始时间戳
end_time: 结束时间戳
memory_id: 特定记忆项ID
limit: 返回结果的最大数量
newest_first: 是否按最新优先排序
min_strength: 最小记忆强度
Returns:
符合条件的记忆项列表
"""
# 如果提供了特定ID直接查找
if memory_id:
item = self.get_by_id(memory_id)
return [item] if item else []
results = []
# 确定要搜索的类型列表
types_to_search = [data_type] if data_type else list(self._memory.keys())
# 对每个类型进行搜索
for typ in types_to_search:
if typ not in self._memory:
continue
# 获取该类型的所有项目
items = self._memory[typ]
# 如果需要最新优先,则反转遍历顺序
if newest_first:
items_to_check = list(reversed(items))
else:
items_to_check = items
# 遍历项目
for item in items_to_check:
# 检查来源是否匹配
if source is not None and not item.matches_source(source):
continue
# 检查标签是否匹配
if tags is not None and not item.has_all_tags(tags):
continue
# 检查时间范围
if start_time is not None and item.timestamp < start_time:
continue
if end_time is not None and item.timestamp > end_time:
continue
# 检查记忆强度
if min_strength > 0 and item.memory_strength < min_strength:
continue
# 所有条件都满足,添加到结果中
results.append(item)
# 如果达到限制数量,提前返回
if limit is not None and len(results) >= limit:
return results
return results
async def summarize_memory_item(self, content: str) -> Dict[str, Any]:
"""
使用LLM总结记忆项
Args:
content: 需要总结的内容
Returns:
包含总结、概括、关键概念和事件的字典
"""
prompt = f"""请对以下内容进行总结,总结成记忆,输出四部分:
1. 记忆内容主题精简20字以内让用户可以一眼看出记忆内容是什么
2. 记忆内容概括200字以内让用户可以了解记忆内容的大致内容
3. 关键概念和知识keypoints多条提取关键的概念、知识点和关键词要包含对概念的解释
4. 事件描述events多条描述谁人物在什么时候时间做了什么事件
内容:
{content}
请按以下JSON格式输出
```json
{{
"brief": "记忆内容主题20字以内",
"detailed": "记忆内容概括200字以内",
"keypoints": [
"概念1解释",
"概念2解释",
...
],
"events": [
"事件1谁在什么时候做了什么",
"事件2谁在什么时候做了什么",
...
]
}}
```
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
default_summary = {
"brief": "主题未知的记忆",
"detailed": "大致内容未知的记忆",
"keypoints": ["未知的概念"],
"events": ["未知的事件"],
}
try:
# 调用LLM生成总结
response, _ = await self.llm_summarizer.generate_response_async(prompt)
# 使用repair_json解析响应
try:
# 使用repair_json修复JSON格式
fixed_json_string = repair_json(response)
# 如果repair_json返回的是字符串需要解析为Python对象
if isinstance(fixed_json_string, str):
try:
json_result = json.loads(fixed_json_string)
except json.JSONDecodeError as decode_error:
logger.error(f"JSON解析错误: {str(decode_error)}")
return default_summary
else:
# 如果repair_json直接返回了字典对象直接使用
json_result = fixed_json_string
# 进行额外的类型检查
if not isinstance(json_result, dict):
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
return default_summary
# 确保所有必要字段都存在且类型正确
if "brief" not in json_result or not isinstance(json_result["brief"], str):
json_result["brief"] = "主题未知的记忆"
if "detailed" not in json_result or not isinstance(json_result["detailed"], str):
json_result["detailed"] = "大致内容未知的记忆"
# 处理关键概念
if "keypoints" not in json_result or not isinstance(json_result["keypoints"], list):
json_result["keypoints"] = ["未知的概念"]
else:
# 确保keypoints中的每个项目都是字符串
json_result["keypoints"] = [str(point) for point in json_result["keypoints"] if point is not None]
if not json_result["keypoints"]:
json_result["keypoints"] = ["未知的概念"]
# 处理事件
if "events" not in json_result or not isinstance(json_result["events"], list):
json_result["events"] = ["未知的事件"]
else:
# 确保events中的每个项目都是字符串
json_result["events"] = [str(event) for event in json_result["events"] if event is not None]
if not json_result["events"]:
json_result["events"] = ["未知的事件"]
# 兼容旧版将keypoints和events合并到key_points中
json_result["key_points"] = json_result["keypoints"] + json_result["events"]
return json_result
except Exception as json_error:
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
# 返回默认结构
return default_summary
except Exception as e:
# 出错时返回简单的结构
logger.error(f"生成总结时出错: {str(e)}")
return default_summary
async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
"""
对记忆进行精简操作,根据要求修改要点、总结和概括
Args:
memory_id: 记忆ID
requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
Returns:
修改后的记忆总结字典
"""
# 获取指定ID的记忆项
logger.info(f"精简记忆: {memory_id}")
memory_item = self.get_by_id(memory_id)
if not memory_item:
raise ValueError(f"未找到ID为{memory_id}的记忆项")
# 增加精简次数
memory_item.increase_compress_count()
summary = memory_item.summary
# 使用LLM根据要求对总结、概括和要点进行精简修改
prompt = f"""
请根据以下要求,对记忆内容的主题、概括、关键概念和事件进行精简,模拟记忆的遗忘过程:
要求:{requirements}
你可以随机对关键概念和事件进行压缩,模糊或者丢弃,修改后,同样修改主题和概括
目前主题:{summary["brief"]}
目前概括:{summary["detailed"]}
目前关键概念:
{chr(10).join([f"- {point}" for point in summary.get("keypoints", [])])}
目前事件:
{chr(10).join([f"- {point}" for point in summary.get("events", [])])}
请生成修改后的主题、概括、关键概念和事件,遵循以下格式:
```json
{{
"brief": "修改后的主题20字以内",
"detailed": "修改后的概括200字以内",
"keypoints": [
"修改后的概念1解释",
"修改后的概念2解释"
],
"events": [
"修改后的事件1谁在什么时候做了什么",
"修改后的事件2谁在什么时候做了什么"
]
}}
```
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
# 检查summary中是否有旧版结构转换为新版结构
if "keypoints" not in summary and "events" not in summary and "key_points" in summary:
# 尝试区分key_points中的keypoints和events
# 简单地将前半部分视为keypoints后半部分视为events
key_points = summary.get("key_points", [])
halfway = len(key_points) // 2
summary["keypoints"] = key_points[:halfway] or ["未知的概念"]
summary["events"] = key_points[halfway:] or ["未知的事件"]
# 定义默认的精简结果
default_refined = {
"brief": summary["brief"],
"detailed": summary["detailed"],
"keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念
"events": summary.get("events", ["未知的事件"])[:1], # 默认只保留第一个事件
}
try:
# 调用LLM修改总结、概括和要点
response, _ = await self.llm_summarizer.generate_response_async(prompt)
logger.info(f"精简记忆响应: {response}")
# 使用repair_json处理响应
try:
# 修复JSON格式
fixed_json_string = repair_json(response)
# 将修复后的字符串解析为Python对象
if isinstance(fixed_json_string, str):
try:
refined_data = json.loads(fixed_json_string)
except json.JSONDecodeError as decode_error:
logger.error(f"JSON解析错误: {str(decode_error)}")
refined_data = default_refined
else:
# 如果repair_json直接返回了字典对象直接使用
refined_data = fixed_json_string
# 确保是字典类型
if not isinstance(refined_data, dict):
logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
refined_data = default_refined
# 更新总结、概括
summary["brief"] = refined_data.get("brief", "主题未知的记忆")
summary["detailed"] = refined_data.get("detailed", "大致内容未知的记忆")
# 更新关键概念
keypoints = refined_data.get("keypoints", [])
if isinstance(keypoints, list) and keypoints:
# 确保所有关键概念都是字符串
summary["keypoints"] = [str(point) for point in keypoints if point is not None]
else:
# 如果keypoints不是列表或为空使用默认值
summary["keypoints"] = ["主要概念已遗忘"]
# 更新事件
events = refined_data.get("events", [])
if isinstance(events, list) and events:
# 确保所有事件都是字符串
summary["events"] = [str(event) for event in events if event is not None]
else:
# 如果events不是列表或为空使用默认值
summary["events"] = ["事件细节已遗忘"]
# 兼容旧版维护key_points
summary["key_points"] = summary["keypoints"] + summary["events"]
except Exception as e:
logger.error(f"精简记忆出错: {str(e)}")
traceback.print_exc()
# 出错时使用简化的默认精简
summary["brief"] = summary["brief"] + " (已简化)"
summary["keypoints"] = summary.get("keypoints", ["未知的概念"])[:1]
summary["events"] = summary.get("events", ["未知的事件"])[:1]
summary["key_points"] = summary["keypoints"] + summary["events"]
except Exception as e:
logger.error(f"精简记忆调用LLM出错: {str(e)}")
traceback.print_exc()
# 更新原记忆项的总结
memory_item.set_summary(summary)
return memory_item
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
"""
使单个记忆衰减
Args:
memory_id: 记忆ID
decay_factor: 衰减因子(0-1之间)
Returns:
是否成功衰减
"""
memory_item = self.get_by_id(memory_id)
if not memory_item:
return False
# 计算衰减量(当前强度 * (1-衰减因子)
old_strength = memory_item.memory_strength
decay_amount = old_strength * (1 - decay_factor)
# 更新强度
memory_item.memory_strength = decay_amount
return True
def delete(self, memory_id: str) -> bool:
"""
删除指定ID的记忆项
Args:
memory_id: 要删除的记忆项ID
Returns:
是否成功删除
"""
if memory_id not in self._id_map:
return False
# 获取要删除的项
item = self._id_map[memory_id]
# 从内存中删除
data_type = item.data_type
if data_type in self._memory:
self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id]
# 从ID映射中删除
del self._id_map[memory_id]
return True
def clear(self, data_type: Optional[Type] = None) -> None:
"""
清除记忆中的数据
Args:
data_type: 要清除的数据类型如果为None则清除所有数据
"""
if data_type is None:
# 清除所有数据
self._memory.clear()
self._id_map.clear()
elif data_type in self._memory:
# 清除指定类型的数据
for item in self._memory[data_type]:
if item.id in self._id_map:
del self._id_map[item.id]
del self._memory[data_type]
async def merge_memories(
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
) -> MemoryItem:
"""
合并两个记忆项
Args:
memory_id1: 第一个记忆项ID
memory_id2: 第二个记忆项ID
reason: 合并原因
delete_originals: 是否删除原始记忆默认为True
Returns:
包含合并后的记忆信息的字典
"""
# 获取两个记忆项
memory_item1 = self.get_by_id(memory_id1)
memory_item2 = self.get_by_id(memory_id2)
if not memory_item1 or not memory_item2:
raise ValueError("无法找到指定的记忆项")
content1 = memory_item1.data
content2 = memory_item2.data
# 获取记忆的摘要信息(如果有)
summary1 = memory_item1.summary
summary2 = memory_item2.summary
# 构建合并提示
prompt = f"""
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
合并原因:{reason}
"""
# 如果有摘要信息,添加到提示中
if summary1:
prompt += f"记忆1主题{summary1['brief']}\n"
prompt += f"记忆1概括{summary1['detailed']}\n"
if "keypoints" in summary1:
prompt += "记忆1关键概念\n" + "\n".join([f"- {point}" for point in summary1["keypoints"]]) + "\n\n"
if "events" in summary1:
prompt += "记忆1事件\n" + "\n".join([f"- {point}" for point in summary1["events"]]) + "\n\n"
elif "key_points" in summary1:
prompt += "记忆1要点\n" + "\n".join([f"- {point}" for point in summary1["key_points"]]) + "\n\n"
if summary2:
prompt += f"记忆2主题{summary2['brief']}\n"
prompt += f"记忆2概括{summary2['detailed']}\n"
if "keypoints" in summary2:
prompt += "记忆2关键概念\n" + "\n".join([f"- {point}" for point in summary2["keypoints"]]) + "\n\n"
if "events" in summary2:
prompt += "记忆2事件\n" + "\n".join([f"- {point}" for point in summary2["events"]]) + "\n\n"
elif "key_points" in summary2:
prompt += "记忆2要点\n" + "\n".join([f"- {point}" for point in summary2["key_points"]]) + "\n\n"
# 添加记忆原始内容
prompt += f"""
记忆1原始内容
{content1}
记忆2原始内容
{content2}
请按以下JSON格式输出合并结果
```json
{{
"content": "合并后的记忆内容文本(尽可能保留原信息,但去除重复)",
"brief": "合并后的主题20字以内",
"detailed": "合并后的概括200字以内",
"keypoints": [
"合并后的概念1解释",
"合并后的概念2解释",
"合并后的概念3解释"
],
"events": [
"合并后的事件1谁在什么时候做了什么",
"合并后的事件2谁在什么时候做了什么"
]
}}
```
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
# 默认合并结果
default_merged = {
"content": f"{content1}\n\n{content2}",
"brief": f"合并:{summary1['brief']} + {summary2['brief']}",
"detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}",
"keypoints": [],
"events": [],
}
# 合并旧版key_points
if "key_points" in summary1:
default_merged["keypoints"].extend(summary1.get("keypoints", []))
default_merged["events"].extend(summary1.get("events", []))
# 如果没有新的结构,尝试从旧结构分离
if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary1:
key_points = summary1["key_points"]
halfway = len(key_points) // 2
default_merged["keypoints"].extend(key_points[:halfway])
default_merged["events"].extend(key_points[halfway:])
if "key_points" in summary2:
default_merged["keypoints"].extend(summary2.get("keypoints", []))
default_merged["events"].extend(summary2.get("events", []))
# 如果没有新的结构,尝试从旧结构分离
if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary2:
key_points = summary2["key_points"]
halfway = len(key_points) // 2
default_merged["keypoints"].extend(key_points[:halfway])
default_merged["events"].extend(key_points[halfway:])
# 确保列表不为空
if not default_merged["keypoints"]:
default_merged["keypoints"] = ["合并的关键概念"]
if not default_merged["events"]:
default_merged["events"] = ["合并的事件"]
# 添加key_points兼容
default_merged["key_points"] = default_merged["keypoints"] + default_merged["events"]
try:
# 调用LLM合并记忆
response, _ = await self.llm_summarizer.generate_response_async(prompt)
# 处理LLM返回的合并结果
try:
# 修复JSON格式
fixed_json_string = repair_json(response)
# 将修复后的字符串解析为Python对象
if isinstance(fixed_json_string, str):
try:
merged_data = json.loads(fixed_json_string)
except json.JSONDecodeError as decode_error:
logger.error(f"JSON解析错误: {str(decode_error)}")
merged_data = default_merged
else:
# 如果repair_json直接返回了字典对象直接使用
merged_data = fixed_json_string
# 确保是字典类型
if not isinstance(merged_data, dict):
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
merged_data = default_merged
# 确保所有必要字段都存在且类型正确
if "content" not in merged_data or not isinstance(merged_data["content"], str):
merged_data["content"] = default_merged["content"]
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
merged_data["brief"] = default_merged["brief"]
if "detailed" not in merged_data or not isinstance(merged_data["detailed"], str):
merged_data["detailed"] = default_merged["detailed"]
# 处理关键概念
if "keypoints" not in merged_data or not isinstance(merged_data["keypoints"], list):
merged_data["keypoints"] = default_merged["keypoints"]
else:
# 确保keypoints中的每个项目都是字符串
merged_data["keypoints"] = [str(point) for point in merged_data["keypoints"] if point is not None]
if not merged_data["keypoints"]:
merged_data["keypoints"] = ["合并的关键概念"]
# 处理事件
if "events" not in merged_data or not isinstance(merged_data["events"], list):
merged_data["events"] = default_merged["events"]
else:
# 确保events中的每个项目都是字符串
merged_data["events"] = [str(event) for event in merged_data["events"] if event is not None]
if not merged_data["events"]:
merged_data["events"] = ["合并的事件"]
# 添加key_points兼容
merged_data["key_points"] = merged_data["keypoints"] + merged_data["events"]
except Exception as e:
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
traceback.print_exc()
merged_data = default_merged
except Exception as e:
logger.error(f"合并记忆调用LLM出错: {str(e)}")
traceback.print_exc()
merged_data = default_merged
# 创建新的记忆项
# 合并记忆项的标签
merged_tags = memory_item1.tags.union(memory_item2.tags)
# 取两个记忆项中更强的来源
merged_source = (
memory_item1.from_source
if memory_item1.memory_strength >= memory_item2.memory_strength
else memory_item2.from_source
)
# 创建新的记忆项
merged_memory = MemoryItem(data=merged_data["content"], from_source=merged_source, tags=list(merged_tags))
# 设置合并后的摘要
summary = {
"brief": merged_data["brief"],
"detailed": merged_data["detailed"],
"keypoints": merged_data["keypoints"],
"events": merged_data["events"],
"key_points": merged_data["key_points"],
}
merged_memory.set_summary(summary)
# 记忆强度取两者最大值
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
# 添加到存储中
self.push_item(merged_memory)
# 如果需要,删除原始记忆
if delete_originals:
self.delete(memory_id1)
self.delete(memory_id2)
return merged_memory
def delete_earliest_memory(self) -> bool:
"""
删除最早的记忆项
Returns:
是否成功删除
"""
# 获取所有记忆项
all_memories = self.get_all_items()
if not all_memories:
return False
# 按时间戳排序,找到最早的记忆项
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
# 删除最早的记忆项
return self.delete(earliest_memory.id)

View File

@@ -0,0 +1,192 @@
from typing import List, Any, Optional
import asyncio
import random
from src.common.logger_manager import get_logger
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
logger = get_logger(__name__)
# 问题是我不知道这个manager是不是需要和其他manager统一管理因为这个manager是从属于每一个聊天流都有自己的定时任务
class WorkingMemory:
"""
工作记忆,负责协调和运作记忆
从属于特定的流用chat_id来标识
"""
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
"""
初始化工作记忆管理器
Args:
max_memories_per_chat: 每个聊天的最大记忆数量
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
"""
self.memory_manager = MemoryManager(chat_id)
# 记忆容量上限
self.max_memories_per_chat = max_memories_per_chat
# 自动衰减间隔
self.auto_decay_interval = auto_decay_interval
# 衰减任务
self.decay_task = None
# 启动自动衰减任务
self._start_auto_decay()
def _start_auto_decay(self):
"""启动自动衰减任务"""
if self.decay_task is None:
self.decay_task = asyncio.create_task(self._auto_decay_loop())
async def _auto_decay_loop(self):
"""自动衰减循环"""
while True:
await asyncio.sleep(self.auto_decay_interval)
try:
await self.decay_all_memories()
except Exception as e:
print(f"自动衰减记忆时出错: {str(e)}")
async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
"""
添加一段记忆到指定聊天
Args:
content: 记忆内容
from_source: 数据来源
tags: 数据标签列表
Returns:
包含记忆信息的字典
"""
memory = await self.memory_manager.push_with_summary(content, from_source, tags)
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
self.remove_earliest_memory()
return memory
def remove_earliest_memory(self):
"""
删除最早的记忆
"""
return self.memory_manager.delete_earliest_memory()
async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
"""
检索记忆
Args:
chat_id: 聊天ID
memory_id: 记忆ID
Returns:
检索到的记忆项如果不存在则返回None
"""
memory_item = self.memory_manager.get_by_id(memory_id)
if memory_item:
memory_item.retrieval_count += 1
memory_item.increase_strength(5)
return memory_item
return None
async def decay_all_memories(self, decay_factor: float = 0.5):
"""
对所有聊天的所有记忆进行衰减
衰减对记忆进行refine压缩强度会变为原先的0.5
Args:
decay_factor: 衰减因子(0-1之间)
"""
logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
all_memories = self.memory_manager.get_all_items()
for memory_item in all_memories:
# 如果压缩完小于1会被删除
memory_id = memory_item.id
self.memory_manager.decay_memory(memory_id, decay_factor)
if memory_item.memory_strength < 1:
self.memory_manager.delete(memory_id)
continue
# 计算衰减量
if memory_item.memory_strength < 5:
await self.memory_manager.refine_memory(
memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
)
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
"""合并记忆
Args:
memory_str: 记忆内容
"""
return await self.memory_manager.merge_memories(
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
)
# 暂时没用,先留着
async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2):
"""
模拟记忆模糊过程,随机选择一部分记忆进行精简
Args:
chat_id: 聊天ID
blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简
"""
memory = self.get_memory(chat_id)
# 获取所有字符串类型且有总结的记忆
all_summarized_memories = []
for type_items in memory._memory.values():
for item in type_items:
if isinstance(item.data, str) and hasattr(item, "summary") and item.summary:
all_summarized_memories.append(item)
if not all_summarized_memories:
return
# 计算要模糊的记忆数量
blur_count = max(1, int(len(all_summarized_memories) * blur_rate))
# 随机选择要模糊的记忆
memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories)))
# 对选中的记忆进行精简
for memory_item in memories_to_blur:
try:
# 根据记忆强度决定模糊程度
if memory_item.memory_strength > 7:
requirement = "保留所有重要信息,仅略微精简"
elif memory_item.memory_strength > 4:
requirement = "保留核心要点,适度精简细节"
else:
requirement = "只保留最关键的1-2个要点大幅精简内容"
# 进行精简
await memory.refine_memory(memory_item.id, requirement)
print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}")
except Exception as e:
print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}")
async def shutdown(self) -> None:
"""关闭管理器,停止所有任务"""
if self.decay_task and not self.decay_task.done():
self.decay_task.cancel()
try:
await self.decay_task
except asyncio.CancelledError:
pass
def get_all_memories(self) -> List[MemoryItem]:
"""
获取所有记忆项目
Returns:
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
"""
return self.memory_manager.get_all_items()

View File

@@ -14,6 +14,7 @@ from typing import Optional
import difflib import difflib
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入 from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.observation import Observation
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt from src.chat.utils.prompt_builder import Prompt
@@ -43,6 +44,7 @@ class ChattingObservation(Observation):
def __init__(self, chat_id): def __init__(self, chat_id):
super().__init__(chat_id) super().__init__(chat_id)
self.chat_id = chat_id self.chat_id = chat_id
self.platform = "qq"
# --- Initialize attributes (defaults) --- # --- Initialize attributes (defaults) ---
self.is_group_chat: bool = False self.is_group_chat: bool = False
@@ -65,7 +67,7 @@ class ChattingObservation(Observation):
self.oldest_messages_str = "" self.oldest_messages_str = ""
self.compressor_prompt = "" self.compressor_prompt = ""
# TODO: API-Adapter修改标记 # TODO: API-Adapter修改标记
self.llm_summary = LLMRequest( self.model_summary = LLMRequest(
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
) )
@@ -106,7 +108,7 @@ class ChattingObservation(Observation):
mid_memory_str += f"{mid_memory['theme']}\n" mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
def serch_message_by_text(self, text: str) -> Optional[MessageRecv]: def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
""" """
根据回复的纯文本 根据回复的纯文本
1. 在talking_message中查找最新的最匹配的消息 1. 在talking_message中查找最新的最匹配的消息
@@ -119,12 +121,12 @@ class ChattingObservation(Observation):
for message in reverse_talking_message: for message in reverse_talking_message:
if message["processed_plain_text"] == text: if message["processed_plain_text"] == text:
find_msg = message find_msg = message
logger.debug(f"找到的锚定消息find_msg: {find_msg}") # logger.debug(f"找到的锚定消息find_msg: {find_msg}")
break break
else: else:
similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio() similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio()
msg_list.append({"message": message, "similarity": similarity}) msg_list.append({"message": message, "similarity": similarity})
logger.debug(f"对锚定消息检查message: {message['processed_plain_text']},similarity: {similarity}") # logger.debug(f"对锚定消息检查message: {message['processed_plain_text']},similarity: {similarity}")
if not find_msg: if not find_msg:
if msg_list: if msg_list:
msg_list.sort(key=lambda x: x["similarity"], reverse=True) msg_list.sort(key=lambda x: x["similarity"], reverse=True)
@@ -151,7 +153,7 @@ class ChattingObservation(Observation):
} }
message_info = { message_info = {
"platform": find_msg.get("platform"), "platform": self.platform,
"message_id": find_msg.get("message_id"), "message_id": find_msg.get("message_id"),
"time": find_msg.get("time"), "time": find_msg.get("time"),
"group_info": group_info, "group_info": group_info,

View File

@@ -3,6 +3,7 @@
from datetime import datetime from datetime import datetime
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.chat.focus_chat.planners.action_manager import ActionManager
from typing import List from typing import List
# Import the new utility function # Import the new utility function
@@ -16,15 +17,17 @@ class HFCloopObservation:
self.observe_id = observe_id self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop: List[CycleDetail] = [] self.history_loop: List[CycleDetail] = []
self.action_manager = ActionManager()
def get_observe_info(self): def get_observe_info(self):
return self.observe_info return self.observe_info
def add_loop_info(self, loop_info: CycleDetail): def add_loop_info(self, loop_info: CycleDetail):
# logger.debug(f"添加循环信息111111111111111111111111111111111111: {loop_info}")
# print(f"添加循环信息111111111111111111111111111111111111: {loop_info}")
self.history_loop.append(loop_info) self.history_loop.append(loop_info)
def set_action_manager(self, action_manager: ActionManager):
self.action_manager = action_manager
async def observe(self): async def observe(self):
recent_active_cycles: List[CycleDetail] = [] recent_active_cycles: List[CycleDetail] = []
for cycle in reversed(self.history_loop): for cycle in reversed(self.history_loop):
@@ -62,7 +65,6 @@ class HFCloopObservation:
if cycle_info_block: if cycle_info_block:
cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n" cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n"
else: else:
# 如果最近的活动循环不是文本回复,或者没有活动循环
cycle_info_block = "\n" cycle_info_block = "\n"
# 获取history_loop中最新添加的 # 获取history_loop中最新添加的
@@ -72,8 +74,16 @@ class HFCloopObservation:
end_time = last_loop.end_time end_time = last_loop.end_time
if start_time is not None and end_time is not None: if start_time is not None and end_time is not None:
time_diff = int(end_time - start_time) time_diff = int(end_time - start_time)
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}分钟\n" if time_diff > 60:
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff / 60}分钟\n"
else: else:
cycle_info_block += "\n无法获取上一次阅读消息的时间\n" cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}\n"
else:
cycle_info_block += "\n你还没看过消息\n"
using_actions = self.action_manager.get_using_actions()
for action_name, action_info in using_actions.items():
action_description = action_info["description"]
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
self.observe_info = cycle_info_block self.observe_info = cycle_info_block

View File

@@ -1,55 +0,0 @@
from src.chat.heart_flow.observation.observation import Observation
from datetime import datetime
from src.common.logger_manager import get_logger
import traceback
# Import the new utility function
from src.chat.memory_system.Hippocampus import HippocampusManager
import jieba
from typing import List
logger = get_logger("memory")
class MemoryObservation(Observation):
def __init__(self, observe_id):
super().__init__(observe_id)
self.observe_info: str = ""
self.context: str = ""
self.running_memory: List[dict] = []
def get_observe_info(self):
for memory in self.running_memory:
self.observe_info += f"{memory['topic']}:{memory['content']}\n"
return self.observe_info
async def observe(self):
# ---------- 2. 获取记忆 ----------
try:
# 从聊天内容中提取关键词
chat_words = set(jieba.cut(self.context))
# 过滤掉停用词和单字词
keywords = [word for word in chat_words if len(word) > 1]
# 去重并限制数量
keywords = list(set(keywords))[:5]
logger.debug(f"取的关键词: {keywords}")
# 调用记忆系统获取相关记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
)
logger.debug(f"获取到的记忆: {related_memory}")
if related_memory:
for topic, memory in related_memory:
# 将记忆添加到 running_memory
self.running_memory.append(
{"topic": topic, "content": memory, "timestamp": datetime.now().isoformat()}
)
logger.debug(f"添加新记忆: {topic} - {memory}")
except Exception as e:
logger.error(f"观察 记忆时出错: {e}")
logger.error(traceback.format_exc())

View File

@@ -0,0 +1,32 @@
from datetime import datetime
from src.common.logger_manager import get_logger
# Import the new utility function
logger = get_logger("observation")
# 所有观察的基类
class StructureObservation:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop = []
self.structured_info = []
def get_observe_info(self):
return self.structured_info
def add_structured_info(self, structured_info: dict):
self.structured_info.append(structured_info)
async def observe(self):
observed_structured_infos = []
for structured_info in self.structured_info:
if structured_info.get("ttl") > 0:
structured_info["ttl"] -= 1
observed_structured_infos.append(structured_info)
logger.debug(f"观察到结构化信息仍旧在: {structured_info}")
self.structured_info = observed_structured_infos

View File

@@ -2,33 +2,33 @@
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体 # 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime from datetime import datetime
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
from typing import List
# Import the new utility function # Import the new utility function
logger = get_logger("observation") logger = get_logger("observation")
# 所有观察的基类 # 所有观察的基类
class WorkingObservation: class WorkingMemoryObservation:
def __init__(self, observe_id): def __init__(self, observe_id, working_memory: WorkingMemory):
self.observe_info = "" self.observe_info = ""
self.observe_id = observe_id self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 self.last_observe_time = datetime.now().timestamp()
self.history_loop = []
self.structured_info = [] self.working_memory = working_memory
self.retrieved_working_memory = []
def get_observe_info(self): def get_observe_info(self):
return self.structured_info return self.working_memory
def add_structured_info(self, structured_info: dict): def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]):
self.structured_info.append(structured_info) self.retrieved_working_memory.append(retrieved_working_memory)
def get_retrieved_working_memory(self):
return self.retrieved_working_memory
async def observe(self): async def observe(self):
observed_structured_infos = [] pass
for structured_info in self.structured_info:
if structured_info.get("ttl") > 0:
structured_info["ttl"] -= 1
observed_structured_infos.append(structured_info)
logger.debug(f"观察到结构化信息仍旧在: {structured_info}")
self.structured_info = observed_structured_infos

View File

@@ -10,7 +10,7 @@ import jieba
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from ...common.database import db from ...common.database.database import memory_db as db
from ...chat.models.utils_model import LLMRequest from ...chat.models.utils_model import LLMRequest
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
@@ -193,7 +193,7 @@ class Hippocampus:
def __init__(self): def __init__(self):
self.memory_graph = MemoryGraph() self.memory_graph = MemoryGraph()
self.llm_topic_judge = None self.llm_topic_judge = None
self.llm_summary = None self.model_summary = None
self.entorhinal_cortex = None self.entorhinal_cortex = None
self.parahippocampal_gyrus = None self.parahippocampal_gyrus = None
@@ -205,7 +205,7 @@ class Hippocampus:
self.entorhinal_cortex.sync_memory_from_db() self.entorhinal_cortex.sync_memory_from_db()
# TODO: API-Adapter修改标记 # TODO: API-Adapter修改标记
self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory") self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory")
self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory") self.model_summary = LLMRequest(global_config.model.summary, request_type="memory")
def get_all_node_names(self) -> list: def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表""" """获取记忆图中所有节点的名字列表"""
@@ -1167,7 +1167,7 @@ class ParahippocampalGyrus:
# 调用修改后的 topic_what不再需要 time_info # 调用修改后的 topic_what不再需要 time_info
topic_what_prompt = self.hippocampus.topic_what(input_text, topic) topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
try: try:
task = self.hippocampus.llm_summary.generate_response_async(topic_what_prompt) task = self.hippocampus.model_summary.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task)) tasks.append((topic.strip(), task))
except Exception as e: except Exception as e:
logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}") logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}")

View File

@@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path) sys.path.append(root_path)
from src.common.logger import get_module_logger # noqa E402 from src.common.logger import get_module_logger # noqa E402
from src.common.database import db # noqa E402 from common.database.database import db # noqa E402
logger = get_module_logger("mem_alter") logger = get_module_logger("mem_alter")
console = Console() console = Console()

View File

@@ -72,6 +72,7 @@ class ChatBot:
message_data["message_info"]["user_info"]["user_id"] = str( message_data["message_info"]["user_info"]["user_id"] = str(
message_data["message_info"]["user_info"]["user_id"] message_data["message_info"]["user_info"]["user_id"]
) )
# print(message_data)
logger.trace(f"处理消息:{str(message_data)[:120]}...") logger.trace(f"处理消息:{str(message_data)[:120]}...")
message = MessageRecv(message_data) message = MessageRecv(message_data)
groupinfo = message.message_info.group_info groupinfo = message.message_info.group_info
@@ -86,12 +87,14 @@ class ChatBot:
logger.trace("检测到私聊消息,检查") logger.trace("检测到私聊消息,检查")
# 好友黑名单拦截 # 好友黑名单拦截
if userinfo.user_id not in global_config.experimental.talk_allowed_private: if userinfo.user_id not in global_config.experimental.talk_allowed_private:
logger.debug(f"用户{userinfo.user_id}没有私聊权限") # logger.debug(f"用户{userinfo.user_id}没有私聊权限")
return return
# 群聊黑名单拦截 # 群聊黑名单拦截
# print(groupinfo.group_id)
# print(global_config.chat_target.talk_allowed_groups)
if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups: if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups:
logger.trace(f"{groupinfo.group_id}被禁止回复") logger.debug(f"{groupinfo.group_id}被禁止回复")
return return
# 确认从接口发来的message是否有自定义的prompt模板信息 # 确认从接口发来的message是否有自定义的prompt模板信息

View File

@@ -5,7 +5,8 @@ import copy
from typing import Dict, Optional from typing import Dict, Optional
from ...common.database import db from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
@@ -82,7 +83,13 @@ class ChatManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self._ensure_collection() try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
self._initialized = True self._initialized = True
# 在事件循环中启动初始化 # 在事件循环中启动初始化
# asyncio.create_task(self._initialize()) # asyncio.create_task(self._initialize())
@@ -107,15 +114,6 @@ class ChatManager:
except Exception as e: except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}") logger.error(f"聊天流自动保存失败: {str(e)}")
@staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
@staticmethod @staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
@@ -151,16 +149,43 @@ class ChatManager:
stream = self.streams[stream_id] stream = self.streams[stream_id]
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.update_active_time() stream.update_active_time()
stream = copy.deepcopy(stream) stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
stream.group_info = group_info stream.group_info = group_info
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
data = db.chat_streams.find_one({"stream_id": stream_id}) def _db_find_stream_sync(s_id: str):
if data: return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
stream = ChatStream.from_dict(data)
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
@@ -175,7 +200,7 @@ class ChatManager:
group_info=group_info, group_info=group_info,
) )
except Exception as e: except Exception as e:
logger.error(f"创建聊天流失败: {e}") logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
raise e raise e
# 保存到内存和数据库 # 保存到内存和数据库
@@ -205,15 +230,38 @@ class ChatManager:
elif stream.user_info and stream.user_info.user_nickname: elif stream.user_info and stream.user_info.user_nickname:
return f"{stream.user_info.user_nickname}的私聊" return f"{stream.user_info.user_nickname}的私聊"
else: else:
# 如果没有群名或用户昵称,返回 None 或其他默认值
return None return None
@staticmethod @staticmethod
async def _save_stream(stream: ChatStream): async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库"""
if not stream.saved: if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True) stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
stream.saved = True stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
async def _save_all_streams(self): async def _save_all_streams(self):
"""保存所有聊天流""" """保存所有聊天流"""
@@ -222,10 +270,44 @@ class ChatManager:
async def load_all_streams(self): async def load_all_streams(self):
"""从数据库加载所有聊天流""" """从数据库加载所有聊天流"""
all_streams = db.chat_streams.find({})
for data in all_streams: def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
stream.saved = True
self.streams[stream.stream_id] = stream self.streams[stream.stream_id] = stream
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
# 创建全局单例 # 创建全局单例

View File

@@ -1,9 +1,10 @@
import re import re
from typing import Union from typing import Union
from ...common.database import db # from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
from .message import MessageSending, MessageRecv from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream from .chat_stream import ChatStream
from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
logger = get_module_logger("message_storage") logger = get_module_logger("message_storage")
@@ -29,34 +30,56 @@ class MessageStorage:
else: else:
filtered_detailed_plain_text = "" filtered_detailed_plain_text = ""
message_data = { chat_info_dict = chat_stream.to_dict()
"message_id": message.message_info.message_id, user_info_dict = message.message_info.user_info.to_dict()
"time": message.message_info.time,
"chat_id": chat_stream.stream_id, # message_id 现在是 TextField直接使用字符串值
"chat_info": chat_stream.to_dict(), msg_id = message.message_info.message_id
"user_info": message.message_info.user_info.to_dict(),
# 使用过滤后的文本 # 安全地获取 group_info, 如果为 None 则视为空字典
"processed_plain_text": filtered_processed_plain_text, group_info_from_chat = chat_info_dict.get("group_info") or {}
"detailed_plain_text": filtered_detailed_plain_text, # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
"memorized_times": message.memorized_times, user_info_from_chat = chat_info_dict.get("user_info") or {}
}
db.messages.insert_one(message_data) Messages.create(
message_id=msg_id,
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
# Flattened chat_info
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
processed_plain_text=filtered_processed_plain_text,
detailed_plain_text=filtered_detailed_plain_text,
memorized_times=message.memorized_times,
)
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")
@staticmethod @staticmethod
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None: async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
"""存储撤回消息到数据库""" """存储撤回消息到数据库"""
if "recalled_messages" not in db.list_collection_names(): # Table creation is handled by initialize_database in database_model.py
db.create_collection("recalled_messages")
else:
try: try:
message_data = { RecalledMessages.create(
"message_id": message_id, message_id=message_id,
"time": time, time=float(time), # Assuming time is a string representing a float timestamp
"stream_id": chat_stream.stream_id, stream_id=chat_stream.stream_id,
} )
db.recalled_messages.insert_one(message_data)
except Exception: except Exception:
logger.exception("存储撤回消息失败") logger.exception("存储撤回消息失败")
@@ -64,7 +87,9 @@ class MessageStorage:
async def remove_recalled_message(time: str) -> None: async def remove_recalled_message(time: str) -> None:
"""删除撤回消息""" """删除撤回消息"""
try: try:
db.recalled_messages.delete_many({"time": {"$lt": time - 300}}) # Assuming input 'time' is a string timestamp that can be converted to float
current_time_float = float(time)
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
except Exception: except Exception:
logger.exception("删除撤回消息失败") logger.exception("删除撤回消息失败")

View File

@@ -12,7 +12,8 @@ import base64
from PIL import Image from PIL import Image
import io import io
import os import os
from ...common.database import db from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from ...config.config import global_config from ...config.config import global_config
from rich.traceback import install from rich.traceback import install
@@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}" f"{image_base64[:10]}...{image_base64[-10:]}"
) )
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload return payload
@@ -134,13 +133,11 @@ class LLMRequest:
def _init_database(): def _init_database():
"""初始化数据库集合""" """初始化数据库集合"""
try: try:
# 创建llm_usage集合的索引 # 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.llm_usage.create_index([("timestamp", 1)]) db.create_tables([LLMUsage], safe=True)
db.llm_usage.create_index([("model_name", 1)]) logger.debug("LLMUsage 表已初始化/确保存在。")
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
except Exception as e: except Exception as e:
logger.error(f"创建数据库索引失败: {str(e)}") logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def _record_usage( def _record_usage(
self, self,
@@ -165,19 +162,19 @@ class LLMRequest:
request_type = self.request_type request_type = self.request_type
try: try:
usage_data = { # 使用 Peewee 模型创建记录
"model_name": self.model_name, LLMUsage.create(
"user_id": user_id, model_name=self.model_name,
"request_type": request_type, user_id=user_id,
"endpoint": endpoint, request_type=request_type,
"prompt_tokens": prompt_tokens, endpoint=endpoint,
"completion_tokens": completion_tokens, prompt_tokens=prompt_tokens,
"total_tokens": total_tokens, completion_tokens=completion_tokens,
"cost": self._calculate_cost(prompt_tokens, completion_tokens), total_tokens=total_tokens,
"status": "success", cost=self._calculate_cost(prompt_tokens, completion_tokens),
"timestamp": datetime.now(), status="success",
} timestamp=datetime.now(), # Peewee 会处理 DateTimeField
db.llm_usage.insert_one(usage_data) )
logger.trace( logger.trace(
f"Token使用情况 - 模型: {self.model_name}, " f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, " f"用户: {user_id}, 类型: {request_type}, "

View File

@@ -1,5 +1,6 @@
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from ...common.database import db from ...common.database.database import db
from ...common.database.database_model import PersonInfo # 新增导入
import copy import copy
import hashlib import hashlib
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
@@ -16,7 +17,7 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
import json import json # 新增导入
import re import re
@@ -38,22 +39,18 @@ logger = get_logger("person_info")
person_info_default = { person_info_default = {
"person_id": None, "person_id": None,
"person_name": None, "person_name": None, # 模型中已设为 null=True此默认值OK
"name_reason": None, "name_reason": None,
"platform": None, "platform": "unknown", # 提供非None的默认值
"user_id": None, "user_id": "unknown", # 提供非None的默认值
"nickname": None, "nickname": "Unknown", # 提供非None的默认值
# "age" : 0,
"relationship_value": 0, "relationship_value": 0,
# "saved" : True, "know_time": 0, # 修正拼写konw_time -> know_time
# "impression" : None,
# "gender" : Unkown,
"konw_time": 0,
"msg_interval": 2000, "msg_interval": 2000,
"msg_interval_list": [], "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
"user_cardname": None, # 添加群名片 "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
"user_avatar": None, # 添加头像信息例如URL或标识符 "user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 }
class PersonInfoManager: class PersonInfoManager:
@@ -65,21 +62,26 @@ class PersonInfoManager:
max_tokens=256, max_tokens=256,
request_type="qv_name", request_type="qv_name",
) )
if "person_info" not in db.list_collection_names(): try:
db.create_collection("person_info") db.connect(reuse_if_open=True)
db.person_info.create_index("person_id", unique=True) db.create_tables([PersonInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name # 初始化时读取所有person_name
cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0}) try:
for doc in cursor: for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
if doc.get("person_name"): PersonInfo.person_name.is_null(False)
self.person_name_list[doc["person_id"]] = doc["person_name"] ):
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod @staticmethod
def get_person_id(platform: str, user_id: int): def get_person_id(platform: str, user_id: int):
"""获取唯一id""" """获取唯一id"""
# 如果platform中存在-,就截取-后面的部分
if "-" in platform: if "-" in platform:
platform = platform.split("-")[1] platform = platform.split("-")[1]
@@ -87,15 +89,27 @@ class PersonInfoManager:
key = "_".join(components) key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
def is_person_known(self, platform: str, user_id: int): async def is_person_known(self, platform: str, user_id: int):
"""判断是否认识某人""" """判断是否认识某人"""
person_id = self.get_person_id(platform, user_id) person_id = self.get_person_id(platform, user_id)
document = db.person_info.find_one({"person_id": person_id})
if document: def _db_check_known_sync(p_id: str):
return True return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
else:
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
return False return False
def get_person_id_by_person_name(self, person_name: str):
"""根据用户名获取用户ID"""
document = db.person_info.find_one({"person_name": person_name})
if document:
return document["person_id"]
else:
return ""
@staticmethod @staticmethod
async def create_person_info(person_id: str, data: dict = None): async def create_person_info(person_id: str, data: dict = None):
"""创建一个项""" """创建一个项"""
@@ -104,73 +118,111 @@ class PersonInfoManager:
return return
_person_info_default = copy.deepcopy(person_info_default) _person_info_default = copy.deepcopy(person_info_default)
_person_info_default["person_id"] = person_id model_fields = PersonInfo._meta.fields.keys()
final_data = {"person_id": person_id}
if data: if data:
for key in _person_info_default: for key, value in data.items():
if key != "person_id" and key in data: if key in model_fields:
_person_info_default[key] = data[key] final_data[key] = value
db.person_info.insert_one(_person_info_default) for key, default_value in _person_info_default.items():
if key in model_fields and key not in final_data:
final_data[key] = default_value
if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
final_data["msg_interval_list"] = json.dumps([])
def _db_create_sync(p_data: dict):
try:
PersonInfo.create(**p_data)
return True
except Exception as e:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
"""更新某一个字段,会补全""" """更新某一个字段,会补全"""
if field_name not in person_info_default.keys(): if field_name not in PersonInfo._meta.fields:
logger.debug(f"更新'{field_name}'失败,未定义的字段") if field_name in person_info_default:
logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
return
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
return return
document = db.person_info.find_one({"person_id": person_id}) def _db_update_sync(p_id: str, f_name: str, val):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if document: if record:
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}}) if f_name == "msg_interval_list" and isinstance(val, list):
setattr(record, f_name, json.dumps(val))
else: else:
data[field_name] = value setattr(record, f_name, val)
logger.debug(f"更新时{person_id}不存在,已新建") record.save()
await self.create_person_info(person_id, data) return True, False
return False, True
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
if needs_creation:
logger.debug(f"更新时 {person_id} 不存在,将新建。")
creation_data = data if data is not None else {}
creation_data[field_name] = value
if "platform" not in creation_data or "user_id" not in creation_data:
logger.warning(f"{person_id} 创建记录时platform/user_id 可能缺失。")
await self.create_person_info(person_id, creation_data)
@staticmethod @staticmethod
async def has_one_field(person_id: str, field_name: str): async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段""" """判断是否存在某一个字段"""
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) if field_name not in PersonInfo._meta.fields:
if document: logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return True return True
else: return False
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
return False return False
@staticmethod @staticmethod
def _extract_json_from_text(text: str) -> dict: def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法""" """从文本中提取JSON数据的高容错方法"""
try: try:
# 尝试直接解析
parsed_json = json.loads(text) parsed_json = json.loads(text)
# 如果解析结果是列表,尝试取第一个元素
if isinstance(parsed_json, list): if isinstance(parsed_json, list):
if parsed_json: # 检查列表是否为空 if parsed_json:
parsed_json = parsed_json[0] parsed_json = parsed_json[0]
else: # 如果列表为空,重置为 None走后续逻辑 else:
parsed_json = None parsed_json = None
# 确保解析结果是字典
if isinstance(parsed_json, dict): if isinstance(parsed_json, dict):
return parsed_json return parsed_json
except json.JSONDecodeError: except json.JSONDecodeError:
# 解析失败,继续尝试其他方法
pass pass
except Exception as e: except Exception as e:
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}") logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
pass # 继续尝试其他方法 pass
# 如果直接解析失败或结果不是字典
try: try:
# 尝试找到JSON对象格式的部分
json_pattern = r"\{[^{}]*\}" json_pattern = r"\{[^{}]*\}"
matches = re.findall(json_pattern, text) matches = re.findall(json_pattern, text)
if matches: if matches:
parsed_obj = json.loads(matches[0]) parsed_obj = json.loads(matches[0])
if isinstance(parsed_obj, dict): # 确保是字典 if isinstance(parsed_obj, dict):
return parsed_obj return parsed_obj
# 如果上面都失败了,尝试提取键值对
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"' nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
reason_pattern = r'"reason"[:\s]+"([^"]+)"' reason_pattern = r'"reason"[:\s]+"([^"]+)"'
@@ -185,7 +237,6 @@ class PersonInfoManager:
except Exception as e: except Exception as e:
logger.error(f"后备JSON提取失败: {str(e)}") logger.error(f"后备JSON提取失败: {str(e)}")
# 如果所有方法都失败了,返回默认字典
logger.warning(f"无法从文本中提取有效的JSON字典: {text}") logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
return {"nickname": "", "reason": ""} return {"nickname": "", "reason": ""}
@@ -200,9 +251,11 @@ class PersonInfoManager:
old_name = await self.get_value(person_id, "person_name") old_name = await self.get_value(person_id, "person_name")
old_reason = await self.get_value(person_id, "name_reason") old_reason = await self.get_value(person_id, "name_reason")
max_retries = 5 # 最大重试次数 max_retries = 5
current_try = 0 current_try = 0
existing_names = "" existing_names_str = ""
current_name_set = set(self.person_name_list.values())
while current_try < max_retries: while current_try < max_retries:
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=1) prompt_personality = individuality.get_prompt(x_person=2, level=1)
@@ -217,45 +270,58 @@ class PersonInfoManager:
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}" qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}"
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸" qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
qv_name_prompt += ( qv_name_prompt += (
"\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改" "\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改"
) )
if existing_names:
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}\n" if existing_names_str:
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}\n"
if len(current_name_set) < 50 and current_name_set:
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
qv_name_prompt += "请用json给出你的想法并给出理由示例如下" qv_name_prompt += "请用json给出你的想法并给出理由示例如下"
qv_name_prompt += """{ qv_name_prompt += """{
"nickname": "昵称", "nickname": "昵称",
"reason": "理由" "reason": "理由"
}""" }"""
# logger.debug(f"取名提示词:{qv_name_prompt}")
response = await self.qv_name_llm.generate_response(qv_name_prompt) response = await self.qv_name_llm.generate_response(qv_name_prompt)
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
result = self._extract_json_from_text(response[0]) result = self._extract_json_from_text(response[0])
if not result["nickname"]: if not result or not result.get("nickname"):
logger.error("生成的昵称为空,重试中...") logger.error("生成的昵称为空或结果格式不正确,重试中...")
current_try += 1 current_try += 1
continue continue
# 检查生成的昵称是否已存在 generated_nickname = result["nickname"]
if result["nickname"] not in self.person_name_list.values():
# 更新数据库和内存中的列表
await self.update_one_field(person_id, "person_name", result["nickname"])
# await self.update_one_field(person_id, "nickname", user_nickname)
# await self.update_one_field(person_id, "avatar", user_avatar)
await self.update_one_field(person_id, "name_reason", result["reason"])
self.person_name_list[person_id] = result["nickname"] is_duplicate = False
# logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}") if generated_nickname in current_name_set:
is_duplicate = True
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
if not is_duplicate:
await self.update_one_field(person_id, "person_name", generated_nickname)
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
self.person_name_list[person_id] = generated_nickname
return result return result
else: else:
existing_names += f"{result['nickname']}" if existing_names_str:
existing_names_str += ""
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...") existing_names_str += generated_nickname
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
current_try += 1 current_try += 1
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称") logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
return None return None
@staticmethod @staticmethod
@@ -265,30 +331,56 @@ class PersonInfoManager:
logger.debug("删除失败person_id 不能为空") logger.debug("删除失败person_id 不能为空")
return return
result = db.person_info.delete_one({"person_id": person_id}) def _db_delete_sync(p_id: str):
if result.deleted_count > 0: try:
logger.debug(f"删除成功person_id={person_id}") query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
deleted_count = query.execute()
return deleted_count
except Exception as e:
logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
else: else:
logger.debug(f"删除失败:未找到 person_id={person_id}") logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
@staticmethod @staticmethod
async def get_value(person_id: str, field_name: str): async def get_value(person_id: str, field_name: str):
"""获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值""" """获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id: if not person_id:
logger.debug("get_value获取失败person_id不能为空") logger.debug("get_value获取失败person_id不能为空")
return person_info_default.get(field_name)
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
logger.trace(f"字段'{field_name}'不在Peewee模型中但存在于默认配置中。返回配置默认值。")
return copy.deepcopy(person_info_default[field_name])
logger.debug(f"get_value获取失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
return None return None
if field_name not in person_info_default: def _db_get_value_sync(p_id: str, f_name: str):
logger.debug(f"get_value获取失败字段'{field_name}'未定义") record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
val = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
return copy.deepcopy(person_info_default.get(f_name, []))
return val
return None return None
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if document and field_name in document: if value is not None:
return document[field_name] return value
else: else:
default_value = copy.deepcopy(person_info_default[field_name]) default_value = copy.deepcopy(person_info_default.get(field_name))
logger.trace(f"获取{person_id}{field_name}失败,已返回默认值{default_value}") logger.trace(f"获取{person_id}{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
return default_value return default_value
@staticmethod @staticmethod
@@ -298,93 +390,84 @@ class PersonInfoManager:
logger.debug("get_values获取失败person_id不能为空") logger.debug("get_values获取失败person_id不能为空")
return {} return {}
# 检查所有字段是否有效
for field in field_names:
if field not in person_info_default:
logger.debug(f"get_values获取失败字段'{field}'未定义")
return {}
# 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names}
document = db.person_info.find_one({"person_id": person_id}, projection)
result = {} result = {}
for field in field_names:
result[field] = copy.deepcopy( def _db_get_record_sync(p_id: str):
document.get(field, person_info_default[field]) if document else person_info_default[field] return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
)
record = await asyncio.to_thread(_db_get_record_sync, person_id)
for field_name in field_names:
if field_name not in PersonInfo._meta.fields:
if field_name in person_info_default:
result[field_name] = copy.deepcopy(person_info_default[field_name])
logger.trace(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if field_name == "msg_interval_list" and isinstance(value, str):
try:
result[field_name] = json.loads(value)
except json.JSONDecodeError:
logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
elif value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
return result return result
@staticmethod @staticmethod
async def del_all_undefined_field(): async def del_all_undefined_field():
"""删除所有项里的未定义字段""" """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
# 获取所有已定义的字段名 logger.info(
defined_fields = set(person_info_default.keys()) "del_all_undefined_field: 对于使用Peewee的SQL数据库此操作通常不适用或不需要因为表结构是预定义的。"
try:
# 遍历集合中的所有文档
for document in db.person_info.find({}):
# 找出文档中未定义的字段
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one(
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
) )
if update_result.modified_count > 0:
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
return
except Exception as e:
logger.error(f"清理未定义字段时出错: {e}")
return return
@staticmethod @staticmethod
async def get_specific_value_list( async def get_specific_value_list(
field_name: str, field_name: str,
way: Callable[[Any], bool], # 接受任意类型值 way: Callable[[Any], bool],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取满足条件的字段值字典 获取满足条件的字段值字典
Args:
field_name: 目标字段名
way: 判断函数 (value: Any) -> bool
Returns:
{person_id: value} | {}
Example:
# 查找所有nickname包含"admin"的用户
result = manager.specific_value_list(
"nickname",
lambda x: "admin" in x.lower()
)
""" """
if field_name not in person_info_default: if field_name not in PersonInfo._meta.fields:
logger.error(f"字段检查失败:'{field_name}'未定义") logger.error(f"字段检查失败:'{field_name}'在 PersonInfo Peewee 模型中定义")
return {} return {}
def _db_get_specific_sync(f_name: str):
found_results = {}
try: try:
result = {} for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}): value = getattr(record, f_name)
if f_name == "msg_interval_list" and isinstance(value, str):
try: try:
value = doc[field_name] processed_value = json.loads(value)
if way(value): except json.JSONDecodeError:
result[doc["person_id"]] = value logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
except (KeyError, TypeError, ValueError) as e:
logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}")
continue continue
else:
processed_value = value
return result if way(processed_value):
found_results[record.person_id] = processed_value
except Exception as e_query:
logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
return found_results
try:
return await asyncio.to_thread(_db_get_specific_sync, field_name)
except Exception as e: except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True) logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
return {} return {}
async def personal_habit_deduction(self): async def personal_habit_deduction(self):
@@ -392,35 +475,31 @@ class PersonInfoManager:
try: try:
while 1: while 1:
await asyncio.sleep(600) await asyncio.sleep(600)
current_time = datetime.datetime.now() current_time_dt = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
# "msg_interval"推断 msg_interval_map_generated = False
msg_interval_map = False msg_interval_lists_map = await self.get_specific_value_list(
msg_interval_lists = await self.get_specific_value_list(
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100 "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
) )
for person_id, msg_interval_list_ in msg_interval_lists.items():
for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
try: try:
time_interval = [] time_interval = []
for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]): for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
delta = t2 - t1 delta = t2 - t1
if delta > 0: if delta > 0:
time_interval.append(delta) time_interval.append(delta)
time_interval = [t for t in time_interval if 200 <= t <= 8000] time_interval = [t for t in time_interval if 200 <= t <= 8000]
# --- 修改后的逻辑 ---
# 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断)
if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条
time_interval.sort()
# 画图(log) - 这部分保留 if len(time_interval) >= 30 + 10:
msg_interval_map = True time_interval.sort()
msg_interval_map_generated = True
log_dir = Path("logs/person_info") log_dir = Path("logs/person_info")
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
# 使用截断前的数据画图,更能反映原始分布
time_series_original = pd.Series(time_interval) time_series_original = pd.Series(time_interval)
plt.hist( plt.hist(
time_series_original, time_series_original,
@@ -442,34 +521,29 @@ class PersonInfoManager:
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path) plt.savefig(img_path)
plt.close() plt.close()
# 画图结束
# 去掉头尾各 5 个数据点
trimmed_interval = time_interval[5:-5] trimmed_interval = time_interval[5:-5]
if trimmed_interval:
# 计算截断后数据的 37% 分位数 msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
if trimmed_interval: # 确保截断后列表不为空 await self.update_one_field(person_id, "msg_interval", msg_interval_val)
msg_interval = int(round(np.percentile(trimmed_interval, 37))) logger.trace(
# 更新数据库 f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
await self.update_one_field(person_id, "msg_interval", msg_interval) )
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}")
else: else:
logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval") logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval")
else: else:
logger.trace( logger.trace(
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)" f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
) )
# --- 修改结束 --- except Exception as e_inner:
except Exception as e: logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}")
continue continue
# 其他... if msg_interval_map_generated:
if msg_interval_map:
logger.trace("已保存分布图到: logs/person_info") logger.trace("已保存分布图到: logs/person_info")
current_time = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") current_time_dt_end = datetime.datetime.now()
logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
await asyncio.sleep(86400) await asyncio.sleep(86400)
except Exception as e: except Exception as e:
@@ -482,41 +556,27 @@ class PersonInfoManager:
""" """
根据 platform 和 user_id 获取 person_id。 根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。 如果对应的用户不存在,则使用提供的可选信息创建新用户。
Args:
platform: 平台标识
user_id: 用户在该平台上的ID
nickname: 用户的昵称 (可选,用于创建新用户)
user_cardname: 用户的群名片 (可选,用于创建新用户)
user_avatar: 用户的头像信息 (可选,用于创建新用户)
Returns:
对应的 person_id。
""" """
person_id = self.get_person_id(platform, user_id) person_id = self.get_person_id(platform, user_id)
# 检查用户是否已存在 def _db_check_exists_sync(p_id: str):
# 使用静态方法 get_person_id因此可以直接调用 db return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
document = db.person_info.find_one({"person_id": person_id})
if document is None: record = await asyncio.to_thread(_db_check_exists_sync, person_id)
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
if record is None:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
initial_data = { initial_data = {
"platform": platform, "platform": platform,
"user_id": user_id, "user_id": str(user_id),
"nickname": nickname, "nickname": nickname,
"konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间 "know_time": int(datetime.datetime.now().timestamp()), # 修正拼写konw_time -> know_time
# 注意:这里没有添加 user_cardname 和 user_avatar因为它们不在 person_info_default 中
# 如果需要存储它们,需要先在 person_info_default 中定义
} }
# 过滤掉值为 None 的初始数据 model_fields = PersonInfo._meta.fields.keys()
initial_data = {k: v for k, v in initial_data.items() if v is not None} filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
# 注意create_person_info 是静态方法 await self.create_person_info(person_id, data=filtered_initial_data)
await PersonInfoManager.create_person_info(person_id, data=initial_data) logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
# 创建后,可以考虑立即为其取名,但这可能会增加延迟
# await self.qv_person_name(person_id, nickname, user_cardname, user_avatar)
logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}")
return person_id return person_id
@@ -526,34 +586,54 @@ class PersonInfoManager:
logger.debug("get_person_info_by_name 获取失败person_name 不能为空") logger.debug("get_person_info_by_name 获取失败person_name 不能为空")
return None return None
# 优先从内存缓存查找 person_id
found_person_id = None found_person_id = None
for pid, name in self.person_name_list.items(): for pid, name_in_cache in self.person_name_list.items():
if name == person_name: if name_in_cache == person_name:
found_person_id = pid found_person_id = pid
break # 找到第一个匹配就停止 break
if not found_person_id: if not found_person_id:
# 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载)
document = db.person_info.find_one({"person_name": person_name})
if document:
found_person_id = document.get("person_id")
else:
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户")
return None # 数据库也找不到
# 根据找到的 person_id 获取所需信息 def _db_find_by_name_sync(p_name_to_find: str):
if found_person_id: return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"]
person_data = await self.get_values(found_person_id, required_fields) record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if person_data: # 确保 get_values 成功返回 if record:
return person_data found_person_id = record.person_id
if (
found_person_id not in self.person_name_list
or self.person_name_list[found_person_id] != person_name
):
self.person_name_list[found_person_id] = person_name
else: else:
logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败") logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
return None return None
if found_person_id:
required_fields = [
"person_id",
"platform",
"user_id",
"nickname",
"user_cardname",
"user_avatar",
"person_name",
"name_reason",
]
valid_fields_to_get = [
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)
if person_data:
final_result = {key: person_data.get(key) for key in required_fields}
return final_result
else: else:
# 这理论上不应该发生,因为上面已经处理了找不到的情况 logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id") return None
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
return None return None

View File

@@ -77,7 +77,7 @@ class RelationshipManager:
@staticmethod @staticmethod
async def is_known_some_one(platform, user_id): async def is_known_some_one(platform, user_id):
"""判断是否认识某人""" """判断是否认识某人"""
is_known = person_info_manager.is_person_known(platform, user_id) is_known = await person_info_manager.is_person_known(platform, user_id)
return is_known return is_known
@staticmethod @staticmethod

View File

@@ -451,10 +451,10 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 处理 回复<aaa:bbb> # 处理 回复<aaa:bbb>
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
def reply_replacer(match): def reply_replacer(match, platform=platform):
# aaa = match.group(1) # aaa = match.group(1)
bbb = match.group(2) bbb = match.group(2)
anon_reply = get_anon_name(platform, bbb) anon_reply = get_anon_name(platform, bbb) # noqa
return f"回复 {anon_reply}" return f"回复 {anon_reply}"
content = re.sub(reply_pattern, reply_replacer, content, count=1) content = re.sub(reply_pattern, reply_replacer, content, count=1)
@@ -462,10 +462,10 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 处理 @<aaa:bbb> # 处理 @<aaa:bbb>
at_pattern = r"@<([^:<>]+):([^:<>]+)>" at_pattern = r"@<([^:<>]+):([^:<>]+)>"
def at_replacer(match): def at_replacer(match, platform=platform):
# aaa = match.group(1) # aaa = match.group(1)
bbb = match.group(2) bbb = match.group(2)
anon_at = get_anon_name(platform, bbb) anon_at = get_anon_name(platform, bbb) # noqa
return f"@{anon_at}" return f"@{anon_at}"
content = re.sub(at_pattern, at_replacer, content) content = re.sub(at_pattern, at_replacer, content)

View File

@@ -1,9 +1,10 @@
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv, MessageSending, Message from src.chat.message_receive.message import MessageRecv, MessageSending, Message
from src.common.database import db from src.common.database.database_model import Messages, ThinkingLog
import time import time
import traceback import traceback
from typing import List from typing import List
import json
class InfoCatcher: class InfoCatcher:
@@ -59,8 +60,6 @@ class InfoCatcher:
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息 def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str): def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1: if len(past_mind) > 1:
@@ -71,25 +70,10 @@ class InfoCatcher:
self.heartflow_data["sub_heartflow_now"] = current_mind self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""): def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
# if self.response_mode == "heart_flow": # 条件判断不需要了喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
# elif self.response_mode == "reasoning": # 条件判断不需要了喵~
# self.reasoning_data["thinking_log"] = reasoning_content
# self.reasoning_data["prompt"] = prompt
# self.reasoning_data["response"] = response
# self.reasoning_data["model"] = model_name
# 直接记录信息喵~
self.reasoning_data["thinking_log"] = reasoning_content self.reasoning_data["thinking_log"] = reasoning_content
self.reasoning_data["prompt"] = prompt self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name self.reasoning_data["model"] = model_name
# 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~
# self.heartflow_data["prompt"] = prompt
# self.heartflow_data["response"] = response
# self.heartflow_data["model"] = model_name
self.response_text = response self.response_text = response
@@ -101,6 +85,7 @@ class InfoCatcher:
): ):
self.timing_results["make_response_time"] = response_duration self.timing_results["make_response_time"] = response_duration
self.response_time = time.time() self.response_time = time.time()
self.response_messages = []
for msg in response_message: for msg in response_message:
self.response_messages.append(msg) self.response_messages.append(msg)
@@ -111,107 +96,112 @@ class InfoCatcher:
@staticmethod @staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message): def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try: try:
# 从数据库中获取消息的时间戳
time_start = message_start.message_info.time time_start = message_start.message_info.time
time_end = message_end.message_info.time time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 messages_between_query = (
messages_between = db.messages.find( Messages.select()
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}} .where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
).sort("time", -1) .order_by(Messages.time.desc())
)
result = list(messages_between) result = list(messages_between_query)
print(f"查询结果数量: {len(result)}") print(f"查询结果数量: {len(result)}")
if result: if result:
print(f"第一条消息时间: {result[0]['time']}") print(f"第一条消息时间: {result[0].time}")
print(f"最后一条消息时间: {result[-1]['time']}") print(f"最后一条消息时间: {result[-1].time}")
return result return result
except Exception as e: except Exception as e:
print(f"获取消息时出错: {str(e)}") print(f"获取消息时出错: {str(e)}")
print(traceback.format_exc())
return [] return []
def get_message_from_db_before_msg(self, message: MessageRecv): def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息 message_id_val = message.message_info.message_id
message_id = message.message_info.message_id chat_id_val = message.chat_stream.stream_id
chat_id = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 messages_before_query = (
messages_before = ( Messages.select()
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) .where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
.sort("time", -1) .order_by(Messages.time.desc())
.limit(global_config.chat.observation_context_size * 3) .limit(global_config.chat.observation_context_size * 3)
) # 获取更多历史信息 )
return list(messages_before) return list(messages_before_query)
def message_list_to_dict(self, message_list): def message_list_to_dict(self, message_list):
# 存储简化的聊天记录
result = [] result = []
for message in message_list: for msg_item in message_list:
if not isinstance(message, dict): processed_msg_item = msg_item
message = self.message_to_dict(message) if not isinstance(msg_item, dict):
# print(message) processed_msg_item = self.message_to_dict(msg_item)
if not processed_msg_item:
continue
lite_message = { lite_message = {
"time": message["time"], "time": processed_msg_item.get("time"),
"user_nickname": message["user_info"]["user_nickname"], "user_nickname": processed_msg_item.get("user_nickname"),
"processed_plain_text": message["processed_plain_text"], "processed_plain_text": processed_msg_item.get("processed_plain_text"),
} }
result.append(lite_message) result.append(lite_message)
return result return result
@staticmethod @staticmethod
def message_to_dict(message): def message_to_dict(msg_obj):
if not message: if not msg_obj:
return None return None
if isinstance(message, dict): if isinstance(msg_obj, dict):
return message return msg_obj
if isinstance(msg_obj, Messages):
return { return {
# "message_id": message.message_info.message_id, "time": msg_obj.time,
"time": message.message_info.time, "user_id": msg_obj.user_id,
"user_id": message.message_info.user_info.user_id, "user_nickname": msg_obj.user_nickname,
"user_nickname": message.message_info.user_info.user_nickname, "processed_plain_text": msg_obj.processed_plain_text,
"processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text
} }
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
return {
"time": msg_obj.message_info.time,
"user_id": msg_obj.message_info.user_info.user_id,
"user_nickname": msg_obj.message_info.user_info.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
return {}
def done_catch(self): def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 集合中喵~""" """将收集到的信息存储到数据库的 thinking_log 中喵~"""
try: try:
# 将消息对象转换为可序列化的字典喵~ trigger_info_dict = self.message_to_dict(self.trigger_response_message)
response_info_dict = {
thinking_log_data = {
"chat_id": self.chat_id,
"trigger_text": self.trigger_response_text,
"response_text": self.response_text,
"trigger_info": {
"time": self.trigger_response_time,
"message": self.message_to_dict(self.trigger_response_message),
},
"response_info": {
"time": self.response_time, "time": self.response_time,
"message": self.response_messages, "message": self.response_messages,
},
"timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
"heartflow_data": self.heartflow_data,
"reasoning_data": self.reasoning_data,
} }
chat_history_list = self.message_list_to_dict(self.chat_history)
chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking)
chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response)
# 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~ log_entry = ThinkingLog(
# if self.response_mode == "heart_flow": chat_id=self.chat_id,
# thinking_log_data["mode_specific_data"] = self.heartflow_data trigger_text=self.trigger_response_text,
# elif self.response_mode == "reasoning": response_text=self.response_text,
# thinking_log_data["mode_specific_data"] = self.reasoning_data trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None,
response_info_json=json.dumps(response_info_dict),
# 将数据插入到 thinking_log 集合中喵~ timing_results_json=json.dumps(self.timing_results),
db.thinking_log.insert_one(thinking_log_data) chat_history_json=json.dumps(chat_history_list),
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
heartflow_data_json=json.dumps(self.heartflow_data),
reasoning_data_json=json.dumps(self.reasoning_data),
)
log_entry.save()
return True return True
except Exception as e: except Exception as e:

View File

@@ -2,10 +2,12 @@ from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask
from ...common.database import db from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic") logger = get_module_logger("maibot_statistic")
@@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask):
def __init__(self): def __init__(self):
super().__init__(task_name="Online Time Record Task", run_interval=60) super().__init__(task_name="Online Time Record Task", run_interval=60)
self.record_id: str | None = None self.record_id: int | None = None # Changed to int for Peewee's default ID
"""记录ID""" """记录ID"""
self._init_database() # 初始化数据库 self._init_database() # 初始化数据库
@@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask):
@staticmethod @staticmethod
def _init_database(): def _init_database():
"""初始化数据库""" """初始化数据库"""
if "online_time" not in db.list_collection_names(): with db.atomic(): # Use atomic operations for schema changes
# 初始化数据库(在线时长) OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
db.create_collection("online_time")
# 创建索引
if ("end_timestamp", 1) not in db.online_time.list_indexes():
db.online_time.create_index([("end_timestamp", 1)])
async def run(self): async def run(self):
try: try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id: if self.record_id:
# 如果有记录,则更新结束时间 # 如果有记录,则更新结束时间
db.online_time.update_one( query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
{"_id": self.record_id}, updated_rows = query.execute()
{ if updated_rows == 0:
"$set": { # Record might have been deleted or ID is stale, try to find/create
"end_timestamp": datetime.now() + timedelta(minutes=1), self.record_id = None # Reset record_id to trigger find/create logic below
}
}, if not self.record_id: # Check again if record_id was reset or initially None
)
else:
# 如果没有记录,检查一分钟以内是否已有记录 # 如果没有记录,检查一分钟以内是否已有记录
current_time = datetime.now() # Look for a record whose end_timestamp is recent enough to be considered ongoing
if recent_record := db.online_time.find_one( recent_record = (
{"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}} OnlineTime.select()
): .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
# 如果有记录,则更新结束时间 .order_by(OnlineTime.end_timestamp.desc())
self.record_id = recent_record["_id"] .first()
db.online_time.update_one(
{"_id": self.record_id},
{
"$set": {
"end_timestamp": current_time + timedelta(minutes=1),
}
},
) )
if recent_record:
# 如果有记录,则更新结束时间
self.record_id = recent_record.id
recent_record.end_timestamp = extended_end_time
recent_record.save()
else: else:
# 若没有记录,则插入新的在线时间记录 # 若没有记录,则插入新的在线时间记录
self.record_id = db.online_time.insert_one( new_record = OnlineTime.create(
{ timestamp=current_time.timestamp(), # 添加此行
"start_timestamp": current_time, start_timestamp=current_time,
"end_timestamp": current_time + timedelta(minutes=1), end_timestamp=extended_end_time,
} duration=5, # 初始时长为5分钟
).inserted_id )
self.record_id = new_record.id
except Exception as e: except Exception as e:
logger.error(f"在线时间记录失败,错误信息:{e}") logger.error(f"在线时间记录失败,错误信息:{e}")
@@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段 :param collect_period: 统计时间段
""" """
if len(collect_period) <= 0: if not collect_period:
return {} return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前) # 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True) collect_period.sort(key=lambda x: x[1], reverse=True)
stats = { stats = {
period_key: { period_key: {
# 总LLM请求数
TOTAL_REQ_CNT: 0, TOTAL_REQ_CNT: 0,
# 请求次数统计
REQ_CNT_BY_TYPE: defaultdict(int), REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int), REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int), REQ_CNT_BY_MODEL: defaultdict(int),
# 输入Token数
IN_TOK_BY_TYPE: defaultdict(int), IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int), IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int), IN_TOK_BY_MODEL: defaultdict(int),
# 输出Token数
OUT_TOK_BY_TYPE: defaultdict(int), OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int), OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int), OUT_TOK_BY_MODEL: defaultdict(int),
# 总Token数
TOTAL_TOK_BY_TYPE: defaultdict(int), TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int), TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int), TOTAL_TOK_BY_MODEL: defaultdict(int),
# 总开销
TOTAL_COST: 0.0, TOTAL_COST: 0.0,
# 请求开销统计
COST_BY_TYPE: defaultdict(float), COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float), COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float), COST_BY_MODEL: defaultdict(float),
@@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask):
} }
# 以最早的时间戳为起始时间获取记录 # 以最早的时间戳为起始时间获取记录
for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}): # Assuming LLMUsage.timestamp is a DateTimeField
record_timestamp = record.get("timestamp") query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period): for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start: if record_timestamp >= period_start:
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]: for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1 stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type", "unknown") # 请求类型 request_type = record.request_type or "unknown"
user_id = str(record.get("user_id", "unknown")) # 用户ID user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.get("model_name", "unknown") # 模型名称 model_name = record.model_name or "unknown"
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
stats[period_key][REQ_CNT_BY_USER][user_id] += 1 stats[period_key][REQ_CNT_BY_USER][user_id] += 1
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数 prompt_tokens = record.prompt_tokens or 0
completion_tokens = record.get("completion_tokens", 0) # 输出Token数 completion_tokens = record.completion_tokens or 0
total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数 total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens
@@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
cost = record.get("cost", 0.0) cost = record.cost or 0.0
stats[period_key][TOTAL_COST] += cost stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost stats[period_key][COST_BY_USER][user_id] += cost
stats[period_key][COST_BY_MODEL][model_name] += cost stats[period_key][COST_BY_MODEL][model_name] += cost
break # 取消更早时间段的判断 break
return stats return stats
@staticmethod @staticmethod
@@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段 :param collect_period: 统计时间段
""" """
if len(collect_period) <= 0: if not collect_period:
return {} return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True) collect_period.sort(key=lambda x: x[1], reverse=True)
stats = { stats = {
period_key: { period_key: {
# 在线时间统计
ONLINE_TIME: 0.0, ONLINE_TIME: 0.0,
} }
for period_key, _ in collect_period for period_key, _ in collect_period
} }
# 统计在线时间 query_start_time = collect_period[-1][1]
for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}): # Assuming OnlineTime.end_timestamp is a DateTimeField
end_timestamp: datetime = record.get("end_timestamp") for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
for idx, (_, period_start) in enumerate(collect_period): # record.end_timestamp and record.start_timestamp are datetime objects
if end_timestamp >= period_start: record_end_timestamp = record.end_timestamp
# 由于end_timestamp会超前标记时间所以我们需要判断是否晚于当前时间如果是则使用当前时间作为结束时间 record_start_timestamp = record.start_timestamp
end_timestamp = min(end_timestamp, now)
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _period_start in collect_period[idx:]:
start_timestamp: datetime = record.get("start_timestamp")
if start_timestamp < _period_start:
# 如果开始时间在查询边界之前,则使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds()
else:
# 否则,使用开始时间
stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds()
break # 取消更早时间段的判断
for idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[idx:]:
# Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end
if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break
return stats return stats
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
@@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask):
:param collect_period: 统计时间段 :param collect_period: 统计时间段
""" """
if len(collect_period) <= 0: if not collect_period:
return {} return {}
else:
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True) collect_period.sort(key=lambda x: x[1], reverse=True)
stats = { stats = {
period_key: { period_key: {
# 消息统计
TOTAL_MSG_CNT: 0, TOTAL_MSG_CNT: 0,
MSG_CNT_BY_CHAT: defaultdict(int), MSG_CNT_BY_CHAT: defaultdict(int),
} }
for period_key, _ in collect_period for period_key, _ in collect_period
} }
# 统计消息量 query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}): for message in Messages.select().where(Messages.time >= query_start_timestamp):
chat_info = message.get("chat_info", None) # 聊天信息 message_time_ts = message.time # This is a float timestamp
user_info = message.get("user_info", None) # 用户信息(消息发送人)
message_time = message.get("time", 0) # 消息时间
group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息 chat_id = None
if group_info is not None: chat_name = None
# 若有群聊信息
chat_id = f"g{group_info.get('group_id')}" # Logic based on Peewee model structure, aiming to replicate original intent
chat_name = group_info.get("group_name", f"{group_info.get('group_id')}") if message.chat_info_group_id:
elif user_info: chat_id = f"g{message.chat_info_group_id}"
# 若没有群聊信息,则尝试获取用户信息 chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
chat_id = f"u{user_info['user_id']}" elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
chat_name = user_info["user_nickname"] # This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
else: else:
continue # 如果没有群组信息也没有用户信息,则跳过 # If neither group_id nor sender_id is available for chat identification
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
)
continue
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
if chat_id in self.name_mapping: if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]: if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
# 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称 self.name_mapping[chat_id] = (chat_name, message_time_ts)
self.name_mapping[chat_id] = (chat_name, message_time)
else: else:
self.name_mapping[chat_id] = (chat_name, message_time) self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start) in enumerate(collect_period): for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time >= period_start.timestamp(): if message_time_ts >= period_start_dt.timestamp():
# 如果记录时间在当前时间段内,则它一定在更早的时间段内
# 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据
for period_key, _ in collect_period[idx:]: for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break break
return stats return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:

View File

@@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator from .typo_generator import ChineseTypoGenerator
from ...common.database import db from ...common.database.database import db
from ...config.config import global_config from ...config.config import global_config
logger = get_module_logger("chat_utils") logger = get_module_logger("chat_utils")

View File

@@ -8,7 +8,8 @@ import io
import numpy as np import numpy as np
from ...common.database import db from ...common.database.database import db
from ...common.database.database_model import Images, ImageDescriptions
from ...config.config import global_config from ...config.config import global_config
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
@@ -32,40 +33,23 @@ class ImageManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True self._initialized = True
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image") self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
self._initialized = True
def _ensure_image_dir(self): def _ensure_image_dir(self):
"""确保图像存储目录存在""" """确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True) os.makedirs(self.IMAGE_DIR, exist_ok=True)
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
# 删除旧索引
db.images.drop_indexes()
# 创建新的复合索引
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 删除旧索引
db.image_descriptions.drop_indexes()
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
@staticmethod @staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
@@ -77,8 +61,14 @@ class ImageManager:
Returns: Returns:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) try:
return result["description"] if result else None record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@staticmethod @staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
@@ -90,20 +80,17 @@ class ImageManager:
description_type: 描述类型 ('emoji''image') description_type: 描述类型 ('emoji''image')
""" """
try: try:
db.image_descriptions.update_one( current_timestamp = time.time()
{"hash": image_hash, "type": description_type}, defaults = {"description": description, "timestamp": current_timestamp}
{ desc_obj, created = ImageDescriptions.get_or_create(
"$set": { hash=image_hash, type=description_type, defaults=defaults
"description": description,
"timestamp": int(time.time()),
"hash": image_hash, # 确保hash字段存在
"type": description_type, # 确保type字段存在
}
},
upsert=True,
) )
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
except Exception as e: except Exception as e:
logger.error(f"保存描述到数据库失败: {str(e)}") logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str: async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能""" """获取表情包描述,带查重和保存功能"""
@@ -116,18 +103,25 @@ class ImageManager:
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji") cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description: if cached_description:
# logger.debug(f"缓存表情包描述: {cached_description}")
return f"[表情包,含义看起来是:{cached_description}]" return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述 # 调用AI获取描述
if image_format == "gif" or image_format == "GIF": if image_format == "gif" or image_format == "GIF":
image_base64 = self.transform_gif(image_base64) image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]"
prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些" prompt = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明使用1-2个词描述一下表情包表达的情感和内容简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg") description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
else: else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些" prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成表情包描述")
return "[表情包(描述生成失败)]"
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji") cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description: if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
@@ -136,31 +130,37 @@ class ImageManager:
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.emoji.save_emoji: if global_config.emoji.save_emoji:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) current_timestamp = time.time()
filename = f"{timestamp}_{image_hash[:8]}.{image_format}" filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")): emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji")) os.makedirs(emoji_dir, exist_ok=True)
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename) file_path = os.path.join(emoji_dir, filename)
try: try:
# 保存文件 # 保存文件
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
# 保存到数据库 # 保存到数据库 (Images表)
image_doc = { try:
"hash": image_hash, img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
"path": file_path, img_obj.path = file_path
"type": "emoji", img_obj.description = description
"description": description, img_obj.timestamp = current_timestamp
"timestamp": timestamp, img_obj.save()
} except Images.DoesNotExist:
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) Images.create(
logger.trace(f"保存表情包: {file_path}") hash=image_hash,
path=file_path,
type="emoji",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存表情包元数据: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}") logger.error(f"保存表情包文件或元数据失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "emoji") self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]" return f"[表情包:{description}]"
@@ -188,6 +188,11 @@ class ImageManager:
) )
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"
# 再次检查缓存
cached_description = self._get_description_from_db(image_hash, "image") cached_description = self._get_description_from_db(image_hash, "image")
if cached_description: if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
@@ -195,38 +200,40 @@ class ImageManager:
logger.debug(f"描述是{description}") logger.debug(f"描述是{description}")
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.emoji.save_pic: if global_config.emoji.save_pic:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) current_timestamp = time.time()
filename = f"{timestamp}_{image_hash[:8]}.{image_format}" filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")): image_dir = os.path.join(self.IMAGE_DIR, "image")
os.makedirs(os.path.join(self.IMAGE_DIR, "image")) os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(self.IMAGE_DIR, "image", filename) file_path = os.path.join(image_dir, filename)
try: try:
# 保存文件 # 保存文件
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
# 保存到数据库 # 保存到数据库 (Images表)
image_doc = { try:
"hash": image_hash, img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image"))
"path": file_path, img_obj.path = file_path
"type": "image", img_obj.description = description
"description": description, img_obj.timestamp = current_timestamp
"timestamp": timestamp, img_obj.save()
} except Images.DoesNotExist:
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) Images.create(
logger.trace(f"保存图片: {file_path}") hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
)
logger.trace(f"保存图片元数据: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}") logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库 (ImageDescriptions表)
self._save_description_to_db(image_hash, description, "image") self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]" return f"[图片:{description}]"

View File

@@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path) sys.path.append(root_path)
# 现在可以导入src模块 # 现在可以导入src模块
from src.common.database import db # noqa E402 from common.database.database import db # noqa E402
# 加载根目录下的env.edv文件 # 加载根目录下的env.edv文件

View File

@@ -1,5 +1,6 @@
import os import os
from pymongo import MongoClient from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database from pymongo.database import Database
from rich.traceback import install from rich.traceback import install
@@ -57,4 +58,15 @@ class DBWrapper:
# 全局数据库访问点 # 全局数据库访问点
db: Database = DBWrapper() memory_db: Database = DBWrapper()
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(_DB_FILE)

View File

@@ -0,0 +1,358 @@
from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
from .database import db
import datetime
from ..logger_manager import get_logger
logger = get_logger("database_model")
# 请在此处定义您的数据库实例。
# 您需要取消注释并配置适合您的数据库的部分。
# 例如,对于 SQLite:
# db = SqliteDatabase('MaiBot.db')
#
# 对于 PostgreSQL:
# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=5432)
#
# 对于 MySQL:
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=3306)
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
# 这允许您在一个地方为所有模型指定数据库。
class BaseModel(Model):
class Meta:
# 将下面的 'db' 替换为您实际的数据库实例变量名。
database = db # 例如: database = my_actual_db_instance
pass # 在用户定义数据库实例之前,此处为占位符
class ChatStreams(BaseModel):
"""
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
"""
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
stream_id = TextField(unique=True, index=True)
# create_time: 1746096761.4490178 (时间戳精确到小数点后7位)
# DoubleField 用于存储浮点数,适合此类时间戳。
create_time = DoubleField()
# group_info 字段:
# platform: "qq"
# group_id: "941657197"
# group_name: "测试"
group_platform = TextField()
group_id = TextField()
group_name = TextField()
# last_active_time: 1746623771.4825106 (时间戳精确到小数点后7位)
last_active_time = DoubleField()
# platform: "qq" (顶层平台字段)
platform = TextField()
# user_info 字段:
# platform: "qq"
# user_id: "1787882683"
# user_nickname: "墨梓柒(IceSakurary)"
# user_cardname: ""
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
user_cardname = TextField(null=True)
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
# database = db
table_name = "chat_streams" # 可选:明确指定数据库中的表名
class LLMUsage(BaseModel):
"""
用于存储 API 使用日志数据的模型。
"""
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
endpoint = TextField()
prompt_tokens = IntegerField()
completion_tokens = IntegerField()
total_tokens = IntegerField()
cost = DoubleField()
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = "llm_usage"
class Emoji(BaseModel):
"""表情包"""
full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名)
format = TextField() # 图片格式
emoji_hash = TextField(index=True) # 表情包的哈希值
description = TextField() # 表情包的描述
query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数)
is_registered = BooleanField(default=False) # 是否已注册
is_banned = BooleanField(default=False) # 是否被禁止注册
# emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化
emotion = TextField(null=True)
record_time = FloatField() # 记录时间(被创建的时间)
register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间)
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
last_used_time = FloatField(null=True) # 上次使用时间
class Meta:
# database = db # 继承自 BaseModel
table_name = "emoji"
class Messages(BaseModel):
"""
用于存储消息数据的模型。
"""
message_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = TextField()
chat_info_platform = TextField()
chat_info_user_platform = TextField()
chat_info_user_id = TextField()
chat_info_user_nickname = TextField()
chat_info_user_cardname = TextField(null=True)
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_id = TextField(null=True)
chat_info_group_name = TextField(null=True)
chat_info_create_time = DoubleField()
chat_info_last_active_time = DoubleField()
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
user_cardname = TextField(null=True)
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数
class Meta:
# database = db # 继承自 BaseModel
table_name = "messages"
class Images(BaseModel):
"""
用于存储图像信息的模型。
"""
emoji_hash = TextField(index=True) # 图像的哈希值
description = TextField(null=True) # 图像的描述
path = TextField(unique=True) # 图像文件的路径
timestamp = FloatField() # 时间戳
type = TextField() # 图像类型,例如 "emoji"
class Meta:
# database = db # 继承自 BaseModel
table_name = "images"
class ImageDescriptions(BaseModel):
"""
用于存储图像描述信息的模型。
"""
type = TextField() # 类型,例如 "emoji"
image_description_hash = TextField(index=True) # 图像的哈希值
description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳
class Meta:
# database = db # 继承自 BaseModel
table_name = "image_descriptions"
class OnlineTime(BaseModel):
"""
用于存储在线时长记录的模型。
"""
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = TextField(default=datetime.datetime.now) # 时间戳
duration = IntegerField() # 时长,单位分钟
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = "online_time"
class PersonInfo(BaseModel):
"""
用于存储个人信息数据的模型。
"""
person_id = TextField(unique=True, index=True) # 个人唯一ID
person_name = TextField(null=True) # 个人名称 (允许为空)
name_reason = TextField(null=True) # 名称设定的原因
platform = TextField() # 平台
user_id = TextField(index=True) # 用户ID
nickname = TextField() # 用户昵称
relationship_value = IntegerField(default=0) # 关系值
know_time = FloatField() # 认识时间 (时间戳)
msg_interval = IntegerField() # 消息间隔
# msg_interval_list: 存储为 JSON 字符串的列表
msg_interval_list = TextField(null=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = "person_info"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = "knowledges"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
timing_results_json = TextField(null=True)
chat_history_json = TextField(null=True)
chat_history_in_thinking_json = TextField(null=True)
chat_history_after_response_json = TextField(null=True)
heartflow_data_json = TextField(null=True)
reasoning_data_json = TextField(null=True)
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = "thinking_logs"
class RecalledMessages(BaseModel):
"""
用于存储撤回消息记录的模型。
"""
message_id = TextField(index=True) # 被撤回的消息 ID
time = DoubleField() # 撤回操作发生的时间戳
stream_id = TextField() # 对应的 ChatStreams stream_id
class Meta:
table_name = "recalled_messages"
def create_tables():
"""
创建所有在模型中定义的数据库表。
"""
with db:
db.create_tables(
[
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
]
)
def initialize_database():
"""
检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。
"""
import sys
models = [
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
]
needs_creation = False
try:
with db: # 管理 table_exists 检查的连接
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 未找到。")
needs_creation = True
break # 一个表丢失,无需进一步检查。
if not needs_creation:
# 检查字段
for model in models:
table_name = model._meta.table_name
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = model._meta.fields
for field_name in model_fields:
if field_name not in existing_columns:
logger.error(f"'{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。")
sys.exit(1)
except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
return
if needs_creation:
logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...")
try:
create_tables() # 此函数有其自己的 'with db:' 上下文管理。
logger.info("数据库表创建过程完成。")
except Exception as e:
logger.exception(f"创建表期间出错: {e}")
else:
logger.info("所有数据库表及字段均已存在。")
# 模块加载时调用初始化函数
initialize_database()

View File

@@ -629,22 +629,22 @@ PROCESSOR_STYLE_CONFIG = {
PLANNER_STYLE_CONFIG = { PLANNER_STYLE_CONFIG = {
"advanced": { "advanced": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #36DEFF>规划器</fg #36DEFF> | <fg #36DEFF>{message}</fg #36DEFF>", "console_format": "<level>{time:HH:mm:ss}</level> | <fg #4DCDFF>规划器</fg #4DCDFF> | <fg #4DCDFF>{message}</fg #4DCDFF>",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}",
}, },
"simple": { "simple": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #36DEFF>规划器</fg #36DEFF> | <fg #36DEFF>{message}</fg #36DEFF>", "console_format": "<level>{time:HH:mm:ss}</level> | <fg #4DCDFF>规划器</fg #4DCDFF> | <fg #4DCDFF>{message}</fg #4DCDFF>",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}",
}, },
} }
ACTION_TAKEN_STYLE_CONFIG = { ACTION_TAKEN_STYLE_CONFIG = {
"advanced": { "advanced": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #22DAFF>动作</fg #22DAFF> | <fg #22DAFF>{message}</fg #22DAFF>", "console_format": "<level>{time:HH:mm:ss}</level> | <fg #FFA01F>动作</fg #FFA01F> | <fg #FFA01F>{message}</fg #FFA01F>",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}",
}, },
"simple": { "simple": {
"console_format": "<level>{time:HH:mm:ss}</level> | <fg #22DAFF>动作</fg #22DAFF> | <fg #22DAFF>{message}</fg #22DAFF>", "console_format": "<level>{time:HH:mm:ss}</level> | <fg #FFA01F>动作</fg #FFA01F> | <fg #FFA01F>{message}</fg #FFA01F>",
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}",
}, },
} }

View File

@@ -1,11 +1,19 @@
from src.common.database import db from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
import traceback import traceback
from typing import List, Any, Optional from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
logger = get_module_logger(__name__) logger = get_module_logger(__name__)
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
"""
将 Peewee 模型实例转换为字典。
"""
return model_instance.__data__
def find_messages( def find_messages(
message_filter: dict[str, Any], message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None, sort: Optional[List[tuple[str, int]]] = None,
@@ -16,39 +24,84 @@ def find_messages(
根据提供的过滤器、排序和限制条件查找消息。 根据提供的过滤器、排序和限制条件查找消息。
Args: Args:
message_filter: MongoDB 查询过滤器。 message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。 sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数0表示不限制。 limit: 返回的最大文档数0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest' limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'
Returns: Returns:
消息文档列表,如果出错则返回空列表。 消息字典列表,如果出错则返回空列表。
""" """
try: try:
query = db.messages.find(message_filter) query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
if limit > 0: if limit > 0:
if limit_mode == "earliest": if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序 # 获取时间最早的 limit 条记录,已经是正序
query = query.sort([("time", 1)]).limit(limit) query = query.order_by(Messages.time.asc()).limit(limit)
results = list(query) peewee_results = list(query)
else: # 默认为 'latest' else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录 # 获取时间最晚的 limit 条记录
query = query.sort([("time", -1)]).limit(limit) query = query.order_by(Messages.time.desc()).limit(limit)
latest_results = list(query) latest_results_peewee = list(query)
# 将结果按时间正序排列 # 将结果按时间正序排列
# 假设消息文档中总是有 'time' 字段且可排序 peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
results = sorted(latest_results, key=lambda msg: msg.get("time"))
else: else:
# limit 为 0 时,应用传入的 sort 参数 # limit 为 0 时,应用传入的 sort 参数
if sort: if sort:
query = query.sort(sort) peewee_sort_terms = []
results = list(query) for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
return results return [_model_to_dict(msg) for msg in peewee_results]
except Exception as e: except Exception as e:
log_message = ( log_message = (
f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc() + traceback.format_exc()
) )
logger.error(log_message) logger.error(log_message)
@@ -60,18 +113,57 @@ def count_messages(message_filter: dict[str, Any]) -> int:
根据提供的过滤器计算消息数量。 根据提供的过滤器计算消息数量。
Args: Args:
message_filter: MongoDB 查询过滤器。 message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
Returns: Returns:
符合条件的消息数量,如果出错则返回 0。 符合条件的消息数量,如果出错则返回 0。
""" """
try: try:
count = db.messages.count_documents(message_filter) query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
)
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
count = query.count()
return count return count
except Exception as e: except Exception as e:
log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message) logger.error(log_message)
return 0 return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。

View File

@@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import (
create_new_message_notification, create_new_message_notification,
create_cold_chat_notification, create_cold_chat_notification,
) )
from src.experimental.PFC.message_storage import MongoDBMessageStorage from src.experimental.PFC.message_storage import PeeweeMessageStorage
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
@@ -53,7 +53,7 @@ class ChatObserver:
self.stream_id = stream_id self.stream_id = stream_id
self.private_name = private_name self.private_name = private_name
self.message_storage = MongoDBMessageStorage() self.message_storage = PeeweeMessageStorage()
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间

View File

@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any from typing import List, Dict, Any
from src.common.database import db
# from src.common.database.database import db # Peewee db 导入
from src.common.database.database_model import Messages # Peewee Messages 模型导入
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
class MessageStorage(ABC): class MessageStorage(ABC):
@@ -47,28 +50,35 @@ class MessageStorage(ABC):
pass pass
class MongoDBMessageStorage(MessageStorage): class PeeweeMessageStorage(MessageStorage):
"""MongoDB消息存储实现""" """Peewee消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$gt": message_time}} query = (
# print(f"storage_check_message: {message_time}") Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time > message_time))
.order_by(Messages.time.asc())
)
return list(db.messages.find(query).sort("time", 1)) # print(f"storage_check_message: {message_time}")
messages_models = list(query)
return [model_to_dict(msg) for msg in messages_models]
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$lt": time_point}} query = (
Messages.select()
messages = list(db.messages.find(query).sort("time", -1).limit(limit)) .where((Messages.chat_id == chat_id) & (Messages.time < time_point))
.order_by(Messages.time.desc())
.limit(limit)
)
messages_models = list(query)
# 将消息按时间正序排列 # 将消息按时间正序排列
messages.reverse() messages_models.reverse()
return messages return [model_to_dict(msg) for msg in messages_models]
async def has_new_messages(self, chat_id: str, after_time: float) -> bool: async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {"chat_id": chat_id, "time": {"$gt": after_time}} return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists()
return db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试 # # 创建一个内存消息存储实现,用于测试

View File

@@ -316,7 +316,7 @@ class GoalAnalyzer:
# message_segment = Seg(type="text", data=content) # message_segment = Seg(type="text", data=content)
# bot_user_info = UserInfo( # bot_user_info = UserInfo(
# user_id=global_config.BOT_QQ, # user_id=global_config.BOT_QQ,
# user_nickname=global_config.BOT_NICKNAME, # user_nickname=global_config.bot.nickname,
# platform=chat_stream.platform, # platform=chat_stream.platform,
# ) # )

101
src/plugins.md Normal file
View File

@@ -0,0 +1,101 @@
# 如何编写MaiBot插件
## 基本步骤
1.`src/plugins/你的插件名/actions/`目录下创建插件文件
2. 继承`PluginAction`基类
3. 实现`process`方法
## 插件结构示例
```python
from src.common.logger_manager import get_logger
from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
from typing import Tuple
logger = get_logger("your_action_name")
@register_action
class YourAction(PluginAction):
"""你的动作描述"""
action_name = "your_action_name" # 动作名称,必须唯一
action_description = "这个动作的详细描述,会展示给用户"
action_parameters = {
"param1": "参数1的说明可选",
"param2": "参数2的说明可选"
}
action_require = [
"使用场景1",
"使用场景2"
]
default = False # 是否默认启用
async def process(self) -> Tuple[bool, str]:
"""插件核心逻辑"""
# 你的代码逻辑...
return True, "执行结果"
```
## 可用的API方法
插件可以使用`PluginAction`基类提供的以下API
### 1. 发送消息
```python
await self.send_message("要发送的文本", target="可选的回复目标")
```
### 2. 获取聊天类型
```python
chat_type = self.get_chat_type() # 返回 "group" 或 "private" 或 "unknown"
```
### 3. 获取最近消息
```python
messages = self.get_recent_messages(count=5) # 获取最近5条消息
# 返回格式: [{"sender": "发送者", "content": "内容", "timestamp": 时间戳}, ...]
```
### 4. 获取动作参数
```python
param_value = self.action_data.get("param_name", "默认值")
```
### 5. 日志记录
```python
logger.info(f"{self.log_prefix} 你的日志信息")
logger.warning("警告信息")
logger.error("错误信息")
```
## 返回值说明
`process`方法必须返回一个元组,包含两个元素:
- 第一个元素(bool): 表示动作是否执行成功
- 第二个元素(str): 执行结果的文本描述
```python
return True, "执行成功的消息"
# 或
return False, "执行失败的原因"
```
## 最佳实践
1. 使用`action_parameters`清晰定义你的动作需要的参数
2. 使用`action_require`描述何时应该使用你的动作
3. 使用`action_description`准确描述你的动作功能
4. 使用`logger`记录重要信息,方便调试
5. 避免操作底层系统,尽量使用`PluginAction`提供的API
## 注册与加载
插件会在系统启动时自动加载,只要放在正确的目录并添加了`@register_action`装饰器。
若设置`default = True`,插件会自动添加到默认动作集;否则需要在系统中手动启用。

1
src/plugins/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""插件系统包"""

View File

@@ -0,0 +1,5 @@
"""测试插件包"""
"""
这是一个测试插件
"""

View File

@@ -0,0 +1,7 @@
"""测试插件动作模块"""
# 导入所有动作模块以确保装饰器被执行
from . import test_action # noqa
# from . import online_action # noqa
from . import mute_action # noqa

View File

@@ -0,0 +1,51 @@
from src.common.logger_manager import get_logger
from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
from typing import Tuple
logger = get_logger("mute_action")
@register_action
class MuteAction(PluginAction):
"""测试动作处理类"""
action_name = "mute_action"
action_description = (
"如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人"
)
action_parameters = {
"target": "禁言对象,输入你要禁言的对象的名字,必填,",
"duration": "禁言时长,输入你要禁言的时长,单位为秒,必填",
}
action_require = [
"当有人违反了公序良俗时使用",
"当有人刷屏时使用",
"当有人要求禁言自己时使用",
"当有人戳你两次以上时,防止刷屏,禁言他,必须牢记",
"当千石可乐或可乐酱要求你禁言时使用",
"当你想回避某个话题时使用",
]
default = True # 不是默认动作,需要手动添加到使用集
async def process(self) -> Tuple[bool, str]:
"""处理测试动作"""
logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}")
# 发送测试消息
target = self.action_data.get("target")
duration = self.action_data.get("duration")
reason = self.action_data.get("reason")
platform, user_id = await self.get_user_id_by_person_name(target)
await self.send_message_by_expressor(f"我要禁言{target}{platform},时长{duration}秒,理由{reason},表达情绪")
try:
await self.send_message(f"[command]mute,{user_id},{duration}")
except Exception as e:
logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}")
await self.send_message_by_expressor(f"执行mute动作时出错: {e}")
return False, "执行mute动作时出错"
return True, "测试动作执行成功"

View File

@@ -0,0 +1,43 @@
from src.common.logger_manager import get_logger
from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
from typing import Tuple
logger = get_logger("check_online_action")
@register_action
class CheckOnlineAction(PluginAction):
"""测试动作处理类"""
action_name = "check_online_action"
action_description = "这是一个检查在线状态的动作当有人要求你检查Maibot麦麦 机器人)在线状态时使用"
action_parameters = {"mode": "查看模式"}
action_require = [
"当有人要求你检查Maibot麦麦 机器人)在线状态时使用",
"mode参数为version时查看在线版本状态默认用这种",
"mode参数为type时查看在线系统类型分布",
]
default = True # 不是默认动作,需要手动添加到使用集
async def process(self) -> Tuple[bool, str]:
"""处理测试动作"""
logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}")
# 发送测试消息
mode = self.action_data.get("mode", "type")
await self.send_message_by_expressor("我看看")
try:
if mode == "type":
await self.send_message("#online detail")
elif mode == "version":
await self.send_message("#online")
except Exception as e:
logger.error(f"{self.log_prefix} 执行online动作时出错: {e}")
await self.send_message_by_expressor("执行online动作时出错: {e}")
return False, "执行online动作时出错"
return True, "测试动作执行成功"

View File

@@ -0,0 +1,37 @@
from src.common.logger_manager import get_logger
from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action
from typing import Tuple
logger = get_logger("test_action")
@register_action
class TestAction(PluginAction):
"""测试动作处理类"""
action_name = "test_action"
action_description = "这是一个测试动作,当有人要求你测试插件系统时使用"
action_parameters = {"test_param": "测试参数(可选)"}
action_require = [
"测试情况下使用",
"想测试插件动作加载时使用",
]
default = False # 不是默认动作,需要手动添加到使用集
async def process(self) -> Tuple[bool, str]:
"""处理测试动作"""
logger.info(f"{self.log_prefix} 执行测试动作: {self.reasoning}")
# 获取聊天类型
chat_type = self.get_chat_type()
logger.info(f"{self.log_prefix} 当前聊天类型: {chat_type}")
# 获取最近消息
recent_messages = self.get_recent_messages(3)
logger.info(f"{self.log_prefix} 最近3条消息: {recent_messages}")
# 发送测试消息
test_param = self.action_data.get("test_param", "默认参数")
await self.send_message_by_expressor(f"测试动作执行成功,参数: {test_param}")
return True, "测试动作执行成功"

View File

@@ -1,8 +1,10 @@
from src.tools.tool_can_use.base_tool import BaseTool from src.tools.tool_can_use.base_tool import BaseTool
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.common.database import db from src.common.database.database_model import Knowledges # Updated import
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from typing import Any, Union from typing import Any, Union, List # Added List
import json # Added for parsing embedding
import math # Added for cosine similarity
logger = get_logger("get_knowledge_tool") logger = get_logger("get_knowledge_tool")
@@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool):
Returns: Returns:
dict: 工具执行结果 dict: 工具执行结果
""" """
query = "" # Initialize query to ensure it's defined in except block
try: try:
query = function_args.get("query") query = function_args.get("query")
threshold = function_args.get("threshold", 0.4) threshold = function_args.get("threshold", 0.4)
@@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool):
logger.error(f"知识库搜索工具执行失败: {str(e)}") logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""计算两个向量之间的余弦相似度"""
dot_product = sum(p * q for p, q in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(p * p for p in vec1))
magnitude2 = math.sqrt(sum(q * q for q in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
@staticmethod @staticmethod
def get_info_from_db( def get_info_from_db(
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]: ) -> Union[str, list]:
"""从数据库中获取相关信息 """从数据库中获取相关信息
@@ -66,66 +79,51 @@ class SearchKnowledgeTool(BaseTool):
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 similar_items = []
pipeline = [ try:
{ all_knowledges = Knowledges.select()
"$addFields": { for item in all_knowledges:
"dotProduct": { try:
"$reduce": { item_embedding_str = item.embedding
"input": {"$range": [0, {"$size": "$embedding"}]}, if not item_embedding_str:
"initialValue": 0, logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
"in": { continue
"$add": [ item_embedding = json.loads(item_embedding_str)
"$$value", if not isinstance(item_embedding, list) or not all(
{ isinstance(x, (int, float)) for x in item_embedding
"$multiply": [ ):
{"$arrayElemAt": ["$embedding", "$$this"]}, logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
{"$arrayElemAt": [query_embedding, "$$this"]}, continue
] except json.JSONDecodeError:
}, logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
] continue
}, except AttributeError:
} logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
}, continue
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline)) similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
logger.debug(f"知识库查询结果数量: {len(results)}")
if similarity >= threshold:
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
# 按相似度降序排序
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
# 应用限制
results = similar_items[:limit]
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
except Exception as e:
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
return "" if not return_raw else []
if not results: if not results:
return "" if not return_raw else [] return "" if not return_raw else []
if return_raw: if return_raw:
return results # Peewee 模型实例不能直接序列化为 JSON如果需要原始模型调用者需要处理
# 这里返回包含内容和相似度的字典列表
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
else: else:
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results) return "\n".join(str(result["content"]) for result in results)

View File

@@ -18,11 +18,11 @@ nickname = "麦麦"
alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效 alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效
[chat_target] [chat_target]
talk_allowed = [ talk_allowed_groups = [
123, 123,
123, 123,
] #可以回复消息的群号码 ] #可以回复消息的群号码
talk_frequency_down = [] #降低回复频率的群号码 talk_frequency_down_groups = [] #降低回复频率的群号码
ban_user_id = [] #禁止回复和读取消息的QQ号 ban_user_id = [] #禁止回复和读取消息的QQ号
[personality] #未完善 [personality] #未完善
@@ -63,7 +63,7 @@ allow_focus_mode = false # 是否允许专注聊天状态
base_normal_chat_num = 999 # 最多允许多少个群进行普通聊天 base_normal_chat_num = 999 # 最多允许多少个群进行普通聊天
base_focused_chat_num = 4 # 最多允许多少个群进行专注聊天 base_focused_chat_num = 4 # 最多允许多少个群进行专注聊天
observation_context_size = 15 # 观察到的最长上下文大小,建议15太短太长都会导致脑袋尖尖 chat.observation_context_size = 15 # 观察到的最长上下文大小,建议15太短太长都会导致脑袋尖尖
message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟 message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟
# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息 # 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息
@@ -99,7 +99,7 @@ default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快,
consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天 consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天
# 以下选项暂时无效 # 以下选项暂时无效
compressed_length = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5 compressed_length = 5 # 不能大于chat.observation_context_size,心流上下文压缩的最短压缩长度超过心流观察到的上下文长度会压缩最短压缩长度为5
compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下文会被删除 compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下文会被删除