diff --git a/.gitignore b/.gitignore index df3ab670f..ac400b137 100644 --- a/.gitignore +++ b/.gitignore @@ -301,3 +301,5 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk +src/chat/focus_chat/working_memory/test/test1.txt +src/chat/focus_chat/working_memory/test/test4.txt diff --git a/requirements.txt b/requirements.txt index 7abdffb48..1e374f4eb 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/src/api/config_api.py b/src/api/config_api.py index 0b23fb993..8b99fb93e 100644 --- a/src/api/config_api.py +++ b/src/api/config_api.py @@ -41,7 +41,7 @@ class APIBotConfig: allow_focus_mode: bool # 是否允许专注聊天状态 base_normal_chat_num: int # 最多允许多少个群进行普通聊天 base_focused_chat_num: int # 最多允许多少个群进行专注聊天 - observation_context_size: int # 观察到的最长上下文大小 + chat.observation_context_size: int # 观察到的最长上下文大小 message_buffer: bool # 是否启用消息缓冲 ban_words: List[str] # 禁止词列表 ban_msgs_regex: List[str] # 禁止消息的正则表达式列表 @@ -128,7 +128,7 @@ class APIBotConfig: llm_reasoning: Dict[str, Any] # 推理模型配置 llm_normal: Dict[str, Any] # 普通模型配置 llm_topic_judge: Dict[str, Any] # 主题判断模型配置 - llm_summary: Dict[str, Any] # 总结模型配置 + model.summary: Dict[str, Any] # 总结模型配置 vlm: Dict[str, Any] # VLM模型配置 llm_heartflow: Dict[str, Any] # 心流模型配置 llm_observation: Dict[str, Any] # 观察模型配置 @@ -203,7 +203,7 @@ class APIBotConfig: "llm_reasoning", "llm_normal", "llm_topic_judge", - "llm_summary", + "model.summary", "vlm", "llm_heartflow", "llm_observation", diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 52a7288ec..fda0a63fd 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -5,12 +5,15 @@ import os import random import time import traceback -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Any from PIL import Image import io import re -from ...common.database import db +# from gradio_client import file + +from ...common.database.database_model import Emoji +from ...common.database.database import db as peewee_db from ...config.config import global_config from ..utils.utils_image import image_path_to_base64, image_manager from ..models.utils_model import LLMRequest @@ -51,7 +54,7 @@ class MaiEmoji: self.is_deleted = False # 标记是否已被删除 self.format = "" - async def initialize_hash_format(self): + async def initialize_hash_format(self) -> Optional[bool]: """从文件创建表情包实例, 计算哈希值和格式""" try: # 使用 full_path 检查文件是否存在 @@ -104,7 +107,7 @@ class MaiEmoji: self.is_deleted = True return None - async def register_to_db(self): + async def register_to_db(self) -> bool: """ 注册表情包 将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下 @@ -143,22 +146,22 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - emoji_record = { - "filename": self.filename, - "path": self.path, # 存储目录路径 - "full_path": self.full_path, # 存储完整文件路径 - "embedding": self.embedding, - "description": self.description, - "emotion": self.emotion, - "hash": self.hash, - "format": self.format, - "timestamp": int(self.register_time), - "usage_count": self.usage_count, - "last_used_time": self.last_used_time, - } + emotion_str = ",".join(self.emotion) if self.emotion else "" - # 使用upsert确保记录存在或被更新 - db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True) + Emoji.create( + hash=self.hash, + full_path=self.full_path, + format=self.format, + description=self.description, + emotion=emotion_str, # Store as comma-separated string + query_count=0, # Default value + is_registered=True, + is_banned=False, # Default value + record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time + register_time=self.register_time, + usage_count=self.usage_count, + last_used_time=self.last_used_time, + ) logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") @@ -166,14 +169,6 @@ class MaiEmoji: except Exception as db_error: logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") - # 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误 - # 可以考虑在这里尝试删除已移动的文件,避免残留 - try: - if os.path.exists(self.full_path): # full_path 此时是目标路径 - os.remove(self.full_path) - logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}") - except Exception as remove_error: - logger.error(f"[错误] 回滚删除文件失败: {remove_error}") return False except Exception as e: @@ -181,7 +176,7 @@ class MaiEmoji: logger.error(traceback.format_exc()) return False - async def delete(self): + async def delete(self) -> bool: """删除表情包 删除表情包的文件和数据库记录 @@ -201,10 +196,14 @@ class MaiEmoji: # 文件删除失败,但仍然尝试删除数据库记录 # 2. 删除数据库记录 - result = db.emoji.delete_one({"hash": self.hash}) - deleted_in_db = result.deleted_count > 0 + try: + will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash) + result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. + except Emoji.DoesNotExist: + logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") + result = 0 # Indicate no DB record was deleted - if deleted_in_db: + if result > 0: logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") # 3. 标记对象已被删除 self.is_deleted = True @@ -224,7 +223,7 @@ class MaiEmoji: return False -def _emoji_objects_to_readable_list(emoji_objects): +def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]: """将表情包对象列表转换为可读的字符串列表 参数: @@ -243,47 +242,48 @@ def _emoji_objects_to_readable_list(emoji_objects): return emoji_info_list -def _to_emoji_objects(data): +def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: emoji_objects = [] load_errors = 0 + # data is now an iterable of Peewee Emoji model instances emoji_data_list = list(data) - for emoji_data in emoji_data_list: - full_path = emoji_data.get("full_path") + for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance + full_path = emoji_data.full_path if not full_path: - logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}") + logger.warning( + f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}" + ) load_errors += 1 - continue # 跳过缺少 full_path 的记录 + continue try: - # 使用 full_path 初始化 MaiEmoji 对象 emoji = MaiEmoji(full_path=full_path) - # 设置从数据库加载的属性 - emoji.hash = emoji_data.get("hash", "") - # 如果 hash 为空,也跳过?取决于业务逻辑 + emoji.hash = emoji_data.emoji_hash if not emoji.hash: logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") load_errors += 1 continue - emoji.description = emoji_data.get("description", "") - emoji.emotion = emoji_data.get("emotion", []) - emoji.usage_count = emoji_data.get("usage_count", 0) - # 优先使用 last_used_time,否则用 timestamp,最后用当前时间 - last_used = emoji_data.get("last_used_time") - timestamp = emoji_data.get("timestamp") - emoji.last_used_time = ( - last_used if last_used is not None else (timestamp if timestamp is not None else time.time()) - ) - emoji.register_time = timestamp if timestamp is not None else time.time() - emoji.format = emoji_data.get("format", "") # 加载格式 + emoji.description = emoji_data.description + # Deserialize emotion string from DB to list + emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else [] + emoji.usage_count = emoji_data.usage_count - # 不需要再手动设置 path 和 filename,__init__ 会自动处理 + db_last_used_time = emoji_data.last_used_time + db_register_time = emoji_data.register_time + + # If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time + emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time + # If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time()) + emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time + + emoji.format = emoji_data.format emoji_objects.append(emoji) - except ValueError as ve: # 捕获 __init__ 可能的错误 + except ValueError as ve: logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") load_errors += 1 except Exception as e: @@ -292,13 +292,13 @@ def _to_emoji_objects(data): return emoji_objects, load_errors -def _ensure_emoji_dir(): +def _ensure_emoji_dir() -> None: """确保表情存储目录存在""" os.makedirs(EMOJI_DIR, exist_ok=True) os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True) -async def clear_temp_emoji(): +async def clear_temp_emoji() -> None: """清理临时表情包 清理/data/emoji和/data/image目录下的所有文件 当目录中文件数超过100时,会全部删除 @@ -320,7 +320,7 @@ async def clear_temp_emoji(): logger.success("[清理] 完成") -async def clean_unused_emojis(emoji_dir, emoji_objects): +async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None: """清理指定目录中未被 emoji_objects 追踪的表情包文件""" if not os.path.exists(emoji_dir): logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") @@ -360,13 +360,13 @@ async def clean_unused_emojis(emoji_dir, emoji_objects): class EmojiManager: _instance = None - def __new__(cls): + def __new__(cls) -> "EmojiManager": if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - def __init__(self): + def __init__(self) -> None: self._initialized = None self._scan_task = None @@ -382,53 +382,30 @@ class EmojiManager: logger.info("启动表情包管理器") - def initialize(self): + def initialize(self) -> None: """初始化数据库连接和表情目录""" - if not self._initialized: - try: - self._ensure_emoji_collection() - _ensure_emoji_dir() - self._initialized = True - # 更新表情包数量 - # 启动时执行一次完整性检查 - # await self.check_emoji_file_integrity() - except Exception as e: - logger.exception(f"初始化表情管理器失败: {e}") + peewee_db.connect(reuse_if_open=True) + if peewee_db.is_closed(): + raise RuntimeError("数据库连接失败") + _ensure_emoji_dir() + Emoji.create_table(safe=True) # Ensures table exists - def _ensure_db(self): + def _ensure_db(self) -> None: """确保数据库已初始化""" if not self._initialized: self.initialize() if not self._initialized: raise RuntimeError("EmojiManager not initialized") - @staticmethod - def _ensure_emoji_collection(): - """确保emoji集合存在并创建索引 - - 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。 - - 索引的作用是加快数据库查询速度: - - embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包 - - tags字段的普通索引: 加快按标签搜索表情包的速度 - - filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度 - - 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 - """ - if "emoji" not in db.list_collection_names(): - db.create_collection("emoji") - db.emoji.create_index([("embedding", "2dsphere")]) - db.emoji.create_index([("filename", 1)], unique=True) - - def record_usage(self, emoji_hash: str): + def record_usage(self, emoji_hash: str) -> None: """记录表情使用次数""" try: - db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}}) - for emoji in self.emoji_objects: - if emoji.hash == emoji_hash: - emoji.usage_count += 1 - break - + emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash) + emoji_update.usage_count += 1 + emoji_update.last_used_time = time.time() # Update last used time + emoji_update.save() # Persist changes to DB + except Emoji.DoesNotExist: + logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -448,7 +425,6 @@ class EmojiManager: if not all_emojis: logger.warning("内存中没有任何表情包对象") - # 可以考虑再查一次数据库?或者依赖定期任务更新 return None # 计算每个表情包与输入文本的最大情感相似度 @@ -464,40 +440,38 @@ class EmojiManager: # 计算与每个emotion标签的相似度,取最大值 max_similarity = 0 - best_matching_emotion = "" # 记录最匹配的 emotion 喵~ + best_matching_emotion = "" for emotion in emotions: # 使用编辑距离计算相似度 distance = self._levenshtein_distance(text_emotion, emotion) max_len = max(len(text_emotion), len(emotion)) similarity = 1 - (distance / max_len if max_len > 0 else 0) - if similarity > max_similarity: # 如果找到更相似的喵~ + if similarity > max_similarity: max_similarity = similarity - best_matching_emotion = emotion # 就记下这个 emotion 喵~ + best_matching_emotion = emotion - if best_matching_emotion: # 确保有匹配的情感才添加喵~ - emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~ + if best_matching_emotion: + emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 按相似度降序排序 emoji_similarities.sort(key=lambda x: x[1], reverse=True) # 获取前10个最相似的表情包 - top_emojis = ( - emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities - ) # 改个名字,更清晰喵~ + top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities if not top_emojis: logger.warning("未找到匹配的表情包") return None # 从前几个中随机选择一个 - selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~ + selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 更新使用次数 - self.record_usage(selected_emoji.hash) + self.record_usage(selected_emoji.emoji_hash) _time_end = time.time() - logger.info( # 使用匹配到的 emotion 记录日志喵~ + logger.info( f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}" ) # 返回完整文件路径和描述 @@ -535,7 +509,7 @@ class EmojiManager: return previous_row[-1] - async def check_emoji_file_integrity(self): + async def check_emoji_file_integrity(self) -> None: """检查表情包文件完整性 遍历self.emoji_objects中的所有对象,检查文件是否存在 如果文件已被删除,则执行对象的删除方法并从列表中移除 @@ -600,7 +574,7 @@ class EmojiManager: logger.error(f"[错误] 检查表情包完整性失败: {str(e)}") logger.error(traceback.format_exc()) - async def start_periodic_check_register(self): + async def start_periodic_check_register(self) -> None: """定期检查表情包完整性和数量""" await self.get_all_emoji_from_db() while True: @@ -654,13 +628,14 @@ class EmojiManager: await asyncio.sleep(global_config.emoji.check_interval * 60) - async def get_all_emoji_from_db(self): + async def get_all_emoji_from_db(self) -> None: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: self._ensure_db() - logger.info("[数据库] 开始加载所有表情包记录...") + logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...") - emoji_objects, load_errors = _to_emoji_objects(db.emoji.find()) + emoji_peewee_instances = Emoji.select() + emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) # 更新内存中的列表和数量 self.emoji_objects = emoji_objects @@ -675,7 +650,7 @@ class EmojiManager: self.emoji_objects = [] # 加载失败则清空列表 self.emoji_num = 0 - async def get_emoji_from_db(self, emoji_hash=None): + async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) 参数: @@ -687,15 +662,16 @@ class EmojiManager: try: self._ensure_db() - query = {} if emoji_hash: - query = {"hash": emoji_hash} + query = Emoji.select().where(Emoji.emoji_hash == emoji_hash) else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) + query = Emoji.select() - emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query)) + emoji_peewee_instances = query + emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) if load_errors > 0: logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") @@ -706,7 +682,7 @@ class EmojiManager: logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") return [] - async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]: + async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: """从内存中的 emoji_objects 列表获取表情包 参数: @@ -759,7 +735,7 @@ class EmojiManager: logger.error(traceback.format_exc()) return False - async def replace_a_emoji(self, new_emoji: MaiEmoji): + async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool: """替换一个表情包 Args: @@ -820,7 +796,7 @@ class EmojiManager: # 删除选定的表情包 logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}") - delete_success = await self.delete_emoji(emoji_to_delete.hash) + delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash) if delete_success: # 修复:等待异步注册完成 @@ -848,7 +824,7 @@ class EmojiManager: logger.error(traceback.format_exc()) return False - async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]: + async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]: """获取表情包描述和情感列表 Args: diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index c5aa5f9a4..d3d21e074 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -10,7 +10,6 @@ from src.config.config import global_config from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.emoji_system.emoji_manager import emoji_manager -from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder from src.chat.focus_chat.heartFC_sender import HeartFCSender from src.chat.utils.utils import process_llm_response from src.chat.utils.info_catcher import info_catcher_manager @@ -18,10 +17,62 @@ from src.manager.mood_manager import mood_manager from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.message_receive.chat_stream import ChatStream from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp +from src.individuality.individuality import Individuality +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat +import time +from src.chat.focus_chat.expressors.exprssion_learner import expression_learner +import random logger = get_logger("expressor") +def init_prompt(): + Prompt( + """ +你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: +{style_habbits} + +你现在正在群里聊天,以下是群里正在进行的聊天内容: +{chat_info} + +以上是聊天内容,你需要了解聊天记录中的内容 + +{chat_target} +你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复 +你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。 +请你根据情景使用以下句法: +{grammar_habbits} +回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。 +回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。 +现在,你说: +""", + "default_expressor_prompt", + ) + + Prompt( + """ +你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: +{style_habbits} + +你现在正在群里聊天,以下是群里正在进行的聊天内容: +{chat_info} + +以上是聊天内容,你需要了解聊天记录中的内容 + +{chat_target} +你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复 +你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。 +请你根据情景使用以下句法: +{grammar_habbits} +回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。 +回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。 +现在,你说: +""", + "default_expressor_private_prompt", # New template for private FOCUSED chat + ) + + class DefaultExpressor: def __init__(self, chat_id: str): self.log_prefix = "expressor" @@ -67,7 +118,7 @@ class DefaultExpressor: reply=anchor_message, # 回复的是锚点消息 thinking_start_time=thinking_time_point, ) - logger.debug(f"创建思考消息thinking_message:{thinking_message}") + # logger.debug(f"创建思考消息thinking_message:{thinking_message}") await self.heart_fc_sender.register_thinking(thinking_message) @@ -107,7 +158,7 @@ class DefaultExpressor: if reply: with Timer("发送消息", cycle_timers): - sent_msg_list = await self._send_response_messages( + sent_msg_list = await self.send_response_messages( anchor_message=anchor_message, thinking_id=thinking_id, response_set=reply, @@ -163,13 +214,10 @@ class DefaultExpressor: # 3. 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 - prompt = await prompt_builder.build_prompt( - build_mode="focus", + prompt = await self.build_prompt_focus( chat_stream=self.chat_stream, # Pass the stream object in_mind_reply=in_mind_reply, reason=reason, - current_mind_info="", - structured_info="", sender_name=sender_name_for_prompt, # Pass determined name target_message=target_message, ) @@ -188,7 +236,7 @@ class DefaultExpressor: # logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n") content, reasoning_content, model_name = await self.express_model.generate_response(prompt) - logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n") + # logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n") logger.info(f"想要表达:{in_mind_reply}") logger.info(f"理由:{reason}") @@ -225,10 +273,108 @@ class DefaultExpressor: traceback.print_exc() return None + async def build_prompt_focus( + self, + reason, + chat_stream, + sender_name, + in_mind_reply, + target_message, + ) -> str: + individuality = Individuality.get_instance() + prompt_personality = individuality.get_prompt(x_person=0, level=2) + + # Determine if it's a group chat + is_group_chat = bool(chat_stream.group_info) + + # Use sender_name passed from caller for private chat, otherwise use a default for group + # Default sender_name for group chat isn't used in the group prompt template, but set for consistency + effective_sender_name = sender_name if not is_group_chat else "某人" + + message_list_before_now = get_raw_msg_before_timestamp_with_chat( + chat_id=chat_stream.stream_id, + timestamp=time.time(), + limit=global_config.chat.observation_context_size, + ) + chat_talking_prompt = await build_readable_messages( + message_list_before_now, + replace_bot_name=True, + merge_messages=True, + timestamp_mode="relative", + read_mark=0.0, + truncate=True, + ) + + ( + learnt_style_expressions, + learnt_grammar_expressions, + personality_expressions, + ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id) + + style_habbits = [] + grammar_habbits = [] + # 1. learnt_expressions加权随机选3条 + if learnt_style_expressions: + weights = [expr["count"] for expr in learnt_style_expressions] + selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3) + for expr in selected_learnt: + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + # 2. learnt_grammar_expressions加权随机选3条 + if learnt_grammar_expressions: + weights = [expr["count"] for expr in learnt_grammar_expressions] + selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3) + for expr in selected_learnt: + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + # 3. personality_expressions随机选1条 + if personality_expressions: + expr = random.choice(personality_expressions) + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + + style_habbits_str = "\n".join(style_habbits) + grammar_habbits_str = "\n".join(grammar_habbits) + + logger.debug("开始构建 focus prompt") + + # --- Choose template based on chat type --- + if is_group_chat: + template_name = "default_expressor_prompt" + # Group specific formatting variables (already fetched or default) + chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") + # chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") + + prompt = await global_prompt_manager.format_prompt( + template_name, + style_habbits=style_habbits_str, + grammar_habbits=grammar_habbits_str, + chat_target=chat_target_1, + chat_info=chat_talking_prompt, + bot_name=global_config.bot.nickname, + prompt_personality="", + reason=reason, + in_mind_reply=in_mind_reply, + target_message=target_message, + ) + else: # Private chat + template_name = "default_expressor_private_prompt" + prompt = await global_prompt_manager.format_prompt( + template_name, + sender_name=effective_sender_name, # Used in private template + chat_talking_prompt=chat_talking_prompt, + bot_name=global_config.bot.nickname, + prompt_personality=prompt_personality, + reason=reason, + moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), + ) + + return prompt + # --- 发送器 (Sender) --- # - async def _send_response_messages( - self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str + async def send_response_messages( + self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str = "" ) -> Optional[MessageSending]: """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" chat = self.chat_stream @@ -243,7 +389,11 @@ class DefaultExpressor: stream_name = chat_manager.get_stream_name(chat_id) or chat_id # 获取流名称用于日志 # 检查思考过程是否仍在进行,并获取开始时间 - thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id) + if thinking_id: + thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id) + else: + thinking_id = "ds" + str(round(time.time(), 2)) + thinking_start_time = time.time() if thinking_start_time is None: logger.error(f"[{stream_name}]思考过程未找到或已结束,无法发送回复。") @@ -276,6 +426,7 @@ class DefaultExpressor: reply_to=reply_to, is_emoji=is_emoji, thinking_id=thinking_id, + thinking_start_time=thinking_start_time, ) try: @@ -297,6 +448,7 @@ class DefaultExpressor: except Exception as e: logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}") + traceback.print_exc() # 这里可以选择是继续发送下一个片段还是中止 # 在尝试发送完所有片段后,完成原始的 thinking_id 状态 @@ -327,10 +479,10 @@ class DefaultExpressor: reply_to: bool, is_emoji: bool, thinking_id: str, + thinking_start_time: float, ) -> MessageSending: """构建单个发送消息""" - thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(self.chat_id, thinking_id) bot_user_info = UserInfo( user_id=global_config.bot.qq_account, user_nickname=global_config.bot.nickname, @@ -350,3 +502,40 @@ class DefaultExpressor: ) return bot_message + + +def weighted_sample_no_replacement(items, weights, k) -> list: + """ + 加权且不放回地随机抽取k个元素。 + + 参数: + items: 待抽取的元素列表 + weights: 每个元素对应的权重(与items等长,且为正数) + k: 需要抽取的元素个数 + 返回: + selected: 按权重加权且不重复抽取的k个元素组成的列表 + + 如果 items 中的元素不足 k 个,就只会返回所有可用的元素 + + 实现思路: + 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。 + 这样保证了: + 1. count越大被选中概率越高 + 2. 不会重复选中同一个元素 + """ + selected = [] + pool = list(zip(items, weights)) + for _ in range(min(k, len(pool))): + total = sum(w for _, w in pool) + r = random.uniform(0, total) + upto = 0 + for idx, (item, weight) in enumerate(pool): + upto += weight + if upto >= r: + selected.append(item) + pool.pop(idx) + break + return selected + + +init_prompt() diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index ff4f7fdb0..0f5371a36 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -14,15 +14,17 @@ from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info_processors.chattinginfo_processor import ChattingInfoProcessor from src.chat.focus_chat.info_processors.mind_processor import MindProcessor -from src.chat.heart_flow.observation.memory_observation import MemoryObservation +from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation -from src.chat.heart_flow.observation.working_observation import WorkingObservation +from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.focus_chat.memory_activator import MemoryActivator from src.chat.focus_chat.info_processors.base_processor import BaseProcessor +from src.chat.focus_chat.info_processors.self_processor import SelfProcessor from src.chat.focus_chat.planners.planner import ActionPlanner -from src.chat.focus_chat.planners.action_factory import ActionManager +from src.chat.focus_chat.planners.action_manager import ActionManager +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory install(extra_lines=3) @@ -57,7 +59,7 @@ async def _handle_cycle_delay(action_taken_this_cycle: bool, cycle_start_time: f class HeartFChatting: """ - 管理一个连续的Plan-Replier-Sender循环 + 管理一个连续的Focus Chat循环 用于在特定聊天流中生成回复。 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。 """ @@ -79,18 +81,24 @@ class HeartFChatting: # 基础属性 self.stream_id: str = chat_id # 聊天流ID self.chat_stream: Optional[ChatStream] = None # 关联的聊天流 - self.observations: List[Observation] = observations # 关联的观察列表,用于监控聊天流状态 self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback self.log_prefix: str = str(chat_id) # Initial default, will be updated - - self.memory_observation = MemoryObservation(observe_id=self.stream_id) self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id) - self.working_observation = WorkingObservation(observe_id=self.stream_id) + self.chatting_observation = observations[0] + self.memory_activator = MemoryActivator() + self.working_memory = WorkingMemory(chat_id=self.stream_id) + self.working_observation = WorkingMemoryObservation( + observe_id=self.stream_id, working_memory=self.working_memory + ) + self.expressor = DefaultExpressor(chat_id=self.stream_id) self.action_manager = ActionManager() self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager) + self.hfcloop_observation.set_action_manager(self.action_manager) + + self.all_observations = observations # --- 处理器列表 --- self.processors: List[BaseProcessor] = [] self._register_default_processors() @@ -107,9 +115,7 @@ class HeartFChatting: self._cycle_counter = 0 self._cycle_history: Deque[CycleDetail] = deque(maxlen=10) # 保留最近10个循环的信息 self._current_cycle: Optional[CycleDetail] = None - self.total_no_reply_count: int = 0 # 连续不回复计数器 self._shutting_down: bool = False # 关闭标志位 - self.total_waiting_time: float = 0.0 # 累计等待时间 async def _initialize(self) -> bool: """ @@ -150,6 +156,8 @@ class HeartFChatting: self.processors.append(ChattingInfoProcessor()) self.processors.append(MindProcessor(subheartflow_id=self.stream_id)) self.processors.append(ToolProcessor(subheartflow_id=self.stream_id)) + self.processors.append(WorkingMemoryProcessor(subheartflow_id=self.stream_id)) + self.processors.append(SelfProcessor(subheartflow_id=self.stream_id)) logger.info(f"{self.log_prefix} 已注册默认处理器: {[p.__class__.__name__ for p in self.processors]}") async def start(self): @@ -327,6 +335,7 @@ class HeartFChatting: f"{self.log_prefix} 处理器 {processor_name} 执行失败,耗时 (自并行开始): {duration_since_parallel_start:.2f}秒. 错误: {e}", exc_info=True, ) + traceback.print_exc() # 即使出错,也认为该任务结束了,已从 pending_tasks 中移除 if pending_tasks: @@ -348,13 +357,12 @@ class HeartFChatting: async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> tuple[bool, str]: try: with Timer("观察", cycle_timers): - await self.observations[0].observe() - await self.memory_observation.observe() + # await self.observations[0].observe() + await self.chatting_observation.observe() await self.working_observation.observe() await self.hfcloop_observation.observe() observations: List[Observation] = [] - observations.append(self.observations[0]) - observations.append(self.memory_observation) + observations.append(self.chatting_observation) observations.append(self.working_observation) observations.append(self.hfcloop_observation) @@ -362,6 +370,8 @@ class HeartFChatting: "observations": observations, } + self.all_observations = observations + with Timer("回忆", cycle_timers): running_memorys = await self.memory_activator.activate_memory(observations) @@ -394,8 +404,7 @@ class HeartFChatting: elif action_type == "no_reply": action_str = "不回复" else: - action_type = "unknown" - action_str = "未知动作" + action_str = action_type logger.info(f"{self.log_prefix} 麦麦决定'{action_str}', 原因'{reasoning}'") @@ -451,14 +460,14 @@ class HeartFChatting: reasoning=reasoning, cycle_timers=cycle_timers, thinking_id=thinking_id, - observations=self.observations, + observations=self.all_observations, expressor=self.expressor, chat_stream=self.chat_stream, current_cycle=self._current_cycle, log_prefix=self.log_prefix, on_consecutive_no_reply_callback=self.on_consecutive_no_reply_callback, - total_no_reply_count=self.total_no_reply_count, - total_waiting_time=self.total_waiting_time, + # total_no_reply_count=self.total_no_reply_count, + # total_waiting_time=self.total_waiting_time, shutting_down=self._shutting_down, ) @@ -469,14 +478,6 @@ class HeartFChatting: # 处理动作并获取结果 success, reply_text = await action_handler.handle_action() - # 更新状态计数器 - if action == "no_reply": - self.total_no_reply_count = getattr(action_handler, "total_no_reply_count", self.total_no_reply_count) - self.total_waiting_time = getattr(action_handler, "total_waiting_time", self.total_waiting_time) - elif action == "reply": - self.total_no_reply_count = 0 - self.total_waiting_time = 0.0 - return success, reply_text except Exception as e: diff --git a/src/chat/focus_chat/heartFC_sender.py b/src/chat/focus_chat/heartFC_sender.py index 057668579..81d463b02 100644 --- a/src/chat/focus_chat/heartFC_sender.py +++ b/src/chat/focus_chat/heartFC_sender.py @@ -106,6 +106,7 @@ class HeartFCSender: and not message.is_private_message() and message.reply.processed_plain_text != "[System Trigger Context]" ): + message.set_reply(message.reply) logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...") await message.process() diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index fae00a9db..d8d2b836f 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -7,41 +7,20 @@ from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time from typing import Union, Optional -from src.common.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager -from src.chat.focus_chat.expressors.exprssion_learner import expression_learner import random +import json +import math +from src.common.database.database_model import Knowledges logger = get_logger("prompt") def init_prompt(): - Prompt( - """ -你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: -{style_habbits} - -你现在正在群里聊天,以下是群里正在进行的聊天内容: -{chat_info} - -以上是聊天内容,你需要了解聊天记录中的内容 - -{chat_target} -你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复 -你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。 -请你根据情景使用以下句法: -{grammar_habbits} -回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。 -回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。 -现在,你说: -""", - "heart_flow_prompt", - ) - Prompt( """ 你有以下信息可供参考: @@ -68,7 +47,7 @@ def init_prompt(): 你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1}, 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger} 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 -请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 +请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。 {moderation_prompt} 不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""", "reasoning_prompt_main", @@ -81,29 +60,6 @@ def init_prompt(): Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") - # --- Template for HeartFChatting (FOCUSED mode) --- - Prompt( - """ -{info_from_tools} -你正在和 {sender_name} 私聊。 -聊天记录如下: -{chat_talking_prompt} -现在你想要回复。 - -你需要扮演一位网名叫{bot_name}的人进行回复,这个人的特点是:"{prompt_personality}"。 -你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,然后给出日常且口语化的回复,平淡一些。 -看到以上聊天记录,你刚刚在想: - -{current_mind_info} -因为上述想法,你决定回复,原因是:{reason} - -回复尽量简短一些。请注意把握聊天内容,{reply_style2}。{prompt_ger},不要复读自己说的话 -{reply_style1},说中文,不要刻意突出自身学科背景,注意只输出回复内容。 -{moderation_prompt}。注意:回复不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""", - "heart_flow_private_prompt", # New template for private FOCUSED chat - ) - - # --- Template for NormalChat (CHAT mode) --- Prompt( """ {memory_prompt} @@ -125,118 +81,6 @@ def init_prompt(): ) -async def _build_prompt_focus( - reason, current_mind_info, structured_info, chat_stream, sender_name, in_mind_reply, target_message -) -> str: - individuality = Individuality.get_instance() - prompt_personality = individuality.get_prompt(x_person=0, level=2) - - # Determine if it's a group chat - is_group_chat = bool(chat_stream.group_info) - - # Use sender_name passed from caller for private chat, otherwise use a default for group - # Default sender_name for group chat isn't used in the group prompt template, but set for consistency - effective_sender_name = sender_name if not is_group_chat else "某人" - - message_list_before_now = get_raw_msg_before_timestamp_with_chat( - chat_id=chat_stream.stream_id, - timestamp=time.time(), - limit=global_config.chat.observation_context_size, - ) - chat_talking_prompt = await build_readable_messages( - message_list_before_now, - replace_bot_name=True, - merge_messages=True, - timestamp_mode="relative", - read_mark=0.0, - truncate=True, - ) - - if structured_info: - structured_info_prompt = await global_prompt_manager.format_prompt( - "info_from_tools", structured_info=structured_info - ) - else: - structured_info_prompt = "" - - # 从/data/expression/对应chat_id/expressions.json中读取表达方式 - ( - learnt_style_expressions, - learnt_grammar_expressions, - personality_expressions, - ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id) - - style_habbits = [] - grammar_habbits = [] - # 1. learnt_expressions加权随机选3条 - if learnt_style_expressions: - weights = [expr["count"] for expr in learnt_style_expressions] - selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3) - for expr in selected_learnt: - if isinstance(expr, dict) and "situation" in expr and "style" in expr: - style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") - # 2. learnt_grammar_expressions加权随机选3条 - if learnt_grammar_expressions: - weights = [expr["count"] for expr in learnt_grammar_expressions] - selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3) - for expr in selected_learnt: - if isinstance(expr, dict) and "situation" in expr and "style" in expr: - grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") - # 3. personality_expressions随机选1条 - if personality_expressions: - expr = random.choice(personality_expressions) - if isinstance(expr, dict) and "situation" in expr and "style" in expr: - style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") - - style_habbits_str = "\n".join(style_habbits) - grammar_habbits_str = "\n".join(grammar_habbits) - - logger.debug("开始构建 focus prompt") - - # --- Choose template based on chat type --- - if is_group_chat: - template_name = "heart_flow_prompt" - # Group specific formatting variables (already fetched or default) - chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") - # chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") - - prompt = await global_prompt_manager.format_prompt( - template_name, - # info_from_tools=structured_info_prompt, - style_habbits=style_habbits_str, - grammar_habbits=grammar_habbits_str, - chat_target=chat_target_1, # Used in group template - # chat_talking_prompt=chat_talking_prompt, - chat_info=chat_talking_prompt, - bot_name=global_config.bot.nickname, - # prompt_personality=prompt_personality, - prompt_personality="", - reason=reason, - in_mind_reply=in_mind_reply, - target_message=target_message, - # moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), - # sender_name is not used in the group template - ) - else: # Private chat - template_name = "heart_flow_private_prompt" - prompt = await global_prompt_manager.format_prompt( - template_name, - info_from_tools=structured_info_prompt, - sender_name=effective_sender_name, # Used in private template - chat_talking_prompt=chat_talking_prompt, - bot_name=global_config.bot.nickname, - prompt_personality=prompt_personality, - # chat_target and chat_target_2 are not used in private template - current_mind_info=current_mind_info, - reason=reason, - moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), - ) - # --- End choosing template --- - - # logger.debug(f"focus_chat_prompt (is_group={is_group_chat}): \n{prompt}") - return prompt - - class PromptBuilder: def __init__(self): self.prompt_built = "" @@ -256,17 +100,6 @@ class PromptBuilder: ) -> Optional[str]: if build_mode == "normal": return await self._build_prompt_normal(chat_stream, message_txt or "", sender_name) - - elif build_mode == "focus": - return await _build_prompt_focus( - reason, - current_mind_info, - structured_info, - chat_stream, - sender_name, - in_mind_reply, - target_message, - ) return None async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> str: @@ -435,30 +268,6 @@ class PromptBuilder: logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") # 1. 先从LLM获取主题,类似于记忆系统的做法 topics = [] - # try: - # # 先尝试使用记忆系统的方法获取主题 - # hippocampus = HippocampusManager.get_instance()._hippocampus - # topic_num = min(5, max(1, int(len(message) * 0.1))) - # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) - - # # 提取关键词 - # topics = re.findall(r"<([^>]+)>", topics_response[0]) - # if not topics: - # topics = [] - # else: - # topics = [ - # topic.strip() - # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - # if topic.strip() - # ] - - # logger.info(f"从LLM提取的主题: {', '.join(topics)}") - # except Exception as e: - # logger.error(f"从LLM提取主题失败: {str(e)}") - # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 - # words = jieba.cut(message) - # topics = [word for word in words if len(word) > 1][:5] - # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") # 如果无法提取到主题,直接使用整个消息 if not topics: @@ -568,8 +377,6 @@ class PromptBuilder: for _i, result in enumerate(results, 1): _similarity = result["similarity"] content = result["content"].strip() - # 调试:为内容添加序号和相似度信息 - # related_info += f"{i}. [{similarity:.2f}] {content}\n" related_info += f"{content}\n" related_info += "\n" @@ -598,14 +405,14 @@ class PromptBuilder: return related_info else: logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索") - knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) + knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) related_info += knowledge_from_old logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") return related_info except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") try: - knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) + knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) related_info += knowledge_from_old logger.debug( f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}" @@ -621,104 +428,70 @@ class PromptBuilder: ) -> Union[str, list]: if not query_embedding: return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") + results_with_similarity = [] + try: + # Fetch all knowledge entries + # This might be inefficient for very large databases. + # Consider strategies like FAISS or other vector search libraries if performance becomes an issue. + all_knowledges = Knowledges.select() - if not results: + if not all_knowledges: + return [] if return_raw else "" + + query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding)) + if query_embedding_magnitude == 0: # Avoid division by zero + return "" if not return_raw else [] + + for knowledge_item in all_knowledges: + try: + db_embedding_str = knowledge_item.embedding + db_embedding = json.loads(db_embedding_str) + + if len(db_embedding) != len(query_embedding): + logger.warning( + f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping." + ) + continue + + # Calculate Cosine Similarity + dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding)) + db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding)) + + if db_embedding_magnitude == 0: # Avoid division by zero + similarity = 0.0 + else: + similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude) + + if similarity >= threshold: + results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity}) + except json.JSONDecodeError: + logger.error( + f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}" + ) + except Exception as e: + logger.error(f"Error processing knowledge item: {e}") + + # Sort by similarity in descending order + results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True) + + # Limit results + limited_results = results_with_similarity[:limit] + + logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}") + + if not limited_results: + return "" if not return_raw else [] + + if return_raw: + return limited_results + else: + return "\n".join(str(result["content"]) for result in limited_results) + + except Exception as e: + logger.error(f"Error querying Knowledges with Peewee: {e}") return "" if not return_raw else [] - if return_raw: - return results - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - - -def weighted_sample_no_replacement(items, weights, k) -> list: - """ - 加权且不放回地随机抽取k个元素。 - - 参数: - items: 待抽取的元素列表 - weights: 每个元素对应的权重(与items等长,且为正数) - k: 需要抽取的元素个数 - 返回: - selected: 按权重加权且不重复抽取的k个元素组成的列表 - - 如果 items 中的元素不足 k 个,就只会返回所有可用的元素 - - 实现思路: - 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。 - 这样保证了: - 1. count越大被选中概率越高 - 2. 不会重复选中同一个元素 - """ - selected = [] - pool = list(zip(items, weights)) - for _ in range(min(k, len(pool))): - total = sum(w for _, w in pool) - r = random.uniform(0, total) - upto = 0 - for idx, (item, weight) in enumerate(pool): - upto += weight - if upto >= r: - selected.append(item) - pool.pop(idx) - break - return selected - init_prompt() prompt_builder = PromptBuilder() diff --git a/src/chat/focus_chat/info/info_base.py b/src/chat/focus_chat/info/info_base.py index 7779d913a..53ad30230 100644 --- a/src/chat/focus_chat/info/info_base.py +++ b/src/chat/focus_chat/info/info_base.py @@ -17,6 +17,7 @@ class InfoBase: type: str = "base" data: Dict[str, Any] = field(default_factory=dict) + processed_info: str = "" def get_type(self) -> str: """获取信息类型 @@ -58,3 +59,11 @@ class InfoBase: if isinstance(value, list): return value return [] + + def get_processed_info(self) -> str: + """获取处理后的信息 + + Returns: + str: 处理后的信息字符串 + """ + return self.processed_info diff --git a/src/chat/focus_chat/info/self_info.py b/src/chat/focus_chat/info/self_info.py new file mode 100644 index 000000000..866457956 --- /dev/null +++ b/src/chat/focus_chat/info/self_info.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from .info_base import InfoBase + + +@dataclass +class SelfInfo(InfoBase): + """思维信息类 + + 用于存储和管理当前思维状态的信息。 + + Attributes: + type (str): 信息类型标识符,默认为 "mind" + data (Dict[str, Any]): 包含 current_mind 的数据字典 + """ + + type: str = "self" + + def get_self_info(self) -> str: + """获取当前思维状态 + + Returns: + str: 当前思维状态 + """ + return self.get_info("self_info") or "" + + def set_self_info(self, self_info: str) -> None: + """设置当前思维状态 + + Args: + self_info: 要设置的思维状态 + """ + self.data["self_info"] = self_info + + def get_processed_info(self) -> str: + """获取处理后的信息 + + Returns: + str: 处理后的信息 + """ + return self.get_self_info() diff --git a/src/chat/focus_chat/info/workingmemory_info.py b/src/chat/focus_chat/info/workingmemory_info.py new file mode 100644 index 000000000..0edce8944 --- /dev/null +++ b/src/chat/focus_chat/info/workingmemory_info.py @@ -0,0 +1,89 @@ +from typing import Dict, Optional, List +from dataclasses import dataclass +from .info_base import InfoBase + + +@dataclass +class WorkingMemoryInfo(InfoBase): + type: str = "workingmemory" + + processed_info: str = "" + + def set_talking_message(self, message: str) -> None: + """设置说话消息 + + Args: + message (str): 说话消息内容 + """ + self.data["talking_message"] = message + + def set_working_memory(self, working_memory: List[str]) -> None: + """设置工作记忆 + + Args: + working_memory (str): 工作记忆内容 + """ + self.data["working_memory"] = working_memory + + def add_working_memory(self, working_memory: str) -> None: + """添加工作记忆 + + Args: + working_memory (str): 工作记忆内容 + """ + working_memory_list = self.data.get("working_memory", []) + # print(f"working_memory_list: {working_memory_list}") + working_memory_list.append(working_memory) + # print(f"working_memory_list: {working_memory_list}") + self.data["working_memory"] = working_memory_list + + def get_working_memory(self) -> List[str]: + """获取工作记忆 + + Returns: + List[str]: 工作记忆内容 + """ + return self.data.get("working_memory", []) + + def get_type(self) -> str: + """获取信息类型 + + Returns: + str: 当前信息对象的类型标识符 + """ + return self.type + + def get_data(self) -> Dict[str, str]: + """获取所有信息数据 + + Returns: + Dict[str, str]: 包含所有信息数据的字典 + """ + return self.data + + def get_info(self, key: str) -> Optional[str]: + """获取特定属性的信息 + + Args: + key: 要获取的属性键名 + + Returns: + Optional[str]: 属性值,如果键不存在则返回 None + """ + return self.data.get(key) + + def get_processed_info(self) -> Dict[str, str]: + """获取处理后的信息 + + Returns: + Dict[str, str]: 处理后的信息数据 + """ + all_memory = self.get_working_memory() + # print(f"all_memory: {all_memory}") + memory_str = "" + for memory in all_memory: + memory_str += f"{memory}\n" + + self.processed_info = memory_str + + return self.processed_info diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index bb70c043a..c9641b9b7 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -27,7 +27,7 @@ class ChattingInfoProcessor(BaseProcessor): """初始化观察处理器""" super().__init__() # TODO: API-Adapter修改标记 - self.llm_summary = LLMRequest( + self.model_summary = LLMRequest( model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" ) @@ -55,6 +55,8 @@ class ChattingInfoProcessor(BaseProcessor): for obs in observations: # print(f"obs: {obs}") if isinstance(obs, ChattingObservation): + # print("1111111111111111111111读取111111111111111") + obs_info = ObsInfo() await self.chat_compress(obs) @@ -92,7 +94,7 @@ class ChattingInfoProcessor(BaseProcessor): async def chat_compress(self, obs: ChattingObservation): if obs.compressor_prompt: try: - summary_result, _, _ = await self.llm_summary.generate_response(obs.compressor_prompt) + summary_result, _, _ = await self.model_summary.generate_response(obs.compressor_prompt) summary = "没有主题的闲聊" # 默认值 if summary_result: # 确保结果不为空 summary = summary_result diff --git a/src/chat/focus_chat/info_processors/mind_processor.py b/src/chat/focus_chat/info_processors/mind_processor.py index 221935e3d..afd7921d4 100644 --- a/src/chat/focus_chat/info_processors/mind_processor.py +++ b/src/chat/focus_chat/info_processors/mind_processor.py @@ -6,21 +6,14 @@ import time import traceback from src.common.logger_manager import get_logger from src.individuality.individuality import Individuality -import random from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.json_utils import safe_json_dumps from src.chat.message_receive.chat_stream import chat_manager -import difflib from src.chat.person_info.relationship_manager import relationship_manager from .base_processor import BaseProcessor from src.chat.focus_chat.info.mind_info import MindInfo from typing import List, Optional from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation -from src.chat.focus_chat.info_processors.processor_utils import ( - calculate_similarity, - calculate_replacement_probability, - get_spark, -) from typing import Dict from src.chat.focus_chat.info.info_base import InfoBase @@ -28,7 +21,6 @@ logger = get_logger("processor") def init_prompt(): - # --- Group Chat Prompt --- group_prompt = """ 你的名字是{bot_name} {memory_str} @@ -44,31 +36,29 @@ def init_prompt(): 现在请你继续输出观察和规划,输出要求: 1. 先关注未读新消息的内容和近期回复历史 2. 根据新信息,修改和删除之前的观察和规划 -3. 根据聊天内容继续输出观察和规划,{hf_do_next} +3. 根据聊天内容继续输出观察和规划 4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。 6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好""" Prompt(group_prompt, "sub_heartflow_prompt_before") - # --- Private Chat Prompt --- private_prompt = """ +你的名字是{bot_name} {memory_str} {extra_info} {relation_prompt} -你的名字是{bot_name},{prompt_personality},你现在{mood_info} {cycle_info_block} -现在是{time_now},你正在上网,和 {chat_target_name} 私聊,以下是你们的聊天内容: +现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容: {chat_observe_info} -以下是你之前对聊天的观察和规划: + +以下是你之前对聊天的观察和规划,你的名字是{bot_name}: {last_mind} -请仔细阅读聊天内容,想想你和 {chat_target_name} 的关系,回顾你们刚刚的交流,你刚刚发言和对方的反应,思考聊天的主题。 -请思考你要不要回复以及如何回复对方。 -思考并输出你的内心想法 -输出要求: -1. 根据聊天内容生成你的想法,{hf_do_next} -2. 不要分点、不要使用表情符号 -3. 避免多余符号(冒号、引号、括号等) -4. 语言简洁自然,不要浮夸 -5. 如果你刚发言,对方没有回复你,请谨慎回复""" + +现在请你继续输出观察和规划,输出要求: +1. 先关注未读新消息的内容和近期回复历史 +2. 根据新信息,修改和删除之前的观察和规划 +3. 根据聊天内容继续输出观察和规划 +4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。 +6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好""" Prompt(private_prompt, "sub_heartflow_prompt_private_before") @@ -210,45 +200,26 @@ class MindProcessor(BaseProcessor): for person in person_list: relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True) - # 构建个性部分 - # prompt_personality = individuality.get_prompt(x_person=2, level=2) - - # 获取当前时间 - time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - - spark_prompt = get_spark() - - # ---------- 5. 构建最终提示词 ---------- template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before" logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板") prompt = (await global_prompt_manager.get_prompt_async(template_name)).format( + bot_name=individuality.name, memory_str=memory_str, extra_info=self.structured_info_str, - # prompt_personality=prompt_personality, relation_prompt=relation_prompt, - bot_name=individuality.name, - time_now=time_now, + time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), chat_observe_info=chat_observe_info, - # mood_info="mood_info", - hf_do_next=spark_prompt, last_mind=previous_mind, cycle_info_block=hfcloop_observe_info, chat_target_name=chat_target_name, ) - # 在构建完提示词后,生成最终的prompt字符串 - final_prompt = prompt - - content = "" # 初始化内容变量 - + content = "(不知道该想些什么...)" try: - # 调用LLM生成响应 - response, _ = await self.llm_model.generate_response_async(prompt=final_prompt) - - # 直接使用LLM返回的文本响应作为 content - content = response if response else "" - + content, _ = await self.llm_model.generate_response_async(prompt=prompt) + if not content: + logger.warning(f"{self.log_prefix} LLM返回空结果,思考失败。") except Exception as e: # 处理总体异常 logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") @@ -256,16 +227,8 @@ class MindProcessor(BaseProcessor): content = "思考过程中出现错误" # 记录初步思考结果 - logger.debug(f"{self.log_prefix} 思考prompt: \n{final_prompt}\n") - - # 处理空响应情况 - if not content: - content = "(不知道该想些什么...)" - logger.warning(f"{self.log_prefix} LLM返回空结果,思考失败。") - - # ---------- 8. 更新思考状态并返回结果 ---------- + logger.debug(f"{self.log_prefix} 思考prompt: \n{prompt}\n") logger.info(f"{self.log_prefix} 思考结果: {content}") - # 更新当前思考内容 self.update_current_mind(content) return content @@ -275,138 +238,5 @@ class MindProcessor(BaseProcessor): self.past_mind.append(self.current_mind) self.current_mind = response - def de_similar(self, previous_mind, new_content): - try: - similarity = calculate_similarity(previous_mind, new_content) - replacement_prob = calculate_replacement_probability(similarity) - logger.debug(f"{self.log_prefix} 新旧想法相似度: {similarity:.2f}, 替换概率: {replacement_prob:.2f}") - - # 定义词语列表 (移到判断之前) - yu_qi_ci_liebiao = ["嗯", "哦", "啊", "唉", "哈", "唔"] - zhuan_zhe_liebiao = ["但是", "不过", "然而", "可是", "只是"] - cheng_jie_liebiao = ["然后", "接着", "此外", "而且", "另外"] - zhuan_jie_ci_liebiao = zhuan_zhe_liebiao + cheng_jie_liebiao - - if random.random() < replacement_prob: - # 相似度非常高时,尝试去重或特殊处理 - if similarity == 1.0: - logger.debug(f"{self.log_prefix} 想法完全重复 (相似度 1.0),执行特殊处理...") - # 随机截取大约一半内容 - if len(new_content) > 1: # 避免内容过短无法截取 - split_point = max( - 1, len(new_content) // 2 + random.randint(-len(new_content) // 4, len(new_content) // 4) - ) - truncated_content = new_content[:split_point] - else: - truncated_content = new_content # 如果只有一个字符或者为空,就不截取了 - - # 添加语气词和转折/承接词 - yu_qi_ci = random.choice(yu_qi_ci_liebiao) - zhuan_jie_ci = random.choice(zhuan_jie_ci_liebiao) - content = f"{yu_qi_ci}{zhuan_jie_ci},{truncated_content}" - logger.debug(f"{self.log_prefix} 想法重复,特殊处理后: {content}") - - else: - # 相似度较高但非100%,执行标准去重逻辑 - logger.debug(f"{self.log_prefix} 执行概率性去重 (概率: {replacement_prob:.2f})...") - logger.debug( - f"{self.log_prefix} previous_mind类型: {type(previous_mind)}, new_content类型: {type(new_content)}" - ) - - matcher = difflib.SequenceMatcher(None, previous_mind, new_content) - logger.debug(f"{self.log_prefix} matcher类型: {type(matcher)}") - - deduplicated_parts = [] - last_match_end_in_b = 0 - - # 获取并记录所有匹配块 - matching_blocks = matcher.get_matching_blocks() - logger.debug(f"{self.log_prefix} 匹配块数量: {len(matching_blocks)}") - logger.debug( - f"{self.log_prefix} 匹配块示例(前3个): {matching_blocks[:3] if len(matching_blocks) > 3 else matching_blocks}" - ) - - # get_matching_blocks()返回形如[(i, j, n), ...]的列表,其中i是a中的索引,j是b中的索引,n是匹配的长度 - for idx, match in enumerate(matching_blocks): - if not isinstance(match, tuple): - logger.error(f"{self.log_prefix} 匹配块 {idx} 不是元组类型,而是 {type(match)}: {match}") - continue - - try: - _i, j, n = match # 解包元组为三个变量 - logger.debug(f"{self.log_prefix} 匹配块 {idx}: i={_i}, j={j}, n={n}") - - if last_match_end_in_b < j: - # 确保添加的是字符串,而不是元组 - try: - non_matching_part = new_content[last_match_end_in_b:j] - logger.debug( - f"{self.log_prefix} 添加非匹配部分: '{non_matching_part}', 类型: {type(non_matching_part)}" - ) - if not isinstance(non_matching_part, str): - logger.warning( - f"{self.log_prefix} 非匹配部分不是字符串类型: {type(non_matching_part)}" - ) - non_matching_part = str(non_matching_part) - deduplicated_parts.append(non_matching_part) - except Exception as e: - logger.error(f"{self.log_prefix} 处理非匹配部分时出错: {e}") - logger.error(traceback.format_exc()) - last_match_end_in_b = j + n - except Exception as e: - logger.error(f"{self.log_prefix} 处理匹配块时出错: {e}") - logger.error(traceback.format_exc()) - - logger.debug(f"{self.log_prefix} 去重前部分列表: {deduplicated_parts}") - logger.debug(f"{self.log_prefix} 列表元素类型: {[type(part) for part in deduplicated_parts]}") - - # 确保所有元素都是字符串 - deduplicated_parts = [str(part) for part in deduplicated_parts] - - # 防止列表为空 - if not deduplicated_parts: - logger.warning(f"{self.log_prefix} 去重后列表为空,添加空字符串") - deduplicated_parts = [""] - - logger.debug(f"{self.log_prefix} 处理后的部分列表: {deduplicated_parts}") - - try: - deduplicated_content = "".join(deduplicated_parts).strip() - logger.debug(f"{self.log_prefix} 拼接后的去重内容: '{deduplicated_content}'") - except Exception as e: - logger.error(f"{self.log_prefix} 拼接去重内容时出错: {e}") - logger.error(traceback.format_exc()) - deduplicated_content = "" - - if deduplicated_content: - # 根据概率决定是否添加词语 - prefix_str = "" - if random.random() < 0.3: # 30% 概率添加语气词 - prefix_str += random.choice(yu_qi_ci_liebiao) - if random.random() < 0.7: # 70% 概率添加转折/承接词 - prefix_str += random.choice(zhuan_jie_ci_liebiao) - - # 组合最终结果 - if prefix_str: - content = f"{prefix_str},{deduplicated_content}" # 更新 content - logger.debug(f"{self.log_prefix} 去重并添加引导词后: {content}") - else: - content = deduplicated_content # 更新 content - logger.debug(f"{self.log_prefix} 去重后 (未添加引导词): {content}") - else: - logger.warning(f"{self.log_prefix} 去重后内容为空,保留原始LLM输出: {new_content}") - content = new_content # 保留原始 content - else: - logger.debug(f"{self.log_prefix} 未执行概率性去重 (概率: {replacement_prob:.2f})") - # content 保持 new_content 不变 - - except Exception as e: - logger.error(f"{self.log_prefix} 应用概率性去重或特殊处理时出错: {e}") - logger.error(traceback.format_exc()) - # 出错时保留原始 content - content = new_content - - return content - init_prompt() diff --git a/src/chat/focus_chat/info_processors/processor_utils.py b/src/chat/focus_chat/info_processors/processor_utils.py deleted file mode 100644 index 77cdc7a6b..000000000 --- a/src/chat/focus_chat/info_processors/processor_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -import difflib -import random -import time - - -def calculate_similarity(text_a: str, text_b: str) -> float: - """ - 计算两个文本字符串的相似度。 - """ - if not text_a or not text_b: - return 0.0 - matcher = difflib.SequenceMatcher(None, text_a, text_b) - return matcher.ratio() - - -def calculate_replacement_probability(similarity: float) -> float: - """ - 根据相似度计算替换的概率。 - 规则: - - 相似度 <= 0.4: 概率 = 0 - - 相似度 >= 0.9: 概率 = 1 - - 相似度 == 0.6: 概率 = 0.7 - - 0.4 < 相似度 <= 0.6: 线性插值 (0.4, 0) 到 (0.6, 0.7) - - 0.6 < 相似度 < 0.9: 线性插值 (0.6, 0.7) 到 (0.9, 1.0) - """ - if similarity <= 0.4: - return 0.0 - elif similarity >= 0.9: - return 1.0 - elif 0.4 < similarity <= 0.6: - # p = 3.5 * s - 1.4 - probability = 3.5 * similarity - 1.4 - return max(0.0, probability) - else: # 0.6 < similarity < 0.9 - # p = s + 0.1 - probability = similarity + 0.1 - return min(1.0, max(0.0, probability)) - - -def get_spark(): - local_random = random.Random() - current_minute = int(time.strftime("%M")) - local_random.seed(current_minute) - - hf_options = [ - ("可以参考之前的想法,在原来想法的基础上继续思考", 0.2), - ("可以参考之前的想法,在原来的想法上尝试新的话题", 0.4), - ("不要太深入", 0.2), - ("进行深入思考", 0.2), - ] - # 加权随机选择思考指导 - hf_do_next = local_random.choices( - [option[0] for option in hf_options], weights=[option[1] for option in hf_options], k=1 - )[0] - - return hf_do_next diff --git a/src/chat/focus_chat/info_processors/self_processor.py b/src/chat/focus_chat/info_processors/self_processor.py new file mode 100644 index 000000000..5114e49b6 --- /dev/null +++ b/src/chat/focus_chat/info_processors/self_processor.py @@ -0,0 +1,161 @@ +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.heart_flow.observation.observation import Observation +from src.chat.models.utils_model import LLMRequest +from src.config.config import global_config +import time +import traceback +from src.common.logger_manager import get_logger +from src.individuality.individuality import Individuality +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.message_receive.chat_stream import chat_manager +from src.chat.person_info.relationship_manager import relationship_manager +from .base_processor import BaseProcessor +from typing import List, Optional +from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation +from typing import Dict +from src.chat.focus_chat.info.info_base import InfoBase +from src.chat.focus_chat.info.self_info import SelfInfo + +logger = get_logger("processor") + + +def init_prompt(): + indentify_prompt = """ +你的名字是{bot_name},你是一个AI,但是你伪装成了一个人类,你的人格是,{prompt_personality}。 +你对外的形象是一只橙色的鱼,头上有绿色的树叶,你用的头像也是这个。 + +{relation_prompt} +{memory_str} + +现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容: +{chat_observe_info} + +现在请你根据现有的信息,思考自我认同 +1. 你是一个什么样的人,你和群里的人关系如何 +2. 思考有没有人提到你,或者图片与你有关 +3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同 +4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景 + +""" + Prompt(indentify_prompt, "indentify_prompt") + + +class SelfProcessor(BaseProcessor): + log_prefix = "自我认同" + + def __init__(self, subheartflow_id: str): + super().__init__() + + self.subheartflow_id = subheartflow_id + + self.llm_model = LLMRequest( + model=global_config.model.sub_heartflow, + temperature=global_config.model.sub_heartflow["temp"], + max_tokens=800, + request_type="self_identify", + ) + + name = chat_manager.get_stream_name(self.subheartflow_id) + self.log_prefix = f"[{name}] " + + async def process_info( + self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos + ) -> List[InfoBase]: + """处理信息对象 + + Args: + *infos: 可变数量的InfoBase类型的信息对象 + + Returns: + List[InfoBase]: 处理后的结构化信息列表 + """ + self_info_str = await self.self_indentify(observations, running_memorys) + + if self_info_str: + self_info = SelfInfo() + self_info.set_self_info(self_info_str) + else: + self_info = None + return None + + return [self_info] + + async def self_indentify( + self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None + ): + """ + 在回复前进行思考,生成内心想法并收集工具调用结果 + + 参数: + observations: 观察信息 + + 返回: + 如果return_prompt为False: + tuple: (current_mind, past_mind) 当前想法和过去的想法列表 + 如果return_prompt为True: + tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt + """ + + memory_str = "" + if running_memorys: + memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" + for running_memory in running_memorys: + memory_str += f"{running_memory['topic']}: {running_memory['content']}\n" + + if observations is None: + observations = [] + for observation in observations: + if isinstance(observation, ChattingObservation): + # 获取聊天元信息 + is_group_chat = observation.is_group_chat + chat_target_info = observation.chat_target_info + chat_target_name = "对方" # 私聊默认名称 + if not is_group_chat and chat_target_info: + # 优先使用person_name,其次user_nickname,最后回退到默认值 + chat_target_name = ( + chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or chat_target_name + ) + # 获取聊天内容 + chat_observe_info = observation.get_observe_info() + person_list = observation.person_list + if isinstance(observation, HFCloopObservation): + # hfcloop_observe_info = observation.get_observe_info() + pass + + individuality = Individuality.get_instance() + personality_block = individuality.get_prompt(x_person=2, level=2) + + relation_prompt = "" + for person in person_list: + relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True) + + prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format( + bot_name=individuality.name, + prompt_personality=personality_block, + memory_str=memory_str, + relation_prompt=relation_prompt, + time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + chat_observe_info=chat_observe_info, + ) + + content = "" + try: + content, _ = await self.llm_model.generate_response_async(prompt=prompt) + if not content: + logger.warning(f"{self.log_prefix} LLM返回空结果,自我识别失败。") + except Exception as e: + # 处理总体异常 + logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") + logger.error(traceback.format_exc()) + content = "自我识别过程中出现错误" + + if content == "None": + content = "" + # 记录初步思考结果 + logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n") + logger.info(f"{self.log_prefix} 自我识别结果: {content}") + + return content + + +init_prompt() diff --git a/src/chat/focus_chat/info_processors/tool_processor.py b/src/chat/focus_chat/info_processors/tool_processor.py index 57bac5f79..de9a9a216 100644 --- a/src/chat/focus_chat/info_processors/tool_processor.py +++ b/src/chat/focus_chat/info_processors/tool_processor.py @@ -11,8 +11,8 @@ from src.chat.person_info.relationship_manager import relationship_manager from .base_processor import BaseProcessor from typing import List, Optional, Dict from src.chat.heart_flow.observation.observation import Observation -from src.chat.heart_flow.observation.working_observation import WorkingObservation from src.chat.focus_chat.info.structured_info import StructuredInfo +from src.chat.heart_flow.observation.structure_observation import StructureObservation logger = get_logger("processor") @@ -24,9 +24,6 @@ def init_prompt(): tool_executor_prompt = """ 你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。 -你要在群聊中扮演以下角色: -{prompt_personality} - 你当前的额外信息: {memory_str} @@ -70,6 +67,8 @@ class ToolProcessor(BaseProcessor): list: 处理后的结构化信息列表 """ + working_infos = [] + if observations: for observation in observations: if isinstance(observation, ChattingObservation): @@ -77,7 +76,7 @@ class ToolProcessor(BaseProcessor): # 更新WorkingObservation中的结构化信息 for observation in observations: - if isinstance(observation, WorkingObservation): + if isinstance(observation, StructureObservation): for structured_info in result: logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}") observation.add_structured_info(structured_info) @@ -86,8 +85,9 @@ class ToolProcessor(BaseProcessor): logger.debug(f"{self.log_prefix} 获取更新后WorkingObservation中的结构化信息: {working_infos}") structured_info = StructuredInfo() - for working_info in working_infos: - structured_info.set_info(working_info.get("type"), working_info.get("content")) + if working_infos: + for working_info in working_infos: + structured_info.set_info(working_info.get("type"), working_info.get("content")) return [structured_info] @@ -134,7 +134,7 @@ class ToolProcessor(BaseProcessor): # 获取个性信息 individuality = Individuality.get_instance() - prompt_personality = individuality.get_prompt(x_person=2, level=2) + # prompt_personality = individuality.get_prompt(x_person=2, level=2) # 获取时间信息 time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) @@ -148,14 +148,14 @@ class ToolProcessor(BaseProcessor): # chat_target_name=chat_target_name, is_group_chat=is_group_chat, # relation_prompt=relation_prompt, - prompt_personality=prompt_personality, + # prompt_personality=prompt_personality, # mood_info=mood_info, bot_name=individuality.name, time_now=time_now, ) # 调用LLM,专注于工具使用 - logger.debug(f"开始执行工具调用{prompt}") + # logger.debug(f"开始执行工具调用{prompt}") response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools) logger.debug(f"获取到工具原始输出:\n{tool_calls}") diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py new file mode 100644 index 000000000..c79c8363d --- /dev/null +++ b/src/chat/focus_chat/info_processors/working_memory_processor.py @@ -0,0 +1,236 @@ +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.heart_flow.observation.observation import Observation +from src.chat.models.utils_model import LLMRequest +from src.config.config import global_config +import time +import traceback +from src.common.logger_manager import get_logger +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.message_receive.chat_stream import chat_manager +from .base_processor import BaseProcessor +from src.chat.focus_chat.info.mind_info import MindInfo +from typing import List, Optional +from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory +from typing import Dict +from src.chat.focus_chat.info.info_base import InfoBase +from json_repair import repair_json +from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo +import asyncio +import json + +logger = get_logger("processor") + + +def init_prompt(): + memory_proces_prompt = """ +你的名字是{bot_name} + +现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容: +{chat_observe_info} + +以下是你已经总结的记忆摘要,你可以调取这些记忆查看内容来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆: +{memory_str} + +观察聊天内容和已经总结的记忆,思考是否有新内容需要总结成记忆,如果有,就输出 true,否则输出 false +如果当前聊天记录的内容已经被总结,千万不要总结新记忆,输出false +如果已经总结的记忆包含了当前聊天记录的内容,千万不要总结新记忆,输出false +如果已经总结的记忆摘要,包含了当前聊天记录的内容,千万不要总结新记忆,输出false + +如果有相近的记忆,请合并记忆,输出merge_memory,格式为[["id1", "id2"], ["id3", "id4"],...],你可以进行多组合并,但是每组合并只能有两个记忆id,不要输出其他内容 + +请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆,以JSON格式输出,格式如下: +```json +{{ + "selected_memory_ids": ["id1", "id2", ...], + "new_memory": "true" or "false", + "merge_memory": [["id1", "id2"], ["id3", "id4"],...] + +}} +``` +""" + Prompt(memory_proces_prompt, "prompt_memory_proces") + + +class WorkingMemoryProcessor(BaseProcessor): + log_prefix = "工作记忆" + + def __init__(self, subheartflow_id: str): + super().__init__() + + self.subheartflow_id = subheartflow_id + + self.llm_model = LLMRequest( + model=global_config.model.sub_heartflow, + temperature=global_config.model.sub_heartflow["temp"], + max_tokens=800, + request_type="working_memory", + ) + + name = chat_manager.get_stream_name(self.subheartflow_id) + self.log_prefix = f"[{name}] " + + async def process_info( + self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos + ) -> List[InfoBase]: + """处理信息对象 + + Args: + *infos: 可变数量的InfoBase类型的信息对象 + + Returns: + List[InfoBase]: 处理后的结构化信息列表 + """ + working_memory = None + chat_info = "" + try: + for observation in observations: + if isinstance(observation, WorkingMemoryObservation): + working_memory = observation.get_observe_info() + # working_memory_obs = observation + if isinstance(observation, ChattingObservation): + chat_info = observation.get_observe_info() + # chat_info_truncate = observation.talking_message_str_truncate + + if not working_memory: + logger.warning(f"{self.log_prefix} 没有找到工作记忆对象") + mind_info = MindInfo() + return [mind_info] + except Exception as e: + logger.error(f"{self.log_prefix} 处理观察时出错: {e}") + logger.error(traceback.format_exc()) + return [] + + all_memory = working_memory.get_all_memories() + memory_prompts = [] + for memory in all_memory: + # memory_content = memory.data + memory_summary = memory.summary + memory_id = memory.id + memory_brief = memory_summary.get("brief") + # memory_detailed = memory_summary.get("detailed") + memory_keypoints = memory_summary.get("keypoints") + memory_events = memory_summary.get("events") + memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n" + memory_prompts.append(memory_single_prompt) + + memory_choose_str = "".join(memory_prompts) + + # 使用提示模板进行处理 + prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format( + bot_name=global_config.bot.nickname, + time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + chat_observe_info=chat_info, + memory_str=memory_choose_str, + ) + + # 调用LLM处理记忆 + content = "" + try: + logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}") + + content, _ = await self.llm_model.generate_response_async(prompt=prompt) + if not content: + logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。") + except Exception as e: + logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") + logger.error(traceback.format_exc()) + + # 解析LLM返回的JSON + try: + result = repair_json(content) + if isinstance(result, str): + result = json.loads(result) + if not isinstance(result, dict): + logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败,结果不是字典类型: {type(result)}") + return [] + + selected_memory_ids = result.get("selected_memory_ids", []) + new_memory = result.get("new_memory", "") + merge_memory = result.get("merge_memory", []) + except Exception as e: + logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}") + logger.error(traceback.format_exc()) + return [] + + logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}") + + # 根据selected_memory_ids,调取记忆 + memory_str = "" + if selected_memory_ids: + for memory_id in selected_memory_ids: + memory = await working_memory.retrieve_memory(memory_id) + if memory: + # memory_content = memory.data + memory_summary = memory.summary + memory_id = memory.id + memory_brief = memory_summary.get("brief") + # memory_detailed = memory_summary.get("detailed") + memory_keypoints = memory_summary.get("keypoints") + memory_events = memory_summary.get("events") + for keypoint in memory_keypoints: + memory_str += f"记忆要点:{keypoint}\n" + for event in memory_events: + memory_str += f"记忆事件:{event}\n" + # memory_str += f"记忆摘要:{memory_detailed}\n" + # memory_str += f"记忆主题:{memory_brief}\n" + + working_memory_info = WorkingMemoryInfo() + if memory_str: + working_memory_info.add_working_memory(memory_str) + logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}") + else: + logger.warning(f"{self.log_prefix} 没有找到工作记忆") + + # 根据聊天内容添加新记忆 + if new_memory: + # 使用异步方式添加新记忆,不阻塞主流程 + logger.debug(f"{self.log_prefix} {new_memory}新记忆: ") + asyncio.create_task(self.add_memory_async(working_memory, chat_info)) + + if merge_memory: + for merge_pairs in merge_memory: + memory1 = await working_memory.retrieve_memory(merge_pairs[0]) + memory2 = await working_memory.retrieve_memory(merge_pairs[1]) + if memory1 and memory2: + memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n" + memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n" + asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1])) + + return [working_memory_info] + + async def add_memory_async(self, working_memory: WorkingMemory, content: str): + """异步添加记忆,不阻塞主流程 + + Args: + working_memory: 工作记忆对象 + content: 记忆内容 + """ + try: + await working_memory.add_memory(content=content, from_source="chat_text") + logger.debug(f"{self.log_prefix} 异步添加新记忆成功: {content[:30]}...") + except Exception as e: + logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}") + logger.error(traceback.format_exc()) + + async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str): + """异步合并记忆,不阻塞主流程 + + Args: + working_memory: 工作记忆对象 + memory_str: 记忆内容 + """ + try: + merged_memory = await working_memory.merge_memory(memory_id1, memory_id2) + logger.debug(f"{self.log_prefix} 异步合并记忆成功: {memory_id1} 和 {memory_id2}...") + logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.summary.get('brief')}") + logger.debug(f"{self.log_prefix} 合并后的记忆详情: {merged_memory.summary.get('detailed')}") + logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('keypoints')}") + logger.debug(f"{self.log_prefix} 合并后的记忆事件: {merged_memory.summary.get('events')}") + + except Exception as e: + logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}") + logger.error(traceback.format_exc()) + + +init_prompt() diff --git a/src/chat/focus_chat/memory_activator.py b/src/chat/focus_chat/memory_activator.py index 4faf43747..4fcd37302 100644 --- a/src/chat/focus_chat/memory_activator.py +++ b/src/chat/focus_chat/memory_activator.py @@ -1,5 +1,5 @@ from src.chat.heart_flow.observation.chatting_observation import ChattingObservation -from src.chat.heart_flow.observation.working_observation import WorkingObservation +from src.chat.heart_flow.observation.structure_observation import StructureObservation from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.models.utils_model import LLMRequest from src.config.config import global_config @@ -54,7 +54,7 @@ class MemoryActivator: for observation in observations: if isinstance(observation, ChattingObservation): obs_info_text += observation.get_observe_info() - elif isinstance(observation, WorkingObservation): + elif isinstance(observation, StructureObservation): working_info = observation.get_observe_info() for working_info_item in working_info: obs_info_text += f"{working_info_item['type']}: {working_info_item['content']}\n" diff --git a/src/chat/focus_chat/planners/action_factory.py b/src/chat/focus_chat/planners/action_manager.py similarity index 79% rename from src/chat/focus_chat/planners/action_factory.py rename to src/chat/focus_chat/planners/action_manager.py index bca49c496..2ee7f349d 100644 --- a/src/chat/focus_chat/planners/action_factory.py +++ b/src/chat/focus_chat/planners/action_manager.py @@ -5,10 +5,14 @@ from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.common.logger_manager import get_logger +import importlib +import pkgutil +import os # 导入动作类,确保装饰器被执行 +import src.chat.focus_chat.planners.actions # noqa -logger = get_logger("action_factory") +logger = get_logger("action_manager") # 定义动作信息类型 ActionInfo = Dict[str, Any] @@ -34,13 +38,12 @@ class ActionManager: # 加载所有已注册动作 self._load_registered_actions() + # 加载插件动作 + self._load_plugin_actions() + # 初始化时将默认动作加载到使用中的动作 self._using_actions = self._default_actions.copy() - # logger.info(f"当前可用动作: {list(self._using_actions.keys())}") - # for action_name, action_info in self._using_actions.items(): - # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") - def _load_registered_actions(self) -> None: """ 加载所有通过装饰器注册的动作 @@ -49,6 +52,11 @@ class ActionManager: # 从_ACTION_REGISTRY获取所有已注册动作 for action_name, action_class in _ACTION_REGISTRY.items(): # 获取动作相关信息 + + # 不读取插件动作和基类 + if action_name == "base_action" or action_name == "plugin_action": + continue + action_description: str = getattr(action_class, "action_description", "") action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {}) action_require: list[str] = getattr(action_class, "action_require", []) @@ -62,10 +70,6 @@ class ActionManager: "require": action_require, } - # 注册2 - print("注册2") - print(action_info) - # 添加到所有已注册的动作 self._registered_actions[action_name] = action_info @@ -73,14 +77,56 @@ class ActionManager: if is_default: self._default_actions[action_name] = action_info - logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") - logger.info(f"默认动作: {list(self._default_actions.keys())}") + # logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") + # logger.info(f"默认动作: {list(self._default_actions.keys())}") # for action_name, action_info in self._default_actions.items(): # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") except Exception as e: logger.error(f"加载已注册动作失败: {e}") + def _load_plugin_actions(self) -> None: + """ + 加载所有插件目录中的动作 + """ + try: + # 检查插件目录是否存在 + plugin_path = "src.plugins" + plugin_dir = plugin_path.replace(".", os.path.sep) + if not os.path.exists(plugin_dir): + logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载") + return + + # 导入插件包 + try: + plugins_package = importlib.import_module(plugin_path) + except ImportError as e: + logger.error(f"导入插件包失败: {e}") + return + + # 遍历插件包中的所有子包 + for _, plugin_name, is_pkg in pkgutil.iter_modules( + plugins_package.__path__, plugins_package.__name__ + "." + ): + if not is_pkg: + continue + + # 检查插件是否有actions子包 + plugin_actions_path = f"{plugin_name}.actions" + try: + # 尝试导入插件的actions包 + importlib.import_module(plugin_actions_path) + logger.info(f"成功加载插件动作模块: {plugin_actions_path}") + except ImportError as e: + logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}") + continue + + # 再次从_ACTION_REGISTRY获取所有动作(包括刚刚从插件加载的) + self._load_registered_actions() + + except Exception as e: + logger.error(f"加载插件动作失败: {e}") + def create_action( self, action_name: str, @@ -94,8 +140,8 @@ class ActionManager: current_cycle: CycleDetail, log_prefix: str, on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], - total_no_reply_count: int = 0, - total_waiting_time: float = 0.0, + # total_no_reply_count: int = 0, + # total_waiting_time: float = 0.0, shutting_down: bool = False, ) -> Optional[BaseAction]: """ @@ -131,7 +177,7 @@ class ActionManager: return None try: - # 创建动作实例并传递所有必要参数 + # 创建动作实例 instance = handler_class( action_name=action_name, action_data=action_data, @@ -139,14 +185,14 @@ class ActionManager: cycle_timers=cycle_timers, thinking_id=thinking_id, observations=observations, - on_consecutive_no_reply_callback=on_consecutive_no_reply_callback, - current_cycle=current_cycle, - log_prefix=log_prefix, - total_no_reply_count=total_no_reply_count, - total_waiting_time=total_waiting_time, - shutting_down=shutting_down, expressor=expressor, chat_stream=chat_stream, + current_cycle=current_cycle, + log_prefix=log_prefix, + on_consecutive_no_reply_callback=on_consecutive_no_reply_callback, + # total_no_reply_count=total_no_reply_count, + # total_waiting_time=total_waiting_time, + shutting_down=shutting_down, ) return instance @@ -272,7 +318,3 @@ class ActionManager: Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None """ return _ACTION_REGISTRY.get(action_name) - - -# 创建全局实例 -ActionFactory = ActionManager() diff --git a/src/chat/focus_chat/planners/actions/__init__.py b/src/chat/focus_chat/planners/actions/__init__.py new file mode 100644 index 000000000..3f2baf665 --- /dev/null +++ b/src/chat/focus_chat/planners/actions/__init__.py @@ -0,0 +1,5 @@ +# 导入所有动作模块以确保装饰器被执行 +from . import reply_action # noqa +from . import no_reply_action # noqa + +# 在此处添加更多动作模块导入 diff --git a/src/chat/focus_chat/planners/actions/no_reply_action.py b/src/chat/focus_chat/planners/actions/no_reply_action.py index 71f1cb3f3..c6852fbe1 100644 --- a/src/chat/focus_chat/planners/actions/no_reply_action.py +++ b/src/chat/focus_chat/planners/actions/no_reply_action.py @@ -43,8 +43,8 @@ class NoReplyAction(BaseAction): on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], current_cycle: CycleDetail, log_prefix: str, - total_no_reply_count: int = 0, - total_waiting_time: float = 0.0, + # total_no_reply_count: int = 0, + # total_waiting_time: float = 0.0, shutting_down: bool = False, **kwargs, ): @@ -69,8 +69,8 @@ class NoReplyAction(BaseAction): self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback self._current_cycle = current_cycle self.log_prefix = log_prefix - self.total_no_reply_count = total_no_reply_count - self.total_waiting_time = total_waiting_time + # self.total_no_reply_count = total_no_reply_count + # self.total_waiting_time = total_waiting_time self._shutting_down = shutting_down async def handle_action(self) -> Tuple[bool, str]: @@ -94,36 +94,7 @@ class NoReplyAction(BaseAction): # 等待新消息、超时或关闭信号,并获取结果 await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix) # 从计时器获取实际等待时间 - current_waiting = self.cycle_timers.get("等待新消息", 0.0) - - if not self._shutting_down: - self.total_no_reply_count += 1 - self.total_waiting_time += current_waiting # 累加等待时间 - logger.debug( - f"{self.log_prefix} 连续不回复计数增加: {self.total_no_reply_count}/{CONSECUTIVE_NO_REPLY_THRESHOLD}, " - f"本次等待: {current_waiting:.2f}秒, 累计等待: {self.total_waiting_time:.2f}秒" - ) - - # 检查是否同时达到次数和时间阈值 - time_threshold = 0.66 * WAITING_TIME_THRESHOLD * CONSECUTIVE_NO_REPLY_THRESHOLD - if ( - self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD - and self.total_waiting_time >= time_threshold - ): - logger.info( - f"{self.log_prefix} 连续不回复达到阈值 ({self.total_no_reply_count}次) " - f"且累计等待时间达到 {self.total_waiting_time:.2f}秒 (阈值 {time_threshold}秒)," - f"调用回调请求状态转换" - ) - # 调用回调。注意:这里不重置计数器和时间,依赖回调函数成功改变状态来隐式重置上下文。 - await self.on_consecutive_no_reply_callback() - elif self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD: - # 仅次数达到阈值,但时间未达到 - logger.debug( - f"{self.log_prefix} 连续不回复次数达到阈值 ({self.total_no_reply_count}次) " - f"但累计等待时间 {self.total_waiting_time:.2f}秒 未达到时间阈值 ({time_threshold}秒),暂不调用回调" - ) - # else: 次数和时间都未达到阈值,不做处理 + _current_waiting = self.cycle_timers.get("等待新消息", 0.0) return True, "" # 不回复动作没有回复文本 diff --git a/src/chat/focus_chat/planners/actions/plugin_action.py b/src/chat/focus_chat/planners/actions/plugin_action.py new file mode 100644 index 000000000..5e8ddd998 --- /dev/null +++ b/src/chat/focus_chat/planners/actions/plugin_action.py @@ -0,0 +1,205 @@ +import traceback +from typing import Tuple, Dict, List, Any, Optional +from src.chat.focus_chat.planners.actions.base_action import BaseAction +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.focus_chat.hfc_utils import create_empty_anchor_message +from src.common.logger_manager import get_logger +from src.chat.person_info.person_info import person_info_manager +from abc import abstractmethod + +logger = get_logger("plugin_action") + + +class PluginAction(BaseAction): + """插件动作基类 + + 封装了主程序内部依赖,提供简化的API接口给插件开发者 + """ + + def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs): + """初始化插件动作基类""" + super().__init__(action_data, reasoning, cycle_timers, thinking_id) + + # 存储内部服务和对象引用 + self._services = {} + + # 从kwargs提取必要的内部服务 + if "observations" in kwargs: + self._services["observations"] = kwargs["observations"] + if "expressor" in kwargs: + self._services["expressor"] = kwargs["expressor"] + if "chat_stream" in kwargs: + self._services["chat_stream"] = kwargs["chat_stream"] + if "current_cycle" in kwargs: + self._services["current_cycle"] = kwargs["current_cycle"] + + self.log_prefix = kwargs.get("log_prefix", "") + + async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]: + """根据用户名获取用户ID""" + person_id = person_info_manager.get_person_id_by_person_name(person_name) + user_id = await person_info_manager.get_value(person_id, "user_id") + platform = await person_info_manager.get_value(person_id, "platform") + return platform, user_id + + # 提供简化的API方法 + async def send_message(self, text: str, target: Optional[str] = None) -> bool: + """发送消息的简化方法 + + Args: + text: 要发送的消息文本 + target: 目标消息(可选) + + Returns: + bool: 是否发送成功 + """ + try: + expressor = self._services.get("expressor") + chat_stream = self._services.get("chat_stream") + + if not expressor or not chat_stream: + logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") + return False + + # 构造简化的动作数据 + reply_data = {"text": text, "target": target or "", "emojis": []} + + # 获取锚定消息(如果有) + observations = self._services.get("observations", []) + + chatting_observation: ChattingObservation = next( + obs for obs in observations if isinstance(obs, ChattingObservation) + ) + anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) + + # 如果没有找到锚点消息,创建一个占位符 + if not anchor_message: + logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message.update_chat_stream(chat_stream) + + response_set = [ + ("text", text), + ] + + # 调用内部方法发送消息 + success = await expressor.send_response_messages( + anchor_message=anchor_message, + response_set=response_set, + ) + + return success + except Exception as e: + logger.error(f"{self.log_prefix} 发送消息时出错: {e}") + traceback.print_exc() + return False + + async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool: + """发送消息的简化方法 + + Args: + text: 要发送的消息文本 + target: 目标消息(可选) + + Returns: + bool: 是否发送成功 + """ + try: + expressor = self._services.get("expressor") + chat_stream = self._services.get("chat_stream") + + if not expressor or not chat_stream: + logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") + return False + + # 构造简化的动作数据 + reply_data = {"text": text, "target": target or "", "emojis": []} + + # 获取锚定消息(如果有) + observations = self._services.get("observations", []) + + chatting_observation: ChattingObservation = next( + obs for obs in observations if isinstance(obs, ChattingObservation) + ) + anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) + + # 如果没有找到锚点消息,创建一个占位符 + if not anchor_message: + logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message.update_chat_stream(chat_stream) + + # 调用内部方法发送消息 + success, _ = await expressor.deal_reply( + cycle_timers=self.cycle_timers, + action_data=reply_data, + anchor_message=anchor_message, + reasoning=self.reasoning, + thinking_id=self.thinking_id, + ) + + return success + except Exception as e: + logger.error(f"{self.log_prefix} 发送消息时出错: {e}") + return False + + def get_chat_type(self) -> str: + """获取当前聊天类型 + + Returns: + str: 聊天类型 ("group" 或 "private") + """ + chat_stream = self._services.get("chat_stream") + if chat_stream and hasattr(chat_stream, "group_info"): + return "group" if chat_stream.group_info else "private" + return "unknown" + + def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]: + """获取最近的消息 + + Args: + count: 要获取的消息数量 + + Returns: + List[Dict]: 消息列表,每个消息包含发送者、内容等信息 + """ + messages = [] + observations = self._services.get("observations", []) + + if observations and len(observations) > 0: + obs = observations[0] + if hasattr(obs, "get_talking_message"): + raw_messages = obs.get_talking_message() + # 转换为简化格式 + for msg in raw_messages[-count:]: + simple_msg = { + "sender": msg.get("sender", "未知"), + "content": msg.get("content", ""), + "timestamp": msg.get("timestamp", 0), + } + messages.append(simple_msg) + + return messages + + @abstractmethod + async def process(self) -> Tuple[bool, str]: + """插件处理逻辑,子类必须实现此方法 + + Returns: + Tuple[bool, str]: (是否执行成功, 回复文本) + """ + pass + + async def handle_action(self) -> Tuple[bool, str]: + """实现BaseAction的抽象方法,调用子类的process方法 + + Returns: + Tuple[bool, str]: (是否执行成功, 回复文本) + """ + return await self.process() diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 6e4f41d4d..07e35b458 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - from src.common.logger_manager import get_logger from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action from typing import Tuple, List @@ -25,19 +24,18 @@ class ReplyAction(BaseAction): action_description: str = "表达想法,可以只包含文本、表情或两者都有" action_parameters: dict[str:str] = { "text": "你想要表达的内容(可选)", - "emojis": "描述当前使用表情包的场景(可选)", + "emojis": "描述当前使用表情包的场景,一段话描述(可选)", "target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)", } action_require: list[str] = [ "有实质性内容需要表达", "有人提到你,但你还没有回应他", - "在合适的时候添加表情(不要总是添加)", - "如果你要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本", - "除非有明确的回复目标,如果选择了target,不用特别提到某个人的人名", + "在合适的时候添加表情(不要总是添加),表情描述要详细,描述当前场景,一段话描述", + "如果你有明确的,要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本", "一次只回复一个人,一次只回复一个话题,突出重点", "如果是自己发的消息想继续,需自然衔接", "避免重复或评价自己的发言,不要和自己聊天", - "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。", + "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要有额外的符号,尽量简单简短", ] default = True @@ -104,13 +102,15 @@ class ReplyAction(BaseAction): "emojis": "微笑" # 表情关键词列表(可选) } """ - # 重置连续不回复计数器 - self.total_no_reply_count = 0 - self.total_waiting_time = 0.0 # 从聊天观察获取锚定消息 - observations: ChattingObservation = self.observations[0] - anchor_message = observations.serch_message_by_text(reply_data["target"]) + chatting_observation: ChattingObservation = next( + obs for obs in self.observations if isinstance(obs, ChattingObservation) + ) + if reply_data.get("target"): + anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) + else: + anchor_message = None # 如果没有找到锚点消息,创建一个占位符 if not anchor_message: diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index 83c8b6791..116419ee1 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -12,8 +12,8 @@ from src.chat.focus_chat.info.structured_info import StructuredInfo from src.common.logger_manager import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.individuality.individuality import Individuality -from src.chat.focus_chat.planners.action_factory import ActionManager -from src.chat.focus_chat.planners.action_factory import ActionInfo +from src.chat.focus_chat.planners.action_manager import ActionManager +from src.chat.focus_chat.planners.action_manager import ActionInfo logger = get_logger("planner") @@ -22,8 +22,12 @@ install(extra_lines=3) def init_prompt(): Prompt( - """你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话: + """{extra_info_block} + +你需要基于以下信息决定如何参与对话 +这些信息可能会有冲突,请你整合这些信息,并选择一个最合适的action: {chat_content_block} + {mind_info_block} {cycle_info_block} @@ -53,10 +57,9 @@ def init_prompt(): action_name: {action_name} 描述:{action_description} 参数: - {action_parameters} +{action_parameters} 动作要求: - {action_require} - """, +{action_require}""", "action_prompt", ) @@ -66,7 +69,7 @@ class ActionPlanner: self.log_prefix = log_prefix # LLM规划器配置 self.planner_llm = LLMRequest( - model=global_config.llm_plan, + model=global_config.model.plan, max_tokens=1000, request_type="action_planning", # 用于动作规划 ) @@ -87,9 +90,10 @@ class ActionPlanner: try: # 获取观察信息 + extra_info: list[str] = [] for info in all_plan_info: if isinstance(info, ObsInfo): - logger.debug(f"{self.log_prefix} 观察信息: {info}") + # logger.debug(f"{self.log_prefix} 观察信息: {info}") observed_messages = info.get_talking_message() observed_messages_str = info.get_talking_message_str_truncate() chat_type = info.get_chat_type() @@ -98,14 +102,17 @@ class ActionPlanner: else: is_group_chat = False elif isinstance(info, MindInfo): - logger.debug(f"{self.log_prefix} 思维信息: {info}") + # logger.debug(f"{self.log_prefix} 思维信息: {info}") current_mind = info.get_current_mind() elif isinstance(info, CycleInfo): - logger.debug(f"{self.log_prefix} 循环信息: {info}") + # logger.debug(f"{self.log_prefix} 循环信息: {info}") cycle_info = info.get_observe_info() elif isinstance(info, StructuredInfo): - logger.debug(f"{self.log_prefix} 结构化信息: {info}") + # logger.debug(f"{self.log_prefix} 结构化信息: {info}") _structured_info = info.get_data() + else: + logger.debug(f"{self.log_prefix} 其他信息: {info}") + extra_info.append(info.get_processed_info()) current_available_actions = self.action_manager.get_using_actions() @@ -118,6 +125,7 @@ class ActionPlanner: # structured_info=structured_info, # <-- Pass SubMind info current_available_actions=current_available_actions, # <-- Pass determined actions cycle_info=cycle_info, # <-- Pass cycle info + extra_info=extra_info, ) # --- 调用 LLM (普通文本生成) --- @@ -144,15 +152,13 @@ class ActionPlanner: extracted_action = parsed_json.get("action", "no_reply") extracted_reasoning = parsed_json.get("reasoning", "LLM未提供理由") - # 新的reply格式 - if extracted_action == "reply": - action_data = { - "text": parsed_json.get("text", []), - "emojis": parsed_json.get("emojis", []), - "target": parsed_json.get("target", ""), - } - else: - action_data = {} # 其他动作可能不需要额外数据 + # 将所有其他属性添加到action_data + action_data = {} + for key, value in parsed_json.items(): + if key not in ["action", "reasoning"]: + action_data[key] = value + + # 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data if extracted_action not in current_available_actions: logger.warning( @@ -207,6 +213,7 @@ class ActionPlanner: current_mind: Optional[str], current_available_actions: Dict[str, ActionInfo], cycle_info: Optional[str], + extra_info: list[str], ) -> str: """构建 Planner LLM 的提示词 (获取模板并填充数据)""" try: @@ -246,11 +253,11 @@ class ActionPlanner: param_text = "" for param_name, param_description in using_actions_info["parameters"].items(): - param_text += f"{param_name}: {param_description}\n" + param_text += f" {param_name}: {param_description}\n" require_text = "" for require_item in using_actions_info["require"]: - require_text += f"- {require_item}\n" + require_text += f" - {require_item}\n" using_action_prompt = using_action_prompt.format( action_name=using_actions_name, @@ -261,15 +268,19 @@ class ActionPlanner: action_options_block += using_action_prompt + extra_info_block = "\n".join(extra_info) + extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策" + planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") prompt = planner_prompt_template.format( - bot_name=global_config.BOT_NICKNAME, + bot_name=global_config.bot.nickname, prompt_personality=personality_block, chat_context_description=chat_context_description, chat_content_block=chat_content_block, mind_info_block=mind_info_block, cycle_info_block=cycle_info, action_options_text=action_options_block, + extra_info_block=extra_info_block, ) return prompt diff --git a/src/chat/focus_chat/working_memory/memory_item.py b/src/chat/focus_chat/working_memory/memory_item.py new file mode 100644 index 000000000..15724a387 --- /dev/null +++ b/src/chat/focus_chat/working_memory/memory_item.py @@ -0,0 +1,112 @@ +from typing import Dict, Any, List, Optional, Set, Tuple +import time +import random +import string + + +class MemoryItem: + """记忆项类,用于存储单个记忆的所有相关信息""" + + def __init__(self, data: Any, from_source: str = "", tags: Optional[List[str]] = None): + """ + 初始化记忆项 + + Args: + data: 记忆数据 + from_source: 数据来源 + tags: 数据标签列表 + """ + # 生成可读ID:时间戳_随机字符串 + timestamp = int(time.time()) + random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2)) + self.id = f"{timestamp}_{random_str}" + self.data = data + self.data_type = type(data) + self.from_source = from_source + self.tags = set(tags) if tags else set() + self.timestamp = time.time() + # 修改summary的结构说明,用于存储可能的总结信息 + # summary结构:{ + # "brief": "记忆内容主题", + # "detailed": "记忆内容概括", + # "keypoints": ["关键概念1", "关键概念2"], + # "events": ["事件1", "事件2"] + # } + self.summary = None + + # 记忆精简次数 + self.compress_count = 0 + + # 记忆提取次数 + self.retrieval_count = 0 + + # 记忆强度 (初始为10) + self.memory_strength = 10.0 + + # 记忆操作历史记录 + # 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...] + self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)] + + def add_tag(self, tag: str) -> None: + """添加标签""" + self.tags.add(tag) + + def remove_tag(self, tag: str) -> None: + """移除标签""" + if tag in self.tags: + self.tags.remove(tag) + + def has_tag(self, tag: str) -> bool: + """检查是否有特定标签""" + return tag in self.tags + + def has_all_tags(self, tags: List[str]) -> bool: + """检查是否有所有指定的标签""" + return all(tag in self.tags for tag in tags) + + def matches_source(self, source: str) -> bool: + """检查来源是否匹配""" + return self.from_source == source + + def set_summary(self, summary: Dict[str, Any]) -> None: + """设置总结信息""" + self.summary = summary + + def increase_strength(self, amount: float) -> None: + """增加记忆强度""" + self.memory_strength = min(10.0, self.memory_strength + amount) + # 记录操作历史 + self.record_operation("strengthen") + + def decrease_strength(self, amount: float) -> None: + """减少记忆强度""" + self.memory_strength = max(0.1, self.memory_strength - amount) + # 记录操作历史 + self.record_operation("weaken") + + def increase_compress_count(self) -> None: + """增加精简次数并减弱记忆强度""" + self.compress_count += 1 + # 记录操作历史 + self.record_operation("compress") + + def record_retrieval(self) -> None: + """记录记忆被提取的情况""" + self.retrieval_count += 1 + # 提取后强度翻倍 + self.memory_strength = min(10.0, self.memory_strength * 2) + # 记录操作历史 + self.record_operation("retrieval") + + def record_operation(self, operation_type: str) -> None: + """记录操作历史""" + current_time = time.time() + self.history.append((operation_type, current_time, self.compress_count, self.memory_strength)) + + def to_tuple(self) -> Tuple[Any, str, Set[str], float, str]: + """转换为元组格式(为了兼容性)""" + return (self.data, self.from_source, self.tags, self.timestamp, self.id) + + def is_memory_valid(self) -> bool: + """检查记忆是否有效(强度是否大于等于1)""" + return self.memory_strength >= 1.0 diff --git a/src/chat/focus_chat/working_memory/memory_manager.py b/src/chat/focus_chat/working_memory/memory_manager.py new file mode 100644 index 000000000..7fda40239 --- /dev/null +++ b/src/chat/focus_chat/working_memory/memory_manager.py @@ -0,0 +1,781 @@ +from typing import Dict, Any, Type, TypeVar, List, Optional +import traceback +from json_repair import repair_json +from rich.traceback import install +from src.common.logger_manager import get_logger +from src.chat.models.utils_model import LLMRequest +from src.config.config import global_config +from src.chat.focus_chat.working_memory.memory_item import MemoryItem +import json # 添加json模块导入 + + +install(extra_lines=3) +logger = get_logger("working_memory") + +T = TypeVar("T") + + +class MemoryManager: + def __init__(self, chat_id: str): + """ + 初始化工作记忆 + + Args: + chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天 + """ + # 关联的聊天ID + self._chat_id = chat_id + + # 主存储: 数据类型 -> 记忆项列表 + self._memory: Dict[Type, List[MemoryItem]] = {} + + # ID到记忆项的映射 + self._id_map: Dict[str, MemoryItem] = {} + + self.llm_summarizer = LLMRequest( + model=global_config.model.summary, temperature=0.3, max_tokens=512, request_type="memory_summarization" + ) + + @property + def chat_id(self) -> str: + """获取关联的聊天ID""" + return self._chat_id + + @chat_id.setter + def chat_id(self, value: str): + """设置关联的聊天ID""" + self._chat_id = value + + def push_item(self, memory_item: MemoryItem) -> str: + """ + 推送一个已创建的记忆项到工作记忆中 + + Args: + memory_item: 要存储的记忆项 + + Returns: + 记忆项的ID + """ + data_type = memory_item.data_type + + # 确保存在该类型的存储列表 + if data_type not in self._memory: + self._memory[data_type] = [] + + # 添加到内存和ID映射 + self._memory[data_type].append(memory_item) + self._id_map[memory_item.id] = memory_item + + return memory_item.id + + async def push_with_summary(self, data: T, from_source: str = "", tags: Optional[List[str]] = None) -> MemoryItem: + """ + 推送一段有类型的信息到工作记忆中,并自动生成总结 + + Args: + data: 要存储的数据 + from_source: 数据来源 + tags: 数据标签列表 + + Returns: + 包含原始数据和总结信息的字典 + """ + # 如果数据是字符串类型,则先进行总结 + if isinstance(data, str): + # 先生成总结 + summary = await self.summarize_memory_item(data) + + # 准备标签 + memory_tags = list(tags) if tags else [] + + # 创建记忆项 + memory_item = MemoryItem(data, from_source, memory_tags) + + # 将总结信息保存到记忆项中 + memory_item.set_summary(summary) + + # 推送记忆项 + self.push_item(memory_item) + + return memory_item + else: + # 非字符串类型,直接创建并推送记忆项 + memory_item = MemoryItem(data, from_source, tags) + self.push_item(memory_item) + + return memory_item + + def get_by_id(self, memory_id: str) -> Optional[MemoryItem]: + """ + 通过ID获取记忆项 + + Args: + memory_id: 记忆项ID + + Returns: + 找到的记忆项,如果不存在则返回None + """ + memory_item = self._id_map.get(memory_id) + if memory_item: + # 检查记忆强度,如果小于1则删除 + if not memory_item.is_memory_valid(): + print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除") + self.delete(memory_id) + return None + + return memory_item + + def get_all_items(self) -> List[MemoryItem]: + """获取所有记忆项""" + return list(self._id_map.values()) + + def find_items( + self, + data_type: Optional[Type] = None, + source: Optional[str] = None, + tags: Optional[List[str]] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + memory_id: Optional[str] = None, + limit: Optional[int] = None, + newest_first: bool = False, + min_strength: float = 0.0, + ) -> List[MemoryItem]: + """ + 按条件查找记忆项 + + Args: + data_type: 要查找的数据类型 + source: 数据来源 + tags: 必须包含的标签列表 + start_time: 开始时间戳 + end_time: 结束时间戳 + memory_id: 特定记忆项ID + limit: 返回结果的最大数量 + newest_first: 是否按最新优先排序 + min_strength: 最小记忆强度 + + Returns: + 符合条件的记忆项列表 + """ + # 如果提供了特定ID,直接查找 + if memory_id: + item = self.get_by_id(memory_id) + return [item] if item else [] + + results = [] + + # 确定要搜索的类型列表 + types_to_search = [data_type] if data_type else list(self._memory.keys()) + + # 对每个类型进行搜索 + for typ in types_to_search: + if typ not in self._memory: + continue + + # 获取该类型的所有项目 + items = self._memory[typ] + + # 如果需要最新优先,则反转遍历顺序 + if newest_first: + items_to_check = list(reversed(items)) + else: + items_to_check = items + + # 遍历项目 + for item in items_to_check: + # 检查来源是否匹配 + if source is not None and not item.matches_source(source): + continue + + # 检查标签是否匹配 + if tags is not None and not item.has_all_tags(tags): + continue + + # 检查时间范围 + if start_time is not None and item.timestamp < start_time: + continue + if end_time is not None and item.timestamp > end_time: + continue + + # 检查记忆强度 + if min_strength > 0 and item.memory_strength < min_strength: + continue + + # 所有条件都满足,添加到结果中 + results.append(item) + + # 如果达到限制数量,提前返回 + if limit is not None and len(results) >= limit: + return results + + return results + + async def summarize_memory_item(self, content: str) -> Dict[str, Any]: + """ + 使用LLM总结记忆项 + + Args: + content: 需要总结的内容 + + Returns: + 包含总结、概括、关键概念和事件的字典 + """ + prompt = f"""请对以下内容进行总结,总结成记忆,输出四部分: +1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么 +2. 记忆内容概括(200字以内):让用户可以了解记忆内容的大致内容 +3. 关键概念和知识(keypoints):多条,提取关键的概念、知识点和关键词,要包含对概念的解释 +4. 事件描述(events):多条,描述谁(人物)在什么时候(时间)做了什么(事件) + +内容: +{content} + +请按以下JSON格式输出: +```json +{{ + "brief": "记忆内容主题(20字以内)", + "detailed": "记忆内容概括(200字以内)", + "keypoints": [ + "概念1:解释", + "概念2:解释", + ... + ], + "events": [ + "事件1:谁在什么时候做了什么", + "事件2:谁在什么时候做了什么", + ... + ] +}} +``` +请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 +""" + default_summary = { + "brief": "主题未知的记忆", + "detailed": "大致内容未知的记忆", + "keypoints": ["未知的概念"], + "events": ["未知的事件"], + } + + try: + # 调用LLM生成总结 + response, _ = await self.llm_summarizer.generate_response_async(prompt) + + # 使用repair_json解析响应 + try: + # 使用repair_json修复JSON格式 + fixed_json_string = repair_json(response) + + # 如果repair_json返回的是字符串,需要解析为Python对象 + if isinstance(fixed_json_string, str): + try: + json_result = json.loads(fixed_json_string) + except json.JSONDecodeError as decode_error: + logger.error(f"JSON解析错误: {str(decode_error)}") + return default_summary + else: + # 如果repair_json直接返回了字典对象,直接使用 + json_result = fixed_json_string + + # 进行额外的类型检查 + if not isinstance(json_result, dict): + logger.error(f"修复后的JSON不是字典类型: {type(json_result)}") + return default_summary + + # 确保所有必要字段都存在且类型正确 + if "brief" not in json_result or not isinstance(json_result["brief"], str): + json_result["brief"] = "主题未知的记忆" + + if "detailed" not in json_result or not isinstance(json_result["detailed"], str): + json_result["detailed"] = "大致内容未知的记忆" + + # 处理关键概念 + if "keypoints" not in json_result or not isinstance(json_result["keypoints"], list): + json_result["keypoints"] = ["未知的概念"] + else: + # 确保keypoints中的每个项目都是字符串 + json_result["keypoints"] = [str(point) for point in json_result["keypoints"] if point is not None] + if not json_result["keypoints"]: + json_result["keypoints"] = ["未知的概念"] + + # 处理事件 + if "events" not in json_result or not isinstance(json_result["events"], list): + json_result["events"] = ["未知的事件"] + else: + # 确保events中的每个项目都是字符串 + json_result["events"] = [str(event) for event in json_result["events"] if event is not None] + if not json_result["events"]: + json_result["events"] = ["未知的事件"] + + # 兼容旧版,将keypoints和events合并到key_points中 + json_result["key_points"] = json_result["keypoints"] + json_result["events"] + + return json_result + + except Exception as json_error: + logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要") + # 返回默认结构 + return default_summary + + except Exception as e: + # 出错时返回简单的结构 + logger.error(f"生成总结时出错: {str(e)}") + return default_summary + + async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]: + """ + 对记忆进行精简操作,根据要求修改要点、总结和概括 + + Args: + memory_id: 记忆ID + requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点 + + Returns: + 修改后的记忆总结字典 + """ + # 获取指定ID的记忆项 + logger.info(f"精简记忆: {memory_id}") + memory_item = self.get_by_id(memory_id) + if not memory_item: + raise ValueError(f"未找到ID为{memory_id}的记忆项") + + # 增加精简次数 + memory_item.increase_compress_count() + + summary = memory_item.summary + + # 使用LLM根据要求对总结、概括和要点进行精简修改 + prompt = f""" +请根据以下要求,对记忆内容的主题、概括、关键概念和事件进行精简,模拟记忆的遗忘过程: +要求:{requirements} +你可以随机对关键概念和事件进行压缩,模糊或者丢弃,修改后,同样修改主题和概括 + +目前主题:{summary["brief"]} + +目前概括:{summary["detailed"]} + +目前关键概念: +{chr(10).join([f"- {point}" for point in summary.get("keypoints", [])])} + +目前事件: +{chr(10).join([f"- {point}" for point in summary.get("events", [])])} + +请生成修改后的主题、概括、关键概念和事件,遵循以下格式: +```json +{{ + "brief": "修改后的主题(20字以内)", + "detailed": "修改后的概括(200字以内)", + "keypoints": [ + "修改后的概念1:解释", + "修改后的概念2:解释" + ], + "events": [ + "修改后的事件1:谁在什么时候做了什么", + "修改后的事件2:谁在什么时候做了什么" + ] +}} +``` +请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 +""" + # 检查summary中是否有旧版结构,转换为新版结构 + if "keypoints" not in summary and "events" not in summary and "key_points" in summary: + # 尝试区分key_points中的keypoints和events + # 简单地将前半部分视为keypoints,后半部分视为events + key_points = summary.get("key_points", []) + halfway = len(key_points) // 2 + summary["keypoints"] = key_points[:halfway] or ["未知的概念"] + summary["events"] = key_points[halfway:] or ["未知的事件"] + + # 定义默认的精简结果 + default_refined = { + "brief": summary["brief"], + "detailed": summary["detailed"], + "keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念 + "events": summary.get("events", ["未知的事件"])[:1], # 默认只保留第一个事件 + } + + try: + # 调用LLM修改总结、概括和要点 + response, _ = await self.llm_summarizer.generate_response_async(prompt) + logger.info(f"精简记忆响应: {response}") + # 使用repair_json处理响应 + try: + # 修复JSON格式 + fixed_json_string = repair_json(response) + + # 将修复后的字符串解析为Python对象 + if isinstance(fixed_json_string, str): + try: + refined_data = json.loads(fixed_json_string) + except json.JSONDecodeError as decode_error: + logger.error(f"JSON解析错误: {str(decode_error)}") + refined_data = default_refined + else: + # 如果repair_json直接返回了字典对象,直接使用 + refined_data = fixed_json_string + + # 确保是字典类型 + if not isinstance(refined_data, dict): + logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}") + refined_data = default_refined + + # 更新总结、概括 + summary["brief"] = refined_data.get("brief", "主题未知的记忆") + summary["detailed"] = refined_data.get("detailed", "大致内容未知的记忆") + + # 更新关键概念 + keypoints = refined_data.get("keypoints", []) + if isinstance(keypoints, list) and keypoints: + # 确保所有关键概念都是字符串 + summary["keypoints"] = [str(point) for point in keypoints if point is not None] + else: + # 如果keypoints不是列表或为空,使用默认值 + summary["keypoints"] = ["主要概念已遗忘"] + + # 更新事件 + events = refined_data.get("events", []) + if isinstance(events, list) and events: + # 确保所有事件都是字符串 + summary["events"] = [str(event) for event in events if event is not None] + else: + # 如果events不是列表或为空,使用默认值 + summary["events"] = ["事件细节已遗忘"] + + # 兼容旧版,维护key_points + summary["key_points"] = summary["keypoints"] + summary["events"] + + except Exception as e: + logger.error(f"精简记忆出错: {str(e)}") + traceback.print_exc() + + # 出错时使用简化的默认精简 + summary["brief"] = summary["brief"] + " (已简化)" + summary["keypoints"] = summary.get("keypoints", ["未知的概念"])[:1] + summary["events"] = summary.get("events", ["未知的事件"])[:1] + summary["key_points"] = summary["keypoints"] + summary["events"] + + except Exception as e: + logger.error(f"精简记忆调用LLM出错: {str(e)}") + traceback.print_exc() + + # 更新原记忆项的总结 + memory_item.set_summary(summary) + + return memory_item + + def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool: + """ + 使单个记忆衰减 + + Args: + memory_id: 记忆ID + decay_factor: 衰减因子(0-1之间) + + Returns: + 是否成功衰减 + """ + memory_item = self.get_by_id(memory_id) + if not memory_item: + return False + + # 计算衰减量(当前强度 * (1-衰减因子)) + old_strength = memory_item.memory_strength + decay_amount = old_strength * (1 - decay_factor) + + # 更新强度 + memory_item.memory_strength = decay_amount + + return True + + def delete(self, memory_id: str) -> bool: + """ + 删除指定ID的记忆项 + + Args: + memory_id: 要删除的记忆项ID + + Returns: + 是否成功删除 + """ + if memory_id not in self._id_map: + return False + + # 获取要删除的项 + item = self._id_map[memory_id] + + # 从内存中删除 + data_type = item.data_type + if data_type in self._memory: + self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id] + + # 从ID映射中删除 + del self._id_map[memory_id] + + return True + + def clear(self, data_type: Optional[Type] = None) -> None: + """ + 清除记忆中的数据 + + Args: + data_type: 要清除的数据类型,如果为None则清除所有数据 + """ + if data_type is None: + # 清除所有数据 + self._memory.clear() + self._id_map.clear() + elif data_type in self._memory: + # 清除指定类型的数据 + for item in self._memory[data_type]: + if item.id in self._id_map: + del self._id_map[item.id] + del self._memory[data_type] + + async def merge_memories( + self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True + ) -> MemoryItem: + """ + 合并两个记忆项 + + Args: + memory_id1: 第一个记忆项ID + memory_id2: 第二个记忆项ID + reason: 合并原因 + delete_originals: 是否删除原始记忆,默认为True + + Returns: + 包含合并后的记忆信息的字典 + """ + # 获取两个记忆项 + memory_item1 = self.get_by_id(memory_id1) + memory_item2 = self.get_by_id(memory_id2) + + if not memory_item1 or not memory_item2: + raise ValueError("无法找到指定的记忆项") + + content1 = memory_item1.data + content2 = memory_item2.data + + # 获取记忆的摘要信息(如果有) + summary1 = memory_item1.summary + summary2 = memory_item2.summary + + # 构建合并提示 + prompt = f""" +请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。 +合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。 + +合并原因:{reason} +""" + + # 如果有摘要信息,添加到提示中 + if summary1: + prompt += f"记忆1主题:{summary1['brief']}\n" + prompt += f"记忆1概括:{summary1['detailed']}\n" + + if "keypoints" in summary1: + prompt += "记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1["keypoints"]]) + "\n\n" + + if "events" in summary1: + prompt += "记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1["events"]]) + "\n\n" + elif "key_points" in summary1: + prompt += "记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1["key_points"]]) + "\n\n" + + if summary2: + prompt += f"记忆2主题:{summary2['brief']}\n" + prompt += f"记忆2概括:{summary2['detailed']}\n" + + if "keypoints" in summary2: + prompt += "记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2["keypoints"]]) + "\n\n" + + if "events" in summary2: + prompt += "记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2["events"]]) + "\n\n" + elif "key_points" in summary2: + prompt += "记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2["key_points"]]) + "\n\n" + + # 添加记忆原始内容 + prompt += f""" +记忆1原始内容: +{content1} + +记忆2原始内容: +{content2} + +请按以下JSON格式输出合并结果: +```json +{{ + "content": "合并后的记忆内容文本(尽可能保留原信息,但去除重复)", + "brief": "合并后的主题(20字以内)", + "detailed": "合并后的概括(200字以内)", + "keypoints": [ + "合并后的概念1:解释", + "合并后的概念2:解释", + "合并后的概念3:解释" + ], + "events": [ + "合并后的事件1:谁在什么时候做了什么", + "合并后的事件2:谁在什么时候做了什么" + ] +}} +``` +请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 +""" + + # 默认合并结果 + default_merged = { + "content": f"{content1}\n\n{content2}", + "brief": f"合并:{summary1['brief']} + {summary2['brief']}", + "detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}", + "keypoints": [], + "events": [], + } + + # 合并旧版key_points + if "key_points" in summary1: + default_merged["keypoints"].extend(summary1.get("keypoints", [])) + default_merged["events"].extend(summary1.get("events", [])) + # 如果没有新的结构,尝试从旧结构分离 + if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary1: + key_points = summary1["key_points"] + halfway = len(key_points) // 2 + default_merged["keypoints"].extend(key_points[:halfway]) + default_merged["events"].extend(key_points[halfway:]) + + if "key_points" in summary2: + default_merged["keypoints"].extend(summary2.get("keypoints", [])) + default_merged["events"].extend(summary2.get("events", [])) + # 如果没有新的结构,尝试从旧结构分离 + if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary2: + key_points = summary2["key_points"] + halfway = len(key_points) // 2 + default_merged["keypoints"].extend(key_points[:halfway]) + default_merged["events"].extend(key_points[halfway:]) + + # 确保列表不为空 + if not default_merged["keypoints"]: + default_merged["keypoints"] = ["合并的关键概念"] + if not default_merged["events"]: + default_merged["events"] = ["合并的事件"] + + # 添加key_points兼容 + default_merged["key_points"] = default_merged["keypoints"] + default_merged["events"] + + try: + # 调用LLM合并记忆 + response, _ = await self.llm_summarizer.generate_response_async(prompt) + + # 处理LLM返回的合并结果 + try: + # 修复JSON格式 + fixed_json_string = repair_json(response) + + # 将修复后的字符串解析为Python对象 + if isinstance(fixed_json_string, str): + try: + merged_data = json.loads(fixed_json_string) + except json.JSONDecodeError as decode_error: + logger.error(f"JSON解析错误: {str(decode_error)}") + merged_data = default_merged + else: + # 如果repair_json直接返回了字典对象,直接使用 + merged_data = fixed_json_string + + # 确保是字典类型 + if not isinstance(merged_data, dict): + logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}") + merged_data = default_merged + + # 确保所有必要字段都存在且类型正确 + if "content" not in merged_data or not isinstance(merged_data["content"], str): + merged_data["content"] = default_merged["content"] + + if "brief" not in merged_data or not isinstance(merged_data["brief"], str): + merged_data["brief"] = default_merged["brief"] + + if "detailed" not in merged_data or not isinstance(merged_data["detailed"], str): + merged_data["detailed"] = default_merged["detailed"] + + # 处理关键概念 + if "keypoints" not in merged_data or not isinstance(merged_data["keypoints"], list): + merged_data["keypoints"] = default_merged["keypoints"] + else: + # 确保keypoints中的每个项目都是字符串 + merged_data["keypoints"] = [str(point) for point in merged_data["keypoints"] if point is not None] + if not merged_data["keypoints"]: + merged_data["keypoints"] = ["合并的关键概念"] + + # 处理事件 + if "events" not in merged_data or not isinstance(merged_data["events"], list): + merged_data["events"] = default_merged["events"] + else: + # 确保events中的每个项目都是字符串 + merged_data["events"] = [str(event) for event in merged_data["events"] if event is not None] + if not merged_data["events"]: + merged_data["events"] = ["合并的事件"] + + # 添加key_points兼容 + merged_data["key_points"] = merged_data["keypoints"] + merged_data["events"] + + except Exception as e: + logger.error(f"合并记忆时处理JSON出错: {str(e)}") + traceback.print_exc() + merged_data = default_merged + except Exception as e: + logger.error(f"合并记忆调用LLM出错: {str(e)}") + traceback.print_exc() + merged_data = default_merged + + # 创建新的记忆项 + # 合并记忆项的标签 + merged_tags = memory_item1.tags.union(memory_item2.tags) + + # 取两个记忆项中更强的来源 + merged_source = ( + memory_item1.from_source + if memory_item1.memory_strength >= memory_item2.memory_strength + else memory_item2.from_source + ) + + # 创建新的记忆项 + merged_memory = MemoryItem(data=merged_data["content"], from_source=merged_source, tags=list(merged_tags)) + + # 设置合并后的摘要 + summary = { + "brief": merged_data["brief"], + "detailed": merged_data["detailed"], + "keypoints": merged_data["keypoints"], + "events": merged_data["events"], + "key_points": merged_data["key_points"], + } + merged_memory.set_summary(summary) + + # 记忆强度取两者最大值 + merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength) + + # 添加到存储中 + self.push_item(merged_memory) + + # 如果需要,删除原始记忆 + if delete_originals: + self.delete(memory_id1) + self.delete(memory_id2) + + return merged_memory + + def delete_earliest_memory(self) -> bool: + """ + 删除最早的记忆项 + + Returns: + 是否成功删除 + """ + # 获取所有记忆项 + all_memories = self.get_all_items() + + if not all_memories: + return False + + # 按时间戳排序,找到最早的记忆项 + earliest_memory = min(all_memories, key=lambda item: item.timestamp) + + # 删除最早的记忆项 + return self.delete(earliest_memory.id) diff --git a/src/chat/focus_chat/working_memory/working_memory.py b/src/chat/focus_chat/working_memory/working_memory.py new file mode 100644 index 000000000..db9824150 --- /dev/null +++ b/src/chat/focus_chat/working_memory/working_memory.py @@ -0,0 +1,192 @@ +from typing import List, Any, Optional +import asyncio +import random +from src.common.logger_manager import get_logger +from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem + +logger = get_logger(__name__) + +# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务 + + +class WorkingMemory: + """ + 工作记忆,负责协调和运作记忆 + 从属于特定的流,用chat_id来标识 + """ + + def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60): + """ + 初始化工作记忆管理器 + + Args: + max_memories_per_chat: 每个聊天的最大记忆数量 + auto_decay_interval: 自动衰减记忆的时间间隔(秒) + """ + self.memory_manager = MemoryManager(chat_id) + + # 记忆容量上限 + self.max_memories_per_chat = max_memories_per_chat + + # 自动衰减间隔 + self.auto_decay_interval = auto_decay_interval + + # 衰减任务 + self.decay_task = None + + # 启动自动衰减任务 + self._start_auto_decay() + + def _start_auto_decay(self): + """启动自动衰减任务""" + if self.decay_task is None: + self.decay_task = asyncio.create_task(self._auto_decay_loop()) + + async def _auto_decay_loop(self): + """自动衰减循环""" + while True: + await asyncio.sleep(self.auto_decay_interval) + try: + await self.decay_all_memories() + except Exception as e: + print(f"自动衰减记忆时出错: {str(e)}") + + async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None): + """ + 添加一段记忆到指定聊天 + + Args: + content: 记忆内容 + from_source: 数据来源 + tags: 数据标签列表 + + Returns: + 包含记忆信息的字典 + """ + memory = await self.memory_manager.push_with_summary(content, from_source, tags) + if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat: + self.remove_earliest_memory() + + return memory + + def remove_earliest_memory(self): + """ + 删除最早的记忆 + """ + return self.memory_manager.delete_earliest_memory() + + async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]: + """ + 检索记忆 + + Args: + chat_id: 聊天ID + memory_id: 记忆ID + + Returns: + 检索到的记忆项,如果不存在则返回None + """ + memory_item = self.memory_manager.get_by_id(memory_id) + if memory_item: + memory_item.retrieval_count += 1 + memory_item.increase_strength(5) + return memory_item + return None + + async def decay_all_memories(self, decay_factor: float = 0.5): + """ + 对所有聊天的所有记忆进行衰减 + 衰减:对记忆进行refine压缩,强度会变为原先的0.5 + + Args: + decay_factor: 衰减因子(0-1之间) + """ + logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}") + + all_memories = self.memory_manager.get_all_items() + + for memory_item in all_memories: + # 如果压缩完小于1会被删除 + memory_id = memory_item.id + self.memory_manager.decay_memory(memory_id, decay_factor) + if memory_item.memory_strength < 1: + self.memory_manager.delete(memory_id) + continue + # 计算衰减量 + if memory_item.memory_strength < 5: + await self.memory_manager.refine_memory( + memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩" + ) + + async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem: + """合并记忆 + + Args: + memory_str: 记忆内容 + """ + return await self.memory_manager.merge_memories( + memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容" + ) + + # 暂时没用,先留着 + async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2): + """ + 模拟记忆模糊过程,随机选择一部分记忆进行精简 + + Args: + chat_id: 聊天ID + blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简 + """ + memory = self.get_memory(chat_id) + + # 获取所有字符串类型且有总结的记忆 + all_summarized_memories = [] + for type_items in memory._memory.values(): + for item in type_items: + if isinstance(item.data, str) and hasattr(item, "summary") and item.summary: + all_summarized_memories.append(item) + + if not all_summarized_memories: + return + + # 计算要模糊的记忆数量 + blur_count = max(1, int(len(all_summarized_memories) * blur_rate)) + + # 随机选择要模糊的记忆 + memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories))) + + # 对选中的记忆进行精简 + for memory_item in memories_to_blur: + try: + # 根据记忆强度决定模糊程度 + if memory_item.memory_strength > 7: + requirement = "保留所有重要信息,仅略微精简" + elif memory_item.memory_strength > 4: + requirement = "保留核心要点,适度精简细节" + else: + requirement = "只保留最关键的1-2个要点,大幅精简内容" + + # 进行精简 + await memory.refine_memory(memory_item.id, requirement) + print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}") + + except Exception as e: + print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}") + + async def shutdown(self) -> None: + """关闭管理器,停止所有任务""" + if self.decay_task and not self.decay_task.done(): + self.decay_task.cancel() + try: + await self.decay_task + except asyncio.CancelledError: + pass + + def get_all_memories(self) -> List[MemoryItem]: + """ + 获取所有记忆项目 + + Returns: + List[MemoryItem]: 当前工作记忆中的所有记忆项目列表 + """ + return self.memory_manager.get_all_items() diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index c30bc8e43..7e4872014 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -14,6 +14,7 @@ from typing import Optional import difflib from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入 from src.chat.heart_flow.observation.observation import Observation + from src.common.logger_manager import get_logger from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.utils.prompt_builder import Prompt @@ -43,6 +44,7 @@ class ChattingObservation(Observation): def __init__(self, chat_id): super().__init__(chat_id) self.chat_id = chat_id + self.platform = "qq" # --- Initialize attributes (defaults) --- self.is_group_chat: bool = False @@ -65,7 +67,7 @@ class ChattingObservation(Observation): self.oldest_messages_str = "" self.compressor_prompt = "" # TODO: API-Adapter修改标记 - self.llm_summary = LLMRequest( + self.model_summary = LLMRequest( model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" ) @@ -106,7 +108,7 @@ class ChattingObservation(Observation): mid_memory_str += f"{mid_memory['theme']}\n" return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str - def serch_message_by_text(self, text: str) -> Optional[MessageRecv]: + def search_message_by_text(self, text: str) -> Optional[MessageRecv]: """ 根据回复的纯文本 1. 在talking_message中查找最新的,最匹配的消息 @@ -119,12 +121,12 @@ class ChattingObservation(Observation): for message in reverse_talking_message: if message["processed_plain_text"] == text: find_msg = message - logger.debug(f"找到的锚定消息:find_msg: {find_msg}") + # logger.debug(f"找到的锚定消息:find_msg: {find_msg}") break else: similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio() msg_list.append({"message": message, "similarity": similarity}) - logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}") + # logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}") if not find_msg: if msg_list: msg_list.sort(key=lambda x: x["similarity"], reverse=True) @@ -151,7 +153,7 @@ class ChattingObservation(Observation): } message_info = { - "platform": find_msg.get("platform"), + "platform": self.platform, "message_id": find_msg.get("message_id"), "time": find_msg.get("time"), "group_info": group_info, diff --git a/src/chat/heart_flow/observation/hfcloop_observation.py b/src/chat/heart_flow/observation/hfcloop_observation.py index 470671e28..82c9c879a 100644 --- a/src/chat/heart_flow/observation/hfcloop_observation.py +++ b/src/chat/heart_flow/observation/hfcloop_observation.py @@ -3,6 +3,7 @@ from datetime import datetime from src.common.logger_manager import get_logger from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail +from src.chat.focus_chat.planners.action_manager import ActionManager from typing import List # Import the new utility function @@ -16,15 +17,17 @@ class HFCloopObservation: self.observe_id = observe_id self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 self.history_loop: List[CycleDetail] = [] + self.action_manager = ActionManager() def get_observe_info(self): return self.observe_info def add_loop_info(self, loop_info: CycleDetail): - # logger.debug(f"添加循环信息111111111111111111111111111111111111: {loop_info}") - # print(f"添加循环信息111111111111111111111111111111111111: {loop_info}") self.history_loop.append(loop_info) + def set_action_manager(self, action_manager: ActionManager): + self.action_manager = action_manager + async def observe(self): recent_active_cycles: List[CycleDetail] = [] for cycle in reversed(self.history_loop): @@ -62,7 +65,6 @@ class HFCloopObservation: if cycle_info_block: cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n" else: - # 如果最近的活动循环不是文本回复,或者没有活动循环 cycle_info_block = "\n" # 获取history_loop中最新添加的 @@ -72,8 +74,16 @@ class HFCloopObservation: end_time = last_loop.end_time if start_time is not None and end_time is not None: time_diff = int(end_time - start_time) - cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}分钟\n" + if time_diff > 60: + cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff / 60}分钟\n" + else: + cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}秒\n" else: - cycle_info_block += "\n无法获取上一次阅读消息的时间\n" + cycle_info_block += "\n你还没看过消息\n" + + using_actions = self.action_manager.get_using_actions() + for action_name, action_info in using_actions.items(): + action_description = action_info["description"] + cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n" self.observe_info = cycle_info_block diff --git a/src/chat/heart_flow/observation/memory_observation.py b/src/chat/heart_flow/observation/memory_observation.py deleted file mode 100644 index 1938a47d3..000000000 --- a/src/chat/heart_flow/observation/memory_observation.py +++ /dev/null @@ -1,55 +0,0 @@ -from src.chat.heart_flow.observation.observation import Observation -from datetime import datetime -from src.common.logger_manager import get_logger -import traceback - -# Import the new utility function -from src.chat.memory_system.Hippocampus import HippocampusManager -import jieba -from typing import List - -logger = get_logger("memory") - - -class MemoryObservation(Observation): - def __init__(self, observe_id): - super().__init__(observe_id) - self.observe_info: str = "" - self.context: str = "" - self.running_memory: List[dict] = [] - - def get_observe_info(self): - for memory in self.running_memory: - self.observe_info += f"{memory['topic']}:{memory['content']}\n" - return self.observe_info - - async def observe(self): - # ---------- 2. 获取记忆 ---------- - try: - # 从聊天内容中提取关键词 - chat_words = set(jieba.cut(self.context)) - # 过滤掉停用词和单字词 - keywords = [word for word in chat_words if len(word) > 1] - # 去重并限制数量 - keywords = list(set(keywords))[:5] - - logger.debug(f"取的关键词: {keywords}") - - # 调用记忆系统获取相关记忆 - related_memory = await HippocampusManager.get_instance().get_memory_from_topic( - valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3 - ) - - logger.debug(f"获取到的记忆: {related_memory}") - - if related_memory: - for topic, memory in related_memory: - # 将记忆添加到 running_memory - self.running_memory.append( - {"topic": topic, "content": memory, "timestamp": datetime.now().isoformat()} - ) - logger.debug(f"添加新记忆: {topic} - {memory}") - - except Exception as e: - logger.error(f"观察 记忆时出错: {e}") - logger.error(traceback.format_exc()) diff --git a/src/chat/heart_flow/observation/structure_observation.py b/src/chat/heart_flow/observation/structure_observation.py new file mode 100644 index 000000000..2732ef0b1 --- /dev/null +++ b/src/chat/heart_flow/observation/structure_observation.py @@ -0,0 +1,32 @@ +from datetime import datetime +from src.common.logger_manager import get_logger + +# Import the new utility function + +logger = get_logger("observation") + + +# 所有观察的基类 +class StructureObservation: + def __init__(self, observe_id): + self.observe_info = "" + self.observe_id = observe_id + self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 + self.history_loop = [] + self.structured_info = [] + + def get_observe_info(self): + return self.structured_info + + def add_structured_info(self, structured_info: dict): + self.structured_info.append(structured_info) + + async def observe(self): + observed_structured_infos = [] + for structured_info in self.structured_info: + if structured_info.get("ttl") > 0: + structured_info["ttl"] -= 1 + observed_structured_infos.append(structured_info) + logger.debug(f"观察到结构化信息仍旧在: {structured_info}") + + self.structured_info = observed_structured_infos diff --git a/src/chat/heart_flow/observation/working_observation.py b/src/chat/heart_flow/observation/working_observation.py index 27b6ab92d..7013c3a2b 100644 --- a/src/chat/heart_flow/observation/working_observation.py +++ b/src/chat/heart_flow/observation/working_observation.py @@ -2,33 +2,33 @@ # 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体 from datetime import datetime from src.common.logger_manager import get_logger - +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory +from src.chat.focus_chat.working_memory.memory_item import MemoryItem +from typing import List # Import the new utility function logger = get_logger("observation") # 所有观察的基类 -class WorkingObservation: - def __init__(self, observe_id): +class WorkingMemoryObservation: + def __init__(self, observe_id, working_memory: WorkingMemory): self.observe_info = "" self.observe_id = observe_id - self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 - self.history_loop = [] - self.structured_info = [] + self.last_observe_time = datetime.now().timestamp() + + self.working_memory = working_memory + + self.retrieved_working_memory = [] def get_observe_info(self): - return self.structured_info + return self.working_memory - def add_structured_info(self, structured_info: dict): - self.structured_info.append(structured_info) + def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]): + self.retrieved_working_memory.append(retrieved_working_memory) + + def get_retrieved_working_memory(self): + return self.retrieved_working_memory async def observe(self): - observed_structured_infos = [] - for structured_info in self.structured_info: - if structured_info.get("ttl") > 0: - structured_info["ttl"] -= 1 - observed_structured_infos.append(structured_info) - logger.debug(f"观察到结构化信息仍旧在: {structured_info}") - - self.structured_info = observed_structured_infos + pass diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index d8c7c50e6..aae1721c2 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -10,7 +10,7 @@ import jieba import networkx as nx import numpy as np from collections import Counter -from ...common.database import db +from ...common.database.database import memory_db as db from ...chat.models.utils_model import LLMRequest from src.common.logger_manager import get_logger from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 @@ -193,7 +193,7 @@ class Hippocampus: def __init__(self): self.memory_graph = MemoryGraph() self.llm_topic_judge = None - self.llm_summary = None + self.model_summary = None self.entorhinal_cortex = None self.parahippocampal_gyrus = None @@ -205,7 +205,7 @@ class Hippocampus: self.entorhinal_cortex.sync_memory_from_db() # TODO: API-Adapter修改标记 self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory") - self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory") + self.model_summary = LLMRequest(global_config.model.summary, request_type="memory") def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表""" @@ -1167,7 +1167,7 @@ class ParahippocampalGyrus: # 调用修改后的 topic_what,不再需要 time_info topic_what_prompt = self.hippocampus.topic_what(input_text, topic) try: - task = self.hippocampus.llm_summary.generate_response_async(topic_what_prompt) + task = self.hippocampus.model_summary.generate_response_async(topic_what_prompt) tasks.append((topic.strip(), task)) except Exception as e: logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}") diff --git a/src/chat/memory_system/manually_alter_memory.py b/src/chat/memory_system/manually_alter_memory.py index ce5abbba7..9bbf59f5b 100644 --- a/src/chat/memory_system/manually_alter_memory.py +++ b/src/chat/memory_system/manually_alter_memory.py @@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) from src.common.logger import get_module_logger # noqa E402 -from src.common.database import db # noqa E402 +from common.database.database import db # noqa E402 logger = get_module_logger("mem_alter") console = Console() diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 0e35f6f6e..cea791de4 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -72,6 +72,7 @@ class ChatBot: message_data["message_info"]["user_info"]["user_id"] = str( message_data["message_info"]["user_info"]["user_id"] ) + # print(message_data) logger.trace(f"处理消息:{str(message_data)[:120]}...") message = MessageRecv(message_data) groupinfo = message.message_info.group_info @@ -86,12 +87,14 @@ class ChatBot: logger.trace("检测到私聊消息,检查") # 好友黑名单拦截 if userinfo.user_id not in global_config.experimental.talk_allowed_private: - logger.debug(f"用户{userinfo.user_id}没有私聊权限") + # logger.debug(f"用户{userinfo.user_id}没有私聊权限") return # 群聊黑名单拦截 + # print(groupinfo.group_id) + # print(global_config.chat_target.talk_allowed_groups) if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups: - logger.trace(f"群{groupinfo.group_id}被禁止回复") + logger.debug(f"群{groupinfo.group_id}被禁止回复") return # 确认从接口发来的message是否有自定义的prompt模板信息 diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 53ebd5026..723d6da47 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -5,7 +5,8 @@ import copy from typing import Dict, Optional -from ...common.database import db +from ...common.database.database import db +from ...common.database.database_model import ChatStreams # 新增导入 from maim_message import GroupInfo, UserInfo from src.common.logger_manager import get_logger @@ -82,7 +83,13 @@ class ChatManager: def __init__(self): if not self._initialized: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream - self._ensure_collection() + try: + db.connect(reuse_if_open=True) + # 确保 ChatStreams 表存在 + db.create_tables([ChatStreams], safe=True) + except Exception as e: + logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") + self._initialized = True # 在事件循环中启动初始化 # asyncio.create_task(self._initialize()) @@ -107,15 +114,6 @@ class ChatManager: except Exception as e: logger.error(f"聊天流自动保存失败: {str(e)}") - @staticmethod - def _ensure_collection(): - """确保数据库集合存在并创建索引""" - if "chat_streams" not in db.list_collection_names(): - db.create_collection("chat_streams") - # 创建索引 - db.chat_streams.create_index([("stream_id", 1)], unique=True) - db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]) - @staticmethod def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: """生成聊天流唯一ID""" @@ -151,16 +149,43 @@ class ChatManager: stream = self.streams[stream_id] # 更新用户信息和群组信息 stream.update_active_time() - stream = copy.deepcopy(stream) + stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 stream.user_info = user_info if group_info: stream.group_info = group_info return stream # 检查数据库中是否存在 - data = db.chat_streams.find_one({"stream_id": stream_id}) - if data: - stream = ChatStream.from_dict(data) + def _db_find_stream_sync(s_id: str): + return ChatStreams.get_or_none(ChatStreams.stream_id == s_id) + + model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) + + if model_instance: + # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息 + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + stream = ChatStream.from_dict(data_for_from_dict) # 更新用户信息和群组信息 stream.user_info = user_info if group_info: @@ -175,7 +200,7 @@ class ChatManager: group_info=group_info, ) except Exception as e: - logger.error(f"创建聊天流失败: {e}") + logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) raise e # 保存到内存和数据库 @@ -205,15 +230,38 @@ class ChatManager: elif stream.user_info and stream.user_info.user_nickname: return f"{stream.user_info.user_nickname}的私聊" else: - # 如果没有群名或用户昵称,返回 None 或其他默认值 return None @staticmethod async def _save_stream(stream: ChatStream): """保存聊天流到数据库""" if not stream.saved: - db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True) - stream.saved = True + stream_data_dict = stream.to_dict() + + def _db_save_stream_sync(s_data_dict: dict): + user_info_d = s_data_dict.get("user_info") + group_info_d = s_data_dict.get("group_info") + + fields_to_save = { + "platform": s_data_dict["platform"], + "create_time": s_data_dict["create_time"], + "last_active_time": s_data_dict["last_active_time"], + "user_platform": user_info_d["platform"] if user_info_d else "", + "user_id": user_info_d["user_id"] if user_info_d else "", + "user_nickname": user_info_d["user_nickname"] if user_info_d else "", + "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, + "group_platform": group_info_d["platform"] if group_info_d else "", + "group_id": group_info_d["group_id"] if group_info_d else "", + "group_name": group_info_d["group_name"] if group_info_d else "", + } + + ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute() + + try: + await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) + stream.saved = True + except Exception as e: + logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) async def _save_all_streams(self): """保存所有聊天流""" @@ -222,10 +270,44 @@ class ChatManager: async def load_all_streams(self): """从数据库加载所有聊天流""" - all_streams = db.chat_streams.find({}) - for data in all_streams: - stream = ChatStream.from_dict(data) - self.streams[stream.stream_id] = stream + + def _db_load_all_streams_sync(): + loaded_streams_data = [] + for model_instance in ChatStreams.select(): + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if model_instance.group_id: + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + loaded_streams_data.append(data_for_from_dict) + return loaded_streams_data + + try: + all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync) + self.streams.clear() + for data in all_streams_data_list: + stream = ChatStream.from_dict(data) + stream.saved = True + self.streams[stream.stream_id] = stream + except Exception as e: + logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) # 创建全局单例 diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index cae029a11..d0041cd51 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,9 +1,10 @@ import re from typing import Union -from ...common.database import db +# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance from .message import MessageSending, MessageRecv from .chat_stream import ChatStream +from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models from src.common.logger import get_module_logger logger = get_module_logger("message_storage") @@ -29,42 +30,66 @@ class MessageStorage: else: filtered_detailed_plain_text = "" - message_data = { - "message_id": message.message_info.message_id, - "time": message.message_info.time, - "chat_id": chat_stream.stream_id, - "chat_info": chat_stream.to_dict(), - "user_info": message.message_info.user_info.to_dict(), - # 使用过滤后的文本 - "processed_plain_text": filtered_processed_plain_text, - "detailed_plain_text": filtered_detailed_plain_text, - "memorized_times": message.memorized_times, - } - db.messages.insert_one(message_data) + chat_info_dict = chat_stream.to_dict() + user_info_dict = message.message_info.user_info.to_dict() + + # message_id 现在是 TextField,直接使用字符串值 + msg_id = message.message_info.message_id + + # 安全地获取 group_info, 如果为 None 则视为空字典 + group_info_from_chat = chat_info_dict.get("group_info") or {} + # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) + user_info_from_chat = chat_info_dict.get("user_info") or {} + + Messages.create( + message_id=msg_id, + time=float(message.message_info.time), + chat_id=chat_stream.stream_id, + # Flattened chat_info + chat_info_stream_id=chat_info_dict.get("stream_id"), + chat_info_platform=chat_info_dict.get("platform"), + chat_info_user_platform=user_info_from_chat.get("platform"), + chat_info_user_id=user_info_from_chat.get("user_id"), + chat_info_user_nickname=user_info_from_chat.get("user_nickname"), + chat_info_user_cardname=user_info_from_chat.get("user_cardname"), + chat_info_group_platform=group_info_from_chat.get("platform"), + chat_info_group_id=group_info_from_chat.get("group_id"), + chat_info_group_name=group_info_from_chat.get("group_name"), + chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)), + chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)), + # Flattened user_info (message sender) + user_platform=user_info_dict.get("platform"), + user_id=user_info_dict.get("user_id"), + user_nickname=user_info_dict.get("user_nickname"), + user_cardname=user_info_dict.get("user_cardname"), + # Text content + processed_plain_text=filtered_processed_plain_text, + detailed_plain_text=filtered_detailed_plain_text, + memorized_times=message.memorized_times, + ) except Exception: logger.exception("存储消息失败") @staticmethod async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None: """存储撤回消息到数据库""" - if "recalled_messages" not in db.list_collection_names(): - db.create_collection("recalled_messages") - else: - try: - message_data = { - "message_id": message_id, - "time": time, - "stream_id": chat_stream.stream_id, - } - db.recalled_messages.insert_one(message_data) - except Exception: - logger.exception("存储撤回消息失败") + # Table creation is handled by initialize_database in database_model.py + try: + RecalledMessages.create( + message_id=message_id, + time=float(time), # Assuming time is a string representing a float timestamp + stream_id=chat_stream.stream_id, + ) + except Exception: + logger.exception("存储撤回消息失败") @staticmethod async def remove_recalled_message(time: str) -> None: """删除撤回消息""" try: - db.recalled_messages.delete_many({"time": {"$lt": time - 300}}) + # Assuming input 'time' is a string timestamp that can be converted to float + current_time_float = float(time) + RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() except Exception: logger.exception("删除撤回消息失败") diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py index a161ae4d9..f6528856d 100644 --- a/src/chat/models/utils_model.py +++ b/src/chat/models/utils_model.py @@ -12,7 +12,8 @@ import base64 from PIL import Image import io import os -from ...common.database import db +from src.common.database.database import db # 确保 db 被导入用于 create_tables +from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from ...config.config import global_config from rich.traceback import install @@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"{image_base64[:10]}...{image_base64[-10:]}" ) - # if isinstance(content, str) and len(content) > 100: - # payload["messages"][0]["content"] = content[:100] return payload @@ -134,13 +133,11 @@ class LLMRequest: def _init_database(): """初始化数据库集合""" try: - # 创建llm_usage集合的索引 - db.llm_usage.create_index([("timestamp", 1)]) - db.llm_usage.create_index([("model_name", 1)]) - db.llm_usage.create_index([("user_id", 1)]) - db.llm_usage.create_index([("request_type", 1)]) + # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 + db.create_tables([LLMUsage], safe=True) + logger.debug("LLMUsage 表已初始化/确保存在。") except Exception as e: - logger.error(f"创建数据库索引失败: {str(e)}") + logger.error(f"创建 LLMUsage 表失败: {str(e)}") def _record_usage( self, @@ -165,19 +162,19 @@ class LLMRequest: request_type = self.request_type try: - usage_data = { - "model_name": self.model_name, - "user_id": user_id, - "request_type": request_type, - "endpoint": endpoint, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "cost": self._calculate_cost(prompt_tokens, completion_tokens), - "status": "success", - "timestamp": datetime.now(), - } - db.llm_usage.insert_one(usage_data) + # 使用 Peewee 模型创建记录 + LLMUsage.create( + model_name=self.model_name, + user_id=user_id, + request_type=request_type, + endpoint=endpoint, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=self._calculate_cost(prompt_tokens, completion_tokens), + status="success", + timestamp=datetime.now(), # Peewee 会处理 DateTimeField + ) logger.trace( f"Token使用情况 - 模型: {self.model_name}, " f"用户: {user_id}, 类型: {request_type}, " diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index aadbb1d2e..562cdc235 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -1,5 +1,6 @@ from src.common.logger_manager import get_logger -from ...common.database import db +from ...common.database.database import db +from ...common.database.database_model import PersonInfo # 新增导入 import copy import hashlib from typing import Any, Callable, Dict @@ -16,7 +17,7 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt from pathlib import Path import pandas as pd -import json +import json # 新增导入 import re @@ -38,22 +39,18 @@ logger = get_logger("person_info") person_info_default = { "person_id": None, - "person_name": None, + "person_name": None, # 模型中已设为 null=True,此默认值OK "name_reason": None, - "platform": None, - "user_id": None, - "nickname": None, - # "age" : 0, + "platform": "unknown", # 提供非None的默认值 + "user_id": "unknown", # 提供非None的默认值 + "nickname": "Unknown", # 提供非None的默认值 "relationship_value": 0, - # "saved" : True, - # "impression" : None, - # "gender" : Unkown, - "konw_time": 0, + "know_time": 0, # 修正拼写:konw_time -> know_time "msg_interval": 2000, - "msg_interval_list": [], - "user_cardname": None, # 添加群名片 - "user_avatar": None, # 添加头像信息(例如URL或标识符) -} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 + "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField + "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中 + "user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中 +} class PersonInfoManager: @@ -65,21 +62,26 @@ class PersonInfoManager: max_tokens=256, request_type="qv_name", ) - if "person_info" not in db.list_collection_names(): - db.create_collection("person_info") - db.person_info.create_index("person_id", unique=True) + try: + db.connect(reuse_if_open=True) + db.create_tables([PersonInfo], safe=True) + except Exception as e: + logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") # 初始化时读取所有person_name - cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0}) - for doc in cursor: - if doc.get("person_name"): - self.person_name_list[doc["person_id"]] = doc["person_name"] - logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") + try: + for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where( + PersonInfo.person_name.is_null(False) + ): + if record.person_name: + self.person_name_list[record.person_id] = record.person_name + logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)") + except Exception as e: + logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") @staticmethod def get_person_id(platform: str, user_id: int): """获取唯一id""" - # 如果platform中存在-,就截取-后面的部分 if "-" in platform: platform = platform.split("-")[1] @@ -87,15 +89,27 @@ class PersonInfoManager: key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() - def is_person_known(self, platform: str, user_id: int): + async def is_person_known(self, platform: str, user_id: int): """判断是否认识某人""" person_id = self.get_person_id(platform, user_id) - document = db.person_info.find_one({"person_id": person_id}) - if document: - return True - else: + + def _db_check_known_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None + + try: + return await asyncio.to_thread(_db_check_known_sync, person_id) + except Exception as e: + logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}") return False + def get_person_id_by_person_name(self, person_name: str): + """根据用户名获取用户ID""" + document = db.person_info.find_one({"person_name": person_name}) + if document: + return document["person_id"] + else: + return "" + @staticmethod async def create_person_info(person_id: str, data: dict = None): """创建一个项""" @@ -104,73 +118,111 @@ class PersonInfoManager: return _person_info_default = copy.deepcopy(person_info_default) - _person_info_default["person_id"] = person_id + model_fields = PersonInfo._meta.fields.keys() + + final_data = {"person_id": person_id} if data: - for key in _person_info_default: - if key != "person_id" and key in data: - _person_info_default[key] = data[key] + for key, value in data.items(): + if key in model_fields: + final_data[key] = value - db.person_info.insert_one(_person_info_default) + for key, default_value in _person_info_default.items(): + if key in model_fields and key not in final_data: + final_data[key] = default_value + + if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list): + final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"]) + elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields: + final_data["msg_interval_list"] = json.dumps([]) + + def _db_create_sync(p_data: dict): + try: + PersonInfo.create(**p_data) + return True + except Exception as e: + logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}") + return False + + await asyncio.to_thread(_db_create_sync, final_data) async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): """更新某一个字段,会补全""" - if field_name not in person_info_default.keys(): - logger.debug(f"更新'{field_name}'失败,未定义的字段") + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。") + return + logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。") return - document = db.person_info.find_one({"person_id": person_id}) + def _db_update_sync(p_id: str, f_name: str, val): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + if f_name == "msg_interval_list" and isinstance(val, list): + setattr(record, f_name, json.dumps(val)) + else: + setattr(record, f_name, val) + record.save() + return True, False + return False, True - if document: - db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}}) - else: - data[field_name] = value - logger.debug(f"更新时{person_id}不存在,已新建") - await self.create_person_info(person_id, data) + found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value) + + if needs_creation: + logger.debug(f"更新时 {person_id} 不存在,将新建。") + creation_data = data if data is not None else {} + creation_data[field_name] = value + if "platform" not in creation_data or "user_id" not in creation_data: + logger.warning(f"为 {person_id} 创建记录时,platform/user_id 可能缺失。") + + await self.create_person_info(person_id, creation_data) @staticmethod async def has_one_field(person_id: str, field_name: str): """判断是否存在某一个字段""" - document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) - if document: - return True - else: + if field_name not in PersonInfo._meta.fields: + logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。") + return False + + def _db_has_field_sync(p_id: str, f_name: str): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + return True + return False + + try: + return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) + except Exception as e: + logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}") return False @staticmethod def _extract_json_from_text(text: str) -> dict: """从文本中提取JSON数据的高容错方法""" try: - # 尝试直接解析 parsed_json = json.loads(text) - # 如果解析结果是列表,尝试取第一个元素 if isinstance(parsed_json, list): - if parsed_json: # 检查列表是否为空 + if parsed_json: parsed_json = parsed_json[0] - else: # 如果列表为空,重置为 None,走后续逻辑 + else: parsed_json = None - # 确保解析结果是字典 if isinstance(parsed_json, dict): return parsed_json except json.JSONDecodeError: - # 解析失败,继续尝试其他方法 pass except Exception as e: logger.warning(f"尝试直接解析JSON时发生意外错误: {e}") - pass # 继续尝试其他方法 + pass - # 如果直接解析失败或结果不是字典 try: - # 尝试找到JSON对象格式的部分 json_pattern = r"\{[^{}]*\}" matches = re.findall(json_pattern, text) if matches: parsed_obj = json.loads(matches[0]) - if isinstance(parsed_obj, dict): # 确保是字典 + if isinstance(parsed_obj, dict): return parsed_obj - # 如果上面都失败了,尝试提取键值对 nickname_pattern = r'"nickname"[:\s]+"([^"]+)"' reason_pattern = r'"reason"[:\s]+"([^"]+)"' @@ -185,7 +237,6 @@ class PersonInfoManager: except Exception as e: logger.error(f"后备JSON提取失败: {str(e)}") - # 如果所有方法都失败了,返回默认字典 logger.warning(f"无法从文本中提取有效的JSON字典: {text}") return {"nickname": "", "reason": ""} @@ -200,9 +251,11 @@ class PersonInfoManager: old_name = await self.get_value(person_id, "person_name") old_reason = await self.get_value(person_id, "name_reason") - max_retries = 5 # 最大重试次数 + max_retries = 5 current_try = 0 - existing_names = "" + existing_names_str = "" + current_name_set = set(self.person_name_list.values()) + while current_try < max_retries: individuality = Individuality.get_instance() prompt_personality = individuality.get_prompt(x_person=2, level=1) @@ -217,45 +270,58 @@ class PersonInfoManager: qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}," qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸" - qv_name_prompt += ( "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改" ) - if existing_names: - qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n" + + if existing_names_str: + qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n" + + if len(current_name_set) < 50 and current_name_set: + qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n" + qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:" qv_name_prompt += """{ "nickname": "昵称", "reason": "理由" }""" - # logger.debug(f"取名提示词:{qv_name_prompt}") response = await self.qv_name_llm.generate_response(qv_name_prompt) logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") result = self._extract_json_from_text(response[0]) - if not result["nickname"]: - logger.error("生成的昵称为空,重试中...") + if not result or not result.get("nickname"): + logger.error("生成的昵称为空或结果格式不正确,重试中...") current_try += 1 continue - # 检查生成的昵称是否已存在 - if result["nickname"] not in self.person_name_list.values(): - # 更新数据库和内存中的列表 - await self.update_one_field(person_id, "person_name", result["nickname"]) - # await self.update_one_field(person_id, "nickname", user_nickname) - # await self.update_one_field(person_id, "avatar", user_avatar) - await self.update_one_field(person_id, "name_reason", result["reason"]) + generated_nickname = result["nickname"] - self.person_name_list[person_id] = result["nickname"] - # logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}") + is_duplicate = False + if generated_nickname in current_name_set: + is_duplicate = True + else: + + def _db_check_name_exists_sync(name_to_check): + return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists() + + if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): + is_duplicate = True + current_name_set.add(generated_nickname) + + if not is_duplicate: + await self.update_one_field(person_id, "person_name", generated_nickname) + await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由")) + + self.person_name_list[person_id] = generated_nickname return result else: - existing_names += f"{result['nickname']}、" + if existing_names_str: + existing_names_str += "、" + existing_names_str += generated_nickname + logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...") + current_try += 1 - logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...") - current_try += 1 - - logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称") + logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}") return None @staticmethod @@ -265,30 +331,56 @@ class PersonInfoManager: logger.debug("删除失败:person_id 不能为空") return - result = db.person_info.delete_one({"person_id": person_id}) - if result.deleted_count > 0: - logger.debug(f"删除成功:person_id={person_id}") + def _db_delete_sync(p_id: str): + try: + query = PersonInfo.delete().where(PersonInfo.person_id == p_id) + deleted_count = query.execute() + return deleted_count + except Exception as e: + logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}") + return 0 + + deleted_count = await asyncio.to_thread(_db_delete_sync, person_id) + + if deleted_count > 0: + logger.debug(f"删除成功:person_id={person_id} (Peewee)") else: - logger.debug(f"删除失败:未找到 person_id={person_id}") + logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)") @staticmethod async def get_value(person_id: str, field_name: str): """获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值""" if not person_id: logger.debug("get_value获取失败:person_id不能为空") + return person_info_default.get(field_name) + + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + logger.trace(f"字段'{field_name}'不在Peewee模型中,但存在于默认配置中。返回配置默认值。") + return copy.deepcopy(person_info_default[field_name]) + logger.debug(f"get_value获取失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") return None - if field_name not in person_info_default: - logger.debug(f"get_value获取失败:字段'{field_name}'未定义") + def _db_get_value_sync(p_id: str, f_name: str): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + val = getattr(record, f_name) + if f_name == "msg_interval_list" and isinstance(val, str): + try: + return json.loads(val) + except json.JSONDecodeError: + logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}") + return copy.deepcopy(person_info_default.get(f_name, [])) + return val return None - document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) + value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name) - if document and field_name in document: - return document[field_name] + if value is not None: + return value else: - default_value = copy.deepcopy(person_info_default[field_name]) - logger.trace(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}") + default_value = copy.deepcopy(person_info_default.get(field_name)) + logger.trace(f"获取{person_id}的{field_name}失败或值为None,已返回默认值{default_value} (Peewee)") return default_value @staticmethod @@ -298,93 +390,84 @@ class PersonInfoManager: logger.debug("get_values获取失败:person_id不能为空") return {} - # 检查所有字段是否有效 - for field in field_names: - if field not in person_info_default: - logger.debug(f"get_values获取失败:字段'{field}'未定义") - return {} - - # 构建查询投影(所有字段都有效才会执行到这里) - projection = {field: 1 for field in field_names} - - document = db.person_info.find_one({"person_id": person_id}, projection) - result = {} - for field in field_names: - result[field] = copy.deepcopy( - document.get(field, person_info_default[field]) if document else person_info_default[field] - ) + + def _db_get_record_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) + + record = await asyncio.to_thread(_db_get_record_sync, person_id) + + for field_name in field_names: + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + result[field_name] = copy.deepcopy(person_info_default[field_name]) + logger.trace(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。") + else: + logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") + result[field_name] = None + continue + + if record: + value = getattr(record, field_name) + if field_name == "msg_interval_list" and isinstance(value, str): + try: + result[field_name] = json.loads(value) + except json.JSONDecodeError: + logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}") + result[field_name] = copy.deepcopy(person_info_default.get(field_name, [])) + elif value is not None: + result[field_name] = value + else: + result[field_name] = copy.deepcopy(person_info_default.get(field_name)) + else: + result[field_name] = copy.deepcopy(person_info_default.get(field_name)) return result @staticmethod async def del_all_undefined_field(): - """删除所有项里的未定义字段""" - # 获取所有已定义的字段名 - defined_fields = set(person_info_default.keys()) - - try: - # 遍历集合中的所有文档 - for document in db.person_info.find({}): - # 找出文档中未定义的字段 - undefined_fields = set(document.keys()) - defined_fields - {"_id"} - - if undefined_fields: - # 构建更新操作,使用$unset删除未定义字段 - update_result = db.person_info.update_one( - {"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}} - ) - - if update_result.modified_count > 0: - logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}") - - return - - except Exception as e: - logger.error(f"清理未定义字段时出错: {e}") - return + """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。""" + logger.info( + "del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。" + ) + return @staticmethod async def get_specific_value_list( field_name: str, - way: Callable[[Any], bool], # 接受任意类型值 + way: Callable[[Any], bool], ) -> Dict[str, Any]: """ 获取满足条件的字段值字典 - - Args: - field_name: 目标字段名 - way: 判断函数 (value: Any) -> bool - - Returns: - {person_id: value} | {} - - Example: - # 查找所有nickname包含"admin"的用户 - result = manager.specific_value_list( - "nickname", - lambda x: "admin" in x.lower() - ) """ - if field_name not in person_info_default: - logger.error(f"字段检查失败:'{field_name}'未定义") + if field_name not in PersonInfo._meta.fields: + logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义") return {} + def _db_get_specific_sync(f_name: str): + found_results = {} + try: + for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)): + value = getattr(record, f_name) + if f_name == "msg_interval_list" and isinstance(value, str): + try: + processed_value = json.loads(value) + except json.JSONDecodeError: + logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}") + continue + else: + processed_value = value + + if way(processed_value): + found_results[record.person_id] = processed_value + except Exception as e_query: + logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True) + return found_results + try: - result = {} - for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}): - try: - value = doc[field_name] - if way(value): - result[doc["person_id"]] = value - except (KeyError, TypeError, ValueError) as e: - logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}") - continue - - return result - + return await asyncio.to_thread(_db_get_specific_sync, field_name) except Exception as e: - logger.error(f"数据库查询失败: {str(e)}", exc_info=True) + logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True) return {} async def personal_habit_deduction(self): @@ -392,35 +475,31 @@ class PersonInfoManager: try: while 1: await asyncio.sleep(600) - current_time = datetime.datetime.now() - logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") + current_time_dt = datetime.datetime.now() + logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}") - # "msg_interval"推断 - msg_interval_map = False - msg_interval_lists = await self.get_specific_value_list( + msg_interval_map_generated = False + msg_interval_lists_map = await self.get_specific_value_list( "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100 ) - for person_id, msg_interval_list_ in msg_interval_lists.items(): + + for person_id, actual_msg_interval_list in msg_interval_lists_map.items(): await asyncio.sleep(0.3) try: time_interval = [] - for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]): + for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]): delta = t2 - t1 if delta > 0: time_interval.append(delta) time_interval = [t for t in time_interval if 200 <= t <= 8000] - # --- 修改后的逻辑 --- - # 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断) - if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条 - time_interval.sort() - # 画图(log) - 这部分保留 - msg_interval_map = True + if len(time_interval) >= 30 + 10: + time_interval.sort() + msg_interval_map_generated = True log_dir = Path("logs/person_info") log_dir.mkdir(parents=True, exist_ok=True) plt.figure(figsize=(10, 6)) - # 使用截断前的数据画图,更能反映原始分布 time_series_original = pd.Series(time_interval) plt.hist( time_series_original, @@ -442,34 +521,29 @@ class PersonInfoManager: img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" plt.savefig(img_path) plt.close() - # 画图结束 - # 去掉头尾各 5 个数据点 trimmed_interval = time_interval[5:-5] - - # 计算截断后数据的 37% 分位数 - if trimmed_interval: # 确保截断后列表不为空 - msg_interval = int(round(np.percentile(trimmed_interval, 37))) - # 更新数据库 - await self.update_one_field(person_id, "msg_interval", msg_interval) - logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}") + if trimmed_interval: + msg_interval_val = int(round(np.percentile(trimmed_interval, 37))) + await self.update_one_field(person_id, "msg_interval", msg_interval_val) + logger.trace( + f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}" + ) else: logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval") else: logger.trace( f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)" ) - # --- 修改结束 --- - except Exception as e: - logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}") + except Exception as e_inner: + logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}") continue - # 其他... - - if msg_interval_map: + if msg_interval_map_generated: logger.trace("已保存分布图到: logs/person_info") - current_time = datetime.datetime.now() - logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") + + current_time_dt_end = datetime.datetime.now() + logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}") await asyncio.sleep(86400) except Exception as e: @@ -482,41 +556,27 @@ class PersonInfoManager: """ 根据 platform 和 user_id 获取 person_id。 如果对应的用户不存在,则使用提供的可选信息创建新用户。 - - Args: - platform: 平台标识 - user_id: 用户在该平台上的ID - nickname: 用户的昵称 (可选,用于创建新用户) - user_cardname: 用户的群名片 (可选,用于创建新用户) - user_avatar: 用户的头像信息 (可选,用于创建新用户) - - Returns: - 对应的 person_id。 """ person_id = self.get_person_id(platform, user_id) - # 检查用户是否已存在 - # 使用静态方法 get_person_id,因此可以直接调用 db - document = db.person_info.find_one({"person_id": person_id}) + def _db_check_exists_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) - if document is None: - logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。") + record = await asyncio.to_thread(_db_check_exists_sync, person_id) + + if record is None: + logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") initial_data = { "platform": platform, - "user_id": user_id, + "user_id": str(user_id), "nickname": nickname, - "konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间 - # 注意:这里没有添加 user_cardname 和 user_avatar,因为它们不在 person_info_default 中 - # 如果需要存储它们,需要先在 person_info_default 中定义 + "know_time": int(datetime.datetime.now().timestamp()), # 修正拼写:konw_time -> know_time } - # 过滤掉值为 None 的初始数据 - initial_data = {k: v for k, v in initial_data.items() if v is not None} + model_fields = PersonInfo._meta.fields.keys() + filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} - # 注意:create_person_info 是静态方法 - await PersonInfoManager.create_person_info(person_id, data=initial_data) - # 创建后,可以考虑立即为其取名,但这可能会增加延迟 - # await self.qv_person_name(person_id, nickname, user_cardname, user_avatar) - logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}") + await self.create_person_info(person_id, data=filtered_initial_data) + logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") return person_id @@ -526,35 +586,55 @@ class PersonInfoManager: logger.debug("get_person_info_by_name 获取失败:person_name 不能为空") return None - # 优先从内存缓存查找 person_id found_person_id = None - for pid, name in self.person_name_list.items(): - if name == person_name: + for pid, name_in_cache in self.person_name_list.items(): + if name_in_cache == person_name: found_person_id = pid - break # 找到第一个匹配就停止 + break if not found_person_id: - # 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载) - document = db.person_info.find_one({"person_name": person_name}) - if document: - found_person_id = document.get("person_id") - else: - logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户") - return None # 数据库也找不到 - # 根据找到的 person_id 获取所需信息 - if found_person_id: - required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"] - person_data = await self.get_values(found_person_id, required_fields) - if person_data: # 确保 get_values 成功返回 - return person_data + def _db_find_by_name_sync(p_name_to_find: str): + return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find) + + record = await asyncio.to_thread(_db_find_by_name_sync, person_name) + if record: + found_person_id = record.person_id + if ( + found_person_id not in self.person_name_list + or self.person_name_list[found_person_id] != person_name + ): + self.person_name_list[found_person_id] = person_name else: - logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败") + logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)") return None - else: - # 这理论上不应该发生,因为上面已经处理了找不到的情况 - logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id") - return None + + if found_person_id: + required_fields = [ + "person_id", + "platform", + "user_id", + "nickname", + "user_cardname", + "user_avatar", + "person_name", + "name_reason", + ] + valid_fields_to_get = [ + f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default + ] + + person_data = await self.get_values(found_person_id, valid_fields_to_get) + + if person_data: + final_result = {key: person_data.get(key) for key in required_fields} + return final_result + else: + logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)") + return None + + logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)") + return None person_info_manager = PersonInfoManager() diff --git a/src/chat/person_info/relationship_manager.py b/src/chat/person_info/relationship_manager.py index c8a443857..a23780c0e 100644 --- a/src/chat/person_info/relationship_manager.py +++ b/src/chat/person_info/relationship_manager.py @@ -77,7 +77,7 @@ class RelationshipManager: @staticmethod async def is_known_some_one(platform, user_id): """判断是否认识某人""" - is_known = person_info_manager.is_person_known(platform, user_id) + is_known = await person_info_manager.is_person_known(platform, user_id) return is_known @staticmethod diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index de018bdb8..d3a062680 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -451,10 +451,10 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # 处理 回复 reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" - def reply_replacer(match): + def reply_replacer(match, platform=platform): # aaa = match.group(1) bbb = match.group(2) - anon_reply = get_anon_name(platform, bbb) + anon_reply = get_anon_name(platform, bbb) # noqa return f"回复 {anon_reply}" content = re.sub(reply_pattern, reply_replacer, content, count=1) @@ -462,10 +462,10 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # 处理 @ at_pattern = r"@<([^:<>]+):([^:<>]+)>" - def at_replacer(match): + def at_replacer(match, platform=platform): # aaa = match.group(1) bbb = match.group(2) - anon_at = get_anon_name(platform, bbb) + anon_at = get_anon_name(platform, bbb) # noqa return f"@{anon_at}" content = re.sub(at_pattern, at_replacer, content) diff --git a/src/chat/utils/info_catcher.py b/src/chat/utils/info_catcher.py index a5b04d704..93cda5113 100644 --- a/src/chat/utils/info_catcher.py +++ b/src/chat/utils/info_catcher.py @@ -1,9 +1,10 @@ from src.config.config import global_config from src.chat.message_receive.message import MessageRecv, MessageSending, Message -from src.common.database import db +from src.common.database.database_model import Messages, ThinkingLog import time import traceback from typing import List +import json class InfoCatcher: @@ -59,8 +60,6 @@ class InfoCatcher: def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息 self.timing_results["sub_heartflow_observe_time"] = obs_duration - # def catch_shf - def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str): self.timing_results["sub_heartflow_step_time"] = step_duration if len(past_mind) > 1: @@ -71,25 +70,10 @@ class InfoCatcher: self.heartflow_data["sub_heartflow_now"] = current_mind def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""): - # if self.response_mode == "heart_flow": # 条件判断不需要了喵~ - # self.heartflow_data["prompt"] = prompt - # self.heartflow_data["response"] = response - # self.heartflow_data["model"] = model_name - # elif self.response_mode == "reasoning": # 条件判断不需要了喵~ - # self.reasoning_data["thinking_log"] = reasoning_content - # self.reasoning_data["prompt"] = prompt - # self.reasoning_data["response"] = response - # self.reasoning_data["model"] = model_name - - # 直接记录信息喵~ self.reasoning_data["thinking_log"] = reasoning_content self.reasoning_data["prompt"] = prompt self.reasoning_data["response"] = response self.reasoning_data["model"] = model_name - # 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~ - # self.heartflow_data["prompt"] = prompt - # self.heartflow_data["response"] = response - # self.heartflow_data["model"] = model_name self.response_text = response @@ -101,6 +85,7 @@ class InfoCatcher: ): self.timing_results["make_response_time"] = response_duration self.response_time = time.time() + self.response_messages = [] for msg in response_message: self.response_messages.append(msg) @@ -111,107 +96,112 @@ class InfoCatcher: @staticmethod def get_message_from_db_between_msgs(message_start: Message, message_end: Message): try: - # 从数据库中获取消息的时间戳 time_start = message_start.message_info.time time_end = message_end.message_info.time chat_id = message_start.chat_stream.stream_id print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") - # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 - messages_between = db.messages.find( - {"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}} - ).sort("time", -1) + messages_between_query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end)) + .order_by(Messages.time.desc()) + ) - result = list(messages_between) + result = list(messages_between_query) print(f"查询结果数量: {len(result)}") if result: - print(f"第一条消息时间: {result[0]['time']}") - print(f"最后一条消息时间: {result[-1]['time']}") + print(f"第一条消息时间: {result[0].time}") + print(f"最后一条消息时间: {result[-1].time}") return result except Exception as e: print(f"获取消息时出错: {str(e)}") + print(traceback.format_exc()) return [] def get_message_from_db_before_msg(self, message: MessageRecv): - # 从数据库中获取消息 - message_id = message.message_info.message_id - chat_id = message.chat_stream.stream_id + message_id_val = message.message_info.message_id + chat_id_val = message.chat_stream.stream_id - # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 - messages_before = ( - db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) - .sort("time", -1) + messages_before_query = ( + Messages.select() + .where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val)) + .order_by(Messages.time.desc()) .limit(global_config.chat.observation_context_size * 3) - ) # 获取更多历史信息 + ) - return list(messages_before) + return list(messages_before_query) def message_list_to_dict(self, message_list): - # 存储简化的聊天记录 result = [] - for message in message_list: - if not isinstance(message, dict): - message = self.message_to_dict(message) - # print(message) + for msg_item in message_list: + processed_msg_item = msg_item + if not isinstance(msg_item, dict): + processed_msg_item = self.message_to_dict(msg_item) + + if not processed_msg_item: + continue lite_message = { - "time": message["time"], - "user_nickname": message["user_info"]["user_nickname"], - "processed_plain_text": message["processed_plain_text"], + "time": processed_msg_item.get("time"), + "user_nickname": processed_msg_item.get("user_nickname"), + "processed_plain_text": processed_msg_item.get("processed_plain_text"), } result.append(lite_message) - return result @staticmethod - def message_to_dict(message): - if not message: + def message_to_dict(msg_obj): + if not msg_obj: return None - if isinstance(message, dict): - return message - return { - # "message_id": message.message_info.message_id, - "time": message.message_info.time, - "user_id": message.message_info.user_info.user_id, - "user_nickname": message.message_info.user_info.user_nickname, - "processed_plain_text": message.processed_plain_text, - # "detailed_plain_text": message.detailed_plain_text - } + if isinstance(msg_obj, dict): + return msg_obj - def done_catch(self): - """将收集到的信息存储到数据库的 thinking_log 集合中喵~""" - try: - # 将消息对象转换为可序列化的字典喵~ - - thinking_log_data = { - "chat_id": self.chat_id, - "trigger_text": self.trigger_response_text, - "response_text": self.response_text, - "trigger_info": { - "time": self.trigger_response_time, - "message": self.message_to_dict(self.trigger_response_message), - }, - "response_info": { - "time": self.response_time, - "message": self.response_messages, - }, - "timing_results": self.timing_results, - "chat_history": self.message_list_to_dict(self.chat_history), - "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking), - "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response), - "heartflow_data": self.heartflow_data, - "reasoning_data": self.reasoning_data, + if isinstance(msg_obj, Messages): + return { + "time": msg_obj.time, + "user_id": msg_obj.user_id, + "user_nickname": msg_obj.user_nickname, + "processed_plain_text": msg_obj.processed_plain_text, } - # 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~ - # if self.response_mode == "heart_flow": - # thinking_log_data["mode_specific_data"] = self.heartflow_data - # elif self.response_mode == "reasoning": - # thinking_log_data["mode_specific_data"] = self.reasoning_data + if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"): + return { + "time": msg_obj.message_info.time, + "user_id": msg_obj.message_info.user_info.user_id, + "user_nickname": msg_obj.message_info.user_info.user_nickname, + "processed_plain_text": msg_obj.processed_plain_text, + } - # 将数据插入到 thinking_log 集合中喵~ - db.thinking_log.insert_one(thinking_log_data) + print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}") + return {} + + def done_catch(self): + """将收集到的信息存储到数据库的 thinking_log 表中喵~""" + try: + trigger_info_dict = self.message_to_dict(self.trigger_response_message) + response_info_dict = { + "time": self.response_time, + "message": self.response_messages, + } + chat_history_list = self.message_list_to_dict(self.chat_history) + chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking) + chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response) + + log_entry = ThinkingLog( + chat_id=self.chat_id, + trigger_text=self.trigger_response_text, + response_text=self.response_text, + trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None, + response_info_json=json.dumps(response_info_dict), + timing_results_json=json.dumps(self.timing_results), + chat_history_json=json.dumps(chat_history_list), + chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list), + chat_history_after_response_json=json.dumps(chat_history_after_response_list), + heartflow_data_json=json.dumps(self.heartflow_data), + reasoning_data_json=json.dumps(self.reasoning_data), + ) + log_entry.save() return True except Exception as e: diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 3f9832926..a657ae85b 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -2,10 +2,12 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Any, Dict, Tuple, List + from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask -from ...common.database import db +from ...common.database.database import db # This db is the Peewee database instance +from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_module_logger("maibot_statistic") @@ -39,7 +41,7 @@ class OnlineTimeRecordTask(AsyncTask): def __init__(self): super().__init__(task_name="Online Time Record Task", run_interval=60) - self.record_id: str | None = None + self.record_id: int | None = None # Changed to int for Peewee's default ID """记录ID""" self._init_database() # 初始化数据库 @@ -47,49 +49,46 @@ class OnlineTimeRecordTask(AsyncTask): @staticmethod def _init_database(): """初始化数据库""" - if "online_time" not in db.list_collection_names(): - # 初始化数据库(在线时长) - db.create_collection("online_time") - # 创建索引 - if ("end_timestamp", 1) not in db.online_time.list_indexes(): - db.online_time.create_index([("end_timestamp", 1)]) + with db.atomic(): # Use atomic operations for schema changes + OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model async def run(self): try: + current_time = datetime.now() + extended_end_time = current_time + timedelta(minutes=1) + if self.record_id: # 如果有记录,则更新结束时间 - db.online_time.update_one( - {"_id": self.record_id}, - { - "$set": { - "end_timestamp": datetime.now() + timedelta(minutes=1), - } - }, - ) - else: + query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) + updated_rows = query.execute() + if updated_rows == 0: + # Record might have been deleted or ID is stale, try to find/create + self.record_id = None # Reset record_id to trigger find/create logic below + + if not self.record_id: # Check again if record_id was reset or initially None # 如果没有记录,检查一分钟以内是否已有记录 - current_time = datetime.now() - if recent_record := db.online_time.find_one( - {"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}} - ): + # Look for a record whose end_timestamp is recent enough to be considered ongoing + recent_record = ( + OnlineTime.select() + .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) + .order_by(OnlineTime.end_timestamp.desc()) + .first() + ) + + if recent_record: # 如果有记录,则更新结束时间 - self.record_id = recent_record["_id"] - db.online_time.update_one( - {"_id": self.record_id}, - { - "$set": { - "end_timestamp": current_time + timedelta(minutes=1), - } - }, - ) + self.record_id = recent_record.id + recent_record.end_timestamp = extended_end_time + recent_record.save() else: # 若没有记录,则插入新的在线时间记录 - self.record_id = db.online_time.insert_one( - { - "start_timestamp": current_time, - "end_timestamp": current_time + timedelta(minutes=1), - } - ).inserted_id + new_record = OnlineTime.create( + timestamp=current_time.timestamp(), # 添加此行 + start_timestamp=current_time, + end_timestamp=extended_end_time, + duration=5, # 初始时长为5分钟 + ) + self.record_id = new_record.id except Exception as e: logger.error(f"在线时间记录失败,错误信息:{e}") @@ -201,35 +200,28 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + # 排序-按照时间段开始时间降序排列(最晚的时间段在前) + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 总LLM请求数 TOTAL_REQ_CNT: 0, - # 请求次数统计 REQ_CNT_BY_TYPE: defaultdict(int), REQ_CNT_BY_USER: defaultdict(int), REQ_CNT_BY_MODEL: defaultdict(int), - # 输入Token数 IN_TOK_BY_TYPE: defaultdict(int), IN_TOK_BY_USER: defaultdict(int), IN_TOK_BY_MODEL: defaultdict(int), - # 输出Token数 OUT_TOK_BY_TYPE: defaultdict(int), OUT_TOK_BY_USER: defaultdict(int), OUT_TOK_BY_MODEL: defaultdict(int), - # 总Token数 TOTAL_TOK_BY_TYPE: defaultdict(int), TOTAL_TOK_BY_USER: defaultdict(int), TOTAL_TOK_BY_MODEL: defaultdict(int), - # 总开销 TOTAL_COST: 0.0, - # 请求开销统计 COST_BY_TYPE: defaultdict(float), COST_BY_USER: defaultdict(float), COST_BY_MODEL: defaultdict(float), @@ -238,26 +230,26 @@ class StatisticOutputTask(AsyncTask): } # 以最早的时间戳为起始时间获取记录 - for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}): - record_timestamp = record.get("timestamp") + # Assuming LLMUsage.timestamp is a DateTimeField + query_start_time = collect_period[-1][1] + for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): + record_timestamp = record.timestamp # This is already a datetime object for idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_REQ_CNT] += 1 - request_type = record.get("request_type", "unknown") # 请求类型 - user_id = str(record.get("user_id", "unknown")) # 用户ID - model_name = record.get("model_name", "unknown") # 模型名称 + request_type = record.request_type or "unknown" + user_id = record.user_id or "unknown" # user_id is TextField, already string + model_name = record.model_name or "unknown" stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 stats[period_key][REQ_CNT_BY_USER][user_id] += 1 stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 - prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数 - completion_tokens = record.get("completion_tokens", 0) # 输出Token数 - total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数 + prompt_tokens = record.prompt_tokens or 0 + completion_tokens = record.completion_tokens or 0 + total_tokens = prompt_tokens + completion_tokens stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens @@ -271,13 +263,12 @@ class StatisticOutputTask(AsyncTask): stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens - cost = record.get("cost", 0.0) + cost = record.cost or 0.0 stats[period_key][TOTAL_COST] += cost stats[period_key][COST_BY_TYPE][request_type] += cost stats[period_key][COST_BY_USER][user_id] += cost stats[period_key][COST_BY_MODEL][model_name] += cost - break # 取消更早时间段的判断 - + break return stats @staticmethod @@ -287,39 +278,38 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 在线时间统计 ONLINE_TIME: 0.0, } for period_key, _ in collect_period } - # 统计在线时间 - for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}): - end_timestamp: datetime = record.get("end_timestamp") - for idx, (_, period_start) in enumerate(collect_period): - if end_timestamp >= period_start: - # 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间 - end_timestamp = min(end_timestamp, now) - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 - for period_key, _period_start in collect_period[idx:]: - start_timestamp: datetime = record.get("start_timestamp") - if start_timestamp < _period_start: - # 如果开始时间在查询边界之前,则使用开始时间 - stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds() - else: - # 否则,使用开始时间 - stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds() - break # 取消更早时间段的判断 + query_start_time = collect_period[-1][1] + # Assuming OnlineTime.end_timestamp is a DateTimeField + for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): + # record.end_timestamp and record.start_timestamp are datetime objects + record_end_timestamp = record.end_timestamp + record_start_timestamp = record.start_timestamp + for idx, (_, period_boundary_start) in enumerate(collect_period): + if record_end_timestamp >= period_boundary_start: + # Calculate effective end time for this record in relation to 'now' + effective_end_time = min(record_end_timestamp, now) + + for period_key, current_period_start_time in collect_period[idx:]: + # Determine the portion of the record that falls within this specific statistical period + overlap_start = max(record_start_timestamp, current_period_start_time) + overlap_end = effective_end_time # Already capped by 'now' and record's own end + + if overlap_end > overlap_start: + stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() + break return stats def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: @@ -328,55 +318,57 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 消息统计 TOTAL_MSG_CNT: 0, MSG_CNT_BY_CHAT: defaultdict(int), } for period_key, _ in collect_period } - # 统计消息量 - for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}): - chat_info = message.get("chat_info", None) # 聊天信息 - user_info = message.get("user_info", None) # 用户信息(消息发送人) - message_time = message.get("time", 0) # 消息时间 + query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) + for message in Messages.select().where(Messages.time >= query_start_timestamp): + message_time_ts = message.time # This is a float timestamp - group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息 - if group_info is not None: - # 若有群聊信息 - chat_id = f"g{group_info.get('group_id')}" - chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}") - elif user_info: - # 若没有群聊信息,则尝试获取用户信息 - chat_id = f"u{user_info['user_id']}" - chat_name = user_info["user_nickname"] + chat_id = None + chat_name = None + + # Logic based on Peewee model structure, aiming to replicate original intent + if message.chat_info_group_id: + chat_id = f"g{message.chat_info_group_id}" + chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}" + elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat + # This uses the message SENDER's ID as per original logic's fallback + chat_id = f"u{message.user_id}" # SENDER's user_id + chat_name = message.user_nickname # SENDER's nickname else: - continue # 如果没有群组信息也没有用户信息,则跳过 + # If neither group_id nor sender_id is available for chat identification + logger.warning( + f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats." + ) + continue + if not chat_id: # Should not happen if above logic is correct + continue + + # Update name_mapping if chat_id in self.name_mapping: - if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]: - # 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称 - self.name_mapping[chat_id] = (chat_name, message_time) + if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]: + self.name_mapping[chat_id] = (chat_name, message_time_ts) else: - self.name_mapping[chat_id] = (chat_name, message_time) + self.name_mapping[chat_id] = (chat_name, message_time_ts) - for idx, (_, period_start) in enumerate(collect_period): - if message_time >= period_start.timestamp(): - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 + for idx, (_, period_start_dt) in enumerate(collect_period): + if message_time_ts >= period_start_dt.timestamp(): for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 break - return stats def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 58eb49de8..c400a9948 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager from ..message_receive.message import MessageRecv from ..models.utils_model import LLMRequest from .typo_generator import ChineseTypoGenerator -from ...common.database import db +from ...common.database.database import db from ...config.config import global_config logger = get_module_logger("chat_utils") diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 6958bc26b..c317fbbd6 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -8,7 +8,8 @@ import io import numpy as np -from ...common.database import db +from ...common.database.database import db +from ...common.database.database_model import Images, ImageDescriptions from ...config.config import global_config from ..models.utils_model import LLMRequest @@ -32,40 +33,23 @@ class ImageManager: def __init__(self): if not self._initialized: - self._ensure_image_collection() - self._ensure_description_collection() self._ensure_image_dir() + self._initialized = True self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image") + try: + db.connect(reuse_if_open=True) + db.create_tables([Images, ImageDescriptions], safe=True) + except Exception as e: + logger.error(f"数据库连接或表创建失败: {e}") + + self._initialized = True + def _ensure_image_dir(self): """确保图像存储目录存在""" os.makedirs(self.IMAGE_DIR, exist_ok=True) - @staticmethod - def _ensure_image_collection(): - """确保images集合存在并创建索引""" - if "images" not in db.list_collection_names(): - db.create_collection("images") - - # 删除旧索引 - db.images.drop_indexes() - # 创建新的复合索引 - db.images.create_index([("hash", 1), ("type", 1)], unique=True) - db.images.create_index([("url", 1)]) - db.images.create_index([("path", 1)]) - - @staticmethod - def _ensure_description_collection(): - """确保image_descriptions集合存在并创建索引""" - if "image_descriptions" not in db.list_collection_names(): - db.create_collection("image_descriptions") - - # 删除旧索引 - db.image_descriptions.drop_indexes() - # 创建新的复合索引 - db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True) - @staticmethod def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 @@ -77,8 +61,14 @@ class ImageManager: Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) - return result["description"] if result else None + try: + record = ImageDescriptions.get_or_none( + (ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type) + ) + return record.description if record else None + except Exception as e: + logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}") + return None @staticmethod def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: @@ -90,20 +80,17 @@ class ImageManager: description_type: 描述类型 ('emoji' 或 'image') """ try: - db.image_descriptions.update_one( - {"hash": image_hash, "type": description_type}, - { - "$set": { - "description": description, - "timestamp": int(time.time()), - "hash": image_hash, # 确保hash字段存在 - "type": description_type, # 确保type字段存在 - } - }, - upsert=True, + current_timestamp = time.time() + defaults = {"description": description, "timestamp": current_timestamp} + desc_obj, created = ImageDescriptions.get_or_create( + hash=image_hash, type=description_type, defaults=defaults ) + if not created: # 如果记录已存在,则更新 + desc_obj.description = description + desc_obj.timestamp = current_timestamp + desc_obj.save() except Exception as e: - logger.error(f"保存描述到数据库失败: {str(e)}") + logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") async def get_emoji_description(self, image_base64: str) -> str: """获取表情包描述,带查重和保存功能""" @@ -116,18 +103,25 @@ class ImageManager: # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: - # logger.debug(f"缓存表情包描述: {cached_description}") return f"[表情包,含义看起来是:{cached_description}]" # 调用AI获取描述 if image_format == "gif" or image_format == "GIF": - image_base64 = self.transform_gif(image_base64) + image_base64_processed = self.transform_gif(image_base64) + if image_base64_processed is None: + logger.warning("GIF转换失败,无法获取描述") + return "[表情包(GIF处理失败)]" prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些" - description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg") + description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg") else: prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些" description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + if description is None: + logger.warning("AI未能生成表情包描述") + return "[表情包(描述生成失败)]" + + # 再次检查缓存,防止并发写入时重复生成 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") @@ -136,31 +130,37 @@ class ImageManager: # 根据配置决定是否保存图片 if global_config.emoji.save_emoji: # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")): - os.makedirs(os.path.join(self.IMAGE_DIR, "emoji")) - file_path = os.path.join(self.IMAGE_DIR, "emoji", filename) + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + emoji_dir = os.path.join(self.IMAGE_DIR, "emoji") + os.makedirs(emoji_dir, exist_ok=True) + file_path = os.path.join(emoji_dir, filename) try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 - image_doc = { - "hash": image_hash, - "path": file_path, - "type": "emoji", - "description": description, - "timestamp": timestamp, - } - db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) - logger.trace(f"保存表情包: {file_path}") + # 保存到数据库 (Images表) + try: + img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + hash=image_hash, + path=file_path, + type="emoji", + description=description, + timestamp=current_timestamp, + ) + logger.trace(f"保存表情包元数据: {file_path}") except Exception as e: - logger.error(f"保存表情包文件失败: {str(e)}") + logger.error(f"保存表情包文件或元数据失败: {str(e)}") - # 保存描述到数据库 + # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "emoji") return f"[表情包:{description}]" @@ -188,6 +188,11 @@ class ImageManager: ) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + if description is None: + logger.warning("AI未能生成图片描述") + return "[图片(描述生成失败)]" + + # 再次检查缓存 cached_description = self._get_description_from_db(image_hash, "image") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}") @@ -195,38 +200,40 @@ class ImageManager: logger.debug(f"描述是{description}") - if description is None: - logger.warning("AI未能生成图片描述") - return "[图片]" - # 根据配置决定是否保存图片 if global_config.emoji.save_pic: # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")): - os.makedirs(os.path.join(self.IMAGE_DIR, "image")) - file_path = os.path.join(self.IMAGE_DIR, "image", filename) + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + image_dir = os.path.join(self.IMAGE_DIR, "image") + os.makedirs(image_dir, exist_ok=True) + file_path = os.path.join(image_dir, filename) try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 - image_doc = { - "hash": image_hash, - "path": file_path, - "type": "image", - "description": description, - "timestamp": timestamp, - } - db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) - logger.trace(f"保存图片: {file_path}") + # 保存到数据库 (Images表) + try: + img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + hash=image_hash, + path=file_path, + type="image", + description=description, + timestamp=current_timestamp, + ) + logger.trace(f"保存图片元数据: {file_path}") except Exception as e: - logger.error(f"保存图片文件失败: {str(e)}") + logger.error(f"保存图片文件或元数据失败: {str(e)}") - # 保存描述到数据库 + # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "image") return f"[图片:{description}]" diff --git a/src/chat/zhishi/knowledge_library.py b/src/chat/zhishi/knowledge_library.py index 6fa1d3e1a..0068a153c 100644 --- a/src/chat/zhishi/knowledge_library.py +++ b/src/chat/zhishi/knowledge_library.py @@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) # 现在可以导入src模块 -from src.common.database import db # noqa E402 +from common.database.database import db # noqa E402 # 加载根目录下的env.edv文件 diff --git a/src/common/database.py b/src/common/database/database.py similarity index 81% rename from src/common/database.py rename to src/common/database/database.py index 752f746db..a2dab739d 100644 --- a/src/common/database.py +++ b/src/common/database/database.py @@ -1,5 +1,6 @@ import os from pymongo import MongoClient +from peewee import SqliteDatabase from pymongo.database import Database from rich.traceback import install @@ -57,4 +58,15 @@ class DBWrapper: # 全局数据库访问点 -db: Database = DBWrapper() +memory_db: Database = DBWrapper() + +# 定义数据库文件路径 +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +_DB_DIR = os.path.join(ROOT_PATH, "data") +_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db") + +# 确保数据库目录存在 +os.makedirs(_DB_DIR, exist_ok=True) + +# 全局 Peewee SQLite 数据库访问点 +db = SqliteDatabase(_DB_FILE) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py new file mode 100644 index 000000000..bd7a2d319 --- /dev/null +++ b/src/common/database/database_model.py @@ -0,0 +1,358 @@ +from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField +from .database import db +import datetime +from ..logger_manager import get_logger + +logger = get_logger("database_model") +# 请在此处定义您的数据库实例。 +# 您需要取消注释并配置适合您的数据库的部分。 +# 例如,对于 SQLite: +# db = SqliteDatabase('MaiBot.db') +# +# 对于 PostgreSQL: +# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password', +# host='localhost', port=5432) +# +# 对于 MySQL: +# db = MySQLDatabase('your_db_name', user='your_user', password='your_password', +# host='localhost', port=3306) + + +# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。 +# 这允许您在一个地方为所有模型指定数据库。 +class BaseModel(Model): + class Meta: + # 将下面的 'db' 替换为您实际的数据库实例变量名。 + database = db # 例如: database = my_actual_db_instance + pass # 在用户定义数据库实例之前,此处为占位符 + + +class ChatStreams(BaseModel): + """ + 用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。 + """ + + # stream_id: "a544edeb1a9b73e3e1d77dff36e41264" + # 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。 + stream_id = TextField(unique=True, index=True) + + # create_time: 1746096761.4490178 (时间戳,精确到小数点后7位) + # DoubleField 用于存储浮点数,适合此类时间戳。 + create_time = DoubleField() + + # group_info 字段: + # platform: "qq" + # group_id: "941657197" + # group_name: "测试" + group_platform = TextField() + group_id = TextField() + group_name = TextField() + + # last_active_time: 1746623771.4825106 (时间戳,精确到小数点后7位) + last_active_time = DoubleField() + + # platform: "qq" (顶层平台字段) + platform = TextField() + + # user_info 字段: + # platform: "qq" + # user_id: "1787882683" + # user_nickname: "墨梓柒(IceSakurary)" + # user_cardname: "" + user_platform = TextField() + user_id = TextField() + user_nickname = TextField() + # user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。 + user_cardname = TextField(null=True) + + class Meta: + # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 + # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它, + # 请取消注释并在下面设置数据库实例: + # database = db + table_name = "chat_streams" # 可选:明确指定数据库中的表名 + + +class LLMUsage(BaseModel): + """ + 用于存储 API 使用日志数据的模型。 + """ + + model_name = TextField(index=True) # 添加索引 + user_id = TextField(index=True) # 添加索引 + request_type = TextField(index=True) # 添加索引 + endpoint = TextField() + prompt_tokens = IntegerField() + completion_tokens = IntegerField() + total_tokens = IntegerField() + cost = DoubleField() + status = TextField() + timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 + + class Meta: + # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 + # database = db + table_name = "llm_usage" + + +class Emoji(BaseModel): + """表情包""" + + full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名) + format = TextField() # 图片格式 + emoji_hash = TextField(index=True) # 表情包的哈希值 + description = TextField() # 表情包的描述 + query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数) + is_registered = BooleanField(default=False) # 是否已注册 + is_banned = BooleanField(default=False) # 是否被禁止注册 + # emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化 + emotion = TextField(null=True) + record_time = FloatField() # 记录时间(被创建的时间) + register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间) + usage_count = IntegerField(default=0) # 使用次数(被使用的次数) + last_used_time = FloatField(null=True) # 上次使用时间 + + class Meta: + # database = db # 继承自 BaseModel + table_name = "emoji" + + +class Messages(BaseModel): + """ + 用于存储消息数据的模型。 + """ + + message_id = TextField(index=True) # 消息 ID (更改自 IntegerField) + time = DoubleField() # 消息时间戳 + + chat_id = TextField(index=True) # 对应的 ChatStreams stream_id + + # 从 chat_info 扁平化而来的字段 + chat_info_stream_id = TextField() + chat_info_platform = TextField() + chat_info_user_platform = TextField() + chat_info_user_id = TextField() + chat_info_user_nickname = TextField() + chat_info_user_cardname = TextField(null=True) + chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在 + chat_info_group_id = TextField(null=True) + chat_info_group_name = TextField(null=True) + chat_info_create_time = DoubleField() + chat_info_last_active_time = DoubleField() + + # 从顶层 user_info 扁平化而来的字段 (消息发送者信息) + user_platform = TextField() + user_id = TextField() + user_nickname = TextField() + user_cardname = TextField(null=True) + + processed_plain_text = TextField(null=True) # 处理后的纯文本消息 + detailed_plain_text = TextField(null=True) # 详细的纯文本消息 + memorized_times = IntegerField(default=0) # 被记忆的次数 + + class Meta: + # database = db # 继承自 BaseModel + table_name = "messages" + + +class Images(BaseModel): + """ + 用于存储图像信息的模型。 + """ + + emoji_hash = TextField(index=True) # 图像的哈希值 + description = TextField(null=True) # 图像的描述 + path = TextField(unique=True) # 图像文件的路径 + timestamp = FloatField() # 时间戳 + type = TextField() # 图像类型,例如 "emoji" + + class Meta: + # database = db # 继承自 BaseModel + table_name = "images" + + +class ImageDescriptions(BaseModel): + """ + 用于存储图像描述信息的模型。 + """ + + type = TextField() # 类型,例如 "emoji" + image_description_hash = TextField(index=True) # 图像的哈希值 + description = TextField() # 图像的描述 + timestamp = FloatField() # 时间戳 + + class Meta: + # database = db # 继承自 BaseModel + table_name = "image_descriptions" + + +class OnlineTime(BaseModel): + """ + 用于存储在线时长记录的模型。 + """ + + # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) + timestamp = TextField(default=datetime.datetime.now) # 时间戳 + duration = IntegerField() # 时长,单位分钟 + start_timestamp = DateTimeField(default=datetime.datetime.now) + end_timestamp = DateTimeField(index=True) + + class Meta: + # database = db # 继承自 BaseModel + table_name = "online_time" + + +class PersonInfo(BaseModel): + """ + 用于存储个人信息数据的模型。 + """ + + person_id = TextField(unique=True, index=True) # 个人唯一ID + person_name = TextField(null=True) # 个人名称 (允许为空) + name_reason = TextField(null=True) # 名称设定的原因 + platform = TextField() # 平台 + user_id = TextField(index=True) # 用户ID + nickname = TextField() # 用户昵称 + relationship_value = IntegerField(default=0) # 关系值 + know_time = FloatField() # 认识时间 (时间戳) + msg_interval = IntegerField() # 消息间隔 + # msg_interval_list: 存储为 JSON 字符串的列表 + msg_interval_list = TextField(null=True) + + class Meta: + # database = db # 继承自 BaseModel + table_name = "person_info" + + +class Knowledges(BaseModel): + """ + 用于存储知识库条目的模型。 + """ + + content = TextField() # 知识内容的文本 + embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 + # 可以添加其他元数据字段,如 source, create_time 等 + + class Meta: + # database = db # 继承自 BaseModel + table_name = "knowledges" + + +class ThinkingLog(BaseModel): + chat_id = TextField(index=True) + trigger_text = TextField(null=True) + response_text = TextField(null=True) + + # Store complex dicts/lists as JSON strings + trigger_info_json = TextField(null=True) + response_info_json = TextField(null=True) + timing_results_json = TextField(null=True) + chat_history_json = TextField(null=True) + chat_history_in_thinking_json = TextField(null=True) + chat_history_after_response_json = TextField(null=True) + heartflow_data_json = TextField(null=True) + reasoning_data_json = TextField(null=True) + + # Add a timestamp for the log entry itself + # Ensure you have: from peewee import DateTimeField + # And: import datetime + created_at = DateTimeField(default=datetime.datetime.now) + + class Meta: + table_name = "thinking_logs" + + +class RecalledMessages(BaseModel): + """ + 用于存储撤回消息记录的模型。 + """ + + message_id = TextField(index=True) # 被撤回的消息 ID + time = DoubleField() # 撤回操作发生的时间戳 + stream_id = TextField() # 对应的 ChatStreams stream_id + + class Meta: + table_name = "recalled_messages" + + +def create_tables(): + """ + 创建所有在模型中定义的数据库表。 + """ + with db: + db.create_tables( + [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog, + RecalledMessages, # 添加新模型 + ] + ) + + +def initialize_database(): + """ + 检查所有定义的表是否存在,如果不存在则创建它们。 + 检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。 + """ + import sys + + models = [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog, + RecalledMessages, # 添加新模型 + ] + + needs_creation = False + try: + with db: # 管理 table_exists 检查的连接 + for model in models: + table_name = model._meta.table_name + if not db.table_exists(model): + logger.warning(f"表 '{table_name}' 未找到。") + needs_creation = True + break # 一个表丢失,无需进一步检查。 + if not needs_creation: + # 检查字段 + for model in models: + table_name = model._meta.table_name + cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") + existing_columns = {row[1] for row in cursor.fetchall()} + model_fields = model._meta.fields + for field_name in model_fields: + if field_name not in existing_columns: + logger.error(f"表 '{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。") + sys.exit(1) + except Exception as e: + logger.exception(f"检查表或字段是否存在时出错: {e}") + # 如果检查失败(例如数据库不可用),则退出 + return + + if needs_creation: + logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...") + try: + create_tables() # 此函数有其自己的 'with db:' 上下文管理。 + logger.info("数据库表创建过程完成。") + except Exception as e: + logger.exception(f"创建表期间出错: {e}") + else: + logger.info("所有数据库表及字段均已存在。") + + +# 模块加载时调用初始化函数 +initialize_database() diff --git a/src/common/logger.py b/src/common/logger.py index 9f2dee455..adc15fe71 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -629,22 +629,22 @@ PROCESSOR_STYLE_CONFIG = { PLANNER_STYLE_CONFIG = { "advanced": { - "console_format": "{time:HH:mm:ss} | 规划器 | {message}", + "console_format": "{time:HH:mm:ss} | 规划器 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}", }, "simple": { - "console_format": "{time:HH:mm:ss} | 规划器 | {message}", + "console_format": "{time:HH:mm:ss} | 规划器 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}", }, } ACTION_TAKEN_STYLE_CONFIG = { "advanced": { - "console_format": "{time:HH:mm:ss} | 动作 | {message}", + "console_format": "{time:HH:mm:ss} | 动作 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}", }, "simple": { - "console_format": "{time:HH:mm:ss} | 动作 | {message}", + "console_format": "{time:HH:mm:ss} | 动作 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}", }, } diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 03f192cea..ee69b22b0 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,11 +1,19 @@ -from src.common.database import db +from src.common.database.database_model import Messages # 更改导入 from src.common.logger import get_module_logger import traceback from typing import List, Any, Optional +from peewee import Model # 添加 Peewee Model 导入 logger = get_module_logger(__name__) +def _model_to_dict(model_instance: Model) -> dict[str, Any]: + """ + 将 Peewee 模型实例转换为字典。 + """ + return model_instance.__data__ + + def find_messages( message_filter: dict[str, Any], sort: Optional[List[tuple[str, int]]] = None, @@ -16,39 +24,84 @@ def find_messages( 根据提供的过滤器、排序和限制条件查找消息。 Args: - message_filter: MongoDB 查询过滤器。 - sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). + sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。 limit: 返回的最大文档数,0表示不限制。 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。 Returns: - 消息文档列表,如果出错则返回空列表。 + 消息字典列表,如果出错则返回空列表。 """ try: - query = db.messages.find(message_filter) + query = Messages.select() + + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + field = getattr(Messages, key) + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") + else: + # 直接相等比较 + conditions.append(field == value) + else: + logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") + if conditions: + query = query.where(*conditions) if limit > 0: if limit_mode == "earliest": # 获取时间最早的 limit 条记录,已经是正序 - query = query.sort([("time", 1)]).limit(limit) - results = list(query) + query = query.order_by(Messages.time.asc()).limit(limit) + peewee_results = list(query) else: # 默认为 'latest' # 获取时间最晚的 limit 条记录 - query = query.sort([("time", -1)]).limit(limit) - latest_results = list(query) + query = query.order_by(Messages.time.desc()).limit(limit) + latest_results_peewee = list(query) # 将结果按时间正序排列 - # 假设消息文档中总是有 'time' 字段且可排序 - results = sorted(latest_results, key=lambda msg: msg.get("time")) + peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time) else: # limit 为 0 时,应用传入的 sort 参数 if sort: - query = query.sort(sort) - results = list(query) + peewee_sort_terms = [] + for field_name, direction in sort: + if hasattr(Messages, field_name): + field = getattr(Messages, field_name) + if direction == 1: # ASC + peewee_sort_terms.append(field.asc()) + elif direction == -1: # DESC + peewee_sort_terms.append(field.desc()) + else: + logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。") + else: + logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") + if peewee_sort_terms: + query = query.order_by(*peewee_sort_terms) + peewee_results = list(query) - return results + return [_model_to_dict(msg) for msg in peewee_results] except Exception as e: log_message = ( - f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + traceback.format_exc() ) logger.error(log_message) @@ -60,18 +113,57 @@ def count_messages(message_filter: dict[str, Any]) -> int: 根据提供的过滤器计算消息数量。 Args: - message_filter: MongoDB 查询过滤器。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). Returns: 符合条件的消息数量,如果出错则返回 0。 """ try: - count = db.messages.count_documents(message_filter) + query = Messages.select() + + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + field = getattr(Messages, key) + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning( + f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" + ) + else: + # 直接相等比较 + conditions.append(field == value) + else: + logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") + if conditions: + query = query.where(*conditions) + + count = query.count() return count except Exception as e: - log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() + log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" logger.error(log_message) return 0 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 +# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。 +# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。 diff --git a/src/experimental/PFC/chat_observer.py b/src/experimental/PFC/chat_observer.py index 6135bd0f7..55914d800 100644 --- a/src/experimental/PFC/chat_observer.py +++ b/src/experimental/PFC/chat_observer.py @@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import ( create_new_message_notification, create_cold_chat_notification, ) -from src.experimental.PFC.message_storage import MongoDBMessageStorage +from src.experimental.PFC.message_storage import PeeweeMessageStorage from rich.traceback import install install(extra_lines=3) @@ -53,7 +53,7 @@ class ChatObserver: self.stream_id = stream_id self.private_name = private_name - self.message_storage = MongoDBMessageStorage() + self.message_storage = PeeweeMessageStorage() # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 diff --git a/src/experimental/PFC/message_storage.py b/src/experimental/PFC/message_storage.py index cd6a01e34..e2e1dd052 100644 --- a/src/experimental/PFC/message_storage.py +++ b/src/experimental/PFC/message_storage.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any -from src.common.database import db + +# from src.common.database.database import db # Peewee db 导入 +from src.common.database.database_model import Messages # Peewee Messages 模型导入 +from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典 class MessageStorage(ABC): @@ -47,28 +50,35 @@ class MessageStorage(ABC): pass -class MongoDBMessageStorage(MessageStorage): - """MongoDB消息存储实现""" +class PeeweeMessageStorage(MessageStorage): + """Peewee消息存储实现""" async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: - query = {"chat_id": chat_id, "time": {"$gt": message_time}} - # print(f"storage_check_message: {message_time}") + query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time > message_time)) + .order_by(Messages.time.asc()) + ) - return list(db.messages.find(query).sort("time", 1)) + # print(f"storage_check_message: {message_time}") + messages_models = list(query) + return [model_to_dict(msg) for msg in messages_models] async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: - query = {"chat_id": chat_id, "time": {"$lt": time_point}} - - messages = list(db.messages.find(query).sort("time", -1).limit(limit)) + query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time < time_point)) + .order_by(Messages.time.desc()) + .limit(limit) + ) + messages_models = list(query) # 将消息按时间正序排列 - messages.reverse() - return messages + messages_models.reverse() + return [model_to_dict(msg) for msg in messages_models] async def has_new_messages(self, chat_id: str, after_time: float) -> bool: - query = {"chat_id": chat_id, "time": {"$gt": after_time}} - - return db.messages.find_one(query) is not None + return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists() # # 创建一个内存消息存储实现,用于测试 diff --git a/src/experimental/PFC/pfc.py b/src/experimental/PFC/pfc.py index 686d4af49..80e75c5bf 100644 --- a/src/experimental/PFC/pfc.py +++ b/src/experimental/PFC/pfc.py @@ -316,7 +316,7 @@ class GoalAnalyzer: # message_segment = Seg(type="text", data=content) # bot_user_info = UserInfo( # user_id=global_config.BOT_QQ, -# user_nickname=global_config.BOT_NICKNAME, +# user_nickname=global_config.bot.nickname, # platform=chat_stream.platform, # ) diff --git a/src/plugins.md b/src/plugins.md new file mode 100644 index 000000000..71ca741a6 --- /dev/null +++ b/src/plugins.md @@ -0,0 +1,101 @@ +# 如何编写MaiBot插件 + +## 基本步骤 + +1. 在`src/plugins/你的插件名/actions/`目录下创建插件文件 +2. 继承`PluginAction`基类 +3. 实现`process`方法 + +## 插件结构示例 + +```python +from src.common.logger_manager import get_logger +from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action +from typing import Tuple + +logger = get_logger("your_action_name") + +@register_action +class YourAction(PluginAction): + """你的动作描述""" + + action_name = "your_action_name" # 动作名称,必须唯一 + action_description = "这个动作的详细描述,会展示给用户" + action_parameters = { + "param1": "参数1的说明(可选)", + "param2": "参数2的说明(可选)" + } + action_require = [ + "使用场景1", + "使用场景2" + ] + default = False # 是否默认启用 + + async def process(self) -> Tuple[bool, str]: + """插件核心逻辑""" + # 你的代码逻辑... + return True, "执行结果" +``` + +## 可用的API方法 + +插件可以使用`PluginAction`基类提供的以下API: + +### 1. 发送消息 + +```python +await self.send_message("要发送的文本", target="可选的回复目标") +``` + +### 2. 获取聊天类型 + +```python +chat_type = self.get_chat_type() # 返回 "group" 或 "private" 或 "unknown" +``` + +### 3. 获取最近消息 + +```python +messages = self.get_recent_messages(count=5) # 获取最近5条消息 +# 返回格式: [{"sender": "发送者", "content": "内容", "timestamp": 时间戳}, ...] +``` + +### 4. 获取动作参数 + +```python +param_value = self.action_data.get("param_name", "默认值") +``` + +### 5. 日志记录 + +```python +logger.info(f"{self.log_prefix} 你的日志信息") +logger.warning("警告信息") +logger.error("错误信息") +``` + +## 返回值说明 + +`process`方法必须返回一个元组,包含两个元素: +- 第一个元素(bool): 表示动作是否执行成功 +- 第二个元素(str): 执行结果的文本描述 + +```python +return True, "执行成功的消息" +# 或 +return False, "执行失败的原因" +``` + +## 最佳实践 + +1. 使用`action_parameters`清晰定义你的动作需要的参数 +2. 使用`action_require`描述何时应该使用你的动作 +3. 使用`action_description`准确描述你的动作功能 +4. 使用`logger`记录重要信息,方便调试 +5. 避免操作底层系统,尽量使用`PluginAction`提供的API + +## 注册与加载 + +插件会在系统启动时自动加载,只要放在正确的目录并添加了`@register_action`装饰器。 + +若设置`default = True`,插件会自动添加到默认动作集;否则需要在系统中手动启用。 diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py new file mode 100644 index 000000000..0b0692d42 --- /dev/null +++ b/src/plugins/__init__.py @@ -0,0 +1 @@ +"""插件系统包""" diff --git a/src/plugins/test_plugin/__init__.py b/src/plugins/test_plugin/__init__.py new file mode 100644 index 000000000..b5fefb97e --- /dev/null +++ b/src/plugins/test_plugin/__init__.py @@ -0,0 +1,5 @@ +"""测试插件包""" + +""" +这是一个测试插件 +""" diff --git a/src/plugins/test_plugin/actions/__init__.py b/src/plugins/test_plugin/actions/__init__.py new file mode 100644 index 000000000..7d96ea8a4 --- /dev/null +++ b/src/plugins/test_plugin/actions/__init__.py @@ -0,0 +1,7 @@ +"""测试插件动作模块""" + +# 导入所有动作模块以确保装饰器被执行 +from . import test_action # noqa + +# from . import online_action # noqa +from . import mute_action # noqa diff --git a/src/plugins/test_plugin/actions/mute_action.py b/src/plugins/test_plugin/actions/mute_action.py new file mode 100644 index 000000000..c96204172 --- /dev/null +++ b/src/plugins/test_plugin/actions/mute_action.py @@ -0,0 +1,51 @@ +from src.common.logger_manager import get_logger +from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action +from typing import Tuple + +logger = get_logger("mute_action") + + +@register_action +class MuteAction(PluginAction): + """测试动作处理类""" + + action_name = "mute_action" + action_description = ( + "如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人" + ) + action_parameters = { + "target": "禁言对象,输入你要禁言的对象的名字,必填,", + "duration": "禁言时长,输入你要禁言的时长,单位为秒,必填", + } + action_require = [ + "当有人违反了公序良俗时使用", + "当有人刷屏时使用", + "当有人要求禁言自己时使用", + "当有人戳你两次以上时,防止刷屏,禁言他,必须牢记", + "当千石可乐或可乐酱要求你禁言时使用", + "当你想回避某个话题时使用", + ] + default = True # 不是默认动作,需要手动添加到使用集 + + async def process(self) -> Tuple[bool, str]: + """处理测试动作""" + logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}") + + # 发送测试消息 + target = self.action_data.get("target") + duration = self.action_data.get("duration") + reason = self.action_data.get("reason") + platform, user_id = await self.get_user_id_by_person_name(target) + + await self.send_message_by_expressor(f"我要禁言{target},{platform},时长{duration}秒,理由{reason},表达情绪") + + try: + await self.send_message(f"[command]mute,{user_id},{duration}") + + except Exception as e: + logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}") + await self.send_message_by_expressor(f"执行mute动作时出错: {e}") + + return False, "执行mute动作时出错" + + return True, "测试动作执行成功" diff --git a/src/plugins/test_plugin/actions/online_action.py b/src/plugins/test_plugin/actions/online_action.py new file mode 100644 index 000000000..4f49045f2 --- /dev/null +++ b/src/plugins/test_plugin/actions/online_action.py @@ -0,0 +1,43 @@ +from src.common.logger_manager import get_logger +from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action +from typing import Tuple + +logger = get_logger("check_online_action") + + +@register_action +class CheckOnlineAction(PluginAction): + """测试动作处理类""" + + action_name = "check_online_action" + action_description = "这是一个检查在线状态的动作,当有人要求你检查Maibot(麦麦 机器人)在线状态时使用" + action_parameters = {"mode": "查看模式"} + action_require = [ + "当有人要求你检查Maibot(麦麦 机器人)在线状态时使用", + "mode参数为version时查看在线版本状态,默认用这种", + "mode参数为type时查看在线系统类型分布", + ] + default = True # 不是默认动作,需要手动添加到使用集 + + async def process(self) -> Tuple[bool, str]: + """处理测试动作""" + logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}") + + # 发送测试消息 + mode = self.action_data.get("mode", "type") + + await self.send_message_by_expressor("我看看") + + try: + if mode == "type": + await self.send_message("#online detail") + elif mode == "version": + await self.send_message("#online") + + except Exception as e: + logger.error(f"{self.log_prefix} 执行online动作时出错: {e}") + await self.send_message_by_expressor("执行online动作时出错: {e}") + + return False, "执行online动作时出错" + + return True, "测试动作执行成功" diff --git a/src/plugins/test_plugin/actions/test_action.py b/src/plugins/test_plugin/actions/test_action.py new file mode 100644 index 000000000..995dd918a --- /dev/null +++ b/src/plugins/test_plugin/actions/test_action.py @@ -0,0 +1,37 @@ +from src.common.logger_manager import get_logger +from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action +from typing import Tuple + +logger = get_logger("test_action") + + +@register_action +class TestAction(PluginAction): + """测试动作处理类""" + + action_name = "test_action" + action_description = "这是一个测试动作,当有人要求你测试插件系统时使用" + action_parameters = {"test_param": "测试参数(可选)"} + action_require = [ + "测试情况下使用", + "想测试插件动作加载时使用", + ] + default = False # 不是默认动作,需要手动添加到使用集 + + async def process(self) -> Tuple[bool, str]: + """处理测试动作""" + logger.info(f"{self.log_prefix} 执行测试动作: {self.reasoning}") + + # 获取聊天类型 + chat_type = self.get_chat_type() + logger.info(f"{self.log_prefix} 当前聊天类型: {chat_type}") + + # 获取最近消息 + recent_messages = self.get_recent_messages(3) + logger.info(f"{self.log_prefix} 最近3条消息: {recent_messages}") + + # 发送测试消息 + test_param = self.action_data.get("test_param", "默认参数") + await self.send_message_by_expressor(f"测试动作执行成功,参数: {test_param}") + + return True, "测试动作执行成功" diff --git a/src/tools/tool_can_use/get_knowledge.py b/src/tools/tool_can_use/get_knowledge.py index 65acd55c0..fd37f11e7 100644 --- a/src/tools/tool_can_use/get_knowledge.py +++ b/src/tools/tool_can_use/get_knowledge.py @@ -1,8 +1,10 @@ from src.tools.tool_can_use.base_tool import BaseTool from src.chat.utils.utils import get_embedding -from src.common.database import db +from src.common.database.database_model import Knowledges # Updated import from src.common.logger_manager import get_logger -from typing import Any, Union +from typing import Any, Union, List # Added List +import json # Added for parsing embedding +import math # Added for cosine similarity logger = get_logger("get_knowledge_tool") @@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool): Returns: dict: 工具执行结果 """ + query = "" # Initialize query to ensure it's defined in except block try: query = function_args.get("query") threshold = function_args.get("threshold", 0.4) @@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool): logger.error(f"知识库搜索工具执行失败: {str(e)}") return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} + @staticmethod + def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: + """计算两个向量之间的余弦相似度""" + dot_product = sum(p * q for p, q in zip(vec1, vec2)) + magnitude1 = math.sqrt(sum(p * p for p in vec1)) + magnitude2 = math.sqrt(sum(q * q for q in vec2)) + if magnitude1 == 0 or magnitude2 == 0: + return 0.0 + return dot_product / (magnitude1 * magnitude2) + @staticmethod def get_info_from_db( - query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False ) -> Union[str, list]: """从数据库中获取相关信息 @@ -66,66 +79,51 @@ class SearchKnowledgeTool(BaseTool): if not query_embedding: return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] + similar_items = [] + try: + all_knowledges = Knowledges.select() + for item in all_knowledges: + try: + item_embedding_str = item.embedding + if not item_embedding_str: + logger.warning(f"Knowledge item ID {item.id} has empty embedding string.") + continue + item_embedding = json.loads(item_embedding_str) + if not isinstance(item_embedding, list) or not all( + isinstance(x, (int, float)) for x in item_embedding + ): + logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.") + continue + except json.JSONDecodeError: + logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}") + continue + except AttributeError: + logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.") + continue - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") + similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding) + + if similarity >= threshold: + similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item}) + + # 按相似度降序排序 + similar_items.sort(key=lambda x: x["similarity"], reverse=True) + + # 应用限制 + results = similar_items[:limit] + logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}") + + except Exception as e: + logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}") + return "" if not return_raw else [] if not results: return "" if not return_raw else [] if return_raw: - return results + # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理 + # 这里返回包含内容和相似度的字典列表 + return [{"content": r["content"], "similarity": r["similarity"]} for r in results] else: # 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in results) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 64e51da77..8ffbcfa92 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -18,11 +18,11 @@ nickname = "麦麦" alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效 [chat_target] -talk_allowed = [ +talk_allowed_groups = [ 123, 123, ] #可以回复消息的群号码 -talk_frequency_down = [] #降低回复频率的群号码 +talk_frequency_down_groups = [] #降低回复频率的群号码 ban_user_id = [] #禁止回复和读取消息的QQ号 [personality] #未完善 @@ -63,7 +63,7 @@ allow_focus_mode = false # 是否允许专注聊天状态 base_normal_chat_num = 999 # 最多允许多少个群进行普通聊天 base_focused_chat_num = 4 # 最多允许多少个群进行专注聊天 -observation_context_size = 15 # 观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖 +chat.observation_context_size = 15 # 观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖 message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟 # 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息 @@ -99,7 +99,7 @@ default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快, consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天 # 以下选项暂时无效 -compressed_length = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 +compressed_length = 5 # 不能大于chat.observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下文会被删除 @@ -121,7 +121,7 @@ memory_build_sample_length = 40 # 采样长度,数值越高一段记忆内容 memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多 forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习 -memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时 +memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时 memory_forget_percentage = 0.01 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认 consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低,麦麦整合越频繁,记忆更精简