From aba4f1a947fcf29283d54487b2f475914860743f Mon Sep 17 00:00:00 2001 From: tt-P607 <68868379+tt-P607@users.noreply.github.com> Date: Sat, 20 Sep 2025 02:21:53 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=89=E6=AC=A1=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/chat_loop/heartFC_chat.py | 2 +- src/chat/emoji_system/emoji_manager.py | 61 +-- src/chat/express/expression_selector.py | 51 +-- src/chat/memory_system/Hippocampus.py | 79 ++-- src/chat/message_receive/chat_stream.py | 34 +- src/chat/message_receive/storage.py | 48 +- src/chat/planner_actions/action_modifier.py | 4 +- src/chat/planner_actions/plan_filter.py | 4 +- src/chat/planner_actions/plan_generator.py | 2 +- src/chat/replyer/default_generator.py | 17 +- src/chat/utils/chat_message_builder.py | 134 +++--- src/common/message_repository.py | 18 +- src/main.py | 10 +- src/person_info/person_info.py | 257 ++++------- src/person_info/relationship_fetcher.py | 457 ++++++++++++++++++++ src/plugin_system/apis/message_api.py | 18 +- src/schedule/database.py | 147 ++++--- src/schedule/monthly_plan_manager.py | 5 + src/schedule/plan_manager.py | 18 +- src/schedule/schedule_manager.py | 36 +- 20 files changed, 923 insertions(+), 479 deletions(-) create mode 100644 src/person_info/relationship_fetcher.py diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 6f63cff1b..05edb3ee0 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -361,7 +361,7 @@ class HeartFChatting: # 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收 filter_command_flag = not (is_sleeping or is_in_insomnia) - recent_messages = message_api.get_messages_by_time_in_chat( + recent_messages = await message_api.get_messages_by_time_in_chat( chat_id=self.context.stream_id, start_time=self.context.last_read_time, end_time=time.time(), diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 8e6079897..ce7b0d074 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -149,7 +149,7 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - with get_db_session() as session: + async with get_db_session() as session: emotion_str = ",".join(self.emotion) if self.emotion else "" emoji = Emoji( @@ -167,7 +167,7 @@ class MaiEmoji: last_used_time=self.last_used_time, ) session.add(emoji) - session.commit() + await session.commit() logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") @@ -203,17 +203,17 @@ class MaiEmoji: # 2. 删除数据库记录 try: - with get_db_session() as session: - will_delete_emoji = session.execute( - select(Emoji).where(Emoji.emoji_hash == self.hash) + async with get_db_session() as session: + will_delete_emoji = ( + await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)) ).scalar_one_or_none() if will_delete_emoji is None: logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") - result = 0 # Indicate no DB record was deleted + result = 0 else: - session.delete(will_delete_emoji) - result = 1 # Successfully deleted one record - session.commit() + await session.delete(will_delete_emoji) + result = 1 + await session.commit() except Exception as e: logger.error(f"[错误] 删除数据库记录时出错: {str(e)}") result = 0 @@ -424,17 +424,19 @@ class EmojiManager: # if not self._initialized: # raise RuntimeError("EmojiManager not initialized") - def record_usage(self, emoji_hash: str) -> None: + async def record_usage(self, emoji_hash: str) -> None: """记录表情使用次数""" try: - with get_db_session() as session: - emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none() + async with get_db_session() as session: + emoji_update = ( + await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)) + ).scalar_one_or_none() if emoji_update is None: logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") else: emoji_update.usage_count += 1 - emoji_update.last_used_time = time.time() # Update last used time - session.commit() + emoji_update.last_used_time = time.time() + await session.commit() except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -658,10 +660,11 @@ class EmojiManager: async def get_all_emoji_from_db(self) -> None: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: - with get_db_session() as session: + async with get_db_session() as session: logger.debug("[数据库] 开始加载所有表情包记录 ...") - emoji_instances = session.execute(select(Emoji)).scalars().all() + result = await session.execute(select(Emoji)) + emoji_instances = result.scalars().all() emoji_objects, load_errors = _to_emoji_objects(emoji_instances) # 更新内存中的列表和数量 @@ -687,14 +690,16 @@ class EmojiManager: list[MaiEmoji]: 表情包对象列表 """ try: - with get_db_session() as session: + async with get_db_session() as session: if emoji_hash: - query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all() + result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)) + query = result.scalars().all() else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) - query = session.execute(select(Emoji)).scalars().all() + result = await session.execute(select(Emoji)) + query = result.scalars().all() emoji_instances = query emoji_objects, load_errors = _to_emoji_objects(emoji_instances) @@ -771,10 +776,11 @@ class EmojiManager: # 如果内存中没有,从数据库查找 try: - with get_db_session() as session: - emoji_record = session.execute( + async with get_db_session() as session: + result = await session.execute( select(Emoji).where(Emoji.emoji_hash == emoji_hash) - ).scalar_one_or_none() + ) + emoji_record = result.scalar_one_or_none() if emoji_record and emoji_record.description: logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") return emoji_record.description @@ -937,10 +943,13 @@ class EmojiManager: # 2. 检查数据库中是否已存在该表情包的描述,实现复用 existing_description = None try: - with get_db_session() as session: - existing_image = session.query(Images).filter( - (Images.emoji_hash == image_hash) & (Images.type == "emoji") - ).one_or_none() + async with get_db_session() as session: + result = await session.execute( + select(Images).filter( + (Images.emoji_hash == image_hash) & (Images.type == "emoji") + ) + ) + existing_image = result.scalar_one_or_none() if existing_image and existing_image.description: existing_description = existing_image.description logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index aa3528185..810c3326c 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -136,18 +136,18 @@ class ExpressionSelector: return related_chat_ids if related_chat_ids else [chat_id] - def get_random_expressions( - self, chat_id: str, total_num: int - ) -> List[Dict[str, Any]]: + async def get_random_expressions( + self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) - with get_db_session() as session: + async with get_db_session() as session: # 优化:一次性查询所有相关chat_id的表达方式 - style_query = session.execute( + style_query = await session.execute( select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")) ) - grammar_query = session.execute( + grammar_query = await session.execute( select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")) ) @@ -193,7 +193,7 @@ class ExpressionSelector: return selected_style, selected_grammar - def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): + async def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" if not expressions_to_update: return @@ -210,26 +210,27 @@ class ExpressionSelector: if key not in updates_by_key: updates_by_key[key] = expr for chat_id, expr_type, situation, style in updates_by_key: - with get_db_session() as session: - query = session.execute( + async with get_db_session() as session: + query = await session.execute( select(Expression).where( (Expression.chat_id == chat_id) & (Expression.type == expr_type) & (Expression.situation == situation) & (Expression.style == style) ) - ).scalar() - if query: - expr_obj = query - current_count = expr_obj.count - new_count = min(current_count + increment, 5.0) - expr_obj.count = new_count - expr_obj.last_active_time = time.time() - - logger.debug( - f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" ) - session.commit() + query = query.scalar() + if query: + expr_obj = query + current_count = expr_obj.count + new_count = min(current_count + increment, 5.0) + expr_obj.count = new_count + expr_obj.last_active_time = time.time() + + logger.debug( + f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" + ) + await session.commit() async def select_suitable_expressions_llm( self, @@ -246,12 +247,8 @@ class ExpressionSelector: logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") return [], [] - # 1. 获取20个随机表达方式(现在按权重抽取) - style_exprs = self.get_random_expressions(chat_id, 10) - - if len(style_exprs) < 10: - logger.info(f"聊天流 {chat_id} 表达方式正在积累中") - return [], [] + # 1. 获取35个随机表达方式(现在按权重抽取) + style_exprs, grammar_exprs = await self.get_random_expressions(chat_id, 30, 0.5, 0.5) # 2. 构建所有表达方式的索引和情境列表 all_expressions: List[Dict[str, Any]] = [] @@ -326,7 +323,7 @@ class ExpressionSelector: # 对选中的所有表达方式,一次性更新count数 if valid_expressions: - self.update_expressions_count_batch(valid_expressions, 0.006) + await self.update_expressions_count_batch(valid_expressions, 0.006) # logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个") return valid_expressions, selected_ids diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 98fd4e1c7..2b33e59be 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -267,8 +267,8 @@ class Hippocampus: self.entorhinal_cortex = EntorhinalCortex(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self) # 从数据库加载记忆图 - self.entorhinal_cortex.sync_memory_from_db() - self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify") + # self.entorhinal_cortex.sync_memory_from_db() # 改为异步启动 + self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.small") def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表""" @@ -877,7 +877,7 @@ class EntorhinalCortex: self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - def get_memory_sample(self): + async def get_memory_sample(self): """从数据库获取记忆样本""" # 硬编码:每条消息最大记忆次数 max_memorized_time_per_msg = 2 @@ -900,7 +900,7 @@ class EntorhinalCortex: logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] for timestamp in timestamps: - if messages := self.random_get_msg_snippet( + if messages := await self.random_get_msg_snippet( timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg, @@ -914,7 +914,9 @@ class EntorhinalCortex: return chat_samples @staticmethod - def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None: + async def random_get_msg_snippet( + target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int + ) -> list | None: # sourcery skip: invert-any-all, use-any, use-named-expression, use-next """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟 @@ -952,13 +954,13 @@ class EntorhinalCortex: for message in messages: # 确保在更新前获取最新的 memorized_times current_memorized_times = message.get("memorized_times", 0) - with get_db_session() as session: - session.execute( + async with get_db_session() as session: + await session.execute( update(Messages) .where(Messages.message_id == message["message_id"]) .values(memorized_times=current_memorized_times + 1) ) - session.commit() + await session.commit() return messages # 直接返回原始的消息列表 target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试 @@ -972,8 +974,8 @@ class EntorhinalCortex: current_time = datetime.datetime.now().timestamp() # 获取数据库中所有节点和内存中所有节点 - with get_db_session() as session: - db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()} + async with get_db_session() as session: + db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) # 批量准备节点数据 @@ -1043,24 +1045,24 @@ class EntorhinalCortex: batch_size = 100 for i in range(0, len(nodes_to_create), batch_size): batch = nodes_to_create[i : i + batch_size] - session.execute(insert(GraphNodes), batch) + await session.execute(insert(GraphNodes), batch) if nodes_to_update: batch_size = 100 for i in range(0, len(nodes_to_update), batch_size): batch = nodes_to_update[i : i + batch_size] for node_data in batch: - session.execute( + await session.execute( update(GraphNodes) .where(GraphNodes.concept == node_data["concept"]) .values(**{k: v for k, v in node_data.items() if k != "concept"}) ) if nodes_to_delete: - session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) + await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete))) # 处理边的信息 - db_edges = list(session.execute(select(GraphEdges)).scalars()) + db_edges = list((await session.execute(select(GraphEdges))).scalars()) memory_edges = list(self.memory_graph.G.edges(data=True)) # 创建边的哈希值字典 @@ -1112,14 +1114,14 @@ class EntorhinalCortex: batch_size = 100 for i in range(0, len(edges_to_create), batch_size): batch = edges_to_create[i : i + batch_size] - session.execute(insert(GraphEdges), batch) + await session.execute(insert(GraphEdges), batch) if edges_to_update: batch_size = 100 for i in range(0, len(edges_to_update), batch_size): batch = edges_to_update[i : i + batch_size] for edge_data in batch: - session.execute( + await session.execute( update(GraphEdges) .where( (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) @@ -1129,12 +1131,12 @@ class EntorhinalCortex: if edges_to_delete: for source, target in edges_to_delete: - session.execute( + await session.execute( delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target)) ) # 提交事务 - session.commit() + await session.commit() end_time = time.time() logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}秒") @@ -1146,10 +1148,10 @@ class EntorhinalCortex: logger.info("[数据库] 开始重新同步所有记忆数据...") # 清空数据库 - with get_db_session() as session: + async with get_db_session() as session: clear_start = time.time() - session.execute(delete(GraphNodes)) - session.execute(delete(GraphEdges)) + await session.execute(delete(GraphNodes)) + await session.execute(delete(GraphEdges)) clear_end = time.time() logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") @@ -1208,7 +1210,7 @@ class EntorhinalCortex: batch_size = 500 # 增加批量大小 for i in range(0, len(nodes_data), batch_size): batch = nodes_data[i : i + batch_size] - session.execute(insert(GraphNodes), batch) + await session.execute(insert(GraphNodes), batch) node_end = time.time() logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒") @@ -1219,8 +1221,8 @@ class EntorhinalCortex: batch_size = 500 # 增加批量大小 for i in range(0, len(edges_data), batch_size): batch = edges_data[i : i + batch_size] - session.execute(insert(GraphEdges), batch) - session.commit() + await session.execute(insert(GraphEdges), batch) + await session.commit() edge_end = time.time() logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒") @@ -1229,7 +1231,7 @@ class EntorhinalCortex: logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒") logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边") - def sync_memory_from_db(self): + async def sync_memory_from_db(self): """从数据库同步数据到内存中的图结构""" current_time = datetime.datetime.now().timestamp() need_update = False @@ -1243,8 +1245,8 @@ class EntorhinalCortex: skipped_nodes = 0 # 从数据库加载所有节点 - with get_db_session() as session: - nodes = list(session.execute(select(GraphNodes)).scalars()) + async with get_db_session() as session: + nodes = list((await session.execute(select(GraphNodes))).scalars()) for node in nodes: concept = node.concept try: @@ -1262,7 +1264,9 @@ class EntorhinalCortex: if not node.last_modified: update_data["last_modified"] = current_time - session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)) + await session.execute( + update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data) + ) # 获取时间信息(如果不存在则使用当前时间) created_time = node.created_time or current_time @@ -1277,7 +1281,7 @@ class EntorhinalCortex: continue # 从数据库加载所有边 - edges = list(session.execute(select(GraphEdges)).scalars()) + edges = list((await session.execute(select(GraphEdges))).scalars()) for edge in edges: source = edge.source target = edge.target @@ -1293,7 +1297,7 @@ class EntorhinalCortex: if not edge.last_modified: update_data["last_modified"] = current_time - session.execute( + await session.execute( update(GraphEdges) .where((GraphEdges.source == source) & (GraphEdges.target == target)) .values(**update_data) @@ -1308,7 +1312,7 @@ class EntorhinalCortex: self.memory_graph.G.add_edge( source, target, strength=strength, created_time=created_time, last_modified=last_modified ) - session.commit() + await session.commit() if need_update: logger.info("[数据库] 已为缺失的时间字段进行补充") @@ -1348,7 +1352,7 @@ class ParahippocampalGyrus: # 1. 使用 build_readable_messages 生成格式化文本 # build_readable_messages 只返回一个字符串,不需要解包 - input_text = build_readable_messages( + input_text = await build_readable_messages( messages, merge_messages=True, # 合并连续消息 timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式 @@ -1436,7 +1440,7 @@ class ParahippocampalGyrus: # sourcery skip: merge-list-appends-into-extend logger.info("------------------------------------开始构建记忆--------------------------------------") start_time = time.time() - memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample() + memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample() all_added_nodes = [] all_connected_nodes = [] all_added_edges = [] @@ -1697,7 +1701,7 @@ class HippocampusManager: return self._hippocampus self._hippocampus = Hippocampus() - self._hippocampus.initialize() + # self._hippocampus.initialize() # 改为异步启动 self._initialized = True # 输出记忆图统计信息 @@ -1716,6 +1720,13 @@ class HippocampusManager: return self._hippocampus + async def initialize_async(self): + """异步初始化海马体实例""" + if not self._initialized: + self.initialize() # 先进行同步部分的初始化 + self._hippocampus.initialize() + await self._hippocampus.entorhinal_cortex.sync_memory_from_db() + def get_hippocampus(self): if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 77d1abb60..b0de6f596 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -246,11 +246,11 @@ class ChatManager: return stream # 检查数据库中是否存在 - def _db_find_stream_sync(s_id: str): - with get_db_session() as session: - return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar() + async def _db_find_stream_async(s_id: str): + async with get_db_session() as session: + return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalar() - model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) + model_instance = await _db_find_stream_async(stream_id) if model_instance: # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 @@ -344,11 +344,10 @@ class ChatManager: return stream_data_dict = stream.to_dict() - def _db_save_stream_sync(s_data_dict: dict): - with get_db_session() as session: + async def _db_save_stream_async(s_data_dict: dict): + async with get_db_session() as session: user_info_d = s_data_dict.get("user_info") group_info_d = s_data_dict.get("group_info") - fields_to_save = { "platform": s_data_dict["platform"], "create_time": s_data_dict["create_time"], @@ -364,8 +363,6 @@ class ChatManager: "sleep_pressure": s_data_dict.get("sleep_pressure", 0.0), "focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value), } - - # 根据数据库类型选择插入语句 if global_config.database.database_type == "sqlite": stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save) @@ -375,15 +372,13 @@ class ChatManager: **{key: value for key, value in fields_to_save.items() if key != "stream_id"} ) else: - # 默认使用通用插入,尝试SQLite语法 stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save) - - session.execute(stmt) - session.commit() + await session.execute(stmt) + await session.commit() try: - await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) + await _db_save_stream_async(stream_data_dict) stream.saved = True except Exception as e: logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) @@ -397,10 +392,10 @@ class ChatManager: """从数据库加载所有聊天流""" logger.info("正在从数据库加载所有聊天流") - def _db_load_all_streams_sync(): + async def _db_load_all_streams_async(): loaded_streams_data = [] - with get_db_session() as session: - for model_instance in session.execute(select(ChatStreams)).scalars(): + async with get_db_session() as session: + for model_instance in (await session.execute(select(ChatStreams))).scalars(): user_info_data = { "platform": model_instance.user_platform, "user_id": model_instance.user_id, @@ -414,7 +409,6 @@ class ChatManager: "group_id": model_instance.group_id, "group_name": model_instance.group_name, } - data_for_from_dict = { "stream_id": model_instance.stream_id, "platform": model_instance.platform, @@ -427,11 +421,11 @@ class ChatManager: "focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value), } loaded_streams_data.append(data_for_from_dict) - session.commit() + await session.commit() return loaded_streams_data try: - all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync) + all_streams_data_list = await _db_load_all_streams_async() self.streams.clear() for data in all_streams_data_list: stream = ChatStream.from_dict(data) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 2edcbe92f..9b6c69d40 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -42,7 +42,7 @@ class MessageStorage: processed_plain_text = message.processed_plain_text if processed_plain_text: - processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text) + processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL) else: filtered_processed_plain_text = "" @@ -129,9 +129,9 @@ class MessageStorage: key_words=key_words, key_words_lite=key_words_lite, ) - with get_db_session() as session: + async with get_db_session() as session: session.add(new_message) - session.commit() + await session.commit() except Exception: logger.exception("存储消息失败") @@ -174,16 +174,18 @@ class MessageStorage: # 使用上下文管理器确保session正确管理 from src.common.database.sqlalchemy_models import get_db_session - with get_db_session() as session: - matched_message = session.execute( - select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) + async with get_db_session() as session: + matched_message = ( + await session.execute( + select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time)) + ) ).scalar() if matched_message: - session.execute( + await session.execute( update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id) ) - session.commit() + await session.commit() # 会在上下文管理器中自动调用 logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") else: @@ -197,28 +199,36 @@ class MessageStorage: ) @staticmethod - def replace_image_descriptions(text: str) -> str: + async def replace_image_descriptions(text: str) -> str: """将[图片:描述]替换为[picid:image_id]""" # 先检查文本中是否有图片标记 pattern = r"\[图片:([^\]]+)\]" - matches = re.findall(pattern, text) + matches = list(re.finditer(pattern, text)) if not matches: logger.debug("文本中没有图片标记,直接返回原文本") return text - def replace_match(match): + new_text = "" + last_end = 0 + for match in matches: + new_text += text[last_end : match.start()] description = match.group(1).strip() try: from src.common.database.sqlalchemy_models import get_db_session - with get_db_session() as session: - image_record = session.execute( - select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) + async with get_db_session() as session: + image_record = ( + await session.execute( + select(Images).where(Images.description == description).order_by(desc(Images.timestamp)) + ) ).scalar() - session.commit() - return f"[picid:{image_record.image_id}]" if image_record else match.group(0) + if image_record: + new_text += f"[picid:{image_record.image_id}]" + else: + new_text += match.group(0) except Exception: - return match.group(0) - - return re.sub(r"\[图片:([^\]]+)\]", replace_match, text) + new_text += match.group(0) + last_end = match.end() + new_text += text[last_end:] + return new_text diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index e9cc1d106..a061c15ae 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -97,12 +97,12 @@ class ActionModifier: for action_name, reason in chat_type_removals: logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}") - message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat( chat_id=self.chat_stream.stream_id, timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 10), ) - chat_content = build_readable_messages( + chat_content = await build_readable_messages( message_list_before_now_half, replace_bot_name=True, merge_messages=False, diff --git a/src/chat/planner_actions/plan_filter.py b/src/chat/planner_actions/plan_filter.py index 53e1e4a80..2d05f0511 100644 --- a/src/chat/planner_actions/plan_filter.py +++ b/src/chat/planner_actions/plan_filter.py @@ -152,7 +152,7 @@ class PlanFilter: ) return prompt, message_id_list - chat_content_block, message_id_list = build_readable_messages_with_id( + chat_content_block, message_id_list = await build_readable_messages_with_id( messages=[msg.flatten() for msg in plan.chat_history], timestamp_mode="normal", read_mark=self.last_obs_time_mark, @@ -167,7 +167,7 @@ class PlanFilter: limit=5, ) - actions_before_now_block = build_readable_actions(actions=actions_before_now) + actions_before_now_block = build_readable_actions(actions=await actions_before_now) actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" self.last_obs_time_mark = time.time() diff --git a/src/chat/planner_actions/plan_generator.py b/src/chat/planner_actions/plan_generator.py index 5dd1b680c..96af31c4b 100644 --- a/src/chat/planner_actions/plan_generator.py +++ b/src/chat/planner_actions/plan_generator.py @@ -63,7 +63,7 @@ class PlanGenerator: timestamp=time.time(), limit=int(global_config.chat.max_context_size), ) - chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw] + chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw] plan = Plan( diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 818bceb00..b372dad4b 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -873,7 +873,8 @@ class DefaultReplyer: platform, # type: ignore reply_message.get("user_id"), # type: ignore ) - person_name = await person_info_manager.get_value(person_id, "person_name") + person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"]) + person_name = person_info.get("person_name") # 如果person_name为None,使用fallback值 if person_name is None: @@ -884,7 +885,7 @@ class DefaultReplyer: # 检查是否是bot自己的名字,如果是则替换为"(你)" bot_user_id = str(global_config.bot.qq_account) - current_user_id = person_info_manager.get_value_sync(person_id, "user_id") + current_user_id = person_info.get("user_id") current_platform = reply_message.get("chat_info_platform") if current_user_id == bot_user_id and current_platform == global_config.bot.platform: @@ -909,18 +910,18 @@ class DefaultReplyer: target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True) - message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size * 1, ) - message_list_before_short = get_raw_msg_before_timestamp_with_chat( + message_list_before_short = await get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=int(global_config.chat.max_context_size * 0.33), ) - chat_talking_prompt_short = build_readable_messages( + chat_talking_prompt_short = await build_readable_messages( message_list_before_short, replace_bot_name=True, merge_messages=False, @@ -932,7 +933,7 @@ class DefaultReplyer: # 获取目标用户信息,用于s4u模式 target_user_info = None if sender: - target_user_info = await person_info_manager.get_person_info_by_name(sender) + target_user_info = person_info_manager.get_person_info_by_name(sender) from src.chat.utils.prompt import Prompt # 并行执行六个构建任务 @@ -1150,12 +1151,12 @@ class DefaultReplyer: else: mood_prompt = "" - message_list_before_now_half = get_raw_msg_before_timestamp_with_chat( + message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=min(int(global_config.chat.max_context_size * 0.33), 15), ) - chat_talking_prompt_half = build_readable_messages( + chat_talking_prompt_half = await build_readable_messages( message_list_before_now_half, replace_bot_name=True, merge_messages=False, diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 5bdee8cd2..3924f1ee3 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -116,8 +116,9 @@ async def replace_user_references_async( # 检查是否是机器人自己 if replace_bot_name and user_id == global_config.bot.qq_account: return f"{global_config.bot.nickname}(你)" - person = Person(platform=platform, user_id=user_id) - return person.person_name or user_id # type: ignore + person_id = PersonInfoManager.get_person_id(platform, user_id) + person_info = await person_info_manager.get_values(person_id, ["person_name"]) + return person_info.get("person_name") or user_id name_resolver = default_resolver @@ -165,7 +166,7 @@ async def replace_user_references_async( return content -def get_raw_msg_by_timestamp( +async def get_raw_msg_by_timestamp( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ @@ -176,10 +177,10 @@ def get_raw_msg_by_timestamp( filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}} # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_raw_msg_by_timestamp_with_chat( +async def get_raw_msg_by_timestamp_with_chat( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -196,7 +197,7 @@ def get_raw_msg_by_timestamp_with_chat( # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return find_messages( + return await find_messages( message_filter=filter_query, sort=sort_order, limit=limit, @@ -206,7 +207,7 @@ def get_raw_msg_by_timestamp_with_chat( ) -def get_raw_msg_by_timestamp_with_chat_inclusive( +async def get_raw_msg_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -223,12 +224,12 @@ def get_raw_msg_by_timestamp_with_chat_inclusive( sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return find_messages( + return await find_messages( message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot ) -def get_raw_msg_by_timestamp_with_chat_users( +async def get_raw_msg_by_timestamp_with_chat_users( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -247,10 +248,10 @@ def get_raw_msg_by_timestamp_with_chat_users( } # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_actions_by_timestamp_with_chat( +async def get_actions_by_timestamp_with_chat( chat_id: str, timestamp_start: float = 0, timestamp_end: float = time.time(), @@ -269,10 +270,10 @@ def get_actions_by_timestamp_with_chat( f"limit={limit}, limit_mode={limit_mode}" ) - with get_db_session() as session: + async with get_db_session() as session: if limit > 0: if limit_mode == "latest": - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -302,7 +303,7 @@ def get_actions_by_timestamp_with_chat( } actions_result.append(action_dict) else: # earliest - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -332,7 +333,7 @@ def get_actions_by_timestamp_with_chat( } actions_result.append(action_dict) else: - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -363,14 +364,14 @@ def get_actions_by_timestamp_with_chat( return actions_result -def get_actions_by_timestamp_with_chat_inclusive( +async def get_actions_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" - with get_db_session() as session: + async with get_db_session() as session: if limit > 0: if limit_mode == "latest": - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -385,7 +386,7 @@ def get_actions_by_timestamp_with_chat_inclusive( actions = list(query.scalars()) return [action.__dict__ for action in reversed(actions)] else: # earliest - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -398,7 +399,7 @@ def get_actions_by_timestamp_with_chat_inclusive( .limit(limit) ) else: - query = session.execute( + query = await session.execute( select(ActionRecords) .where( and_( @@ -414,14 +415,14 @@ def get_actions_by_timestamp_with_chat_inclusive( return [action.__dict__ for action in actions] -def get_raw_msg_by_timestamp_random( +async def get_raw_msg_by_timestamp_random( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ 先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息 """ # 获取所有消息,只取chat_id字段 - all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end) + all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end) if not all_msgs: return [] # 随机选一条 @@ -429,10 +430,10 @@ def get_raw_msg_by_timestamp_random( chat_id = msg["chat_id"] timestamp_start = msg["time"] # 用 chat_id 获取该聊天在指定时间戳范围内的消息 - return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") + return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") -def get_raw_msg_by_timestamp_with_users( +async def get_raw_msg_by_timestamp_with_users( timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 @@ -442,37 +443,39 @@ def get_raw_msg_by_timestamp_with_users( filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}} # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"time": {"$lt": timestamp}} sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}} sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_users( + timestamp: float, person_ids: list, limit: int = 0 +) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}} sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: +async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -486,10 +489,10 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp return 0 # 起始时间大于等于结束时间,没有新消息 filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}} - return count_messages(message_filter=filter_query) + return await count_messages(message_filter=filter_query) -def num_new_messages_since_with_users( +async def num_new_messages_since_with_users( chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list ) -> int: """检查某些特定用户在特定聊天在指定时间戳之间有多少新消息""" @@ -500,10 +503,10 @@ def num_new_messages_since_with_users( "time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}, } - return count_messages(message_filter=filter_query) + return await count_messages(message_filter=filter_query) -def _build_readable_messages_internal( +async def _build_readable_messages_internal( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -622,7 +625,8 @@ def _build_readable_messages_internal( if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" else: - person_name = person.person_name or user_id # type: ignore + person_info = await person_info_manager.get_values(person_id, ["person_name"]) + person_name = person_info.get("person_name") # type: ignore # 如果 person_name 未设置,则使用消息中的 nickname 或默认名称 if not person_name: @@ -791,7 +795,7 @@ def _build_readable_messages_internal( ) -def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: +async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # sourcery skip: use-contextlib-suppress """ 构建图片映射信息字符串,显示图片的具体描述内容 @@ -814,9 +818,9 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # 从数据库中获取图片描述 description = "[图片内容未知]" # 默认描述 try: - with get_db_session() as session: - image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none() - if image and image.description: # type: ignore + async with get_db_session() as session: + image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none() + if image and image.description: # type: ignore description = image.description except Exception: # 如果查询失败,保持默认描述 @@ -912,17 +916,17 @@ async def build_readable_messages_with_list( 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 允许通过参数控制格式化行为。 """ - formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal( messages, replace_bot_name, merge_messages, timestamp_mode, truncate ) - if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): + if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping): formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" return formatted_string, details_list -def build_readable_messages_with_id( +async def build_readable_messages_with_id( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -938,7 +942,7 @@ def build_readable_messages_with_id( """ message_id_list = assign_message_ids(messages) - formatted_string = build_readable_messages( + formatted_string = await build_readable_messages( messages=messages, replace_bot_name=replace_bot_name, merge_messages=merge_messages, @@ -953,7 +957,7 @@ def build_readable_messages_with_id( return formatted_string, message_id_list -def build_readable_messages( +async def build_readable_messages( messages: List[Dict[str, Any]], replace_bot_name: bool = True, merge_messages: bool = False, @@ -994,24 +998,28 @@ def build_readable_messages( from src.common.database.sqlalchemy_database_api import get_db_session - with get_db_session() as session: + async with get_db_session() as session: # 获取这个时间范围内的动作记录,并匹配chat_id - actions_in_range = session.execute( - select(ActionRecords) - .where( - and_( - ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id + actions_in_range = ( + await session.execute( + select(ActionRecords) + .where( + and_( + ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id + ) ) + .order_by(ActionRecords.time) ) - .order_by(ActionRecords.time) ).scalars() # 获取最新消息之后的第一个动作记录 - action_after_latest = session.execute( - select(ActionRecords) - .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) - .order_by(ActionRecords.time) - .limit(1) + action_after_latest = ( + await session.execute( + select(ActionRecords) + .where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id)) + .order_by(ActionRecords.time) + .limit(1) + ) ).scalars() # 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError @@ -1043,7 +1051,7 @@ def build_readable_messages( if read_mark <= 0: # 没有有效的 read_mark,直接格式化所有消息 - formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal( copy_messages, replace_bot_name, merge_messages, @@ -1054,7 +1062,7 @@ def build_readable_messages( ) # 生成图片映射信息并添加到最前面 - pic_mapping_info = build_pic_mapping_info(pic_id_mapping) + pic_mapping_info = await build_pic_mapping_info(pic_id_mapping) if pic_mapping_info: return f"{pic_mapping_info}\n\n{formatted_string}" else: @@ -1069,7 +1077,7 @@ def build_readable_messages( pic_counter = 1 # 分别格式化,但使用共享的图片映射 - formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal( + formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal( messages_before_mark, replace_bot_name, merge_messages, @@ -1080,7 +1088,7 @@ def build_readable_messages( show_pic=show_pic, message_id_list=message_id_list, ) - formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( + formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal( messages_after_mark, replace_bot_name, merge_messages, @@ -1096,7 +1104,7 @@ def build_readable_messages( # 生成图片映射信息 if pic_id_mapping: - pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" + pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n" else: pic_mapping_info = "聊天记录信息:\n" diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 78e856f39..63d4c000d 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -25,7 +25,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]: return {col.name: getattr(instance, col.name) for col in instance.__table__.columns} -def find_messages( +async def find_messages( message_filter: dict[str, Any], sort: Optional[List[tuple[str, int]]] = None, limit: int = 0, @@ -46,7 +46,7 @@ def find_messages( 消息字典列表,如果出错则返回空列表。 """ try: - with get_db_session() as session: + async with get_db_session() as session: query = select(Messages) # 应用过滤器 @@ -96,7 +96,7 @@ def find_messages( # 获取时间最早的 limit 条记录,已经是正序 query = query.order_by(Messages.time.asc()).limit(limit) try: - results = session.execute(query).scalars().all() + results = (await session.execute(query)).scalars().all() except Exception as e: logger.error(f"执行earliest查询失败: {e}") results = [] @@ -104,7 +104,7 @@ def find_messages( # 获取时间最晚的 limit 条记录 query = query.order_by(Messages.time.desc()).limit(limit) try: - latest_results = session.execute(query).scalars().all() + latest_results = (await session.execute(query)).scalars().all() # 将结果按时间正序排列 results = sorted(latest_results, key=lambda msg: msg.time) except Exception as e: @@ -128,12 +128,12 @@ def find_messages( if sort_terms: query = query.order_by(*sort_terms) try: - results = session.execute(query).scalars().all() + results = (await session.execute(query)).scalars().all() except Exception as e: logger.error(f"执行无限制查询失败: {e}") results = [] - return [_model_to_dict(msg) for msg in results] + return [_model_to_dict(msg) for msg in results] except Exception as e: log_message = ( f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" @@ -143,7 +143,7 @@ def find_messages( return [] -def count_messages(message_filter: dict[str, Any]) -> int: +async def count_messages(message_filter: dict[str, Any]) -> int: """ 根据提供的过滤器计算消息数量。 @@ -154,7 +154,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: 符合条件的消息数量,如果出错则返回 0。 """ try: - with get_db_session() as session: + async with get_db_session() as session: query = select(func.count(Messages.id)) # 应用过滤器 @@ -192,7 +192,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: if conditions: query = query.where(*conditions) - count = session.execute(query).scalar() + count = (await session.execute(query)).scalar() return count or 0 except Exception as e: log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" diff --git a/src/main.py b/src/main.py index 36aa4a7db..ee7a67875 100644 --- a/src/main.py +++ b/src/main.py @@ -40,6 +40,9 @@ if not global_config.memory.enable_memory: def initialize(self): pass + async def initialize_async(self): + pass + def get_hippocampus(self): return None @@ -248,7 +251,7 @@ MoFox_Bot(第三方修改版) logger.info("聊天管理器初始化成功") # 初始化记忆系统 - self.hippocampus_manager.initialize() + await self.hippocampus_manager.initialize_async() logger.info("记忆系统初始化成功") # 初始化LPMM知识库 @@ -283,7 +286,7 @@ MoFox_Bot(第三方修改版) if global_config.planning_system.monthly_plan_enable: logger.info("正在初始化月度计划管理器...") try: - await monthly_plan_manager.start_monthly_plan_generation() + await monthly_plan_manager.initialize() logger.info("月度计划管理器初始化成功") except Exception as e: logger.error(f"月度计划管理器初始化失败: {e}") @@ -291,8 +294,7 @@ MoFox_Bot(第三方修改版) # 初始化日程管理器 if global_config.planning_system.schedule_enable: logger.info("日程表功能已启用,正在初始化管理器...") - await schedule_manager.load_or_generate_today_schedule() - await schedule_manager.start_daily_schedule_generation() + await schedule_manager.initialize() logger.info("日程表管理器初始化成功。") try: diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 68917de0f..e1d0d23b1 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -401,14 +401,15 @@ class PersonInfoManager: # # 初始化时读取所有person_name try: + pass # 在这里获取会话 - with get_db_session() as session: - for record in session.execute( - select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None)) - ).fetchall(): - if record.person_name: - self.person_name_list[record.person_id] = record.person_name - logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)") + # with get_db_session() as session: + # for record in session.execute( + # select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None)) + # ).fetchall(): + # if record.person_name: + # self.person_name_list[record.person_id] = record.person_name + # logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)") except Exception as e: logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}") @@ -430,23 +431,25 @@ class PersonInfoManager: """判断是否认识某人""" person_id = self.get_person_id(platform, user_id) - def _db_check_known_sync(p_id: str): + async def _db_check_known_async(p_id: str): # 在需要时获取会话 - with get_db_session() as session: - return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None + async with get_db_session() as session: + return ( + await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)) + ).scalar() is not None try: - return await asyncio.to_thread(_db_check_known_sync, person_id) + return await _db_check_known_async(person_id) except Exception as e: logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}") return False - def get_person_id_by_person_name(self, person_name: str) -> str: + async def get_person_id_by_person_name(self, person_name: str) -> str: """根据用户名获取用户ID""" try: # 在需要时获取会话 - with get_db_session() as session: - record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar() + async with get_db_session() as session: + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar() return record.person_id if record else "" except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}") @@ -500,19 +503,18 @@ class PersonInfoManager: final_data[key] = orjson.dumps([]).decode("utf-8") # If it's already a string, assume it's valid JSON or a non-JSON string field - def _db_create_sync(p_data: dict): - with get_db_session() as session: + async def _db_create_async(p_data: dict): + async with get_db_session() as session: try: new_person = PersonInfo(**p_data) session.add(new_person) - session.commit() - + await session.commit() return True except Exception as e: logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") return False - await asyncio.to_thread(_db_create_sync, final_data) + await _db_create_async(final_data) async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None): """安全地创建用户信息,处理竞态条件""" @@ -557,11 +559,11 @@ class PersonInfoManager: elif final_data[key] is None: # Default for lists is [], store as "[]" final_data[key] = orjson.dumps([]).decode("utf-8") - def _db_safe_create_sync(p_data: dict): - with get_db_session() as session: + async def _db_safe_create_async(p_data: dict): + async with get_db_session() as session: try: - existing = session.execute( - select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"]) + existing = ( + await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])) ).scalar() if existing: logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建") @@ -570,18 +572,17 @@ class PersonInfoManager: # 尝试创建 new_person = PersonInfo(**p_data) session.add(new_person) - session.commit() - + await session.commit() return True except Exception as e: if "UNIQUE constraint failed" in str(e): logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误") - return True # 其他协程已创建,视为成功 + return True else: logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}") return False - await asyncio.to_thread(_db_safe_create_sync, final_data) + await _db_safe_create_async(final_data) async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None): """更新某一个字段,会补全""" @@ -598,37 +599,33 @@ class PersonInfoManager: elif value is None: # Store None as "[]" for JSON list fields processed_value = orjson.dumps([]).decode("utf-8") - def _db_update_sync(p_id: str, f_name: str, val_to_set): + async def _db_update_async(p_id: str, f_name: str, val_to_set): start_time = time.time() - with get_db_session() as session: + async with get_db_session() as session: try: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() query_time = time.time() - if record: setattr(record, f_name, val_to_set) - save_time = time.time() - total_time = save_time - start_time - if total_time > 0.5: # 如果超过500ms就记录日志 + if total_time > 0.5: logger.warning( f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}" ) - session.commit() - - return True, False # Found and updated, no creation needed + await session.commit() + return True, False else: total_time = time.time() - start_time if total_time > 0.5: logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}") - return False, True # Not found, needs creation + return False, True except Exception as e: total_time = time.time() - start_time logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}") raise - found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value) + found, needs_creation = await _db_update_async(person_id, field_name, processed_value) if needs_creation: logger.info(f"{person_id} 不存在,将新建。") @@ -666,13 +663,13 @@ class PersonInfoManager: logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。") return False - def _db_has_field_sync(p_id: str, f_name: str): - with get_db_session() as session: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + async def _db_has_field_async(p_id: str, f_name: str): + async with get_db_session() as session: + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() return bool(record) try: - return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) + return await _db_has_field_async(person_id, field_name) except Exception as e: logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}") return False @@ -778,14 +775,14 @@ class PersonInfoManager: logger.info(f"尝试给用户{user_nickname} {person_id} 取名,但是 {generated_nickname} 已存在,重试中...") else: - def _db_check_name_exists_sync(name_to_check): - with get_db_session() as session: + async def _db_check_name_exists_async(name_to_check): + async with get_db_session() as session: return ( - session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() + (await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar() is not None ) - if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): + if await _db_check_name_exists_async(generated_nickname): is_duplicate = True current_name_set.add(generated_nickname) @@ -824,91 +821,26 @@ class PersonInfoManager: logger.debug("删除失败:person_id 不能为空") return - def _db_delete_sync(p_id: str): + async def _db_delete_async(p_id: str): try: - with get_db_session() as session: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + async with get_db_session() as session: + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() if record: - session.delete(record) - session.commit() - return 1 + await session.delete(record) + await session.commit() + return 1 return 0 except Exception as e: logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}") return 0 - deleted_count = await asyncio.to_thread(_db_delete_sync, person_id) + deleted_count = await _db_delete_async(person_id) if deleted_count > 0: - logger.debug(f"删除成功:person_id={person_id} (Peewee)") + logger.debug(f"删除成功:person_id={person_id}") else: - logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)") + logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行") - @staticmethod - async def get_value(person_id: str, field_name: str): - """获取指定用户指定字段的值""" - default_value_for_field = person_info_default.get(field_name) - if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: - default_value_for_field = [] # Ensure JSON fields default to [] if not in DB - - def _db_get_value_sync(p_id: str, f_name: str): - with get_db_session() as session: - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() - if record: - val = getattr(record, f_name, None) - if f_name in JSON_SERIALIZED_FIELDS: - if isinstance(val, str): - try: - return orjson.loads(val) - except orjson.JSONDecodeError: - logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.") - return [] # Default for JSON fields on error - elif val is None: # Field exists in DB but is None - return [] # Default for JSON fields - # If val is already a list/dict (e.g. if somehow set without serialization) - return val # Should ideally not happen if update_one_field is always used - return val - return None # Record not found - - try: - value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name) - if value_from_db is not None: - return value_from_db - if field_name in person_info_default: - return default_value_for_field - logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") - return None # Ultimate fallback - except Exception as e: - logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}") - # Fallback to default in case of any error during DB access - return default_value_for_field if field_name in person_info_default else None - - @staticmethod - def get_value_sync(person_id: str, field_name: str): - """同步获取指定用户指定字段的值""" - default_value_for_field = person_info_default.get(field_name) - with get_db_session() as session: - if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: - default_value_for_field = [] - - if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar(): - val = getattr(record, field_name, None) - if field_name in JSON_SERIALIZED_FIELDS: - if isinstance(val, str): - try: - return orjson.loads(val) - except orjson.JSONDecodeError: - logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.") - return [] - elif val is None: - return [] - return val - return val - - if field_name in person_info_default: - return default_value_for_field - logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。") - return None @staticmethod async def get_values(person_id: str, field_names: list) -> dict: @@ -919,11 +851,11 @@ class PersonInfoManager: result = {} - def _db_get_record_sync(p_id: str): - with get_db_session() as session: - return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + async def _db_get_record_async(p_id: str): + async with get_db_session() as session: + return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() - record = await asyncio.to_thread(_db_get_record_sync, person_id) + record = await _db_get_record_async(person_id) # 获取 SQLAlchemy 模型的所有字段名 model_fields = [column.name for column in PersonInfo.__table__.columns] @@ -960,14 +892,15 @@ class PersonInfoManager: # 获取 SQLAlchemy 模型的所有字段名 model_fields = [column.name for column in PersonInfo.__table__.columns] if field_name not in model_fields: - logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义") + logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模型中定义") return {} - def _db_get_specific_sync(f_name: str): + async def _db_get_specific_async(f_name: str): found_results = {} try: - with get_db_session() as session: - for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall(): + async with get_db_session() as session: + result = await session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))) + for record in result.fetchall(): value = getattr(record, f_name) if way(value): found_results[record.person_id] = value @@ -978,9 +911,9 @@ class PersonInfoManager: return found_results try: - return await asyncio.to_thread(_db_get_specific_sync, field_name) + return await _db_get_specific_async(field_name) except Exception as e: - logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True) + logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True) return {} async def get_or_create_person( @@ -993,40 +926,38 @@ class PersonInfoManager: """ person_id = self.get_person_id(platform, user_id) - def _db_get_or_create_sync(p_id: str, init_data: dict): + async def _db_get_or_create_async(p_id: str, init_data: dict): """原子性的获取或创建操作""" - with get_db_session() as session: + async with get_db_session() as session: # 首先尝试获取现有记录 - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() if record: return record, False # 记录存在,未创建 - # 记录不存在,尝试创建 - try: - new_person = PersonInfo(**init_data) - session.add(new_person) - session.commit() - - return session.execute( - select(PersonInfo).where(PersonInfo.person_id == p_id) - ).scalar(), True # 创建成功 - except Exception as e: - # 如果创建失败(可能是因为竞态条件),再次尝试获取 - if "UNIQUE constraint failed" in str(e): - logger.debug(f"检测到并发创建用户 {p_id},获取现有记录") - record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() - if record: - return record, False # 其他协程已创建,返回现有记录 - # 如果仍然失败,重新抛出异常 - raise e - + # 记录不存在,尝试创建 + try: + new_person = PersonInfo(**init_data) + session.add(new_person) + await session.commit() + await session.refresh(new_person) + return new_person, True # 创建成功 + except Exception as e: + # 如果创建失败(可能是因为竞态条件),再次尝试获取 + if "UNIQUE constraint failed" in str(e): + logger.debug(f"检测到并发创建用户 {p_id},获取现有记录") + record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar() + if record: + return record, False # 其他协程已创建,返回现有记录 + # 如果仍然失败,重新抛出异常 + raise e + unique_nickname = await self._generate_unique_person_name(nickname) initial_data = { "person_id": person_id, "platform": platform, "user_id": str(user_id), "nickname": nickname, - "person_name": unique_nickname, # 使用群昵称作为person_name + "person_name": unique_nickname, "name_reason": "从群昵称获取", "know_times": 0, "know_since": int(datetime.datetime.now().timestamp()), @@ -1036,7 +967,6 @@ class PersonInfoManager: "forgotten_points": [], } - # 序列化JSON字段 for key in JSON_SERIALIZED_FIELDS: if key in initial_data: if isinstance(initial_data[key], (list, dict)): @@ -1044,15 +974,14 @@ class PersonInfoManager: elif initial_data[key] is None: initial_data[key] = orjson.dumps([]).decode("utf-8") - # 获取 SQLAlchemy 模odel的所有字段名 model_fields = [column.name for column in PersonInfo.__table__.columns] filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} - record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data) + record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data) if was_created: - logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") - logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") + logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。") + logger.info(f"已为 {person_id} 创建新记录,初始数据: {filtered_initial_data}") else: logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。") @@ -1072,11 +1001,13 @@ class PersonInfoManager: if not found_person_id: - def _db_find_by_name_sync(p_name_to_find: str): - with get_db_session() as session: - return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar() + async def _db_find_by_name_async(p_name_to_find: str): + async with get_db_session() as session: + return ( + await session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)) + ).scalar() - record = await asyncio.to_thread(_db_find_by_name_sync, person_name) + record = await _db_find_by_name_async(person_name) if record: found_person_id = record.person_id if ( diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py new file mode 100644 index 000000000..e903915a7 --- /dev/null +++ b/src/person_info/relationship_fetcher.py @@ -0,0 +1,457 @@ +import time +import traceback +import orjson +import random + +from typing import List, Dict, Any +from json_repair import repair_json + +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.person_info.person_info import get_person_info_manager + + +logger = get_logger("relationship_fetcher") + + +def init_real_time_info_prompts(): + """初始化实时信息提取相关的提示词""" + relationship_prompt = """ +<聊天记录> +{chat_observe_info} + + +{name_block} +现在,你想要回复{person_name}的消息,消息内容是:{target_message}。请根据聊天记录和你要回复的消息,从你对{person_name}的了解中提取有关的信息: +1.你需要提供你想要提取的信息具体是哪方面的信息,例如:年龄,性别,你们之间的交流方式,最近发生的事等等。 +2.请注意,请不要重复调取相同的信息,已经调取的信息如下: +{info_cache_block} +3.如果当前聊天记录中没有需要查询的信息,或者现有信息已经足够回复,请返回{{"none": "不需要查询"}} + +请以json格式输出,例如: + +{{ + "info_type": "信息类型", +}} + +请严格按照json输出格式,不要输出多余内容: +""" + Prompt(relationship_prompt, "real_time_info_identify_prompt") + + fetch_info_prompt = """ + +{name_block} +以下是你在之前与{person_name}的交流中,产生的对{person_name}的了解: +{person_impression_block} +{points_text_block} + +请从中提取用户"{person_name}"的有关"{info_type}"信息 +请以json格式输出,例如: + +{{ + {info_json_str} +}} + +请严格按照json输出格式,不要输出多余内容: +""" + Prompt(fetch_info_prompt, "real_time_fetch_person_info_prompt") + + +class RelationshipFetcher: + def __init__(self, chat_id): + self.chat_id = chat_id + + # 信息获取缓存:记录正在获取的信息请求 + self.info_fetching_cache: List[Dict[str, Any]] = [] + + # 信息结果缓存:存储已获取的信息结果,带TTL + self.info_fetched_cache: Dict[str, Dict[str, Any]] = {} + # 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}} + + # LLM模型配置 + self.llm_model = LLMRequest( + model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher" + ) + + # 小模型用于即时信息提取 + self.instant_llm_model = LLMRequest( + model_set=model_config.model_task_config.utils_small, request_type="relation.fetch" + ) + + name = get_chat_manager().get_stream_name(self.chat_id) + self.log_prefix = f"[{name}] 实时信息" + + def _cleanup_expired_cache(self): + """清理过期的信息缓存""" + for person_id in list(self.info_fetched_cache.keys()): + for info_type in list(self.info_fetched_cache[person_id].keys()): + self.info_fetched_cache[person_id][info_type]["ttl"] -= 1 + if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0: + del self.info_fetched_cache[person_id][info_type] + if not self.info_fetched_cache[person_id]: + del self.info_fetched_cache[person_id] + + async def build_relation_info(self, person_id, points_num=3): + # 清理过期的信息缓存 + self._cleanup_expired_cache() + + person_info_manager = get_person_info_manager() + person_info = await person_info_manager.get_values( + person_id, ["person_name", "short_impression", "nickname", "platform", "points"] + ) + person_name = person_info.get("person_name") + short_impression = person_info.get("short_impression") + nickname_str = person_info.get("nickname") + platform = person_info.get("platform") + + if person_name == nickname_str and not short_impression: + return "" + + current_points = person_info.get("points") or [] + + # 按时间排序forgotten_points + current_points.sort(key=lambda x: x[2]) + # 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大 + if len(current_points) > points_num: + # point[1] 取值范围1-10,直接作为权重 + weights = [max(1, min(10, int(point[1]))) for point in current_points] + # 使用加权采样不放回,保证不重复 + indices = list(range(len(current_points))) + points = [] + for _ in range(points_num): + if not indices: + break + sub_weights = [weights[i] for i in indices] + chosen_idx = random.choices(indices, weights=sub_weights, k=1)[0] + points.append(current_points[chosen_idx]) + indices.remove(chosen_idx) + else: + points = current_points + + # 构建points文本 + points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points]) + + nickname_str = "" + if person_name != nickname_str: + nickname_str = f"(ta在{platform}上的昵称是{nickname_str})" + + relation_info = "" + + if short_impression and relation_info: + if points_text: + relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}。你还记得ta最近做的事:{points_text}" + else: + relation_info = ( + f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}" + ) + elif short_impression: + if points_text: + relation_info = ( + f"你对{person_name}的印象是{nickname_str}:{short_impression}。你还记得ta最近做的事:{points_text}" + ) + else: + relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}" + elif relation_info: + if points_text: + relation_info = ( + f"你对{person_name}的了解{nickname_str}:{relation_info}。你还记得ta最近做的事:{points_text}" + ) + else: + relation_info = f"你对{person_name}的了解{nickname_str}:{relation_info}" + elif points_text: + relation_info = f"你记得{person_name}{nickname_str}最近做的事:{points_text}" + else: + relation_info = "" + + return relation_info + + async def _build_fetch_query(self, person_id, target_message, chat_history): + nickname_str = ",".join(global_config.bot.alias_names) + name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" + person_info_manager = get_person_info_manager() + person_info = await person_info_manager.get_values(person_id, ["person_name"]) + person_name: str = person_info.get("person_name") # type: ignore + + info_cache_block = self._build_info_cache_block() + + prompt = (await global_prompt_manager.get_prompt_async("real_time_info_identify_prompt")).format( + chat_observe_info=chat_history, + name_block=name_block, + info_cache_block=info_cache_block, + person_name=person_name, + target_message=target_message, + ) + + try: + logger.debug(f"{self.log_prefix} 信息识别prompt: \n{prompt}\n") + content, _ = await self.llm_model.generate_response_async(prompt=prompt) + + if content: + content_json = orjson.loads(repair_json(content)) + + # 检查是否返回了不需要查询的标志 + if "none" in content_json: + logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息:{content_json.get('none', '')}") + return None + + if info_type := content_json.get("info_type"): + # 记录信息获取请求 + self.info_fetching_cache.append( + { + "person_id": get_person_info_manager().get_person_id_by_person_name(person_name), + "person_name": person_name, + "info_type": info_type, + "start_time": time.time(), + "forget": False, + } + ) + + # 限制缓存大小 + if len(self.info_fetching_cache) > 10: + self.info_fetching_cache.pop(0) + + logger.info(f"{self.log_prefix} 识别到需要调取用户 {person_name} 的[{info_type}]信息") + return info_type + else: + logger.warning(f"{self.log_prefix} LLM未返回有效的info_type。响应: {content}") + + except Exception as e: + logger.error(f"{self.log_prefix} 执行信息识别LLM请求时出错: {e}") + logger.error(traceback.format_exc()) + + return None + + def _build_info_cache_block(self) -> str: + """构建已获取信息的缓存块""" + info_cache_block = "" + if self.info_fetching_cache: + # 对于每个(person_id, info_type)组合,只保留最新的记录 + latest_records = {} + for info_fetching in self.info_fetching_cache: + key = (info_fetching["person_id"], info_fetching["info_type"]) + if key not in latest_records or info_fetching["start_time"] > latest_records[key]["start_time"]: + latest_records[key] = info_fetching + + # 按时间排序并生成显示文本 + sorted_records = sorted(latest_records.values(), key=lambda x: x["start_time"]) + for info_fetching in sorted_records: + info_cache_block += ( + f"你已经调取了[{info_fetching['person_name']}]的[{info_fetching['info_type']}]信息\n" + ) + return info_cache_block + + async def _extract_single_info(self, person_id: str, info_type: str, person_name: str): + """提取单个信息类型 + + Args: + person_id: 用户ID + info_type: 信息类型 + person_name: 用户名 + """ + start_time = time.time() + person_info_manager = get_person_info_manager() + + # 首先检查 info_list 缓存 + person_info = await person_info_manager.get_values(person_id, ["info_list"]) + info_list = person_info.get("info_list") or [] + cached_info = None + + # 查找对应的 info_type + for info_item in info_list: + if info_item.get("info_type") == info_type: + cached_info = info_item.get("info_content") + logger.debug(f"{self.log_prefix} 在info_list中找到 {person_name} 的 {info_type} 信息: {cached_info}") + break + + # 如果缓存中有信息,直接使用 + if cached_info: + if person_id not in self.info_fetched_cache: + self.info_fetched_cache[person_id] = {} + + self.info_fetched_cache[person_id][info_type] = { + "info": cached_info, + "ttl": 2, + "start_time": start_time, + "person_name": person_name, + "unknown": cached_info == "none", + } + logger.info(f"{self.log_prefix} 记得 {person_name} 的 {info_type}: {cached_info}") + return + + # 如果缓存中没有,尝试从用户档案中提取 + try: + person_info = await person_info_manager.get_values(person_id, ["impression", "points"]) + person_impression = person_info.get("impression") + points = person_info.get("points") + + # 构建印象信息块 + if person_impression: + person_impression_block = ( + f"<对{person_name}的总体了解>\n{person_impression}\n" + ) + else: + person_impression_block = "" + + # 构建要点信息块 + if points: + points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points]) + points_text_block = f"<对{person_name}的近期了解>\n{points_text}\n" + else: + points_text_block = "" + + # 如果完全没有用户信息 + if not points_text_block and not person_impression_block: + if person_id not in self.info_fetched_cache: + self.info_fetched_cache[person_id] = {} + self.info_fetched_cache[person_id][info_type] = { + "info": "none", + "ttl": 2, + "start_time": start_time, + "person_name": person_name, + "unknown": True, + } + logger.info(f"{self.log_prefix} 完全不认识 {person_name}") + await self._save_info_to_cache(person_id, info_type, "none") + return + + # 使用LLM提取信息 + nickname_str = ",".join(global_config.bot.alias_names) + name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" + + prompt = (await global_prompt_manager.get_prompt_async("real_time_fetch_person_info_prompt")).format( + name_block=name_block, + info_type=info_type, + person_impression_block=person_impression_block, + person_name=person_name, + info_json_str=f'"{info_type}": "有关{info_type}的信息内容"', + points_text_block=points_text_block, + ) + + # 使用小模型进行即时提取 + content, _ = await self.instant_llm_model.generate_response_async(prompt=prompt) + + if content: + content_json = orjson.loads(repair_json(content)) + if info_type in content_json: + info_content = content_json[info_type] + is_unknown = info_content == "none" or not info_content + + # 保存到运行时缓存 + if person_id not in self.info_fetched_cache: + self.info_fetched_cache[person_id] = {} + self.info_fetched_cache[person_id][info_type] = { + "info": "unknown" if is_unknown else info_content, + "ttl": 3, + "start_time": start_time, + "person_name": person_name, + "unknown": is_unknown, + } + + # 保存到持久化缓存 (info_list) + await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content) + + if not is_unknown: + logger.info(f"{self.log_prefix} 思考得到,{person_name} 的 {info_type}: {info_content}") + else: + logger.info(f"{self.log_prefix} 思考了也不知道{person_name} 的 {info_type} 信息") + else: + logger.warning(f"{self.log_prefix} 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。") + + except Exception as e: + logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}") + logger.error(traceback.format_exc()) + + async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str): + # sourcery skip: use-next + """将提取到的信息保存到 person_info 的 info_list 字段中 + + Args: + person_id: 用户ID + info_type: 信息类型 + info_content: 信息内容 + """ + try: + person_info_manager = get_person_info_manager() + + # 获取现有的 info_list + person_info = await person_info_manager.get_values(person_id, ["info_list"]) + info_list = person_info.get("info_list") or [] + + # 查找是否已存在相同 info_type 的记录 + found_index = -1 + for i, info_item in enumerate(info_list): + if isinstance(info_item, dict) and info_item.get("info_type") == info_type: + found_index = i + break + + # 创建新的信息记录 + new_info_item = { + "info_type": info_type, + "info_content": info_content, + } + + if found_index >= 0: + # 更新现有记录 + info_list[found_index] = new_info_item + logger.info(f"{self.log_prefix} [缓存更新] 更新 {person_id} 的 {info_type} 信息缓存") + else: + # 添加新记录 + info_list.append(new_info_item) + logger.info(f"{self.log_prefix} [缓存保存] 新增 {person_id} 的 {info_type} 信息缓存") + + # 保存更新后的 info_list + await person_info_manager.update_one_field(person_id, "info_list", info_list) + + except Exception as e: + logger.error(f"{self.log_prefix} [缓存保存] 保存信息到缓存失败: {e}") + logger.error(traceback.format_exc()) + + +class RelationshipFetcherManager: + """关系提取器管理器 + + 管理不同 chat_id 的 RelationshipFetcher 实例 + """ + + def __init__(self): + self._fetchers: Dict[str, RelationshipFetcher] = {} + + def get_fetcher(self, chat_id: str) -> RelationshipFetcher: + """获取或创建指定 chat_id 的 RelationshipFetcher + + Args: + chat_id: 聊天ID + + Returns: + RelationshipFetcher: 关系提取器实例 + """ + if chat_id not in self._fetchers: + self._fetchers[chat_id] = RelationshipFetcher(chat_id) + return self._fetchers[chat_id] + + def remove_fetcher(self, chat_id: str): + """移除指定 chat_id 的 RelationshipFetcher + + Args: + chat_id: 聊天ID + """ + if chat_id in self._fetchers: + del self._fetchers[chat_id] + + def clear_all(self): + """清空所有 RelationshipFetcher""" + self._fetchers.clear() + + def get_active_chat_ids(self) -> List[str]: + """获取所有活跃的 chat_id 列表""" + return list(self._fetchers.keys()) + + +# 全局管理器实例 +relationship_fetcher_manager = RelationshipFetcherManager() + + +init_real_time_info_prompts() diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 98fab2342..3d161b847 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -62,7 +62,7 @@ def get_messages_by_time( return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode) -def get_messages_by_time_in_chat( +async def get_messages_by_time_in_chat( chat_id: str, start_time: float, end_time: float, @@ -97,13 +97,13 @@ def get_messages_by_time_in_chat( if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") if filter_mai: - return filter_mai_messages( - get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) + return await filter_mai_messages( + await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) ) - return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) + return await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command) -def get_messages_by_time_in_chat_inclusive( +async def get_messages_by_time_in_chat_inclusive( chat_id: str, start_time: float, end_time: float, @@ -138,12 +138,12 @@ def get_messages_by_time_in_chat_inclusive( if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") if filter_mai: - return filter_mai_messages( - get_raw_msg_by_timestamp_with_chat_inclusive( + return await filter_mai_messages( + await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id, start_time, end_time, limit, limit_mode, filter_command ) ) - return get_raw_msg_by_timestamp_with_chat_inclusive( + return await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id, start_time, end_time, limit, limit_mode, filter_command ) @@ -478,7 +478,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s # ============================================================================= -def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 从消息列表中移除麦麦的消息 Args: diff --git a/src/schedule/database.py b/src/schedule/database.py index 88337f4df..5025c1fa3 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -1,6 +1,7 @@ # mmc/src/schedule/database.py from typing import List +from sqlalchemy import select, func, update, delete from src.common.database.sqlalchemy_models import MonthlyPlan, get_db_session from src.common.logger import get_logger from src.config.config import global_config @@ -8,21 +9,22 @@ from src.config.config import global_config logger = get_logger("schedule_database") -def add_new_plans(plans: List[str], month: str): +async def add_new_plans(plans: List[str], month: str): """ 批量添加新生成的月度计划到数据库,并确保不超过上限。 :param plans: 计划内容列表。 :param month: 目标月份,格式为 "YYYY-MM"。 """ - with get_db_session() as session: + async with get_db_session() as session: try: # 1. 获取当前有效计划数量(状态为 'active') - current_plan_count = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") - .count() + result = await session.execute( + select(func.count(MonthlyPlan.id)).where( + MonthlyPlan.target_month == month, MonthlyPlan.status == "active" + ) ) + current_plan_count = result.scalar_one() # 2. 从配置获取上限 max_plans = global_config.planning_system.max_plans_per_month @@ -41,7 +43,7 @@ def add_new_plans(plans: List[str], month: str): MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add ] session.add_all(new_plan_objects) - session.commit() + await session.commit() logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。") if len(plans) > len(plans_to_add): @@ -49,32 +51,31 @@ def add_new_plans(plans: List[str], month: str): except Exception as e: logger.error(f"添加月度计划时发生错误: {e}") - session.rollback() + await session.rollback() raise -def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: +async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: """ 获取指定月份所有状态为 'active' 的计划。 :param month: 目标月份,格式为 "YYYY-MM"。 :return: MonthlyPlan 对象列表。 """ - with get_db_session() as session: + async with get_db_session() as session: try: - plans = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") + result = await session.execute( + select(MonthlyPlan) + .where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") .order_by(MonthlyPlan.created_at.desc()) - .all() ) - return plans + return result.scalars().all() except Exception as e: logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}") return [] -def mark_plans_completed(plan_ids: List[int]): +async def mark_plans_completed(plan_ids: List[int]): """ 将指定ID的计划标记为已完成。 @@ -83,9 +84,10 @@ def mark_plans_completed(plan_ids: List[int]): if not plan_ids: return - with get_db_session() as session: + async with get_db_session() as session: try: - plans_to_mark = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all() + result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids))) + plans_to_mark = result.scalars().all() if not plans_to_mark: logger.info("没有需要标记为完成的月度计划。") return @@ -93,17 +95,17 @@ def mark_plans_completed(plan_ids: List[int]): plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)]) logger.info(f"以下 {len(plans_to_mark)} 条月度计划将被标记为已完成:\n{plan_details}") - session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update( - {"status": "completed"}, synchronize_session=False + await session.execute( + update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed") ) - session.commit() + await session.commit() except Exception as e: logger.error(f"标记月度计划为完成时发生错误: {e}") - session.rollback() + await session.rollback() raise -def delete_plans_by_ids(plan_ids: List[int]): +async def delete_plans_by_ids(plan_ids: List[int]): """ 根据ID列表从数据库中物理删除月度计划。 @@ -112,10 +114,11 @@ def delete_plans_by_ids(plan_ids: List[int]): if not plan_ids: return - with get_db_session() as session: + async with get_db_session() as session: try: # 先查询要删除的计划,用于日志记录 - plans_to_delete = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all() + result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids))) + plans_to_delete = result.scalars().all() if not plans_to_delete: logger.info("没有找到需要删除的月度计划。") return @@ -124,16 +127,16 @@ def delete_plans_by_ids(plan_ids: List[int]): logger.info(f"检测到月度计划超额,将删除以下 {len(plans_to_delete)} 条计划:\n{plan_details}") # 执行删除 - session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).delete(synchronize_session=False) - session.commit() + await session.execute(delete(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids))) + await session.commit() except Exception as e: logger.error(f"删除月度计划时发生错误: {e}") - session.rollback() + await session.rollback() raise -def update_plan_usage(plan_ids: List[int], used_date: str): +async def update_plan_usage(plan_ids: List[int], used_date: str): """ 更新计划的使用统计信息。 @@ -143,44 +146,47 @@ def update_plan_usage(plan_ids: List[int], used_date: str): if not plan_ids: return - with get_db_session() as session: + async with get_db_session() as session: try: # 获取完成阈值配置,如果不存在则使用默认值 completion_threshold = getattr(global_config.planning_system, "completion_threshold", 3) # 批量更新使用次数和最后使用日期 - session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update( - {"usage_count": MonthlyPlan.usage_count + 1, "last_used_date": used_date}, synchronize_session=False + await session.execute( + update(MonthlyPlan) + .where(MonthlyPlan.id.in_(plan_ids)) + .values(usage_count=MonthlyPlan.usage_count + 1, last_used_date=used_date) ) # 检查是否有计划达到完成阈值 - plans_to_complete = ( - session.query(MonthlyPlan) - .filter( + result = await session.execute( + select(MonthlyPlan).where( MonthlyPlan.id.in_(plan_ids), MonthlyPlan.usage_count >= completion_threshold, MonthlyPlan.status == "active", ) - .all() ) + plans_to_complete = result.scalars().all() if plans_to_complete: completed_ids = [plan.id for plan in plans_to_complete] - session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(completed_ids)).update( - {"status": "completed"}, synchronize_session=False + await session.execute( + update(MonthlyPlan).where(MonthlyPlan.id.in_(completed_ids)).values(status="completed") ) logger.info(f"计划 {completed_ids} 已达到使用阈值 ({completion_threshold}),标记为已完成。") - session.commit() + await session.commit() logger.info(f"成功更新了 {len(plan_ids)} 条月度计划的使用统计。") except Exception as e: logger.error(f"更新月度计划使用统计时发生错误: {e}") - session.rollback() + await session.rollback() raise -def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]: +async def get_smart_plans_for_daily_schedule( + month: str, max_count: int = 3, avoid_days: int = 7 +) -> List[MonthlyPlan]: """ 智能抽取月度计划用于每日日程生成。 @@ -196,19 +202,24 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day """ from datetime import datetime, timedelta - with get_db_session() as session: + async with get_db_session() as session: try: # 计算避免重复的日期阈值 avoid_date = (datetime.now() - timedelta(days=avoid_days)).strftime("%Y-%m-%d") # 查询符合条件的计划 - query = session.query(MonthlyPlan).filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") + query = select(MonthlyPlan).where( + MonthlyPlan.target_month == month, MonthlyPlan.status == "active" + ) # 排除最近使用过的计划 - query = query.filter((MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date)) + query = query.where( + (MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date) + ) # 按使用次数升序排列,优先选择使用次数少的 - plans = query.order_by(MonthlyPlan.usage_count.asc()).all() + result = await session.execute(query.order_by(MonthlyPlan.usage_count.asc())) + plans = result.scalars().all() if not plans: logger.info(f"没有找到符合条件的 {month} 月度计划。") @@ -228,31 +239,31 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day return [] -def archive_active_plans_for_month(month: str): +async def archive_active_plans_for_month(month: str): """ 将指定月份所有状态为 'active' 的计划归档为 'archived'。 通常在月底调用。 :param month: 目标月份,格式为 "YYYY-MM"。 """ - with get_db_session() as session: + async with get_db_session() as session: try: - updated_count = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") - .update({"status": "archived"}, synchronize_session=False) + result = await session.execute( + update(MonthlyPlan) + .where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") + .values(status="archived") ) - - session.commit() + updated_count = result.rowcount + await session.commit() logger.info(f"成功将 {updated_count} 条 {month} 的活跃月度计划归档。") return updated_count except Exception as e: logger.error(f"归档 {month} 的月度计划时发生错误: {e}") - session.rollback() + await session.rollback() raise -def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: +async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: """ 获取指定月份所有状态为 'archived' 的计划。 用于生成下个月计划时的参考。 @@ -260,34 +271,34 @@ def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]: :param month: 目标月份,格式为 "YYYY-MM"。 :return: MonthlyPlan 对象列表。 """ - with get_db_session() as session: + async with get_db_session() as session: try: - plans = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived") - .all() + result = await session.execute( + select(MonthlyPlan).where( + MonthlyPlan.target_month == month, MonthlyPlan.status == "archived" + ) ) - return plans + return result.scalars().all() except Exception as e: logger.error(f"查询 {month} 的归档月度计划时发生错误: {e}") return [] -def has_active_plans(month: str) -> bool: +async def has_active_plans(month: str) -> bool: """ 检查指定月份是否存在任何状态为 'active' 的计划。 :param month: 目标月份,格式为 "YYYY-MM"。 :return: 如果存在则返回 True,否则返回 False。 """ - with get_db_session() as session: + async with get_db_session() as session: try: - count = ( - session.query(MonthlyPlan) - .filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active") - .count() + result = await session.execute( + select(func.count(MonthlyPlan.id)).where( + MonthlyPlan.target_month == month, MonthlyPlan.status == "active" + ) ) - return count > 0 + return result.scalar_one() > 0 except Exception as e: logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}") return False \ No newline at end of file diff --git a/src/schedule/monthly_plan_manager.py b/src/schedule/monthly_plan_manager.py index 1d5984ea3..7deaaf77d 100644 --- a/src/schedule/monthly_plan_manager.py +++ b/src/schedule/monthly_plan_manager.py @@ -14,6 +14,11 @@ class MonthlyPlanManager: self.plan_manager = PlanManager() self.monthly_task_started = False + async def initialize(self): + logger.info("正在初始化月度计划管理器...") + await self.start_monthly_plan_generation() + logger.info("月度计划管理器初始化成功") + async def start_monthly_plan_generation(self): if not self.monthly_task_started: logger.info(" 正在启动每月月度计划生成任务...") diff --git a/src/schedule/plan_manager.py b/src/schedule/plan_manager.py index 0fae5c381..b84a37b72 100644 --- a/src/schedule/plan_manager.py +++ b/src/schedule/plan_manager.py @@ -28,20 +28,20 @@ class PlanManager: if target_month is None: target_month = datetime.now().strftime("%Y-%m") - if not has_active_plans(target_month): + if not await has_active_plans(target_month): logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。") generation_successful = await self._generate_monthly_plans_logic(target_month) return generation_successful else: logger.info(f"{target_month} 已存在有效的月度计划。") - plans = get_active_plans_for_month(target_month) + plans = await get_active_plans_for_month(target_month) max_plans = global_config.planning_system.max_plans_per_month if len(plans) > max_plans: logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。") plans_to_delete = plans[: len(plans) - max_plans] delete_ids = [p.id for p in plans_to_delete] - delete_plans_by_ids(delete_ids) # type: ignore - plans = get_active_plans_for_month(target_month) + await delete_plans_by_ids(delete_ids) # type: ignore + plans = await get_active_plans_for_month(target_month) if plans: plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)]) @@ -64,11 +64,11 @@ class PlanManager: return False last_month = self._get_previous_month(target_month) - archived_plans = get_archived_plans_for_month(last_month) + archived_plans = await get_archived_plans_for_month(last_month) plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans) if plans: - add_new_plans(plans, target_month) + await add_new_plans(plans, target_month) logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。") return True else: @@ -95,11 +95,11 @@ class PlanManager: if target_month is None: target_month = datetime.now().strftime("%Y-%m") logger.info(f" 开始归档 {target_month} 的活跃月度计划...") - archived_count = archive_active_plans_for_month(target_month) + archived_count = await archive_active_plans_for_month(target_month) logger.info(f" 成功归档了 {archived_count} 条 {target_month} 的月度计划。") except Exception as e: logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}") - def get_plans_for_schedule(self, month: str, max_count: int) -> List: + async def get_plans_for_schedule(self, month: str, max_count: int) -> List: avoid_days = global_config.planning_system.avoid_repetition_days - return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) \ No newline at end of file + return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) \ No newline at end of file diff --git a/src/schedule/schedule_manager.py b/src/schedule/schedule_manager.py index f97d7c03c..822131dec 100644 --- a/src/schedule/schedule_manager.py +++ b/src/schedule/schedule_manager.py @@ -23,6 +23,13 @@ class ScheduleManager: self.daily_task_started = False self.schedule_generation_running = False + async def initialize(self): + if global_config.planning_system.schedule_enable: + logger.info("日程表功能已启用,正在初始化管理器...") + await self.load_or_generate_today_schedule() + await self.start_daily_schedule_generation() + logger.info("日程表管理器初始化成功。") + async def start_daily_schedule_generation(self): if not self.daily_task_started: logger.info("正在启动每日日程生成任务...") @@ -40,7 +47,7 @@ class ScheduleManager: today_str = datetime.now().strftime("%Y-%m-%d") try: - schedule_data = self._load_schedule_from_db(today_str) + schedule_data = await self._load_schedule_from_db(today_str) if schedule_data: self.today_schedule = schedule_data self._log_loaded_schedule(today_str) @@ -54,9 +61,10 @@ class ScheduleManager: logger.info("尝试生成日程作为备用方案...") await self.generate_and_save_schedule() - def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]: - with get_db_session() as session: - schedule_record = session.query(Schedule).filter(Schedule.date == date_str).first() + async def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]: + async with get_db_session() as session: + result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) + schedule_record = result.scalars().first() if schedule_record: logger.info(f"从数据库加载今天的日程 ({date_str})。") schedule_data = orjson.loads(str(schedule_record.schedule_data)) @@ -90,35 +98,35 @@ class ScheduleManager: sampled_plans = [] if global_config.planning_system.monthly_plan_enable: await self.plan_manager.ensure_and_generate_plans_if_needed(current_month_str) - sampled_plans = self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3) + sampled_plans = await self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3) schedule_data = await self.llm_generator.generate_schedule_with_llm(sampled_plans) if schedule_data: - self._save_schedule_to_db(today_str, schedule_data) + await self._save_schedule_to_db(today_str, schedule_data) self.today_schedule = schedule_data self._log_generated_schedule(today_str, schedule_data) if sampled_plans: used_plan_ids = [plan.id for plan in sampled_plans] logger.info(f"更新使用过的月度计划 {used_plan_ids} 的统计信息。") - update_plan_usage(used_plan_ids, today_str) + await update_plan_usage(used_plan_ids, today_str) finally: self.schedule_generation_running = False logger.info("日程生成任务结束") - def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]): - with get_db_session() as session: + async def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]): + async with get_db_session() as session: schedule_json = orjson.dumps(schedule_data).decode("utf-8") - existing_schedule = session.query(Schedule).filter(Schedule.date == date_str).first() + result = await session.execute(select(Schedule).filter(Schedule.date == date_str)) + existing_schedule = result.scalars().first() if existing_schedule: - session.query(Schedule).filter(Schedule.date == date_str).update( - {Schedule.schedule_data: schedule_json, Schedule.updated_at: datetime.now()} - ) + existing_schedule.schedule_data = schedule_json + existing_schedule.updated_at = datetime.now() else: new_schedule = Schedule(date=date_str, schedule_data=schedule_json) session.add(new_schedule) - session.commit() + await session.commit() def _log_generated_schedule(self, date_str: str, schedule_data: List[Dict[str, Any]]): schedule_str = f"✅ 成功生成并保存今天的日程 ({date_str}):\n"