From e35c2bb9b4570e1bb48c7c898ba602d3aeb272af Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Wed, 9 Apr 2025 13:52:12 +0800 Subject: [PATCH 01/24] =?UTF-8?q?=E5=A4=A7=E4=BF=AE=5Fexecute=5Frequest?= =?UTF-8?q?=E7=82=B8=E7=A8=8B=E5=BA=8F=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/heart_flow/heartflow.py | 33 ++++++++----- src/heart_flow/observation.py | 6 ++- src/heart_flow/sub_heartflow.py | 48 +++++++++++-------- src/plugins/PFC/pfc.py | 9 ++-- src/plugins/chat/emoji_manager.py | 3 ++ src/plugins/chat/utils.py | 8 +++- src/plugins/memory_system/Hippocampus.py | 17 +++++-- src/plugins/schedule/schedule_generator.py | 6 ++- .../topic_identify/topic_identifier.py | 9 ++-- 9 files changed, 93 insertions(+), 46 deletions(-) diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index 9cf8d4674..798b1867b 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -18,7 +18,7 @@ heartflow_config = LogConfig( logger = get_module_logger("heartflow", config=heartflow_config) -class CuttentState: +class CurrentState: def __init__(self): self.willing = 0 self.current_state_info = "" @@ -34,7 +34,7 @@ class Heartflow: def __init__(self): self.current_mind = "你什么也没想" self.past_mind = [] - self.current_state: CuttentState = CuttentState() + self.current_state: CurrentState = CurrentState() self.llm_model = LLM_request( model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" ) @@ -102,7 +102,11 @@ class Heartflow: current_thinking_info = self.current_mind mood_info = self.current_state.mood related_memory_info = "memory" - sub_flows_info = await self.get_all_subheartflows_minds() + try: + sub_flows_info = await self.get_all_subheartflows_minds() + except Exception as e: + logger.error(f"获取子心流的想法失败: {e}") + return schedule_info = bot_schedule.get_current_num_task(num=4, time_info=True) @@ -111,26 +115,29 @@ class Heartflow: prompt += f"{personality_info}\n" prompt += f"你想起来{related_memory_info}。" prompt += f"刚刚你的主要想法是{current_thinking_info}。" - prompt += f"你还有一些小想法,因为你在参加不同的群聊天,是你正在做的事情:{sub_flows_info}\n" + prompt += f"你还有一些小想法,因为你在参加不同的群聊天,这是你正在做的事情:{sub_flows_info}\n" prompt += f"你现在{mood_info}。" prompt += "现在你接下去继续思考,产生新的想法,但是要基于原有的主要想法,不要分点输出," prompt += "输出连贯的内心独白,不要太长,但是记得结合上述的消息,关注新内容:" - reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) + try: + response, reasoning_content = await self.llm_model.generate_response_async(prompt) + except Exception as e: + logger.error(f"内心独白获取失败: {e}") + return + self.update_current_mind(response) - self.update_current_mind(reponse) - - self.current_mind = reponse + self.current_mind = response logger.info(f"麦麦的总体脑内状态:{self.current_mind}") # logger.info("麦麦想了想,当前活动:") # await bot_schedule.move_doing(self.current_mind) for _, subheartflow in self._subheartflows.items(): - subheartflow.main_heartflow_info = reponse + subheartflow.main_heartflow_info = response - def update_current_mind(self, reponse): + def update_current_mind(self, response): self.past_mind.append(self.current_mind) - self.current_mind = reponse + self.current_mind = response async def get_all_subheartflows_minds(self): sub_minds = "" @@ -167,9 +174,9 @@ class Heartflow: prompt += """现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白 不要太长,但是记得结合上述的消息,要记得你的人设,关注新内容:""" - reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) + response, reasoning_content = await self.llm_model.generate_response_async(prompt) - return reponse + return response def create_subheartflow(self, subheartflow_id): """ diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index 5befd7322..78cb9ef67 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -142,7 +142,11 @@ class ChattingObservation(Observation): prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容, 以及聊天中的一些重要信息,注意识别你自己的发言,记得不要分点,不要太长,精简的概括成一段文本\n""" prompt += "总结概括:" - self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt) + try: + self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt) + except Exception as e: + print(f"获取总结失败: {e}") + self.observe_info = "" print(f"prompt:{prompt}") print(f"self.observe_info:{self.observe_info}") diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index a2ba023e2..cd15bb608 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -22,7 +22,7 @@ subheartflow_config = LogConfig( logger = get_module_logger("subheartflow", config=subheartflow_config) -class CuttentState: +class CurrentState: def __init__(self): self.willing = 0 self.current_state_info = "" @@ -40,7 +40,7 @@ class SubHeartflow: self.current_mind = "" self.past_mind = [] - self.current_state: CuttentState = CuttentState() + self.current_state: CurrentState = CurrentState() self.llm_model = LLM_request( model=global_config.llm_sub_heartflow, temperature=0.7, max_tokens=600, request_type="sub_heart_flow" ) @@ -143,11 +143,11 @@ class SubHeartflow: # prompt += f"你现在{mood_info}\n" # prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," # prompt += "但是记得结合上述的消息,要记得维持住你的人设,关注聊天和新内容,不要思考太多:" - # reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) + # response, reasoning_content = await self.llm_model.generate_response_async(prompt) - # self.update_current_mind(reponse) + # self.update_current_mind(response) - # self.current_mind = reponse + # self.current_mind = response # logger.debug(f"prompt:\n{prompt}\n") # logger.info(f"麦麦的脑内状态:{self.current_mind}") @@ -217,11 +217,14 @@ class SubHeartflow: prompt += f"你注意到有人刚刚说:{message_txt}\n" prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," prompt += "记得结合上述的消息,要记得维持住你的人设,注意自己的名字,关注有人刚刚说的内容,不要思考太多:" - reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) + try: + response, reasoning_content = await self.llm_model.generate_response_async(prompt) + except Exception as e: + logger.error(f"回复前内心独白获取失败: {e}") + response = "" + self.update_current_mind(response) - self.update_current_mind(reponse) - - self.current_mind = reponse + self.current_mind = response logger.debug(f"prompt:\n{prompt}\n") logger.info(f"麦麦的思考前脑内状态:{self.current_mind}") @@ -264,12 +267,14 @@ class SubHeartflow: prompt += f"你现在{mood_info}" prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白" prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:" + try: + response, reasoning_content = await self.llm_model.generate_response_async(prompt) + except Exception as e: + logger.error(f"回复后内心独白获取失败: {e}") + response = "" + self.update_current_mind(response) - reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) - - self.update_current_mind(reponse) - - self.current_mind = reponse + self.current_mind = response logger.info(f"麦麦回复后的脑内状态:{self.current_mind}") self.last_reply_time = time.time() @@ -302,10 +307,13 @@ class SubHeartflow: prompt += f"你现在{mood_info}。" prompt += "现在请你思考,你想不想发言或者回复,请你输出一个数字,1-10,1表示非常不想,10表示非常想。" prompt += "请你用<>包裹你的回复意愿,输出<1>表示不想回复,输出<10>表示非常想回复。请你考虑,你完全可以不回复" - - response, reasoning_content = await self.llm_model.generate_response_async(prompt) - # 解析willing值 - willing_match = re.search(r"<(\d+)>", response) + try: + response, reasoning_content = await self.llm_model.generate_response_async(prompt) + # 解析willing值 + willing_match = re.search(r"<(\d+)>", response) + except Exception as e: + logger.error(f"意愿判断获取失败: {e}") + willing_match = None if willing_match: self.current_state.willing = int(willing_match.group(1)) else: @@ -313,9 +321,9 @@ class SubHeartflow: return self.current_state.willing - def update_current_mind(self, reponse): + def update_current_mind(self, response): self.past_mind.append(self.current_mind) - self.current_mind = reponse + self.current_mind = response async def get_prompt_info(self, message: str, threshold: float): start_time = time.time() diff --git a/src/plugins/PFC/pfc.py b/src/plugins/PFC/pfc.py index 25c4728e0..824b8e93a 100644 --- a/src/plugins/PFC/pfc.py +++ b/src/plugins/PFC/pfc.py @@ -117,9 +117,12 @@ class GoalAnalyzer: }}""" logger.debug(f"发送到LLM的提示词: {prompt}") - content, _ = await self.llm.generate_response_async(prompt) - logger.debug(f"LLM原始返回内容: {content}") - + try: + content, _ = await self.llm.generate_response_async(prompt) + logger.debug(f"LLM原始返回内容: {content}") + except Exception as e: + logger.error(f"分析对话目标时出错: {str(e)}") + content = "" # 使用简化函数提取JSON内容 success, result = get_items_from_json( content, diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 6d070c83f..de3a5a54d 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -340,6 +340,9 @@ class EmojiManager: if description is not None: embedding = await get_embedding(description, request_type="emoji") + if not embedding: + logger.error("获取消息嵌入向量失败") + raise ValueError("获取消息嵌入向量失败") # 准备数据库记录 emoji_record = { "filename": filename, diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index b7cc32e2f..b7986ae3e 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -79,7 +79,13 @@ async def get_embedding(text, request_type="embedding"): """获取文本的embedding向量""" llm = LLM_request(model=global_config.embedding, request_type=request_type) # return llm.get_embedding_sync(text) - return await llm.get_embedding(text) + try: + embedding = await llm.get_embedding(text) + except Exception as e: + logger.error(f"获取embedding失败: {str(e)}") + embedding = None + return embedding + async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list: diff --git a/src/plugins/memory_system/Hippocampus.py b/src/plugins/memory_system/Hippocampus.py index 717cebe17..8e2cd21e7 100644 --- a/src/plugins/memory_system/Hippocampus.py +++ b/src/plugins/memory_system/Hippocampus.py @@ -1316,15 +1316,24 @@ class HippocampusManager: """从文本中获取相关记忆的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - return await self._hippocampus.get_memory_from_text( - text, max_memory_num, max_memory_length, max_depth, fast_retrieval - ) + try: + response = await self._hippocampus.get_memory_from_text(text, max_memory_num, max_memory_length, max_depth, fast_retrieval) + except Exception as e: + logger.error(f"文本激活记忆失败: {e}") + response = [] + return response + async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float: """从文本中获取激活值的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") - return await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval) + try: + response = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval) + except Exception as e: + logger.error(f"文本产生激活值失败: {e}") + response = 0.0 + return response def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: """从关键词获取相关记忆的公共接口""" diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index ccab662d1..23b898f7d 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -121,7 +121,11 @@ class ScheduleGenerator: self.today_done_list = [] if not self.today_schedule_text: logger.info(f"{today.strftime('%Y-%m-%d')}的日程不存在,准备生成新的日程") - self.today_schedule_text = await self.generate_daily_schedule(target_date=today) + try: + self.today_schedule_text = await self.generate_daily_schedule(target_date=today) + except Exception as e: + logger.error(f"生成日程时发生错误: {str(e)}") + self.today_schedule_text = "" self.save_today_schedule_to_db() diff --git a/src/plugins/topic_identify/topic_identifier.py b/src/plugins/topic_identify/topic_identifier.py index 39b985d7c..743e45870 100644 --- a/src/plugins/topic_identify/topic_identifier.py +++ b/src/plugins/topic_identify/topic_identifier.py @@ -29,10 +29,13 @@ class TopicIdentifier: 消息内容:{text}""" # 使用 LLM_request 类进行请求 - topic, _, _ = await self.llm_topic_judge.generate_response(prompt) - + try: + topic, _, _ = await self.llm_topic_judge.generate_response(prompt) + except Exception as e: + logger.error(f"LLM 请求topic失败: {e}") + return None if not topic: - logger.error("LLM API 返回为空") + logger.error("LLM 得到的topic为空") return None # 直接在这里处理主题解析 From c8c432f6b07e8c055f0e93c04751a48d2b6efa0e Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Wed, 9 Apr 2025 17:00:49 +0800 Subject: [PATCH 02/24] =?UTF-8?q?fix:=20maimmessage=E9=83=A8=E5=88=86?= =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E4=B8=8D=E5=86=8D=E5=88=9D=E5=A7=8B=E5=8C=96?= =?UTF-8?q?fastapi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/message/api.py | 169 +++++++++++++++++++++---------------- 1 file changed, 96 insertions(+), 73 deletions(-) diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index 2a6a2b6fc..19457bbec 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -1,5 +1,5 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect -from typing import Dict, Any, Callable, List, Set +from typing import Dict, Any, Callable, List, Set, Optional from src.common.logger import get_module_logger from src.plugins.message.message_base import MessageBase import aiohttp @@ -49,13 +49,22 @@ class MessageServer(BaseMessageHandler): _class_handlers: List[Callable] = [] # 类级别的消息处理器 - def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False): + def __init__( + self, + host: str = "0.0.0.0", + port: int = 18000, + enable_token=False, + app: Optional[FastAPI] = None, + path: str = "/ws", + ): super().__init__() # 将类级别的处理器添加到实例处理器中 self.message_handlers.extend(self._class_handlers) - self.app = FastAPI() self.host = host self.port = port + self.path = path + self.app = app or FastAPI() + self.own_app = app is None # 标记是否使用自己创建的app self.active_websockets: Set[WebSocket] = set() self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射 self.valid_tokens: Set[str] = set() @@ -63,28 +72,6 @@ class MessageServer(BaseMessageHandler): self._setup_routes() self._running = False - @classmethod - def register_class_handler(cls, handler: Callable): - """注册类级别的消息处理器""" - if handler not in cls._class_handlers: - cls._class_handlers.append(handler) - - def register_message_handler(self, handler: Callable): - """注册实例级别的消息处理器""" - if handler not in self.message_handlers: - self.message_handlers.append(handler) - - async def verify_token(self, token: str) -> bool: - if not self.enable_token: - return True - return token in self.valid_tokens - - def add_valid_token(self, token: str): - self.valid_tokens.add(token) - - def remove_valid_token(self, token: str): - self.valid_tokens.discard(token) - def _setup_routes(self): @self.app.post("/api/message") async def handle_message(message: Dict[str, Any]): @@ -125,6 +112,90 @@ class MessageServer(BaseMessageHandler): finally: self._remove_websocket(websocket, platform) + @classmethod + def register_class_handler(cls, handler: Callable): + """注册类级别的消息处理器""" + if handler not in cls._class_handlers: + cls._class_handlers.append(handler) + + def register_message_handler(self, handler: Callable): + """注册实例级别的消息处理器""" + if handler not in self.message_handlers: + self.message_handlers.append(handler) + + async def verify_token(self, token: str) -> bool: + if not self.enable_token: + return True + return token in self.valid_tokens + + def add_valid_token(self, token: str): + self.valid_tokens.add(token) + + def remove_valid_token(self, token: str): + self.valid_tokens.discard(token) + + def run_sync(self): + """同步方式运行服务器""" + if not self.own_app: + raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法") + uvicorn.run(self.app, host=self.host, port=self.port) + + async def run(self): + """异步方式运行服务器""" + self._running = True + try: + if self.own_app: + # 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器 + config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") + self.server = uvicorn.Server(config) + await self.server.serve() + else: + # 如果使用外部 FastAPI 实例,保持运行状态以处理消息 + while self._running: + await asyncio.sleep(1) + except KeyboardInterrupt: + await self.stop() + raise + except Exception as e: + await self.stop() + raise RuntimeError(f"服务器运行错误: {str(e)}") from e + finally: + await self.stop() + + async def start_server(self): + """启动服务器的异步方法""" + if not self._running: + self._running = True + await self.run() + + async def stop(self): + """停止服务器""" + # 清理platform映射 + self.platform_websockets.clear() + + # 取消所有后台任务 + for task in self.background_tasks: + task.cancel() + # 等待所有任务完成 + await asyncio.gather(*self.background_tasks, return_exceptions=True) + self.background_tasks.clear() + + # 关闭所有WebSocket连接 + for websocket in self.active_websockets: + await websocket.close() + self.active_websockets.clear() + + if hasattr(self, "server") and self.own_app: + self._running = False + # 正确关闭 uvicorn 服务器 + self.server.should_exit = True + await self.server.shutdown() + # 等待服务器完全停止 + if hasattr(self.server, "started") and self.server.started: + await self.server.main_loop() + # 清理处理程序 + self.message_handlers.clear() + def _remove_websocket(self, websocket: WebSocket, platform: str): """从所有集合中移除websocket""" if websocket in self.active_websockets: @@ -161,54 +232,6 @@ class MessageServer(BaseMessageHandler): async def send_message(self, message: MessageBase): await self.broadcast_to_platform(message.message_info.platform, message.to_dict()) - def run_sync(self): - """同步方式运行服务器""" - uvicorn.run(self.app, host=self.host, port=self.port) - - async def run(self): - """异步方式运行服务器""" - config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") - self.server = uvicorn.Server(config) - try: - await self.server.serve() - except KeyboardInterrupt as e: - await self.stop() - raise KeyboardInterrupt from e - - async def start_server(self): - """启动服务器的异步方法""" - if not self._running: - self._running = True - await self.run() - - async def stop(self): - """停止服务器""" - # 清理platform映射 - self.platform_websockets.clear() - - # 取消所有后台任务 - for task in self.background_tasks: - task.cancel() - # 等待所有任务完成 - await asyncio.gather(*self.background_tasks, return_exceptions=True) - self.background_tasks.clear() - - # 关闭所有WebSocket连接 - for websocket in self.active_websockets: - await websocket.close() - self.active_websockets.clear() - - if hasattr(self, "server"): - self._running = False - # 正确关闭 uvicorn 服务器 - self.server.should_exit = True - await self.server.shutdown() - # 等待服务器完全停止 - if hasattr(self.server, "started") and self.server.started: - await self.server.main_loop() - # 清理处理程序 - self.message_handlers.clear() - async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: """发送消息到指定端点""" async with aiohttp.ClientSession() as session: From c8172fe853a87bb19f1e244c94cfac5fccadb074 Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Wed, 9 Apr 2025 17:25:25 +0800 Subject: [PATCH 03/24] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=8D=A2fastapi?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/server.py | 73 ++++++++++++++++++++++++++ src/plugins/message/api.py | 104 +------------------------------------ 2 files changed, 75 insertions(+), 102 deletions(-) create mode 100644 src/common/server.py diff --git a/src/common/server.py b/src/common/server.py new file mode 100644 index 000000000..fd1f3ff18 --- /dev/null +++ b/src/common/server.py @@ -0,0 +1,73 @@ +from fastapi import FastAPI, APIRouter +from typing import Optional, Union +from uvicorn import Config, Server as UvicornServer +import os + + +class Server: + def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): + self.app = FastAPI(title=app_name) + self._host: str = "127.0.0.1" + self._port: int = 8080 + self._server: Optional[UvicornServer] = None + self.set_address(host, port) + + def register_router(self, router: APIRouter, prefix: str = ""): + """注册路由 + + APIRouter 用于对相关的路由端点进行分组和模块化管理: + 1. 可以将相关的端点组织在一起,便于管理 + 2. 支持添加统一的路由前缀 + 3. 可以为一组路由添加共同的依赖项、标签等 + + 示例: + router = APIRouter() + + @router.get("/users") + def get_users(): + return {"users": [...]} + + @router.post("/users") + def create_user(): + return {"msg": "user created"} + + # 注册路由,添加前缀 "/api/v1" + server.register_router(router, prefix="/api/v1") + """ + self.app.include_router(router, prefix=prefix) + + def set_address(self, host: Optional[str] = None, port: Optional[int] = None): + """设置服务器地址和端口""" + if host: + self._host = host + if port: + self._port = port + + async def run(self): + """启动服务器""" + config = Config(app=self.app, host=self._host, port=self._port) + self._server = UvicornServer(config=config) + try: + await self._server.serve() + except KeyboardInterrupt: + await self.shutdown() + raise + except Exception as e: + await self.shutdown() + raise RuntimeError(f"服务器运行错误: {str(e)}") from e + finally: + await self.shutdown() + + async def shutdown(self): + """安全关闭服务器""" + if self._server: + self._server.should_exit = True + await self._server.shutdown() + self._server = None + + def get_app(self) -> FastAPI: + """获取 FastAPI 实例""" + return self.app + + +global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"])) diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index 19457bbec..0c3e3a5a1 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from typing import Dict, Any, Callable, List, Set, Optional from src.common.logger import get_module_logger from src.plugins.message.message_base import MessageBase +from src.common.server import global_server import aiohttp import asyncio import uvicorn @@ -242,105 +243,4 @@ class MessageServer(BaseMessageHandler): raise e -class BaseMessageAPI: - def __init__(self, host: str = "0.0.0.0", port: int = 18000): - self.app = FastAPI() - self.host = host - self.port = port - self.message_handlers: List[Callable] = [] - self.cache = [] - self._setup_routes() - self._running = False - - def _setup_routes(self): - """设置基础路由""" - - @self.app.post("/api/message") - async def handle_message(message: Dict[str, Any]): - try: - # 创建后台任务处理消息 - asyncio.create_task(self._background_message_handler(message)) - return {"status": "success"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e - - async def _background_message_handler(self, message: Dict[str, Any]): - """后台处理单个消息""" - try: - await self.process_single_message(message) - except Exception as e: - logger.error(f"Background message processing failed: {str(e)}") - logger.error(traceback.format_exc()) - - def register_message_handler(self, handler: Callable): - """注册消息处理函数""" - self.message_handlers.append(handler) - - async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: - """发送消息到指定端点""" - async with aiohttp.ClientSession() as session: - try: - async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response: - return await response.json() - except Exception: - # logger.error(f"发送消息失败: {str(e)}") - pass - - async def process_single_message(self, message: Dict[str, Any]): - """处理单条消息""" - tasks = [] - for handler in self.message_handlers: - try: - tasks.append(handler(message)) - except Exception as e: - logger.error(str(e)) - logger.error(traceback.format_exc()) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - def run_sync(self): - """同步方式运行服务器""" - uvicorn.run(self.app, host=self.host, port=self.port) - - async def run(self): - """异步方式运行服务器""" - config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") - self.server = uvicorn.Server(config) - try: - await self.server.serve() - except KeyboardInterrupt as e: - await self.stop() - raise KeyboardInterrupt from e - - async def start_server(self): - """启动服务器的异步方法""" - if not self._running: - self._running = True - await self.run() - - async def stop(self): - """停止服务器""" - if hasattr(self, "server"): - self._running = False - # 正确关闭 uvicorn 服务器 - self.server.should_exit = True - await self.server.shutdown() - # 等待服务器完全停止 - if hasattr(self.server, "started") and self.server.started: - await self.server.main_loop() - # 清理处理程序 - self.message_handlers.clear() - - def start(self): - """启动服务器的便捷方法""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(self.start_server()) - except KeyboardInterrupt: - pass - finally: - loop.close() - - -global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"])) +global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app()) From 10c72ea43510acb6f72159112c535930c620531c Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Wed, 9 Apr 2025 17:25:25 +0800 Subject: [PATCH 04/24] =?UTF-8?q?refactor:=20=E6=9B=B4=E6=8D=A2fastapi?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/server.py | 73 ++++++++++++++++++++++ src/main.py | 4 +- src/plugins/message/__init__.py | 3 +- src/plugins/message/api.py | 104 +------------------------------- 4 files changed, 79 insertions(+), 105 deletions(-) create mode 100644 src/common/server.py diff --git a/src/common/server.py b/src/common/server.py new file mode 100644 index 000000000..fd1f3ff18 --- /dev/null +++ b/src/common/server.py @@ -0,0 +1,73 @@ +from fastapi import FastAPI, APIRouter +from typing import Optional, Union +from uvicorn import Config, Server as UvicornServer +import os + + +class Server: + def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): + self.app = FastAPI(title=app_name) + self._host: str = "127.0.0.1" + self._port: int = 8080 + self._server: Optional[UvicornServer] = None + self.set_address(host, port) + + def register_router(self, router: APIRouter, prefix: str = ""): + """注册路由 + + APIRouter 用于对相关的路由端点进行分组和模块化管理: + 1. 可以将相关的端点组织在一起,便于管理 + 2. 支持添加统一的路由前缀 + 3. 可以为一组路由添加共同的依赖项、标签等 + + 示例: + router = APIRouter() + + @router.get("/users") + def get_users(): + return {"users": [...]} + + @router.post("/users") + def create_user(): + return {"msg": "user created"} + + # 注册路由,添加前缀 "/api/v1" + server.register_router(router, prefix="/api/v1") + """ + self.app.include_router(router, prefix=prefix) + + def set_address(self, host: Optional[str] = None, port: Optional[int] = None): + """设置服务器地址和端口""" + if host: + self._host = host + if port: + self._port = port + + async def run(self): + """启动服务器""" + config = Config(app=self.app, host=self._host, port=self._port) + self._server = UvicornServer(config=config) + try: + await self._server.serve() + except KeyboardInterrupt: + await self.shutdown() + raise + except Exception as e: + await self.shutdown() + raise RuntimeError(f"服务器运行错误: {str(e)}") from e + finally: + await self.shutdown() + + async def shutdown(self): + """安全关闭服务器""" + if self._server: + self._server.should_exit = True + await self._server.shutdown() + self._server = None + + def get_app(self) -> FastAPI: + """获取 FastAPI 实例""" + return self.app + + +global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"])) diff --git a/src/main.py b/src/main.py index aa6f908bf..d94cfce64 100644 --- a/src/main.py +++ b/src/main.py @@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot from .common.logger import get_module_logger from .plugins.remote import heartbeat_thread # noqa: F401 from .individuality.individuality import Individuality - +from .common.server import global_server logger = get_module_logger("main") @@ -33,6 +33,7 @@ class MainSystem: from .plugins.message import global_api self.app = global_api + self.server = global_server async def initialize(self): """初始化系统组件""" @@ -126,6 +127,7 @@ class MainSystem: emoji_manager.start_periodic_check_register(), # emoji_manager.start_periodic_register(), self.app.run(), + self.server.run(), ] await asyncio.gather(*tasks) diff --git a/src/plugins/message/__init__.py b/src/plugins/message/__init__.py index bee5c5e58..286ef2310 100644 --- a/src/plugins/message/__init__.py +++ b/src/plugins/message/__init__.py @@ -2,7 +2,7 @@ __version__ = "0.1.0" -from .api import BaseMessageAPI, global_api +from .api import global_api from .message_base import ( Seg, GroupInfo, @@ -14,7 +14,6 @@ from .message_base import ( ) __all__ = [ - "BaseMessageAPI", "Seg", "global_api", "GroupInfo", diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index 19457bbec..0c3e3a5a1 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from typing import Dict, Any, Callable, List, Set, Optional from src.common.logger import get_module_logger from src.plugins.message.message_base import MessageBase +from src.common.server import global_server import aiohttp import asyncio import uvicorn @@ -242,105 +243,4 @@ class MessageServer(BaseMessageHandler): raise e -class BaseMessageAPI: - def __init__(self, host: str = "0.0.0.0", port: int = 18000): - self.app = FastAPI() - self.host = host - self.port = port - self.message_handlers: List[Callable] = [] - self.cache = [] - self._setup_routes() - self._running = False - - def _setup_routes(self): - """设置基础路由""" - - @self.app.post("/api/message") - async def handle_message(message: Dict[str, Any]): - try: - # 创建后台任务处理消息 - asyncio.create_task(self._background_message_handler(message)) - return {"status": "success"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e - - async def _background_message_handler(self, message: Dict[str, Any]): - """后台处理单个消息""" - try: - await self.process_single_message(message) - except Exception as e: - logger.error(f"Background message processing failed: {str(e)}") - logger.error(traceback.format_exc()) - - def register_message_handler(self, handler: Callable): - """注册消息处理函数""" - self.message_handlers.append(handler) - - async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: - """发送消息到指定端点""" - async with aiohttp.ClientSession() as session: - try: - async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response: - return await response.json() - except Exception: - # logger.error(f"发送消息失败: {str(e)}") - pass - - async def process_single_message(self, message: Dict[str, Any]): - """处理单条消息""" - tasks = [] - for handler in self.message_handlers: - try: - tasks.append(handler(message)) - except Exception as e: - logger.error(str(e)) - logger.error(traceback.format_exc()) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - def run_sync(self): - """同步方式运行服务器""" - uvicorn.run(self.app, host=self.host, port=self.port) - - async def run(self): - """异步方式运行服务器""" - config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") - self.server = uvicorn.Server(config) - try: - await self.server.serve() - except KeyboardInterrupt as e: - await self.stop() - raise KeyboardInterrupt from e - - async def start_server(self): - """启动服务器的异步方法""" - if not self._running: - self._running = True - await self.run() - - async def stop(self): - """停止服务器""" - if hasattr(self, "server"): - self._running = False - # 正确关闭 uvicorn 服务器 - self.server.should_exit = True - await self.server.shutdown() - # 等待服务器完全停止 - if hasattr(self.server, "started") and self.server.started: - await self.server.main_loop() - # 清理处理程序 - self.message_handlers.clear() - - def start(self): - """启动服务器的便捷方法""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(self.start_server()) - except KeyboardInterrupt: - pass - finally: - loop.close() - - -global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"])) +global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app()) From 08e5dd2f7bec9bbd9da2c66c65431fb7a5893279 Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Wed, 9 Apr 2025 17:50:54 +0800 Subject: [PATCH 05/24] ruff: --- bot.py | 2 +- src/common/crash_logger.py | 47 ++++---- src/plugins/PFC/action_planner.py | 60 ++++------ src/plugins/PFC/chat_observer.py | 146 ++++++++++-------------- src/plugins/PFC/chat_states.py | 143 ++++++++++++----------- src/plugins/PFC/conversation.py | 129 ++++++++++----------- src/plugins/PFC/conversation_info.py | 4 +- src/plugins/PFC/message_sender.py | 20 ++-- src/plugins/PFC/message_storage.py | 79 ++++++------- src/plugins/PFC/notification_handler.py | 30 ++--- src/plugins/PFC/observation_info.py | 113 +++++++++--------- src/plugins/PFC/pfc.py | 55 ++++----- src/plugins/PFC/pfc_manager.py | 43 ++++--- src/plugins/PFC/pfc_types.py | 3 +- src/plugins/PFC/reply_generator.py | 49 +++----- src/plugins/PFC/waiter.py | 21 ++-- src/plugins/chat/bot.py | 7 +- src/plugins/config/config.py | 2 +- 18 files changed, 439 insertions(+), 514 deletions(-) diff --git a/bot.py b/bot.py index ca214967e..5b12b0389 100644 --- a/bot.py +++ b/bot.py @@ -196,7 +196,7 @@ def raw_main(): # 安装崩溃日志处理器 install_crash_handler() - + check_eula() print("检查EULA和隐私条款完成") easter_egg() diff --git a/src/common/crash_logger.py b/src/common/crash_logger.py index 658e1bb02..d1e4fb51f 100644 --- a/src/common/crash_logger.py +++ b/src/common/crash_logger.py @@ -4,69 +4,66 @@ import logging from pathlib import Path from logging.handlers import RotatingFileHandler + def setup_crash_logger(): """设置崩溃日志记录器""" # 创建logs/crash目录(如果不存在) crash_log_dir = Path("logs/crash") crash_log_dir.mkdir(parents=True, exist_ok=True) - + # 创建日志记录器 - crash_logger = logging.getLogger('crash_logger') + crash_logger = logging.getLogger("crash_logger") crash_logger.setLevel(logging.ERROR) - + # 设置日志格式 formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s\n' - '异常类型: %(exc_info)s\n' - '详细信息:\n%(message)s\n' - '-------------------\n' + "%(asctime)s - %(name)s - %(levelname)s\n异常类型: %(exc_info)s\n详细信息:\n%(message)s\n-------------------\n" ) - + # 创建按大小轮转的文件处理器(最大10MB,保留5个备份) log_file = crash_log_dir / "crash.log" file_handler = RotatingFileHandler( log_file, - maxBytes=10*1024*1024, # 10MB + maxBytes=10 * 1024 * 1024, # 10MB backupCount=5, - encoding='utf-8' + encoding="utf-8", ) file_handler.setFormatter(formatter) crash_logger.addHandler(file_handler) - + return crash_logger + def log_crash(exc_type, exc_value, exc_traceback): """记录崩溃信息到日志文件""" if exc_type is None: return - + # 获取崩溃日志记录器 - crash_logger = logging.getLogger('crash_logger') - + crash_logger = logging.getLogger("crash_logger") + # 获取完整的异常堆栈信息 - stack_trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - + stack_trace = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + # 记录崩溃信息 - crash_logger.error( - stack_trace, - exc_info=(exc_type, exc_value, exc_traceback) - ) + crash_logger.error(stack_trace, exc_info=(exc_type, exc_value, exc_traceback)) + def install_crash_handler(): """安装全局异常处理器""" # 设置崩溃日志记录器 setup_crash_logger() - + # 保存原始的异常处理器 original_hook = sys.excepthook - + def exception_handler(exc_type, exc_value, exc_traceback): """全局异常处理器""" # 记录崩溃信息 log_crash(exc_type, exc_value, exc_traceback) - + # 调用原始的异常处理器 original_hook(exc_type, exc_value, exc_traceback) - + # 设置全局异常处理器 - sys.excepthook = exception_handler \ No newline at end of file + sys.excepthook = exception_handler diff --git a/src/plugins/PFC/action_planner.py b/src/plugins/PFC/action_planner.py index ad69fea1d..43b0749a1 100644 --- a/src/plugins/PFC/action_planner.py +++ b/src/plugins/PFC/action_planner.py @@ -10,6 +10,7 @@ from .conversation_info import ConversationInfo logger = get_module_logger("action_planner") + class ActionPlannerInfo: def __init__(self): self.done_action = [] @@ -20,68 +21,57 @@ class ActionPlannerInfo: class ActionPlanner: """行动规划器""" - + def __init__(self, stream_id: str): self.llm = LLM_request( - model=global_config.llm_normal, - temperature=0.7, - max_tokens=1000, - request_type="action_planning" + model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning" ) - self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) + self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.name = global_config.BOT_NICKNAME self.chat_observer = ChatObserver.get_instance(stream_id) - - async def plan( - self, - observation_info: ObservationInfo, - conversation_info: ConversationInfo - ) -> Tuple[str, str]: + + async def plan(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> Tuple[str, str]: """规划下一步行动 - + Args: observation_info: 决策信息 conversation_info: 对话信息 - + Returns: Tuple[str, str]: (行动类型, 行动原因) """ # 构建提示词 logger.debug(f"开始规划行动:当前目标: {conversation_info.goal_list}") - - #构建对话目标 + + # 构建对话目标 if conversation_info.goal_list: goal, reasoning = conversation_info.goal_list[-1] else: goal = "目前没有明确对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标" - - + # 获取聊天历史记录 chat_history_list = observation_info.chat_history chat_history_text = "" for msg in chat_history_list: chat_history_text += f"{msg}\n" - + if observation_info.new_messages_count > 0: new_messages_list = observation_info.unprocessed_messages - + chat_history_text += f"有{observation_info.new_messages_count}条新消息:\n" for msg in new_messages_list: chat_history_text += f"{msg}\n" - + observation_info.clear_unprocessed_messages() - - + personality_text = f"你的名字是{self.name},{self.personality_info}" - + # 构建action历史文本 action_history_list = conversation_info.done_action action_history_text = "你之前做的事情是:" for action in action_history_list: action_history_text += f"{action}\n" - - prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动: @@ -111,29 +101,27 @@ rethink_goal: 重新思考对话目标,当发现对话目标不合适时选择 try: content, _ = await self.llm.generate_response_async(prompt) logger.debug(f"LLM原始返回内容: {content}") - + # 使用简化函数提取JSON内容 success, result = get_items_from_json( - content, - "action", "reason", - default_values={"action": "direct_reply", "reason": "没有明确原因"} + content, "action", "reason", default_values={"action": "direct_reply", "reason": "没有明确原因"} ) - + if not success: return "direct_reply", "JSON解析失败,选择直接回复" - + action = result["action"] reason = result["reason"] - + # 验证action类型 if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal"]: logger.warning(f"未知的行动类型: {action},默认使用listening") action = "listening" - + logger.info(f"规划的行动: {action}") logger.info(f"行动原因: {reason}") return action, reason - + except Exception as e: logger.error(f"规划行动时出错: {str(e)}") - return "direct_reply", "发生错误,选择直接回复" \ No newline at end of file + return "direct_reply", "发生错误,选择直接回复" diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py index 93618cf2d..c96bc47b1 100644 --- a/src/plugins/PFC/chat_observer.py +++ b/src/plugins/PFC/chat_observer.py @@ -17,20 +17,20 @@ class ChatObserver: _instances: Dict[str, "ChatObserver"] = {} @classmethod - def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver': + def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver": """获取或创建观察器实例 Args: stream_id: 聊天流ID message_storage: 消息存储实现,如果为None则使用MongoDB实现 - + Returns: ChatObserver: 观察器实例 """ if stream_id not in cls._instances: cls._instances[stream_id] = cls(stream_id, message_storage) return cls._instances[stream_id] - + def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None): """初始化观察器 @@ -43,15 +43,15 @@ class ChatObserver: self.stream_id = stream_id self.message_storage = message_storage or MongoDBMessageStorage() - + self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 - self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 - self.last_check_time: float = time.time() # 上次查看聊天记录时间 - self.last_message_read: Optional[str] = None # 最后读取的消息ID - self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 - - self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 - + self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 + self.last_check_time: float = time.time() # 上次查看聊天记录时间 + self.last_message_read: Optional[str] = None # 最后读取的消息ID + self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 + + self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 + # 消息历史记录 self.message_history: List[Dict[str, Any]] = [] # 所有消息历史 self.last_message_id: Optional[str] = None # 最后一条消息的ID @@ -62,20 +62,20 @@ class ChatObserver: self._task: Optional[asyncio.Task] = None self._update_event = asyncio.Event() # 触发更新的事件 self._update_complete = asyncio.Event() # 更新完成的事件 - + # 通知管理器 self.notification_manager = NotificationManager() - + # 冷场检查配置 self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场 self.last_cold_chat_check: float = time.time() self.is_cold_chat_state: bool = False - + self.update_event = asyncio.Event() self.update_interval = 5 # 更新间隔(秒) self.message_cache = [] self.update_running = False - + async def check(self) -> bool: """检查距离上一次观察之后是否有了新消息 @@ -83,21 +83,18 @@ class ChatObserver: bool: 是否有新消息 """ logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}") - - new_message_exists = await self.message_storage.has_new_messages( - self.stream_id, - self.last_check_time - ) - + + new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time) + if new_message_exists: logger.debug("发现新消息") self.last_check_time = time.time() return new_message_exists - + async def _add_message_to_history(self, message: Dict[str, Any]): """添加消息到历史记录并发送通知 - + Args: message: 消息数据 """ @@ -112,76 +109,65 @@ class ChatObserver: self.last_bot_speak_time = message["time"] else: self.last_user_speak_time = message["time"] - + # 发送新消息通知 - notification = create_new_message_notification( - sender="chat_observer", - target="pfc", - message=message - ) + notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message) await self.notification_manager.send_notification(notification) - + # 检查并更新冷场状态 await self._check_cold_chat() - + async def _check_cold_chat(self): """检查是否处于冷场状态并发送通知""" current_time = time.time() - + # 每10秒检查一次冷场状态 if current_time - self.last_cold_chat_check < 10: return - + self.last_cold_chat_check = current_time - + # 判断是否冷场 is_cold = False if self.last_message_time is None: is_cold = True else: is_cold = (current_time - self.last_message_time) > self.cold_chat_threshold - + # 如果冷场状态发生变化,发送通知 if is_cold != self.is_cold_chat_state: self.is_cold_chat_state = is_cold - notification = create_cold_chat_notification( - sender="chat_observer", - target="pfc", - is_cold=is_cold - ) + notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold) await self.notification_manager.send_notification(notification) - + async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象""" - messages = await self.message_storage.get_messages_after( - self.stream_id, - self.last_message_read - ) + messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read) for message in messages: await self._add_message_to_history(message) return messages, self.message_history - + def new_message_after(self, time_point: float) -> bool: """判断是否在指定时间点后有新消息 - + Args: time_point: 时间戳 - + Returns: bool: 是否有新消息 """ if time_point is None: logger.warning("time_point 为 None,返回 False") return False - + if self.last_message_time is None: logger.debug("没有最后消息时间,返回 False") return False - + has_new = self.last_message_time > time_point logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}") return has_new - + def get_message_history( self, start_time: Optional[float] = None, @@ -224,11 +210,8 @@ class ChatObserver: Returns: List[Dict[str, Any]]: 新消息列表 """ - new_messages = await self.message_storage.get_messages_after( - self.stream_id, - self.last_message_read - ) - + new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read) + if new_messages: self.last_message_read = new_messages[-1]["message_id"] @@ -243,17 +226,15 @@ class ChatObserver: Returns: List[Dict[str, Any]]: 最多5条消息 """ - new_messages = await self.message_storage.get_messages_before( - self.stream_id, - time_point - ) - + new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point) + if new_messages: self.last_message_read = new_messages[-1]["message_id"] return new_messages - - '''主要观察循环''' + + """主要观察循环""" + async def _update_loop(self): """更新循环""" try: @@ -282,7 +263,7 @@ class ChatObserver: # 处理新消息 for message in new_messages: await self._add_message_to_history(message) - + # 设置完成事件 self._update_complete.set() @@ -379,7 +360,7 @@ class ChatObserver: if not self.update_running: self.update_running = True asyncio.create_task(self._periodic_update()) - + async def _periodic_update(self): """定期更新消息历史""" try: @@ -388,53 +369,52 @@ class ChatObserver: await asyncio.sleep(self.update_interval) except Exception as e: logger.error(f"定期更新消息历史时出错: {str(e)}") - + async def _update_message_history(self) -> bool: """更新消息历史 - + Returns: bool: 是否有新消息 """ try: - messages = await self.message_storage.get_messages_for_stream( - self.stream_id, - limit=50 - ) - + messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50) + if not messages: return False - + # 检查是否有新消息 has_new_messages = False - if messages and (not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]): + if messages and ( + not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"] + ): has_new_messages = True - + self.message_cache = messages - + if has_new_messages: self.update_event.set() self.update_event.clear() return True return False - + except Exception as e: logger.error(f"更新消息历史时出错: {str(e)}") return False - + def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: """获取缓存的消息历史 - + Args: limit: 获取的最大消息数量,默认50 - + Returns: List[Dict[str, Any]]: 缓存的消息历史列表 - """ + """ return self.message_cache[:limit] - + def get_last_message(self) -> Optional[Dict[str, Any]]: """获取最后一条消息 - + Returns: Optional[Dict[str, Any]]: 最后一条消息,如果没有则返回None """ diff --git a/src/plugins/PFC/chat_states.py b/src/plugins/PFC/chat_states.py index bb7cfc4a6..b28ca69a6 100644 --- a/src/plugins/PFC/chat_states.py +++ b/src/plugins/PFC/chat_states.py @@ -4,32 +4,38 @@ from dataclasses import dataclass from datetime import datetime from abc import ABC, abstractmethod + class ChatState(Enum): """聊天状态枚举""" - NORMAL = auto() # 正常状态 - NEW_MESSAGE = auto() # 有新消息 - COLD_CHAT = auto() # 冷场状态 - ACTIVE_CHAT = auto() # 活跃状态 - BOT_SPEAKING = auto() # 机器人正在说话 - USER_SPEAKING = auto() # 用户正在说话 - SILENT = auto() # 沉默状态 - ERROR = auto() # 错误状态 + + NORMAL = auto() # 正常状态 + NEW_MESSAGE = auto() # 有新消息 + COLD_CHAT = auto() # 冷场状态 + ACTIVE_CHAT = auto() # 活跃状态 + BOT_SPEAKING = auto() # 机器人正在说话 + USER_SPEAKING = auto() # 用户正在说话 + SILENT = auto() # 沉默状态 + ERROR = auto() # 错误状态 + class NotificationType(Enum): """通知类型枚举""" - NEW_MESSAGE = auto() # 新消息通知 - COLD_CHAT = auto() # 冷场通知 - ACTIVE_CHAT = auto() # 活跃通知 - BOT_SPEAKING = auto() # 机器人说话通知 - USER_SPEAKING = auto() # 用户说话通知 - MESSAGE_DELETED = auto() # 消息删除通知 - USER_JOINED = auto() # 用户加入通知 - USER_LEFT = auto() # 用户离开通知 - ERROR = auto() # 错误通知 + + NEW_MESSAGE = auto() # 新消息通知 + COLD_CHAT = auto() # 冷场通知 + ACTIVE_CHAT = auto() # 活跃通知 + BOT_SPEAKING = auto() # 机器人说话通知 + USER_SPEAKING = auto() # 用户说话通知 + MESSAGE_DELETED = auto() # 消息删除通知 + USER_JOINED = auto() # 用户加入通知 + USER_LEFT = auto() # 用户离开通知 + ERROR = auto() # 错误通知 + @dataclass class ChatStateInfo: """聊天状态信息""" + state: ChatState last_message_time: Optional[float] = None last_message_content: Optional[str] = None @@ -38,53 +44,55 @@ class ChatStateInfo: cold_duration: float = 0.0 # 冷场持续时间(秒) active_duration: float = 0.0 # 活跃持续时间(秒) + @dataclass class Notification: """通知基类""" + type: NotificationType timestamp: float - sender: str # 发送者标识 - target: str # 接收者标识 + sender: str # 发送者标识 + target: str # 接收者标识 data: Dict[str, Any] - + def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" - return { - "type": self.type.name, - "timestamp": self.timestamp, - "data": self.data - } + return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data} + @dataclass class StateNotification(Notification): """持续状态通知""" + is_active: bool = True - + def to_dict(self) -> Dict[str, Any]: base_dict = super().to_dict() base_dict["is_active"] = self.is_active return base_dict + class NotificationHandler(ABC): """通知处理器接口""" - + @abstractmethod async def handle_notification(self, notification: Notification): """处理通知""" pass + class NotificationManager: """通知管理器""" - + def __init__(self): # 按接收者和通知类型存储处理器 self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {} self._active_states: Set[NotificationType] = set() self._notification_history: List[Notification] = [] - + def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): """注册通知处理器 - + Args: target: 接收者标识(例如:"pfc") notification_type: 要处理的通知类型 @@ -95,10 +103,10 @@ class NotificationManager: if notification_type not in self._handlers[target]: self._handlers[target][notification_type] = [] self._handlers[target][notification_type].append(handler) - + def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): """注销通知处理器 - + Args: target: 接收者标识 notification_type: 通知类型 @@ -114,56 +122,56 @@ class NotificationManager: # 如果该目标没有任何处理器,删除该目标 if not self._handlers[target]: del self._handlers[target] - + async def send_notification(self, notification: Notification): """发送通知""" self._notification_history.append(notification) - + # 如果是状态通知,更新活跃状态 if isinstance(notification, StateNotification): if notification.is_active: self._active_states.add(notification.type) else: self._active_states.discard(notification.type) - + # 调用目标接收者的处理器 target = notification.target if target in self._handlers: handlers = self._handlers[target].get(notification.type, []) for handler in handlers: await handler.handle_notification(notification) - + def get_active_states(self) -> Set[NotificationType]: """获取当前活跃的状态""" return self._active_states.copy() - + def is_state_active(self, state_type: NotificationType) -> bool: """检查特定状态是否活跃""" return state_type in self._active_states - - def get_notification_history(self, - sender: Optional[str] = None, - target: Optional[str] = None, - limit: Optional[int] = None) -> List[Notification]: + + def get_notification_history( + self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None + ) -> List[Notification]: """获取通知历史 - + Args: sender: 过滤特定发送者的通知 target: 过滤特定接收者的通知 limit: 限制返回数量 """ history = self._notification_history - + if sender: history = [n for n in history if n.sender == sender] if target: history = [n for n in history if n.target == target] - + if limit is not None: history = history[-limit:] - + return history + # 一些常用的通知创建函数 def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification: """创建新消息通知""" @@ -176,10 +184,11 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str, "message_id": message.get("message_id"), "content": message.get("content"), "sender": message.get("sender"), - "time": message.get("time") - } + "time": message.get("time"), + }, ) + def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification: """创建冷场状态通知""" return StateNotification( @@ -188,9 +197,10 @@ def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> St sender=sender, target=target, data={"is_cold": is_cold}, - is_active=is_cold + is_active=is_cold, ) + def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification: """创建活跃状态通知""" return StateNotification( @@ -199,69 +209,70 @@ def create_active_chat_notification(sender: str, target: str, is_active: bool) - sender=sender, target=target, data={"is_active": is_active}, - is_active=is_active + is_active=is_active, ) + class ChatStateManager: """聊天状态管理器""" - + def __init__(self): self.current_state = ChatState.NORMAL self.state_info = ChatStateInfo(state=ChatState.NORMAL) self.state_history: list[ChatStateInfo] = [] - + def update_state(self, new_state: ChatState, **kwargs): """更新聊天状态 - + Args: new_state: 新的状态 **kwargs: 其他状态信息 """ self.current_state = new_state self.state_info.state = new_state - + # 更新其他状态信息 for key, value in kwargs.items(): if hasattr(self.state_info, key): setattr(self.state_info, key, value) - + # 记录状态历史 self.state_history.append(self.state_info) - + def get_current_state_info(self) -> ChatStateInfo: """获取当前状态信息""" return self.state_info - + def get_state_history(self) -> list[ChatStateInfo]: """获取状态历史""" return self.state_history - + def is_cold_chat(self, threshold: float = 60.0) -> bool: """判断是否处于冷场状态 - + Args: threshold: 冷场阈值(秒) - + Returns: bool: 是否冷场 """ if not self.state_info.last_message_time: return True - + current_time = datetime.now().timestamp() return (current_time - self.state_info.last_message_time) > threshold - + def is_active_chat(self, threshold: float = 5.0) -> bool: """判断是否处于活跃状态 - + Args: threshold: 活跃阈值(秒) - + Returns: bool: 是否活跃 """ if not self.state_info.last_message_time: return False - + current_time = datetime.now().timestamp() - return (current_time - self.state_info.last_message_time) <= threshold \ No newline at end of file + return (current_time - self.state_info.last_message_time) <= threshold diff --git a/src/plugins/PFC/conversation.py b/src/plugins/PFC/conversation.py index dda380491..40a729671 100644 --- a/src/plugins/PFC/conversation.py +++ b/src/plugins/PFC/conversation.py @@ -20,23 +20,23 @@ logger = get_module_logger("pfc_conversation") class Conversation: """对话类,负责管理单个对话的状态和行为""" - + def __init__(self, stream_id: str): """初始化对话实例 - + Args: stream_id: 聊天流ID """ self.stream_id = stream_id self.state = ConversationState.INIT self.should_continue = False - + # 回复相关 self.generated_reply = "" - + async def _initialize(self): """初始化实例,注册所有组件""" - + try: self.action_planner = ActionPlanner(self.stream_id) self.goal_analyzer = GoalAnalyzer(self.stream_id) @@ -44,37 +44,35 @@ class Conversation: self.knowledge_fetcher = KnowledgeFetcher() self.waiter = Waiter(self.stream_id) self.direct_sender = DirectMessageSender() - + # 获取聊天流信息 self.chat_stream = chat_manager.get_stream(self.stream_id) - + self.stop_action_planner = False except Exception as e: logger.error(f"初始化对话实例:注册运行组件失败: {e}") logger.error(traceback.format_exc()) raise - - + try: - #决策所需要的信息,包括自身自信和观察信息两部分 - #注册观察器和观测信息 + # 决策所需要的信息,包括自身自信和观察信息两部分 + # 注册观察器和观测信息 self.chat_observer = ChatObserver.get_instance(self.stream_id) self.chat_observer.start() self.observation_info = ObservationInfo() self.observation_info.bind_to_chat_observer(self.stream_id) - - #对话信息 + + # 对话信息 self.conversation_info = ConversationInfo() except Exception as e: logger.error(f"初始化对话实例:注册信息组件失败: {e}") logger.error(traceback.format_exc()) raise - + # 组件准备完成,启动该论对话 self.should_continue = True asyncio.create_task(self.start()) - - + async def start(self): """开始对话流程""" try: @@ -83,17 +81,13 @@ class Conversation: except Exception as e: logger.error(f"启动对话系统失败: {e}") raise - - + async def _plan_and_action_loop(self): """思考步,PFC核心循环模块""" # 获取最近的消息历史 while self.should_continue: # 使用决策信息来辅助行动规划 - action, reason = await self.action_planner.plan( - self.observation_info, - self.conversation_info - ) + action, reason = await self.action_planner.plan(self.observation_info, self.conversation_info) if self._check_new_messages_after_planning(): continue @@ -107,93 +101,90 @@ class Conversation: # 如果需要,可以在这里添加逻辑来根据新消息重新决定行动 return True return False - - + def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message: """将消息字典转换为Message对象""" try: chat_info = msg_dict.get("chat_info", {}) chat_stream = ChatStream.from_dict(chat_info) user_info = UserInfo.from_dict(msg_dict.get("user_info", {})) - + return Message( message_id=msg_dict["message_id"], chat_stream=chat_stream, time=msg_dict["time"], user_info=user_info, processed_plain_text=msg_dict.get("processed_plain_text", ""), - detailed_plain_text=msg_dict.get("detailed_plain_text", "") + detailed_plain_text=msg_dict.get("detailed_plain_text", ""), ) except Exception as e: logger.warning(f"转换消息时出错: {e}") raise - async def _handle_action(self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo): + async def _handle_action( + self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo + ): """处理规划的行动""" logger.info(f"执行行动: {action}, 原因: {reason}") - + # 记录action历史,先设置为stop,完成后再设置为done - conversation_info.done_action.append({ - "action": action, - "reason": reason, - "status": "start", - "time": datetime.datetime.now().strftime("%H:%M:%S") - }) - - + conversation_info.done_action.append( + { + "action": action, + "reason": reason, + "status": "start", + "time": datetime.datetime.now().strftime("%H:%M:%S"), + } + ) + if action == "direct_reply": self.state = ConversationState.GENERATING - self.generated_reply = await self.reply_generator.generate( - observation_info, - conversation_info - ) - + self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info) + # # 检查回复是否合适 # is_suitable, reason, need_replan = await self.reply_generator.check_reply( # self.generated_reply, # self.current_goal # ) - + if self._check_new_messages_after_planning(): return None - + await self._send_reply() - - conversation_info.done_action.append({ - "action": action, - "reason": reason, - "status": "done", - "time": datetime.datetime.now().strftime("%H:%M:%S") - }) - + + conversation_info.done_action.append( + { + "action": action, + "reason": reason, + "status": "done", + "time": datetime.datetime.now().strftime("%H:%M:%S"), + } + ) + elif action == "fetch_knowledge": self.state = ConversationState.FETCHING knowledge = "TODO:知识" topic = "TODO:关键词" - + logger.info(f"假装获取到知识{knowledge},关键词是: {topic}") - + if knowledge: if topic not in self.conversation_info.knowledge_list: - self.conversation_info.knowledge_list.append({ - "topic": topic, - "knowledge": knowledge - }) + self.conversation_info.knowledge_list.append({"topic": topic, "knowledge": knowledge}) else: self.conversation_info.knowledge_list[topic] += knowledge - + elif action == "rethink_goal": self.state = ConversationState.RETHINKING await self.goal_analyzer.analyze_goal(conversation_info, observation_info) - elif action == "listening": self.state = ConversationState.LISTENING logger.info("倾听对方发言...") if await self.waiter.wait(): # 如果返回True表示超时 await self._send_timeout_message() await self._stop_conversation() - + else: # wait self.state = ConversationState.WAITING logger.info("等待更多信息...") @@ -207,12 +198,10 @@ class Conversation: messages = self.chat_observer.get_cached_messages(limit=1) if not messages: return - + latest_message = self._convert_to_message(messages[0]) await self.direct_sender.send_message( - chat_stream=self.chat_stream, - content="TODO:超时消息", - reply_to_message=latest_message + chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message ) except Exception as e: logger.error(f"发送超时消息失败: {str(e)}") @@ -222,24 +211,22 @@ class Conversation: if not self.generated_reply: logger.warning("没有生成回复") return - + messages = self.chat_observer.get_cached_messages(limit=1) if not messages: logger.warning("没有最近的消息可以回复") return - + latest_message = self._convert_to_message(messages[0]) try: await self.direct_sender.send_message( - chat_stream=self.chat_stream, - content=self.generated_reply, - reply_to_message=latest_message + chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message ) self.chat_observer.trigger_update() # 触发立即更新 if not await self.chat_observer.wait_for_update(): logger.warning("等待消息更新超时") - + self.state = ConversationState.ANALYZING except Exception as e: logger.error(f"发送消息失败: {str(e)}") - self.state = ConversationState.ANALYZING \ No newline at end of file + self.state = ConversationState.ANALYZING diff --git a/src/plugins/PFC/conversation_info.py b/src/plugins/PFC/conversation_info.py index 5b8262a16..cae9f0b34 100644 --- a/src/plugins/PFC/conversation_info.py +++ b/src/plugins/PFC/conversation_info.py @@ -1,8 +1,6 @@ - - class ConversationInfo: def __init__(self): self.done_action = [] self.goal_list = [] self.knowledge_list = [] - self.memory_list = [] \ No newline at end of file + self.memory_list = [] diff --git a/src/plugins/PFC/message_sender.py b/src/plugins/PFC/message_sender.py index 6df1e7ded..76b07945f 100644 --- a/src/plugins/PFC/message_sender.py +++ b/src/plugins/PFC/message_sender.py @@ -7,12 +7,13 @@ from src.plugins.chat.message import MessageSending logger = get_module_logger("message_sender") + class DirectMessageSender: """直接消息发送器""" - + def __init__(self): pass - + async def send_message( self, chat_stream: ChatStream, @@ -20,7 +21,7 @@ class DirectMessageSender: reply_to_message: Optional[Message] = None, ) -> None: """发送消息到聊天流 - + Args: chat_stream: 聊天流 content: 消息内容 @@ -29,21 +30,18 @@ class DirectMessageSender: try: # 创建消息内容 segments = [Seg(type="text", data={"text": content})] - + # 检查是否需要引用回复 if reply_to_message: reply_id = reply_to_message.message_id - message_sending = MessageSending( - segments=segments, - reply_to_id=reply_id - ) + message_sending = MessageSending(segments=segments, reply_to_id=reply_id) else: message_sending = MessageSending(segments=segments) - + # 发送消息 await chat_stream.send_message(message_sending) logger.info(f"消息已发送: {content}") - + except Exception as e: logger.error(f"发送消息失败: {str(e)}") - raise \ No newline at end of file + raise diff --git a/src/plugins/PFC/message_storage.py b/src/plugins/PFC/message_storage.py index 3c7cab8b3..88f409641 100644 --- a/src/plugins/PFC/message_storage.py +++ b/src/plugins/PFC/message_storage.py @@ -2,133 +2,126 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional from src.common.database import db + class MessageStorage(ABC): """消息存储接口""" - + @abstractmethod async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: """获取指定消息ID之后的所有消息 - + Args: chat_id: 聊天ID message_id: 消息ID,如果为None则获取所有消息 - + Returns: List[Dict[str, Any]]: 消息列表 """ pass - + @abstractmethod async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: """获取指定时间点之前的消息 - + Args: chat_id: 聊天ID time_point: 时间戳 limit: 最大消息数量 - + Returns: List[Dict[str, Any]]: 消息列表 """ pass - + @abstractmethod async def has_new_messages(self, chat_id: str, after_time: float) -> bool: """检查是否有新消息 - + Args: chat_id: 聊天ID after_time: 时间戳 - + Returns: bool: 是否有新消息 """ pass + class MongoDBMessageStorage(MessageStorage): """MongoDB消息存储实现""" - + def __init__(self): self.db = db - + async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: query = {"chat_id": chat_id} - + if message_id: # 获取ID大于message_id的消息 last_message = self.db.messages.find_one({"message_id": message_id}) if last_message: query["time"] = {"$gt": last_message["time"]} - - return list( - self.db.messages.find(query).sort("time", 1) - ) - + + return list(self.db.messages.find(query).sort("time", 1)) + async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: - query = { - "chat_id": chat_id, - "time": {"$lt": time_point} - } - - messages = list( - self.db.messages.find(query).sort("time", -1).limit(limit) - ) - + query = {"chat_id": chat_id, "time": {"$lt": time_point}} + + messages = list(self.db.messages.find(query).sort("time", -1).limit(limit)) + # 将消息按时间正序排列 messages.reverse() return messages - + async def has_new_messages(self, chat_id: str, after_time: float) -> bool: - query = { - "chat_id": chat_id, - "time": {"$gt": after_time} - } - + query = {"chat_id": chat_id, "time": {"$gt": after_time}} + return self.db.messages.find_one(query) is not None + # # 创建一个内存消息存储实现,用于测试 # class InMemoryMessageStorage(MessageStorage): # """内存消息存储实现,主要用于测试""" - + # def __init__(self): # self.messages: Dict[str, List[Dict[str, Any]]] = {} - + # async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: # if chat_id not in self.messages: # return [] - + # messages = self.messages[chat_id] # if not message_id: # return messages - + # # 找到message_id的索引 # try: # index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id) # return messages[index + 1:] # except StopIteration: # return [] - + # async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: # if chat_id not in self.messages: # return [] - + # messages = [ # m for m in self.messages[chat_id] # if m["time"] < time_point # ] - + # return messages[-limit:] - + # async def has_new_messages(self, chat_id: str, after_time: float) -> bool: # if chat_id not in self.messages: # return False - + # return any(m["time"] > after_time for m in self.messages[chat_id]) - + # # 测试辅助方法 # def add_message(self, chat_id: str, message: Dict[str, Any]): # """添加测试消息""" # if chat_id not in self.messages: # self.messages[chat_id] = [] # self.messages[chat_id].append(message) -# self.messages[chat_id].sort(key=lambda m: m["time"]) \ No newline at end of file +# self.messages[chat_id].sort(key=lambda m: m["time"]) diff --git a/src/plugins/PFC/notification_handler.py b/src/plugins/PFC/notification_handler.py index 38c0d0dee..1131d18bf 100644 --- a/src/plugins/PFC/notification_handler.py +++ b/src/plugins/PFC/notification_handler.py @@ -7,25 +7,26 @@ if TYPE_CHECKING: logger = get_module_logger("notification_handler") + class PFCNotificationHandler(NotificationHandler): """PFC通知处理器""" - - def __init__(self, conversation: 'Conversation'): + + def __init__(self, conversation: "Conversation"): """初始化PFC通知处理器 - + Args: conversation: 对话实例 """ self.conversation = conversation - + async def handle_notification(self, notification: Notification): """处理通知 - + Args: notification: 通知对象 """ logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}") - + # 根据通知类型执行不同的处理 if notification.type == NotificationType.NEW_MESSAGE: # 新消息通知 @@ -38,34 +39,33 @@ class PFCNotificationHandler(NotificationHandler): await self._handle_command(notification) else: logger.warning(f"未知的通知类型: {notification.type.name}") - + async def _handle_new_message(self, notification: Notification): """处理新消息通知 - + Args: notification: 通知对象 """ - + # 更新决策信息 observation_info = self.conversation.observation_info observation_info.last_message_time = notification.data.get("time", 0) observation_info.add_unprocessed_message(notification.data) - + # 手动触发观察器更新 self.conversation.chat_observer.trigger_update() - + async def _handle_cold_chat(self, notification: Notification): """处理冷聊天通知 - + Args: notification: 通知对象 """ # 获取冷聊天信息 cold_duration = notification.data.get("duration", 0) - + # 更新决策信息 observation_info = self.conversation.observation_info observation_info.conversation_cold_duration = cold_duration - + logger.info(f"对话已冷: {cold_duration}秒") - \ No newline at end of file diff --git a/src/plugins/PFC/observation_info.py b/src/plugins/PFC/observation_info.py index 2967f10e3..d0eee2236 100644 --- a/src/plugins/PFC/observation_info.py +++ b/src/plugins/PFC/observation_info.py @@ -1,5 +1,5 @@ -#Programmable Friendly Conversationalist -#Prefrontal cortex +# Programmable Friendly Conversationalist +# Prefrontal cortex from typing import List, Optional, Dict, Any, Set from ..message.message_base import UserInfo import time @@ -10,26 +10,27 @@ from .chat_states import NotificationHandler logger = get_module_logger("observation_info") + class ObservationInfoHandler(NotificationHandler): """ObservationInfo的通知处理器""" - - def __init__(self, observation_info: 'ObservationInfo'): + + def __init__(self, observation_info: "ObservationInfo"): """初始化处理器 - + Args: observation_info: 要更新的ObservationInfo实例 """ self.observation_info = observation_info - + async def handle_notification(self, notification: Dict[str, Any]): """处理通知 - + Args: notification: 通知数据 """ notification_type = notification.get("type") data = notification.get("data", {}) - + if notification_type == "NEW_MESSAGE": # 处理新消息通知 logger.debug(f"收到新消息通知data: {data}") @@ -37,62 +38,62 @@ class ObservationInfoHandler(NotificationHandler): self.observation_info.update_from_message(message) # self.observation_info.has_unread_messages = True # self.observation_info.new_unread_message.append(message.get("processed_plain_text", "")) - + elif notification_type == "COLD_CHAT": # 处理冷场通知 is_cold = data.get("is_cold", False) self.observation_info.update_cold_chat_status(is_cold, time.time()) - + elif notification_type == "ACTIVE_CHAT": # 处理活跃通知 is_active = data.get("is_active", False) self.observation_info.is_cold = not is_active - + elif notification_type == "BOT_SPEAKING": # 处理机器人说话通知 self.observation_info.is_typing = False self.observation_info.last_bot_speak_time = time.time() - + elif notification_type == "USER_SPEAKING": # 处理用户说话通知 self.observation_info.is_typing = False self.observation_info.last_user_speak_time = time.time() - + elif notification_type == "MESSAGE_DELETED": # 处理消息删除通知 message_id = data.get("message_id") self.observation_info.unprocessed_messages = [ - msg for msg in self.observation_info.unprocessed_messages - if msg.get("message_id") != message_id + msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id ] - + elif notification_type == "USER_JOINED": # 处理用户加入通知 user_id = data.get("user_id") if user_id: self.observation_info.active_users.add(user_id) - + elif notification_type == "USER_LEFT": # 处理用户离开通知 user_id = data.get("user_id") if user_id: self.observation_info.active_users.discard(user_id) - + elif notification_type == "ERROR": # 处理错误通知 error_msg = data.get("error", "") logger.error(f"收到错误通知: {error_msg}") + @dataclass class ObservationInfo: """决策信息类,用于收集和管理来自chat_observer的通知信息""" - - #data_list + + # data_list chat_history: List[str] = field(default_factory=list) unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list) active_users: Set[str] = field(default_factory=set) - - #data + + # data last_bot_speak_time: Optional[float] = None last_user_speak_time: Optional[float] = None last_message_time: Optional[float] = None @@ -101,78 +102,70 @@ class ObservationInfo: bot_id: Optional[str] = None new_messages_count: int = 0 cold_chat_duration: float = 0.0 - - #state + + # state is_typing: bool = False has_unread_messages: bool = False is_cold_chat: bool = False changed: bool = False - + # #spec # meta_plan_trigger: bool = False - + def __post_init__(self): """初始化后创建handler""" self.chat_observer = None self.handler = ObservationInfoHandler(self) - + def bind_to_chat_observer(self, stream_id: str): """绑定到指定的chat_observer - + Args: stream_id: 聊天流ID """ self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer.notification_manager.register_handler( - target="observation_info", - notification_type="NEW_MESSAGE", - handler=self.handler + target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler ) self.chat_observer.notification_manager.register_handler( - target="observation_info", - notification_type="COLD_CHAT", - handler=self.handler + target="observation_info", notification_type="COLD_CHAT", handler=self.handler ) - + def unbind_from_chat_observer(self): """解除与chat_observer的绑定""" if self.chat_observer: self.chat_observer.notification_manager.unregister_handler( - target="observation_info", - notification_type="NEW_MESSAGE", - handler=self.handler + target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler ) self.chat_observer.notification_manager.unregister_handler( - target="observation_info", - notification_type="COLD_CHAT", - handler=self.handler + target="observation_info", notification_type="COLD_CHAT", handler=self.handler ) self.chat_observer = None - + def update_from_message(self, message: Dict[str, Any]): """从消息更新信息 - + Args: message: 消息数据 """ logger.debug(f"更新信息from_message: {message}") self.last_message_time = message["time"] self.last_message_content = message.get("processed_plain_text", "") - + user_info = UserInfo.from_dict(message.get("user_info", {})) self.last_message_sender = user_info.user_id - + if user_info.user_id == self.bot_id: self.last_bot_speak_time = message["time"] else: self.last_user_speak_time = message["time"] self.active_users.add(user_info.user_id) - + self.new_messages_count += 1 self.unprocessed_messages.append(message) - + self.update_changed() - + def update_changed(self): """更新changed状态""" self.changed = True @@ -180,7 +173,7 @@ class ObservationInfo: def update_cold_chat_status(self, is_cold: bool, current_time: float): """更新冷场状态 - + Args: is_cold: 是否冷场 current_time: 当前时间 @@ -188,37 +181,37 @@ class ObservationInfo: self.is_cold_chat = is_cold if is_cold and self.last_message_time: self.cold_chat_duration = current_time - self.last_message_time - + def get_active_duration(self) -> float: """获取当前活跃时长 - + Returns: float: 最后一条消息到现在的时长(秒) """ if not self.last_message_time: return 0.0 return time.time() - self.last_message_time - + def get_user_response_time(self) -> Optional[float]: """获取用户响应时间 - + Returns: Optional[float]: 用户最后发言到现在的时长(秒),如果没有用户发言则返回None """ if not self.last_user_speak_time: return None return time.time() - self.last_user_speak_time - + def get_bot_response_time(self) -> Optional[float]: """获取机器人响应时间 - + Returns: Optional[float]: 机器人最后发言到现在的时长(秒),如果没有机器人发言则返回None """ if not self.last_bot_speak_time: return None return time.time() - self.last_bot_speak_time - + def clear_unprocessed_messages(self): """清空未处理消息列表""" # 将未处理消息添加到历史记录中 @@ -229,10 +222,10 @@ class ObservationInfo: self.has_unread_messages = False self.unprocessed_messages.clear() self.new_messages_count = 0 - + def add_unprocessed_message(self, message: Dict[str, Any]): """添加未处理的消息 - + Args: message: 消息数据 """ @@ -241,6 +234,6 @@ class ObservationInfo: if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages): self.unprocessed_messages.append(message) self.new_messages_count += 1 - + # 同时更新其他消息相关信息 - self.update_from_message(message) \ No newline at end of file + self.update_from_message(message) diff --git a/src/plugins/PFC/pfc.py b/src/plugins/PFC/pfc.py index 62b28acb4..3436dce8f 100644 --- a/src/plugins/PFC/pfc.py +++ b/src/plugins/PFC/pfc.py @@ -49,43 +49,40 @@ class GoalAnalyzer: Args: conversation_info: 对话信息 observation_info: 观察信息 - + Returns: Tuple[str, str, str]: (目标, 方法, 原因) """ - #构建对话目标 + # 构建对话目标 goal_list = conversation_info.goal_list goal_text = "" for goal, reason in goal_list: goal_text += f"目标:{goal};" goal_text += f"原因:{reason}\n" - - + # 获取聊天历史记录 chat_history_list = observation_info.chat_history chat_history_text = "" for msg in chat_history_list: chat_history_text += f"{msg}\n" - + if observation_info.new_messages_count > 0: new_messages_list = observation_info.unprocessed_messages - + chat_history_text += f"有{observation_info.new_messages_count}条新消息:\n" for msg in new_messages_list: chat_history_text += f"{msg}\n" - + observation_info.clear_unprocessed_messages() - - + personality_text = f"你的名字是{self.name},{self.personality_info}" - + # 构建action历史文本 action_history_list = conversation_info.done_action action_history_text = "你之前做的事情是:" for action in action_history_list: action_history_text += f"{action}\n" - - + prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。 这些目标应该反映出对话的不同方面和意图。 @@ -119,20 +116,15 @@ class GoalAnalyzer: logger.debug(f"发送到LLM的提示词: {prompt}") content, _ = await self.llm.generate_response_async(prompt) logger.debug(f"LLM原始返回内容: {content}") - + # 使用简化函数提取JSON内容 success, result = get_items_from_json( - content, - "goal", "reasoning", - required_types={"goal": str, "reasoning": str} + content, "goal", "reasoning", required_types={"goal": str, "reasoning": str} ) - #TODO - - + # TODO + conversation_info.goal_list.append(result) - - async def _update_goals(self, new_goal: str, method: str, reasoning: str): """更新目标列表 @@ -229,24 +221,26 @@ class GoalAnalyzer: try: content, _ = await self.llm.generate_response_async(prompt) logger.debug(f"LLM原始返回内容: {content}") - + # 尝试解析JSON success, result = get_items_from_json( content, - "goal_achieved", "stop_conversation", "reason", - required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str} + "goal_achieved", + "stop_conversation", + "reason", + required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}, ) if not success: logger.error("无法解析对话分析结果JSON") return False, False, "解析结果失败" - + goal_achieved = result["goal_achieved"] stop_conversation = result["stop_conversation"] reason = result["reason"] - + return goal_achieved, stop_conversation, reason - + except Exception as e: logger.error(f"分析对话状态时出错: {str(e)}") return False, False, f"分析出错: {str(e)}" @@ -269,23 +263,22 @@ class Waiter: # 使用当前时间作为等待开始时间 wait_start_time = time.time() self.chat_observer.waiting_start_time = wait_start_time # 设置等待开始时间 - + while True: # 检查是否有新消息 if self.chat_observer.new_message_after(wait_start_time): logger.info("等待结束,收到新消息") return False - + # 检查是否超时 if time.time() - wait_start_time > 300: logger.info("等待超过300秒,结束对话") return True - + await asyncio.sleep(1) logger.info("等待中...") - class DirectMessageSender: """直接发送消息到平台的发送器""" diff --git a/src/plugins/PFC/pfc_manager.py b/src/plugins/PFC/pfc_manager.py index 9a36bef19..5be15a100 100644 --- a/src/plugins/PFC/pfc_manager.py +++ b/src/plugins/PFC/pfc_manager.py @@ -5,33 +5,34 @@ import traceback logger = get_module_logger("pfc_manager") + class PFCManager: """PFC对话管理器,负责管理所有对话实例""" - + # 单例模式 _instance = None - + # 会话实例管理 _instances: Dict[str, Conversation] = {} _initializing: Dict[str, bool] = {} - + @classmethod - def get_instance(cls) -> 'PFCManager': + def get_instance(cls) -> "PFCManager": """获取管理器单例 - + Returns: PFCManager: 管理器实例 """ if cls._instance is None: cls._instance = PFCManager() return cls._instance - + async def get_or_create_conversation(self, stream_id: str) -> Optional[Conversation]: """获取或创建对话实例 - + Args: stream_id: 聊天流ID - + Returns: Optional[Conversation]: 对话实例,创建失败则返回None """ @@ -39,11 +40,11 @@ class PFCManager: if stream_id in self._initializing and self._initializing[stream_id]: logger.debug(f"会话实例正在初始化中: {stream_id}") return None - + if stream_id in self._instances: logger.debug(f"使用现有会话实例: {stream_id}") return self._instances[stream_id] - + try: # 创建新实例 logger.info(f"创建新的对话实例: {stream_id}") @@ -51,47 +52,45 @@ class PFCManager: # 创建实例 conversation_instance = Conversation(stream_id) self._instances[stream_id] = conversation_instance - + # 启动实例初始化 await self._initialize_conversation(conversation_instance) except Exception as e: logger.error(f"创建会话实例失败: {stream_id}, 错误: {e}") return None - + return conversation_instance - async def _initialize_conversation(self, conversation: Conversation): """初始化会话实例 - + Args: conversation: 要初始化的会话实例 """ stream_id = conversation.stream_id - + try: logger.info(f"开始初始化会话实例: {stream_id}") # 启动初始化流程 await conversation._initialize() - + # 标记初始化完成 self._initializing[stream_id] = False - + logger.info(f"会话实例 {stream_id} 初始化完成") - + except Exception as e: logger.error(f"管理器初始化会话实例失败: {stream_id}, 错误: {e}") logger.error(traceback.format_exc()) # 清理失败的初始化 - async def get_conversation(self, stream_id: str) -> Optional[Conversation]: """获取已存在的会话实例 - + Args: stream_id: 聊天流ID - + Returns: Optional[Conversation]: 会话实例,不存在则返回None """ - return self._instances.get(stream_id) \ No newline at end of file + return self._instances.get(stream_id) diff --git a/src/plugins/PFC/pfc_types.py b/src/plugins/PFC/pfc_types.py index d7ad8e91f..7391c448d 100644 --- a/src/plugins/PFC/pfc_types.py +++ b/src/plugins/PFC/pfc_types.py @@ -4,6 +4,7 @@ from typing import Literal class ConversationState(Enum): """对话状态""" + INIT = "初始化" RETHINKING = "重新思考" ANALYZING = "分析历史" @@ -18,4 +19,4 @@ class ConversationState(Enum): JUDGING = "判断" -ActionType = Literal["direct_reply", "fetch_knowledge", "wait"] \ No newline at end of file +ActionType = Literal["direct_reply", "fetch_knowledge", "wait"] diff --git a/src/plugins/PFC/reply_generator.py b/src/plugins/PFC/reply_generator.py index beec9dd3e..00ac7c413 100644 --- a/src/plugins/PFC/reply_generator.py +++ b/src/plugins/PFC/reply_generator.py @@ -13,33 +13,26 @@ logger = get_module_logger("reply_generator") class ReplyGenerator: """回复生成器""" - + def __init__(self, stream_id: str): self.llm = LLM_request( - model=global_config.llm_normal, - temperature=0.7, - max_tokens=300, - request_type="reply_generation" + model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation" ) - self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) + self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.name = global_config.BOT_NICKNAME self.chat_observer = ChatObserver.get_instance(stream_id) self.reply_checker = ReplyChecker(stream_id) - - async def generate( - self, - observation_info: ObservationInfo, - conversation_info: ConversationInfo - ) -> str: + + async def generate(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> str: """生成回复 - + Args: goal: 对话目标 chat_history: 聊天历史 knowledge_cache: 知识缓存 previous_reply: 上一次生成的回复(如果有) retry_count: 当前重试次数 - + Returns: str: 生成的回复 """ @@ -51,22 +44,21 @@ class ReplyGenerator: for goal, reason in goal_list: goal_text += f"目标:{goal};" goal_text += f"原因:{reason}\n" - + # 获取聊天历史记录 chat_history_list = observation_info.chat_history chat_history_text = "" for msg in chat_history_list: chat_history_text += f"{msg}\n" - - + # 整理知识缓存 knowledge_text = "" knowledge_list = conversation_info.knowledge_list for knowledge in knowledge_list: knowledge_text += f"知识:{knowledge}\n" - + personality_text = f"你的名字是{self.name},{self.personality_info}" - + prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请根据以下信息生成回复: 当前对话目标:{goal_text} @@ -92,7 +84,7 @@ class ReplyGenerator: logger.info(f"生成的回复: {content}") # is_new = self.chat_observer.check() # logger.debug(f"再看一眼聊天记录,{'有' if is_new else '没有'}新消息") - + # 如果有新消息,重新生成回复 # if is_new: # logger.info("检测到新消息,重新生成回复") @@ -100,27 +92,22 @@ class ReplyGenerator: # goal, chat_history, knowledge_cache, # None, retry_count # ) - + return content - + except Exception as e: logger.error(f"生成回复时出错: {e}") return "抱歉,我现在有点混乱,让我重新思考一下..." - async def check_reply( - self, - reply: str, - goal: str, - retry_count: int = 0 - ) -> Tuple[bool, str, bool]: + async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]: """检查回复是否合适 - + Args: reply: 生成的回复 goal: 对话目标 retry_count: 当前重试次数 - + Returns: Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划) """ - return await self.reply_checker.check(reply, goal, retry_count) \ No newline at end of file + return await self.reply_checker.check(reply, goal, retry_count) diff --git a/src/plugins/PFC/waiter.py b/src/plugins/PFC/waiter.py index 0e1bf59f3..66f98e9c3 100644 --- a/src/plugins/PFC/waiter.py +++ b/src/plugins/PFC/waiter.py @@ -3,43 +3,44 @@ from .chat_observer import ChatObserver logger = get_module_logger("waiter") + class Waiter: """等待器,用于等待对话流中的事件""" - + def __init__(self, stream_id: str): self.stream_id = stream_id self.chat_observer = ChatObserver.get_instance(stream_id) - + async def wait(self, timeout: float = 20.0) -> bool: """等待用户回复或超时 - + Args: timeout: 超时时间(秒) - + Returns: bool: 如果因为超时返回则为True,否则为False """ try: message_before = self.chat_observer.get_last_message() - + # 等待新消息 logger.debug(f"等待新消息,超时时间: {timeout}秒") - + is_timeout = await self.chat_observer.wait_for_update(timeout=timeout) if is_timeout: logger.debug("等待超时,没有收到新消息") return True - + # 检查是否是新消息 message_after = self.chat_observer.get_last_message() if message_before and message_after and message_before.get("message_id") == message_after.get("message_id"): # 如果消息ID相同,说明没有新消息 logger.debug("没有收到新消息") return True - + logger.debug("收到新消息") return False - + except Exception as e: logger.error(f"等待时出错: {str(e)}") - return True \ No newline at end of file + return True diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 40a00a3ab..884beead7 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -30,7 +30,7 @@ class ChatBot: self.think_flow_chat = ThinkFlowChat() self.reasoning_chat = ReasoningChat() self.only_process_chat = MessageProcessor() - + # 创建初始化PFC管理器的任务,会在_ensure_started时执行 self.pfc_manager = PFCManager.get_instance() @@ -38,7 +38,7 @@ class ChatBot: """确保所有任务已启动""" if not self._started: logger.info("确保ChatBot所有任务已启动") - + self._started = True async def _create_PFC_chat(self, message: MessageRecv): @@ -46,7 +46,6 @@ class ChatBot: chat_id = str(message.chat_stream.stream_id) if global_config.enable_pfc_chatting: - await self.pfc_manager.get_or_create_conversation(chat_id) except Exception as e: @@ -80,7 +79,7 @@ class ChatBot: try: # 确保所有任务已启动 await self._ensure_started() - + message = MessageRecv(message_data) groupinfo = message.message_info.group_info userinfo = message.message_info.user_info diff --git a/src/plugins/config/config.py b/src/plugins/config/config.py index eccb3bc0b..23e277498 100644 --- a/src/plugins/config/config.py +++ b/src/plugins/config/config.py @@ -24,7 +24,7 @@ config_config = LogConfig( # 配置主程序日志格式 logger = get_module_logger("config", config=config_config) -#考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 +# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 is_test = True mai_version_main = "0.6.2" mai_version_fix = "snapshot-1" From a889d9d2224456fdc09b8429b5857e2dff4929b0 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 9 Apr 2025 19:27:23 +0800 Subject: [PATCH 06/24] =?UTF-8?q?feat=EF=BC=9A=E6=9B=B4=E5=A5=BD=E7=9A=84?= =?UTF-8?q?=E5=9B=9E=E5=A4=8D=E4=BF=A1=E6=81=AF=E6=94=B6=E9=9B=86=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/heart_flow/heartflow.py | 2 +- src/heart_flow/sub_heartflow.py | 4 +- src/plugins/chat/bot.py | 6 +- src/plugins/chat/message.py | 2 +- .../reasoning_chat/reasoning_generator.py | 61 +++-- .../think_flow_chat/think_flow_chat.py | 43 +++- .../think_flow_chat/think_flow_generator.py | 83 ++++--- .../respon_info_catcher/info_catcher.py | 228 ++++++++++++++++++ template/bot_config_template.toml | 2 +- 9 files changed, 349 insertions(+), 82 deletions(-) create mode 100644 src/plugins/respon_info_catcher/info_catcher.py diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index 9cf8d4674..5c67fe125 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -200,7 +200,7 @@ class Heartflow: logger.error(f"创建 subheartflow 失败: {e}") return None - def get_subheartflow(self, observe_chat_id): + def get_subheartflow(self, observe_chat_id) -> SubHeartflow: """获取指定ID的SubHeartflow实例""" return self._subheartflows.get(observe_chat_id) diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index a2ba023e2..83f505cf8 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -42,7 +42,7 @@ class SubHeartflow: self.past_mind = [] self.current_state: CuttentState = CuttentState() self.llm_model = LLM_request( - model=global_config.llm_sub_heartflow, temperature=0.7, max_tokens=600, request_type="sub_heart_flow" + model=global_config.llm_sub_heartflow, temperature=0.5, max_tokens=600, request_type="sub_heart_flow" ) self.main_heartflow_info = "" @@ -221,9 +221,9 @@ class SubHeartflow: self.update_current_mind(reponse) - self.current_mind = reponse logger.debug(f"prompt:\n{prompt}\n") logger.info(f"麦麦的思考前脑内状态:{self.current_mind}") + return self.current_mind ,self.past_mind async def do_thinking_after_reply(self, reply_content, chat_talking_prompt): # print("麦麦回复之后脑袋转起来了") diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 40a00a3ab..42234da8e 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -84,7 +84,7 @@ class ChatBot: message = MessageRecv(message_data) groupinfo = message.message_info.group_info userinfo = message.message_info.user_info - logger.debug(f"处理消息:{str(message_data)[:80]}...") + logger.debug(f"处理消息:{str(message_data)[:120]}...") if userinfo.user_id in global_config.ban_user_id: logger.debug(f"用户{userinfo.user_id}被禁止回复") @@ -106,11 +106,11 @@ class ChatBot: await self._create_PFC_chat(message) else: if groupinfo.group_id in global_config.talk_allowed_groups: - logger.debug(f"开始群聊模式{str(message_data)[:50]}...") + # logger.debug(f"开始群聊模式{str(message_data)[:50]}...") if global_config.response_mode == "heart_flow": await self.think_flow_chat.process_message(message_data) elif global_config.response_mode == "reasoning": - logger.debug(f"开始推理模式{str(message_data)[:50]}...") + # logger.debug(f"开始推理模式{str(message_data)[:50]}...") await self.reasoning_chat.process_message(message_data) else: logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}") diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index f3369d7bb..5dc688c03 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -365,7 +365,7 @@ class MessageSet: self.chat_stream = chat_stream self.message_id = message_id self.messages: List[MessageSending] = [] - self.time = round(time.time(), 2) + self.time = round(time.time(), 3) # 保留3位小数 def add_message(self, message: MessageSending) -> None: """添加消息到集合""" diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py index eca5d0956..8bdc9c000 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py @@ -96,40 +96,39 @@ class ResponseGenerator: return None # 保存到数据库 - self._save_to_db( - message=message, - sender_name=sender_name, - prompt=prompt, - content=content, - reasoning_content=reasoning_content, - # reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else "" - ) + # self._save_to_db( + # message=message, + # sender_name=sender_name, + # prompt=prompt, + # content=content, + # reasoning_content=reasoning_content, + # # reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else "" + # ) return content - # def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str, - # content: str, content_check: str, reasoning_content: str, reasoning_content_check: str): - def _save_to_db( - self, - message: MessageRecv, - sender_name: str, - prompt: str, - content: str, - reasoning_content: str, - ): - """保存对话记录到数据库""" - db.reasoning_logs.insert_one( - { - "time": time.time(), - "chat_id": message.chat_stream.stream_id, - "user": sender_name, - "message": message.processed_plain_text, - "model": self.current_model_name, - "reasoning": reasoning_content, - "response": content, - "prompt": prompt, - } - ) + + # def _save_to_db( + # self, + # message: MessageRecv, + # sender_name: str, + # prompt: str, + # content: str, + # reasoning_content: str, + # ): + # """保存对话记录到数据库""" + # db.reasoning_logs.insert_one( + # { + # "time": time.time(), + # "chat_id": message.chat_stream.stream_id, + # "user": sender_name, + # "message": message.processed_plain_text, + # "model": self.current_model_name, + # "reasoning": reasoning_content, + # "response": content, + # "prompt": prompt, + # } + # ) async def _get_emotion_tags(self, content: str, processed_plain_text: str): """提取情感标签,结合立场和情绪""" diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py index f845770d3..909180556 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py @@ -1,7 +1,8 @@ import time from random import random import re - +import traceback +from typing import List from ...memory_system.Hippocampus import HippocampusManager from ...moods.moods import MoodManager from ...config.config import global_config @@ -19,6 +20,7 @@ from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from ...chat.chat_stream import chat_manager from ...person_info.relationship_manager import relationship_manager from ...chat.message_buffer import message_buffer +from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager # 定义日志配置 chat_config = LogConfig( @@ -59,7 +61,11 @@ class ThinkFlowChat: return thinking_id - async def _send_response_messages(self, message, chat, response_set, thinking_id): + async def _send_response_messages(self, + message, + chat, + response_set:List[str], + thinking_id) -> MessageSending: """发送回复消息""" container = message_manager.get_container(chat.stream_id) thinking_message = None @@ -72,12 +78,13 @@ class ThinkFlowChat: if not thinking_message: logger.warning("未找到对应的思考消息,可能已超时被移除") - return + return None thinking_start_time = thinking_message.thinking_start_time message_set = MessageSet(chat, thinking_id) mark_head = False + first_bot_msg = None for msg in response_set: message_segment = Seg(type="text", data=msg) bot_message = MessageSending( @@ -97,10 +104,12 @@ class ThinkFlowChat: ) if not mark_head: mark_head = True + first_bot_msg = bot_message # print(f"thinking_start_time:{bot_message.thinking_start_time}") message_set.add_message(bot_message) message_manager.add_message(message_set) + return first_bot_msg async def _handle_emoji(self, message, chat, response): """处理表情包""" @@ -257,6 +266,8 @@ class ThinkFlowChat: if random() < reply_probability: try: do_reply = True + + # 创建思考消息 try: @@ -266,6 +277,11 @@ class ThinkFlowChat: timing_results["创建思考消息"] = timer2 - timer1 except Exception as e: logger.error(f"心流创建思考消息失败: {e}") + + logger.debug(f"创建捕捉器,thinking_id:{thinking_id}") + + info_catcher = info_catcher_manager.get_info_catcher(thinking_id) + info_catcher.catch_decide_to_response(message) try: # 观察 @@ -275,36 +291,48 @@ class ThinkFlowChat: timing_results["观察"] = timer2 - timer1 except Exception as e: logger.error(f"心流观察失败: {e}") + + info_catcher.catch_after_observe(timing_results["观察"]) # 思考前脑内状态 try: timer1 = time.time() - await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply( + current_mind,past_mind = await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply( message.processed_plain_text ) timer2 = time.time() timing_results["思考前脑内状态"] = timer2 - timer1 except Exception as e: logger.error(f"心流思考前脑内状态失败: {e}") + + info_catcher.catch_afer_shf_step(timing_results["思考前脑内状态"],past_mind,current_mind) # 生成回复 timer1 = time.time() - response_set = await self.gpt.generate_response(message) + response_set = await self.gpt.generate_response(message,thinking_id) timer2 = time.time() timing_results["生成回复"] = timer2 - timer1 + info_catcher.catch_after_generate_response(timing_results["生成回复"]) + if not response_set: - logger.info("为什么生成回复失败?") + logger.info("回复生成失败,返回为空") return # 发送消息 try: timer1 = time.time() - await self._send_response_messages(message, chat, response_set, thinking_id) + first_bot_msg = await self._send_response_messages(message, chat, response_set, thinking_id) timer2 = time.time() timing_results["发送消息"] = timer2 - timer1 except Exception as e: logger.error(f"心流发送消息失败: {e}") + + + info_catcher.catch_after_response(timing_results["发送消息"],response_set,first_bot_msg) + + + info_catcher.done_catch() # 处理表情包 try: @@ -335,6 +363,7 @@ class ThinkFlowChat: except Exception as e: logger.error(f"心流处理消息失败: {e}") + logger.error(traceback.format_exc()) # 输出性能计时结果 if do_reply: diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py index 4087b0b89..8758b91d7 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py @@ -9,6 +9,7 @@ from ...chat.message import MessageRecv, MessageThinking from .think_flow_prompt_builder import prompt_builder from ...chat.utils import process_llm_response from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG +from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager # 定义日志配置 llm_config = LogConfig( @@ -32,15 +33,16 @@ class ResponseGenerator: self.current_model_type = "r1" # 默认使用 R1 self.current_model_name = "unknown model" - async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]: + async def generate_response(self, message: MessageRecv,thinking_id:str) -> Optional[List[str]]: """根据当前模型类型选择对应的生成函数""" + logger.info( f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" ) current_model = self.model_normal - model_response = await self._generate_response_with_model(message, current_model) + model_response = await self._generate_response_with_model(message, current_model,thinking_id) # print(f"raw_content: {model_response}") @@ -53,8 +55,11 @@ class ResponseGenerator: logger.info(f"{self.current_model_type}思考,失败") return None - async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request): + async def _generate_response_with_model(self, message: MessageRecv, model: LLM_request,thinking_id:str): sender_name = "" + + info_catcher = info_catcher_manager.get_info_catcher(thinking_id) + if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: sender_name = ( f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]" @@ -79,45 +84,51 @@ class ResponseGenerator: try: content, reasoning_content, self.current_model_name = await model.generate_response(prompt) + + info_catcher.catch_after_llm_generated( + prompt=prompt, + response=content, + reasoning_content=reasoning_content, + model_name=self.current_model_name) + except Exception: logger.exception("生成回复时出错") return None # 保存到数据库 - self._save_to_db( - message=message, - sender_name=sender_name, - prompt=prompt, - content=content, - reasoning_content=reasoning_content, - # reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else "" - ) + # self._save_to_db( + # message=message, + # sender_name=sender_name, + # prompt=prompt, + # content=content, + # reasoning_content=reasoning_content, + # # reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else "" + # ) return content - # def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str, - # content: str, content_check: str, reasoning_content: str, reasoning_content_check: str): - def _save_to_db( - self, - message: MessageRecv, - sender_name: str, - prompt: str, - content: str, - reasoning_content: str, - ): - """保存对话记录到数据库""" - db.reasoning_logs.insert_one( - { - "time": time.time(), - "chat_id": message.chat_stream.stream_id, - "user": sender_name, - "message": message.processed_plain_text, - "model": self.current_model_name, - "reasoning": reasoning_content, - "response": content, - "prompt": prompt, - } - ) + + # def _save_to_db( + # self, + # message: MessageRecv, + # sender_name: str, + # prompt: str, + # content: str, + # reasoning_content: str, + # ): + # """保存对话记录到数据库""" + # db.reasoning_logs.insert_one( + # { + # "time": time.time(), + # "chat_id": message.chat_stream.stream_id, + # "user": sender_name, + # "message": message.processed_plain_text, + # "model": self.current_model_name, + # "reasoning": reasoning_content, + # "response": content, + # "prompt": prompt, + # } + # ) async def _get_emotion_tags(self, content: str, processed_plain_text: str): """提取情感标签,结合立场和情绪""" @@ -167,10 +178,10 @@ class ResponseGenerator: logger.debug(f"获取情感标签时出错: {e}") return "中立", "平静" # 出错时返回默认值 - async def _process_response(self, content: str) -> Tuple[List[str], List[str]]: + async def _process_response(self, content: str) -> List[str]: """处理响应内容,返回处理后的内容和情感标签""" if not content: - return None, [] + return None processed_response = process_llm_response(content) diff --git a/src/plugins/respon_info_catcher/info_catcher.py b/src/plugins/respon_info_catcher/info_catcher.py new file mode 100644 index 000000000..4e9943b8c --- /dev/null +++ b/src/plugins/respon_info_catcher/info_catcher.py @@ -0,0 +1,228 @@ +from src.plugins.config.config import global_config +from src.plugins.chat.message import MessageRecv,MessageSending,Message +from src.common.database import db +import time +import traceback +from typing import List + +class InfoCatcher: + def __init__(self): + self.chat_history = [] # 聊天历史,长度为三倍使用的上下文 + self.context_length = global_config.MAX_CONTEXT_SIZE + self.chat_history_in_thinking = [] # 思考期间的聊天内容 + self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文 + + self.chat_id = "" + self.response_mode = global_config.response_mode + self.trigger_response_text = "" + self.response_text = "" + + self.trigger_response_time = 0 + self.trigger_response_message = None + + self.response_time = 0 + self.response_messages = [] + + # 使用字典来存储 heartflow 模式的数据 + self.heartflow_data = { + "heart_flow_prompt": "", + "sub_heartflow_before": "", + "sub_heartflow_now": "", + "sub_heartflow_after": "", + "sub_heartflow_model": "", + "prompt": "", + "response": "", + "model": "" + } + + # 使用字典来存储 reasoning 模式的数据 + self.reasoning_data = { + "thinking_log": "", + "prompt": "", + "response": "", + "model": "" + } + + # 耗时 + self.timing_results = { + "interested_rate_time": 0, + "sub_heartflow_observe_time": 0, + "sub_heartflow_step_time": 0, + "make_response_time": 0, + } + + def catch_decide_to_response(self,message:MessageRecv): + # 搜集决定回复时的信息 + self.trigger_response_message = message + self.trigger_response_text = message.detailed_plain_text + + self.trigger_response_time = time.time() + + self.chat_id = message.chat_stream.stream_id + + self.chat_history = self.get_message_from_db_before_msg(message) + + def catch_after_observe(self,obs_duration:float):#这里可以有更多信息 + self.timing_results["sub_heartflow_observe_time"] = obs_duration + + # def catch_shf + + def catch_afer_shf_step(self,step_duration:float,past_mind:str,current_mind:str): + self.timing_results["sub_heartflow_step_time"] = step_duration + if len(past_mind) > 1: + self.heartflow_data["sub_heartflow_before"] = past_mind[-1] + self.heartflow_data["sub_heartflow_now"] = current_mind + else: + self.heartflow_data["sub_heartflow_before"] = past_mind[-1] + self.heartflow_data["sub_heartflow_now"] = current_mind + + def catch_after_llm_generated(self,prompt:str, + response:str, + reasoning_content:str = "", + model_name:str = ""): + if self.response_mode == "heart_flow": + self.heartflow_data["prompt"] = prompt + self.heartflow_data["response"] = response + self.heartflow_data["model"] = model_name + elif self.response_mode == "reasoning": + self.reasoning_data["thinking_log"] = reasoning_content + self.reasoning_data["prompt"] = prompt + self.reasoning_data["response"] = response + self.reasoning_data["model"] = model_name + + self.response_text = response + + def catch_after_generate_response(self,response_duration:float): + self.timing_results["make_response_time"] = response_duration + + + + def catch_after_response(self,response_duration:float, + response_message:List[str], + first_bot_msg:MessageSending): + self.timing_results["make_response_time"] = response_duration + self.response_time = time.time() + for msg in response_message: + self.response_messages.append(msg) + + self.chat_history_in_thinking = self.get_message_from_db_between_msgs(self.trigger_response_message,first_bot_msg) + + def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message): + try: + # 从数据库中获取消息的时间戳 + time_start = message_start.message_info.time + time_end = message_end.message_info.time + chat_id = message_start.chat_stream.stream_id + + print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") + + # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 + messages_between = db.messages.find( + { + "chat_id": chat_id, + "time": {"$gt": time_start, "$lt": time_end} + } + ).sort("time", -1) + + result = list(messages_between) + print(f"查询结果数量: {len(result)}") + if result: + print(f"第一条消息时间: {result[0]['time']}") + print(f"最后一条消息时间: {result[-1]['time']}") + return result + except Exception as e: + print(f"获取消息时出错: {str(e)}") + return [] + + def get_message_from_db_before_msg(self, message: MessageRecv): + # 从数据库中获取消息 + message_id = message.message_info.message_id + chat_id = message.chat_stream.stream_id + + # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 + messages_before = db.messages.find( + {"chat_id": chat_id, "message_id": {"$lt": message_id}} + ).sort("time", -1).limit(self.context_length*3) #获取更多历史信息 + + return list(messages_before) + + def message_list_to_dict(self, message_list): + #存储简化的聊天记录 + result = [] + for message in message_list: + if not isinstance(message, dict): + message = self.message_to_dict(message) + # print(message) + + lite_message = { + "time": message["time"], + "user_nickname": message["user_info"]["user_nickname"], + "processed_plain_text": message["processed_plain_text"], + } + result.append(lite_message) + + return result + + def message_to_dict(self, message): + if not message: + return None + if isinstance(message, dict): + return message + return { + # "message_id": message.message_info.message_id, + "time": message.message_info.time, + "user_id": message.message_info.user_info.user_id, + "user_nickname": message.message_info.user_info.user_nickname, + "processed_plain_text": message.processed_plain_text, + # "detailed_plain_text": message.detailed_plain_text + } + + def done_catch(self): + """将收集到的信息存储到数据库的 thinking_log 集合中""" + try: + # 将消息对象转换为可序列化的字典 + + thinking_log_data = { + "chat_id": self.chat_id, + "response_mode": self.response_mode, + "trigger_text": self.trigger_response_text, + "response_text": self.response_text, + "trigger_info": { + "time": self.trigger_response_time, + "message": self.message_to_dict(self.trigger_response_message), + }, + "response_info": { + "time": self.response_time, + "message": self.response_messages, + }, + "timing_results": self.timing_results, + "chat_history": self.message_list_to_dict(self.chat_history), + "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking), + "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response) + } + + # 根据不同的响应模式添加相应的数据 + if self.response_mode == "heart_flow": + thinking_log_data["mode_specific_data"] = self.heartflow_data + elif self.response_mode == "reasoning": + thinking_log_data["mode_specific_data"] = self.reasoning_data + + # 将数据插入到 thinking_log 集合中 + db.thinking_log.insert_one(thinking_log_data) + + return True + except Exception as e: + print(f"存储思考日志时出错: {str(e)}") + print(traceback.format_exc()) + return False + +class InfoCatcherManager: + def __init__(self): + self.info_catchers = {} + + def get_info_catcher(self,thinking_id:str) -> InfoCatcher: + if thinking_id not in self.info_catchers: + self.info_catchers[thinking_id] = InfoCatcher() + return self.info_catchers[thinking_id] + +info_catcher_manager = InfoCatcherManager() \ No newline at end of file diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 70cf0e0b7..0061b9ca2 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -164,7 +164,7 @@ response_max_sentence_num = 4 # 回复允许的最大句子数 [remote] #发送统计信息,主要是看全球有多少只麦麦 enable = true -[experimental] +[experimental] #实验性功能,不一定完善或者根本不能用 enable_friend_chat = false # 是否启用好友聊天 pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与回复模式独立 From 7b0bdc8f29dd0db305f70db7f44f90daf1d70453 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 9 Apr 2025 20:11:36 +0800 Subject: [PATCH 07/24] fix ruff --- src/common/server.py | 2 +- .../chat_module/reasoning_chat/reasoning_generator.py | 3 +-- .../chat_module/think_flow_chat/think_flow_generator.py | 5 ++--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/common/server.py b/src/common/server.py index fd1f3ff18..a4998a305 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -1,5 +1,5 @@ from fastapi import FastAPI, APIRouter -from typing import Optional, Union +from typing import Optional from uvicorn import Config, Server as UvicornServer import os diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py index 8bdc9c000..5d4587675 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py @@ -2,10 +2,9 @@ import time from typing import List, Optional, Tuple, Union import random -from ....common.database import db from ...models.utils_model import LLM_request from ...config.config import global_config -from ...chat.message import MessageRecv, MessageThinking +from ...chat.message import MessageThinking from .reasoning_prompt_builder import prompt_builder from ...chat.utils import process_llm_response from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py index 8758b91d7..346a41d8e 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py @@ -1,11 +1,10 @@ import time -from typing import List, Optional, Tuple, Union +from typing import List, Optional -from ....common.database import db from ...models.utils_model import LLM_request from ...config.config import global_config -from ...chat.message import MessageRecv, MessageThinking +from ...chat.message import MessageRecv from .think_flow_prompt_builder import prompt_builder from ...chat.utils import process_llm_response from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG From f3d6e7cfa5de98e3e2a2c9e9edfba60a6fde385e Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 9 Apr 2025 22:50:21 +0800 Subject: [PATCH 08/24] =?UTF-8?q?feat=EF=BC=9A=E5=9B=9E=E5=A4=8D=E6=B8=A9?= =?UTF-8?q?=E5=BA=A6=E7=8E=B0=E5=9C=A8=E4=BC=9A=E5=8F=97=E5=88=B0=EF=BC=9A?= =?UTF-8?q?=E4=BA=BA=E6=A0=BC-=E6=83=85=E7=BB=AA-temp=E7=9A=84=E9=93=BE?= =?UTF-8?q?=E6=9D=A1=E5=BD=B1=E5=93=8D;=E9=A1=BA=E4=BE=BF=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=BA=86=E6=83=85=E7=BB=AA=E6=BF=80=E6=B4=BB=E5=BA=A6?= =?UTF-8?q?=E7=9A=84=E5=8F=96=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat:回复温度现在会受到:人格-情绪-temp的链条影响;顺便修改了情绪激活度的取值 --- .../reasoning_chat/reasoning_chat.py | 29 ++++++++-- .../reasoning_chat/reasoning_generator.py | 18 +++++-- .../think_flow_chat/think_flow_generator.py | 8 ++- src/plugins/moods/moods.py | 54 +++++++++++-------- 4 files changed, 79 insertions(+), 30 deletions(-) diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py index 683bef463..b9e94e4fe 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py @@ -1,7 +1,7 @@ import time from random import random import re - +from typing import List from ...memory_system.Hippocampus import HippocampusManager from ...moods.moods import MoodManager from ...config.config import global_config @@ -18,6 +18,7 @@ from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from ...chat.chat_stream import chat_manager from ...person_info.relationship_manager import relationship_manager from ...chat.message_buffer import message_buffer +from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager # 定义日志配置 chat_config = LogConfig( @@ -58,7 +59,11 @@ class ReasoningChat: return thinking_id - async def _send_response_messages(self, message, chat, response_set, thinking_id): + async def _send_response_messages(self, + message, + chat, + response_set:List[str], + thinking_id) -> MessageSending: """发送回复消息""" container = message_manager.get_container(chat.stream_id) thinking_message = None @@ -77,6 +82,7 @@ class ReasoningChat: message_set = MessageSet(chat, thinking_id) mark_head = False + first_bot_msg = None for msg in response_set: message_segment = Seg(type="text", data=msg) bot_message = MessageSending( @@ -96,9 +102,12 @@ class ReasoningChat: ) if not mark_head: mark_head = True + first_bot_msg = bot_message message_set.add_message(bot_message) message_manager.add_message(message_set) + return first_bot_msg + async def _handle_emoji(self, message, chat, response): """处理表情包""" if random() < global_config.emoji_chance: @@ -231,12 +240,19 @@ class ReasoningChat: thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo) timer2 = time.time() timing_results["创建思考消息"] = timer2 - timer1 + + logger.debug(f"创建捕捉器,thinking_id:{thinking_id}") + + info_catcher = info_catcher_manager.get_info_catcher(thinking_id) + info_catcher.catch_decide_to_response(message) # 生成回复 timer1 = time.time() - response_set = await self.gpt.generate_response(message) + response_set = await self.gpt.generate_response(message,thinking_id) timer2 = time.time() timing_results["生成回复"] = timer2 - timer1 + + info_catcher.catch_after_generate_response(timing_results["生成回复"]) if not response_set: logger.info("为什么生成回复失败?") @@ -244,9 +260,14 @@ class ReasoningChat: # 发送消息 timer1 = time.time() - await self._send_response_messages(message, chat, response_set, thinking_id) + first_bot_msg = await self._send_response_messages(message, chat, response_set, thinking_id) timer2 = time.time() timing_results["发送消息"] = timer2 - timer1 + + info_catcher.catch_after_response(timing_results["发送消息"],response_set,first_bot_msg) + + + info_catcher.done_catch() # 处理表情包 timer1 = time.time() diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py index 5d4587675..31bb378c3 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py @@ -8,6 +8,7 @@ from ...chat.message import MessageThinking from .reasoning_prompt_builder import prompt_builder from ...chat.utils import process_llm_response from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG +from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager # 定义日志配置 llm_config = LogConfig( @@ -37,7 +38,7 @@ class ResponseGenerator: self.current_model_type = "r1" # 默认使用 R1 self.current_model_name = "unknown model" - async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]: + async def generate_response(self, message: MessageThinking,thinking_id:str) -> Optional[Union[str, List[str]]]: """根据当前模型类型选择对应的生成函数""" # 从global_config中获取模型概率值并选择模型 if random.random() < global_config.MODEL_R1_PROBABILITY: @@ -51,7 +52,7 @@ class ResponseGenerator: f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" ) # noqa: E501 - model_response = await self._generate_response_with_model(message, current_model) + model_response = await self._generate_response_with_model(message, current_model,thinking_id) # print(f"raw_content: {model_response}") @@ -64,8 +65,11 @@ class ResponseGenerator: logger.info(f"{self.current_model_type}思考,失败") return None - async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request): + async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request,thinking_id:str): sender_name = "" + + info_catcher = info_catcher_manager.get_info_catcher(thinking_id) + if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: sender_name = ( f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]" @@ -90,6 +94,14 @@ class ResponseGenerator: try: content, reasoning_content, self.current_model_name = await model.generate_response(prompt) + + info_catcher.catch_after_llm_generated( + prompt=prompt, + response=content, + reasoning_content=reasoning_content, + model_name=self.current_model_name) + + except Exception: logger.exception("生成回复时出错") return None diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py index 346a41d8e..6826a3ded 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py @@ -10,6 +10,8 @@ from ...chat.utils import process_llm_response from src.common.logger import get_module_logger, LogConfig, LLM_STYLE_CONFIG from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager +from src.plugins.moods.moods import MoodManager + # 定义日志配置 llm_config = LogConfig( # 使用消息发送专用样式 @@ -39,8 +41,12 @@ class ResponseGenerator: logger.info( f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" ) - + + arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier() + + current_model = self.model_normal + current_model.temperature = 0.7 * arousal_multiplier #激活度越高,温度越高 model_response = await self._generate_response_with_model(message, current_model,thinking_id) # print(f"raw_content: {model_response}") diff --git a/src/plugins/moods/moods.py b/src/plugins/moods/moods.py index 61b211523..d564b48b6 100644 --- a/src/plugins/moods/moods.py +++ b/src/plugins/moods/moods.py @@ -19,7 +19,7 @@ logger = get_module_logger("mood_manager", config=mood_config) @dataclass class MoodState: valence: float # 愉悦度 (-1.0 到 1.0),-1表示极度负面,1表示极度正面 - arousal: float # 唤醒度 (0.0 到 1.0),0表示完全平静,1表示极度兴奋 + arousal: float # 唤醒度 (-1.0 到 1.0),-1表示抑制,1表示兴奋 text: str # 心情文本描述 @@ -42,7 +42,7 @@ class MoodManager: self._initialized = True # 初始化心情状态 - self.current_mood = MoodState(valence=0.0, arousal=0.5, text="平静") + self.current_mood = MoodState(valence=0.0, arousal=0.0, text="平静") # 从配置文件获取衰减率 self.decay_rate_valence = 1 - global_config.mood_decay_rate # 愉悦度衰减率 @@ -71,21 +71,21 @@ class MoodManager: # 情绪文本映射表 self.mood_text_map = { # 第一象限:高唤醒,正愉悦 - (0.5, 0.7): "兴奋", - (0.3, 0.8): "快乐", - (0.2, 0.65): "满足", + (0.5, 0.4): "兴奋", + (0.3, 0.6): "快乐", + (0.2, 0.3): "满足", # 第二象限:高唤醒,负愉悦 - (-0.5, 0.7): "愤怒", - (-0.3, 0.8): "焦虑", - (-0.2, 0.65): "烦躁", + (-0.5, 0.4): "愤怒", + (-0.3, 0.6): "焦虑", + (-0.2, 0.3): "烦躁", # 第三象限:低唤醒,负愉悦 - (-0.5, 0.3): "悲伤", - (-0.3, 0.35): "疲倦", - (-0.4, 0.15): "疲倦", + (-0.5, -0.4): "悲伤", + (-0.3, -0.3): "疲倦", + (-0.4, -0.7): "疲倦", # 第四象限:低唤醒,正愉悦 - (0.2, 0.45): "平静", - (0.3, 0.4): "安宁", - (0.5, 0.3): "放松", + (0.2, -0.1): "平静", + (0.3, -0.2): "安宁", + (0.5, -0.4): "放松", } @classmethod @@ -164,15 +164,15 @@ class MoodManager: -decay_rate_negative * time_diff * neuroticism_factor ) - # Arousal 向中性(0.5)回归 - arousal_target = 0.5 + # Arousal 向中性(0)回归 + arousal_target = 0 self.current_mood.arousal = arousal_target + (self.current_mood.arousal - arousal_target) * math.exp( -self.decay_rate_arousal * time_diff * neuroticism_factor ) # 确保值在合理范围内 self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) - self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) + self.current_mood.arousal = max(-1.0, min(1.0, self.current_mood.arousal)) self.last_update = current_time @@ -184,7 +184,7 @@ class MoodManager: # 限制范围 self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) - self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) + self.current_mood.arousal = max(-1.0, min(1.0, self.current_mood.arousal)) self._update_mood_text() @@ -217,7 +217,7 @@ class MoodManager: # 限制范围 self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) - self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) + self.current_mood.arousal = max(-1.0, min(1.0, self.current_mood.arousal)) self._update_mood_text() @@ -232,12 +232,22 @@ class MoodManager: elif self.current_mood.valence < -0.5: base_prompt += "你现在心情不太好," - if self.current_mood.arousal > 0.7: + if self.current_mood.arousal > 0.4: base_prompt += "情绪比较激动。" - elif self.current_mood.arousal < 0.3: + elif self.current_mood.arousal < -0.4: base_prompt += "情绪比较平静。" return base_prompt + + def get_arousal_multiplier(self) -> float: + """根据当前情绪状态返回唤醒度乘数""" + if self.current_mood.arousal > 0.4: + multiplier = 1 + min(0.15,(self.current_mood.arousal - 0.4)/3) + return multiplier + elif self.current_mood.arousal < -0.4: + multiplier = 1 - min(0.15,((0 - self.current_mood.arousal) - 0.4)/3) + return multiplier + return 1.0 def get_current_mood(self) -> MoodState: """获取当前情绪状态""" @@ -278,7 +288,7 @@ class MoodManager: # 限制范围 self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence)) - self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal)) + self.current_mood.arousal = max(-1.0, min(1.0, self.current_mood.arousal)) self._update_mood_text() From 451d0c9a32d0af153a6c48f4634a9ccde791e37c Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 9 Apr 2025 23:24:09 +0800 Subject: [PATCH 09/24] =?UTF-8?q?better=EF=BC=9A=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=9F=BA=E4=BA=8EV3=E7=9A=84=E5=BF=83=E6=B5=81=E6=95=88?= =?UTF-8?q?=E6=9E=9C=EF=BC=8C=E4=BC=98=E5=8C=96=E6=B8=A9=E5=BA=A6=E5=92=8C?= =?UTF-8?q?prompt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/heart_flow/sub_heartflow.py | 85 +++++++------------ .../think_flow_chat/think_flow_chat.py | 4 +- .../think_flow_chat/think_flow_generator.py | 2 +- .../think_flow_prompt_builder.py | 35 ++------ src/plugins/schedule/schedule_generator.py | 2 +- template/bot_config_template.toml | 11 ++- 6 files changed, 48 insertions(+), 91 deletions(-) diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index a2a4c0bbf..1fc95e224 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -13,6 +13,9 @@ from src.common.database import db from typing import Union from src.individuality.individuality import Individuality import random +from src.plugins.chat.chat_stream import ChatStream +from src.plugins.person_info.relationship_manager import relationship_manager +from src.plugins.chat.utils import get_recent_group_speaker subheartflow_config = LogConfig( # 使用海马体专用样式 @@ -42,7 +45,7 @@ class SubHeartflow: self.past_mind = [] self.current_state: CurrentState = CurrentState() self.llm_model = LLM_request( - model=global_config.llm_sub_heartflow, temperature=0.5, max_tokens=600, request_type="sub_heart_flow" + model=global_config.llm_sub_heartflow, temperature=0.3, max_tokens=600, request_type="sub_heart_flow" ) self.main_heartflow_info = "" @@ -58,6 +61,8 @@ class SubHeartflow: self.observations: list[Observation] = [] self.running_knowledges = [] + + self.bot_name = global_config.BOT_NICKNAME def add_observation(self, observation: Observation): """添加一个新的observation对象到列表中,如果已存在相同id的observation则不添加""" @@ -106,56 +111,11 @@ class SubHeartflow: ): # 5分钟无回复/不在场,销毁 logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...") break # 退出循环,销毁自己 - - # async def do_a_thinking(self): - # current_thinking_info = self.current_mind - # mood_info = self.current_state.mood - - # observation = self.observations[0] - # chat_observe_info = observation.observe_info - # # print(f"chat_observe_info:{chat_observe_info}") - - # # 调取记忆 - # related_memory = await HippocampusManager.get_instance().get_memory_from_text( - # text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False - # ) - - # if related_memory: - # related_memory_info = "" - # for memory in related_memory: - # related_memory_info += memory[1] - # else: - # related_memory_info = "" - - # # print(f"相关记忆:{related_memory_info}") - - # schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False) - - # prompt = "" - # prompt += f"你刚刚在做的事情是:{schedule_info}\n" - # # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" - # prompt += f"你{self.personality_info}\n" - # if related_memory_info: - # prompt += f"你想起来你之前见过的回忆:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n" - # prompt += f"刚刚你的想法是{current_thinking_info}。\n" - # prompt += "-----------------------------------\n" - # prompt += f"现在你正在上网,和qq群里的网友们聊天,群里正在聊的话题是:{chat_observe_info}\n" - # prompt += f"你现在{mood_info}\n" - # prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," - # prompt += "但是记得结合上述的消息,要记得维持住你的人设,关注聊天和新内容,不要思考太多:" - # response, reasoning_content = await self.llm_model.generate_response_async(prompt) - - # self.update_current_mind(response) - - # self.current_mind = response - # logger.debug(f"prompt:\n{prompt}\n") - # logger.info(f"麦麦的脑内状态:{self.current_mind}") - async def do_observe(self): observation = self.observations[0] await observation.observe() - async def do_thinking_before_reply(self, message_txt): + async def do_thinking_before_reply(self, message_txt:str, sender_name:str, chat_stream:ChatStream): current_thinking_info = self.current_mind mood_info = self.current_state.mood # mood_info = "你很生气,很愤怒" @@ -164,7 +124,7 @@ class SubHeartflow: # print(f"chat_observe_info:{chat_observe_info}") # 开始构建prompt - prompt_personality = "你" + prompt_personality = f"你的名字是{self.bot_name},你" # person individuality = Individuality.get_instance() @@ -178,6 +138,25 @@ class SubHeartflow: identity_detail = individuality.identity.identity_detail random.shuffle(identity_detail) prompt_personality += f",{identity_detail[0]}" + + # 关系 + who_chat_in_group = [ + (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname) + ] + who_chat_in_group += get_recent_group_speaker( + chat_stream.stream_id, + (chat_stream.user_info.platform, chat_stream.user_info.user_id), + limit=global_config.MAX_CONTEXT_SIZE, + ) + + relation_prompt = "" + for person in who_chat_in_group: + relation_prompt += await relationship_manager.build_relationship_info(person) + + relation_prompt_all = ( + f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录," + f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。" + ) # 调取记忆 related_memory = await HippocampusManager.get_instance().get_memory_from_text( @@ -204,6 +183,7 @@ class SubHeartflow: prompt = "" # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" + prompt += f"{relation_prompt_all}\n" prompt += f"{prompt_personality}\n" prompt += f"你刚刚在做的事情是:{schedule_info}\n" if related_memory_info: @@ -214,9 +194,10 @@ class SubHeartflow: prompt += "-----------------------------------\n" prompt += f"现在你正在上网,和qq群里的网友们聊天,群里正在聊的话题是:{chat_observe_info}\n" prompt += f"你现在{mood_info}\n" - prompt += f"你注意到有人刚刚说:{message_txt}\n" - prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," - prompt += "记得结合上述的消息,要记得维持住你的人设,注意自己的名字,关注有人刚刚说的内容,不要思考太多:" + prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" + prompt += "现在你接下去继续浅浅思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," + prompt += "思考时可以想想如何对群聊内容进行回复。请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)," + prompt += f"记得结合上述的消息,要记得维持住你的人设,注意你就是{self.bot_name},{self.bot_name}指的就是你。" try: response, reasoning_content = await self.llm_model.generate_response_async(prompt) @@ -235,7 +216,7 @@ class SubHeartflow: # print("麦麦回复之后脑袋转起来了") # 开始构建prompt - prompt_personality = "你" + prompt_personality = f"你的名字是{self.bot_name},你" # person individuality = Individuality.get_instance() diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py index 909180556..51bafcbcc 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py @@ -298,7 +298,9 @@ class ThinkFlowChat: try: timer1 = time.time() current_mind,past_mind = await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply( - message.processed_plain_text + message_txt = message.processed_plain_text, + sender_name = message.message_info.user_info.user_nickname, + chat_stream = chat ) timer2 = time.time() timing_results["思考前脑内状态"] = timer2 - timer1 diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py index 6826a3ded..2df0eb138 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py @@ -25,7 +25,7 @@ logger = get_module_logger("llm_generator", config=llm_config) class ResponseGenerator: def __init__(self): self.model_normal = LLM_request( - model=global_config.llm_normal, temperature=0.8, max_tokens=256, request_type="response_heartflow" + model=global_config.llm_normal, temperature=0.6, max_tokens=256, request_type="response_heartflow" ) self.model_sum = LLM_request( diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py index fc52a6151..d8b6d3395 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py @@ -26,30 +26,7 @@ class PromptBuilder: individuality = Individuality.get_instance() prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1) prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1) - # 关系 - who_chat_in_group = [ - (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname) - ] - who_chat_in_group += get_recent_group_speaker( - stream_id, - (chat_stream.user_info.platform, chat_stream.user_info.user_id), - limit=global_config.MAX_CONTEXT_SIZE, - ) - relation_prompt = "" - for person in who_chat_in_group: - relation_prompt += await relationship_manager.build_relationship_info(person) - - relation_prompt_all = ( - f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录," - f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。" - ) - - # 心情 - mood_manager = MoodManager.get_instance() - mood_prompt = mood_manager.get_prompt() - - logger.info(f"心情prompt: {mood_prompt}") # 日程构建 # schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}''' @@ -101,18 +78,16 @@ class PromptBuilder: logger.info("开始构建prompt") prompt = f""" - {relation_prompt_all}\n {chat_target} {chat_talking_prompt} -你刚刚脑子里在想: -{current_mind_info} 现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。\n 你的网名叫{global_config.BOT_NICKNAME},有人也叫你{"/".join(global_config.BOT_ALIAS_NAMES)},{prompt_personality} {prompt_identity}。 你正在{chat_target_2},现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些, -尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger} -请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 -请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 -{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""" +你刚刚脑子里在想: +{current_mind_info} +回复尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger} +请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。 +{moderation_prompt}。注意:不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""" return prompt diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index 23b898f7d..c1b5fdec6 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -32,7 +32,7 @@ class ScheduleGenerator: # 使用离线LLM模型 self.llm_scheduler_all = LLM_request( model=global_config.llm_reasoning, - temperature=global_config.SCHEDULE_TEMPERATURE, + temperature=global_config.SCHEDULE_TEMPERATURE+0.3, max_tokens=7000, request_type="schedule", ) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 0061b9ca2..84a70cd98 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -60,7 +60,7 @@ appearance = "用几句话描述外貌特征" # 外貌特征 enable_schedule_gen = true # 是否启用日程表(尚未完成) prompt_schedule_gen = "用几句话描述描述性格特点或行动规律,这个特征会用来生成日程表" schedule_doing_update_interval = 900 # 日程表更新间隔 单位秒 -schedule_temperature = 0.3 # 日程表温度,建议0.3-0.6 +schedule_temperature = 0.2 # 日程表温度,建议0.2-0.5 time_zone = "Asia/Shanghai" # 给你的机器人设置时区,可以解决运行电脑时区和国内时区不同的情况,或者模拟国外留学生日程 [platforms] # 必填项目,填写每个平台适配器提供的链接 @@ -239,12 +239,11 @@ provider = "SILICONFLOW" pri_in = 0 pri_out = 0 -[model.llm_sub_heartflow] #心流:建议使用qwen2.5 7b -# name = "Pro/Qwen/Qwen2.5-7B-Instruct" -name = "Qwen/Qwen2.5-32B-Instruct" +[model.llm_sub_heartflow] #子心流:建议使用V3级别 +name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" -pri_in = 1.26 -pri_out = 1.26 +pri_in = 2 +pri_out = 8 [model.llm_heartflow] #心流:建议使用qwen2.5 32b # name = "Pro/Qwen/Qwen2.5-7B-Instruct" From dbc60dbfeea2ca8dcba6f4263a871e936f3deaba Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 9 Apr 2025 23:28:08 +0800 Subject: [PATCH 10/24] fic:ruff --- .../chat_module/think_flow_chat/think_flow_prompt_builder.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py index d8b6d3395..5d701c6a2 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py @@ -1,12 +1,10 @@ import random from typing import Optional -from ...moods.moods import MoodManager from ...config.config import global_config -from ...chat.utils import get_recent_group_detailed_plain_text, get_recent_group_speaker +from ...chat.utils import get_recent_group_detailed_plain_text from ...chat.chat_stream import chat_manager from src.common.logger import get_module_logger -from ...person_info.relationship_manager import relationship_manager from ....individuality.individuality import Individuality from src.heart_flow.heartflow import heartflow From 360406efde6c723bcea77fb35024232c38558a72 Mon Sep 17 00:00:00 2001 From: HexatomicRing <54496918+HexatomicRing@users.noreply.github.com> Date: Thu, 10 Apr 2025 10:29:55 +0800 Subject: [PATCH 11/24] =?UTF-8?q?=E7=BB=99keywords=5Freaction=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=AD=A3=E5=88=99=E8=A1=A8=E8=BE=BE=E5=BC=8F=E5=8C=B9?= =?UTF-8?q?=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 顺便做了正则表达式预编译 --- .../chat_module/only_process/only_message_process.py | 3 +-- .../chat_module/reasoning_chat/reasoning_chat.py | 4 ++-- .../reasoning_chat/reasoning_prompt_builder.py | 11 +++++++++++ .../chat_module/think_flow_chat/think_flow_chat.py | 3 +-- .../think_flow_chat/think_flow_prompt_builder.py | 11 +++++++++++ src/plugins/config/config.py | 8 ++++++-- template/bot_config_template.toml | 5 +++++ 7 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/plugins/chat_module/only_process/only_message_process.py b/src/plugins/chat_module/only_process/only_message_process.py index 6da19efe7..a39b7f8b0 100644 --- a/src/plugins/chat_module/only_process/only_message_process.py +++ b/src/plugins/chat_module/only_process/only_message_process.py @@ -2,7 +2,6 @@ from src.common.logger import get_module_logger from src.plugins.chat.message import MessageRecv from src.plugins.storage.storage import MessageStorage from src.plugins.config.config import global_config -import re from datetime import datetime logger = get_module_logger("pfc_message_processor") @@ -28,7 +27,7 @@ class MessageProcessor: def _check_ban_regex(self, text: str, chat, userinfo) -> bool: """检查消息是否匹配过滤正则表达式""" for pattern in global_config.ban_msgs_regex: - if re.search(pattern, text): + if pattern.search(text): logger.info( f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" ) diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py index b9e94e4fe..357c9c87d 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py @@ -1,6 +1,6 @@ import time from random import random -import re + from typing import List from ...memory_system.Hippocampus import HippocampusManager from ...moods.moods import MoodManager @@ -302,7 +302,7 @@ class ReasoningChat: def _check_ban_regex(self, text: str, chat, userinfo) -> bool: """检查消息是否匹配过滤正则表达式""" for pattern in global_config.ban_msgs_regex: - if re.search(pattern, text): + if pattern.search(text): logger.info( f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" ) diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py index a379fa6d5..045045bae 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py @@ -115,6 +115,17 @@ class PromptBuilder: f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" ) keywords_reaction_prompt += rule.get("reaction", "") + "," + for pattern in rule.get("regex", []): + result = pattern.search(message_txt) + if result: + reaction = rule.get('reaction', '') + for name, content in result.groupdict().items(): + reaction = reaction.replace(f'[{name}]', content) + logger.info( + f"匹配到以下正则表达式:{pattern},触发反应:{reaction}" + ) + keywords_reaction_prompt += reaction + "," + break # 中文高手(新加的好玩功能) prompt_ger = "" diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py index 51bafcbcc..329619256 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py @@ -1,6 +1,5 @@ import time from random import random -import re import traceback from typing import List from ...memory_system.Hippocampus import HippocampusManager @@ -388,7 +387,7 @@ class ThinkFlowChat: def _check_ban_regex(self, text: str, chat, userinfo) -> bool: """检查消息是否匹配过滤正则表达式""" for pattern in global_config.ban_msgs_regex: - if re.search(pattern, text): + if pattern.search(text): logger.info( f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" ) diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py index 5d701c6a2..8938e2e78 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py @@ -61,6 +61,17 @@ class PromptBuilder: f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" ) keywords_reaction_prompt += rule.get("reaction", "") + "," + for pattern in rule.get("regex", []): + result = pattern.search(message_txt) + if result: + reaction = rule.get('reaction', '') + for name, content in result.groupdict().items(): + reaction = reaction.replace(f'[{name}]', content) + logger.info( + f"匹配到以下正则表达式:{pattern},触发反应:{reaction}" + ) + keywords_reaction_prompt += reaction + "," + break # 中文高手(新加的好玩功能) prompt_ger = "" diff --git a/src/plugins/config/config.py b/src/plugins/config/config.py index 23e277498..be3343292 100644 --- a/src/plugins/config/config.py +++ b/src/plugins/config/config.py @@ -1,4 +1,5 @@ import os +import re from dataclasses import dataclass, field from typing import Dict, List, Optional from dateutil import tz @@ -545,8 +546,8 @@ class BotConfig: "response_interested_rate_amplifier", config.response_interested_rate_amplifier ) config.down_frequency_rate = msg_config.get("down_frequency_rate", config.down_frequency_rate) - config.ban_msgs_regex = msg_config.get("ban_msgs_regex", config.ban_msgs_regex) - + for r in msg_config.get("ban_msgs_regex", config.ban_msgs_regex): + config.ban_msgs_regex.add(re.compile(r)) if config.INNER_VERSION in SpecifierSet(">=0.0.11"): config.max_response_length = msg_config.get("max_response_length", config.max_response_length) if config.INNER_VERSION in SpecifierSet(">=1.1.4"): @@ -587,6 +588,9 @@ class BotConfig: keywords_reaction_config = parent["keywords_reaction"] if keywords_reaction_config.get("enable", False): config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules) + for rule in config.keywords_reaction_rules: + if rule.get("enable", False) and "regex" in rule: + rule["regex"] = [re.compile(r) for r in rule.get("regex", [])] def chinese_typo(parent: dict): chinese_typo_config = parent["chinese_typo"] diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 84a70cd98..059a03ed2 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -149,6 +149,11 @@ enable = false # 仅作示例,不会触发 keywords = ["测试关键词回复","test",""] reaction = "回答“测试成功”" +[[keywords_reaction.rules]] # 使用正则表达式匹配句式 +enable = false # 仅作示例,不会触发 +regex = ["^(?P\\S{1,20})是这样的$"] # 将匹配到的词汇命名为n,反应中对应的[n]会被替换为匹配到的内容,若不了解正则表达式请勿编写 +reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)" + [chinese_typo] enable = true # 是否启用中文错别字生成器 error_rate=0.001 # 单字替换概率 From b34e870892dba7057d03b198d9ebab1237921eef Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Thu, 10 Apr 2025 16:18:45 +0800 Subject: [PATCH 12/24] =?UTF-8?q?better=EF=BC=9A=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=BF=83=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/heart_flow/observation.py | 4 +- src/heart_flow/sub_heartflow.py | 23 +- .../think_flow_chat/think_flow_generator.py | 136 ++- .../think_flow_prompt_builder.py | 95 +- src/plugins/memory_system/Hippocampus.py | 857 +++++++++--------- src/plugins/moods/moods.py | 8 +- template/bot_config_template.toml | 4 +- 7 files changed, 632 insertions(+), 495 deletions(-) diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index 78cb9ef67..c54df2f92 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -147,8 +147,8 @@ class ChattingObservation(Observation): except Exception as e: print(f"获取总结失败: {e}") self.observe_info = "" - print(f"prompt:{prompt}") - print(f"self.observe_info:{self.observe_info}") + # print(f"prompt:{prompt}") + # print(f"self.observe_info:{self.observe_info}") def translate_message_list_to_str(self): self.talking_message_str = "" diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index 1fc95e224..a6c6e047a 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -45,7 +45,7 @@ class SubHeartflow: self.past_mind = [] self.current_state: CurrentState = CurrentState() self.llm_model = LLM_request( - model=global_config.llm_sub_heartflow, temperature=0.3, max_tokens=600, request_type="sub_heart_flow" + model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow" ) self.main_heartflow_info = "" @@ -185,19 +185,20 @@ class SubHeartflow: # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" prompt += f"{relation_prompt_all}\n" prompt += f"{prompt_personality}\n" - prompt += f"你刚刚在做的事情是:{schedule_info}\n" - if related_memory_info: - prompt += f"你想起来你之前见过的回忆:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n" - if related_info: - prompt += f"你想起你知道:{related_info}\n" - prompt += f"刚刚你的想法是{current_thinking_info}。\n" + # prompt += f"你刚刚在做的事情是:{schedule_info}\n" + # if related_memory_info: + # prompt += f"你想起来你之前见过的回忆:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n" + # if related_info: + # prompt += f"你想起你知道:{related_info}\n" + prompt += f"刚刚你的想法是{current_thinking_info}。如果有新的内容,记得转换话题\n" prompt += "-----------------------------------\n" prompt += f"现在你正在上网,和qq群里的网友们聊天,群里正在聊的话题是:{chat_observe_info}\n" prompt += f"你现在{mood_info}\n" prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" - prompt += "现在你接下去继续浅浅思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," - prompt += "思考时可以想想如何对群聊内容进行回复。请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)," - prompt += f"记得结合上述的消息,要记得维持住你的人设,注意你就是{self.bot_name},{self.bot_name}指的就是你。" + prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," + prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。" + prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)," + prompt += f"记得结合上述的消息,生成符合内心想法的内心独白,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。" try: response, reasoning_content = await self.llm_model.generate_response_async(prompt) @@ -208,7 +209,7 @@ class SubHeartflow: self.current_mind = response - logger.debug(f"prompt:\n{prompt}\n") + logger.info(f"prompt:\n{prompt}\n") logger.info(f"麦麦的思考前脑内状态:{self.current_mind}") return self.current_mind ,self.past_mind diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py index 2df0eb138..f422b8c99 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py @@ -1,5 +1,6 @@ import time from typing import List, Optional +import random from ...models.utils_model import LLM_request @@ -25,7 +26,7 @@ logger = get_module_logger("llm_generator", config=llm_config) class ResponseGenerator: def __init__(self): self.model_normal = LLM_request( - model=global_config.llm_normal, temperature=0.6, max_tokens=256, request_type="response_heartflow" + model=global_config.llm_normal, temperature=0.3, max_tokens=256, request_type="response_heartflow" ) self.model_sum = LLM_request( @@ -44,23 +45,42 @@ class ResponseGenerator: arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier() + time1 = time.time() - current_model = self.model_normal - current_model.temperature = 0.7 * arousal_multiplier #激活度越高,温度越高 - model_response = await self._generate_response_with_model(message, current_model,thinking_id) + checked = False + if random.random() > 0: + checked = False + current_model = self.model_normal + current_model.temperature = 0.3 * arousal_multiplier #激活度越高,温度越高 + model_response = await self._generate_response_with_model(message, current_model,thinking_id,mode="normal") + + model_checked_response = model_response + else: + checked = True + current_model = self.model_normal + current_model.temperature = 0.3 * arousal_multiplier #激活度越高,温度越高 + print(f"生成{message.processed_plain_text}回复温度是:{current_model.temperature}") + model_response = await self._generate_response_with_model(message, current_model,thinking_id,mode="simple") + + current_model.temperature = 0.3 + model_checked_response = await self._check_response_with_model(message, model_response, current_model,thinking_id) - # print(f"raw_content: {model_response}") + time2 = time.time() if model_response: - logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}") - model_response = await self._process_response(model_response) + if checked: + logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},思忖后,回复是:{model_checked_response},生成回复时间: {time2 - time1}秒") + else: + logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},生成回复时间: {time2 - time1}秒") + + model_processed_response = await self._process_response(model_checked_response) - return model_response + return model_processed_response else: logger.info(f"{self.current_model_type}思考,失败") return None - async def _generate_response_with_model(self, message: MessageRecv, model: LLM_request,thinking_id:str): + async def _generate_response_with_model(self, message: MessageRecv, model: LLM_request,thinking_id:str,mode:str = "normal") -> str: sender_name = "" info_catcher = info_catcher_manager.get_info_catcher(thinking_id) @@ -75,20 +95,28 @@ class ResponseGenerator: else: sender_name = f"用户({message.chat_stream.user_info.user_id})" - logger.debug("开始使用生成回复-2") # 构建prompt timer1 = time.time() - prompt = await prompt_builder._build_prompt( - message.chat_stream, - message_txt=message.processed_plain_text, - sender_name=sender_name, - stream_id=message.chat_stream.stream_id, - ) + if mode == "normal": + prompt = await prompt_builder._build_prompt( + message.chat_stream, + message_txt=message.processed_plain_text, + sender_name=sender_name, + stream_id=message.chat_stream.stream_id, + ) + elif mode == "simple": + prompt = await prompt_builder._build_prompt_simple( + message.chat_stream, + message_txt=message.processed_plain_text, + sender_name=sender_name, + stream_id=message.chat_stream.stream_id, + ) timer2 = time.time() - logger.info(f"构建prompt时间: {timer2 - timer1}秒") + logger.info(f"构建{mode}prompt时间: {timer2 - timer1}秒") try: content, reasoning_content, self.current_model_name = await model.generate_response(prompt) + info_catcher.catch_after_llm_generated( prompt=prompt, @@ -100,40 +128,54 @@ class ResponseGenerator: logger.exception("生成回复时出错") return None - # 保存到数据库 - # self._save_to_db( - # message=message, - # sender_name=sender_name, - # prompt=prompt, - # content=content, - # reasoning_content=reasoning_content, - # # reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else "" - # ) return content + + async def _check_response_with_model(self, message: MessageRecv, content:str, model: LLM_request,thinking_id:str) -> str: + + _info_catcher = info_catcher_manager.get_info_catcher(thinking_id) + + sender_name = "" + if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: + sender_name = ( + f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]" + f"{message.chat_stream.user_info.user_cardname}" + ) + elif message.chat_stream.user_info.user_nickname: + sender_name = f"({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}" + else: + sender_name = f"用户({message.chat_stream.user_info.user_id})" + + + # 构建prompt + timer1 = time.time() + prompt = await prompt_builder._build_prompt_check_response( + message.chat_stream, + message_txt=message.processed_plain_text, + sender_name=sender_name, + stream_id=message.chat_stream.stream_id, + content=content + ) + timer2 = time.time() + logger.info(f"构建check_prompt: {prompt}") + logger.info(f"构建check_prompt时间: {timer2 - timer1}秒") + + try: + checked_content, reasoning_content, self.current_model_name = await model.generate_response(prompt) + + + # info_catcher.catch_after_llm_generated( + # prompt=prompt, + # response=content, + # reasoning_content=reasoning_content, + # model_name=self.current_model_name) + + except Exception: + logger.exception("检查回复时出错") + return None - # def _save_to_db( - # self, - # message: MessageRecv, - # sender_name: str, - # prompt: str, - # content: str, - # reasoning_content: str, - # ): - # """保存对话记录到数据库""" - # db.reasoning_logs.insert_one( - # { - # "time": time.time(), - # "chat_id": message.chat_stream.stream_id, - # "user": sender_name, - # "message": message.processed_plain_text, - # "model": self.current_model_name, - # "reasoning": reasoning_content, - # "response": content, - # "prompt": prompt, - # } - # ) + return checked_content async def _get_emotion_tags(self, content: str, processed_plain_text: str): """提取情感标签,结合立场和情绪""" diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py index 5d701c6a2..8d57567c4 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py @@ -79,12 +79,105 @@ class PromptBuilder: {chat_target} {chat_talking_prompt} 现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。\n -你的网名叫{global_config.BOT_NICKNAME},有人也叫你{"/".join(global_config.BOT_ALIAS_NAMES)},{prompt_personality} {prompt_identity}。 +你的网名叫{global_config.BOT_NICKNAME},{prompt_personality} {prompt_identity}。 你正在{chat_target_2},现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些, 你刚刚脑子里在想: {current_mind_info} 回复尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger} 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。 +{moderation_prompt}。注意:不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""" + + return prompt + + async def _build_prompt_simple( + self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None + ) -> tuple[str, str]: + current_mind_info = heartflow.get_subheartflow(stream_id).current_mind + + individuality = Individuality.get_instance() + prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1) + prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1) + + + # 日程构建 + # schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}''' + + # 获取聊天上下文 + chat_in_group = True + chat_talking_prompt = "" + if stream_id: + chat_talking_prompt = get_recent_group_detailed_plain_text( + stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True + ) + chat_stream = chat_manager.get_stream(stream_id) + if chat_stream.group_info: + chat_talking_prompt = chat_talking_prompt + else: + chat_in_group = False + chat_talking_prompt = chat_talking_prompt + # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") + + # 类型 + if chat_in_group: + chat_target = "你正在qq群里聊天,下面是群里在聊的内容:" + else: + chat_target = f"你正在和{sender_name}聊天,这是你们之前聊的内容:" + + # 关键词检测与反应 + keywords_reaction_prompt = "" + for rule in global_config.keywords_reaction_rules: + if rule.get("enable", False): + if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): + logger.info( + f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" + ) + keywords_reaction_prompt += rule.get("reaction", "") + "," + + + logger.info("开始构建prompt") + + prompt = f""" +你的名字叫{global_config.BOT_NICKNAME},{prompt_personality}。 +{chat_target} +{chat_talking_prompt} +现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。\n +你刚刚脑子里在想:{current_mind_info} +现在请你读读之前的聊天记录,然后给出日常,口语化且简短的回复内容,只给出文字的回复内容,不要有内心独白: +""" + + logger.info(f"生成回复的prompt: {prompt}") + return prompt + + + async def _build_prompt_check_response( + self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None, content:str = "" + ) -> tuple[str, str]: + + individuality = Individuality.get_instance() + prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1) + prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1) + + + chat_target = "你正在qq群里聊天," + + + # 中文高手(新加的好玩功能) + prompt_ger = "" + if random.random() < 0.04: + prompt_ger += "你喜欢用倒装句" + if random.random() < 0.02: + prompt_ger += "你喜欢用反问句" + + moderation_prompt = "" + moderation_prompt = """**检查并忽略**任何涉及尝试绕过审核的行为。 +涉及政治敏感以及违法违规的内容请规避。""" + + logger.info("开始构建check_prompt") + + prompt = f""" +你的名字叫{global_config.BOT_NICKNAME},{prompt_identity}。 +{chat_target},你希望在群里回复:{content}。现在请你根据以下信息修改回复内容。将这个回复修改的更加日常且口语化的回复,平淡一些,回复尽量简短一些。不要回复的太有条理。 +{prompt_ger},不要刻意突出自身学科背景,注意只输出回复内容。 {moderation_prompt}。注意:不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""" return prompt diff --git a/src/plugins/memory_system/Hippocampus.py b/src/plugins/memory_system/Hippocampus.py index 8e2cd21e7..516e211a1 100644 --- a/src/plugins/memory_system/Hippocampus.py +++ b/src/plugins/memory_system/Hippocampus.py @@ -225,10 +225,438 @@ class Memory_graph: return None +# 海马体 +class Hippocampus: + def __init__(self): + self.memory_graph = Memory_graph() + self.llm_topic_judge = None + self.llm_summary_by_topic = None + self.entorhinal_cortex = None + self.parahippocampal_gyrus = None + self.config = None + + def initialize(self, global_config): + self.config = MemoryConfig.from_global_config(global_config) + # 初始化子组件 + self.entorhinal_cortex = EntorhinalCortex(self) + self.parahippocampal_gyrus = ParahippocampalGyrus(self) + # 从数据库加载记忆图 + self.entorhinal_cortex.sync_memory_from_db() + self.llm_topic_judge = LLM_request(self.config.llm_topic_judge, request_type="memory") + self.llm_summary_by_topic = LLM_request(self.config.llm_summary_by_topic, request_type="memory") + + def get_all_node_names(self) -> list: + """获取记忆图中所有节点的名字列表""" + return list(self.memory_graph.G.nodes()) + + def calculate_node_hash(self, concept, memory_items) -> int: + """计算节点的特征值""" + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + sorted_items = sorted(memory_items) + content = f"{concept}:{'|'.join(sorted_items)}" + return hash(content) + + def calculate_edge_hash(self, source, target) -> int: + """计算边的特征值""" + nodes = sorted([source, target]) + return hash(f"{nodes[0]}:{nodes[1]}") + + def find_topic_llm(self, text, topic_num): + prompt = ( + f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," + f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" + f"如果确定找不出主题或者没有明显主题,返回。" + ) + return prompt + + def topic_what(self, text, topic, time_info): + prompt = ( + f'这是一段文字,{time_info}:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,' + f"可以包含时间和人物,以及具体的观点。只输出这句话就好" + ) + return prompt + + def calculate_topic_num(self, text, compress_rate): + """计算文本的话题数量""" + information_content = calculate_information_content(text) + topic_by_length = text.count("\n") * compress_rate + topic_by_information_content = max(1, min(5, int((information_content - 3) * 2))) + topic_num = int((topic_by_length + topic_by_information_content) / 2) + logger.debug( + f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, " + f"topic_num: {topic_num}" + ) + return topic_num + + def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: + """从关键词获取相关记忆。 + + Args: + keyword (str): 关键词 + max_depth (int, optional): 记忆检索深度,默认为2。1表示只获取直接相关的记忆,2表示获取间接相关的记忆。 + + Returns: + list: 记忆列表,每个元素是一个元组 (topic, memory_items, similarity) + - topic: str, 记忆主题 + - memory_items: list, 该主题下的记忆项列表 + - similarity: float, 与关键词的相似度 + """ + if not keyword: + return [] + + # 获取所有节点 + all_nodes = list(self.memory_graph.G.nodes()) + memories = [] + + # 计算关键词的词集合 + keyword_words = set(jieba.cut(keyword)) + + # 遍历所有节点,计算相似度 + for node in all_nodes: + node_words = set(jieba.cut(node)) + all_words = keyword_words | node_words + v1 = [1 if word in keyword_words else 0 for word in all_words] + v2 = [1 if word in node_words else 0 for word in all_words] + similarity = cosine_similarity(v1, v2) + + # 如果相似度超过阈值,获取该节点的记忆 + if similarity >= 0.3: # 可以调整这个阈值 + node_data = self.memory_graph.G.nodes[node] + memory_items = node_data.get("memory_items", []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + memories.append((node, memory_items, similarity)) + + # 按相似度降序排序 + memories.sort(key=lambda x: x[2], reverse=True) + return memories + + async def get_memory_from_text( + self, + text: str, + max_memory_num: int = 3, + max_memory_length: int = 2, + max_depth: int = 3, + fast_retrieval: bool = False, + ) -> list: + """从文本中提取关键词并获取相关记忆。 + + Args: + text (str): 输入文本 + num (int, optional): 需要返回的记忆数量。默认为5。 + max_depth (int, optional): 记忆检索深度。默认为2。 + fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 + 如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。 + 如果为False,使用LLM提取关键词,速度较慢但更准确。 + + Returns: + list: 记忆列表,每个元素是一个元组 (topic, memory_items, similarity) + - topic: str, 记忆主题 + - memory_items: list, 该主题下的记忆项列表 + - similarity: float, 与文本的相似度 + """ + if not text: + return [] + + if fast_retrieval: + # 使用jieba分词提取关键词 + words = jieba.cut(text) + # 过滤掉停用词和单字词 + keywords = [word for word in words if len(word) > 1] + # 去重 + keywords = list(set(keywords)) + # 限制关键词数量 + keywords = keywords[:5] + else: + # 使用LLM提取关键词 + topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 + # logger.info(f"提取关键词数量: {topic_num}") + topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, topic_num)) + + # 提取关键词 + keywords = re.findall(r"<([^>]+)>", topics_response[0]) + if not keywords: + keywords = [] + else: + keywords = [ + keyword.strip() + for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + if keyword.strip() + ] + + # logger.info(f"提取的关键词: {', '.join(keywords)}") + + # 过滤掉不存在于记忆图中的关键词 + valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] + if not valid_keywords: + logger.info("没有找到有效的关键词节点") + return [] + + logger.info(f"有效的关键词: {', '.join(valid_keywords)}") + + # 从每个关键词获取记忆 + all_memories = [] + activate_map = {} # 存储每个词的累计激活值 + + # 对每个关键词进行扩散式检索 + for keyword in valid_keywords: + logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):") + # 初始化激活值 + activation_values = {keyword: 1.0} + # 记录已访问的节点 + visited_nodes = {keyword} + # 待处理的节点队列,每个元素是(节点, 激活值, 当前深度) + nodes_to_process = [(keyword, 1.0, 0)] + + while nodes_to_process: + current_node, current_activation, current_depth = nodes_to_process.pop(0) + + # 如果激活值小于0或超过最大深度,停止扩散 + if current_activation <= 0 or current_depth >= max_depth: + continue + + # 获取当前节点的所有邻居 + neighbors = list(self.memory_graph.G.neighbors(current_node)) + + for neighbor in neighbors: + if neighbor in visited_nodes: + continue + + # 获取连接强度 + edge_data = self.memory_graph.G[current_node][neighbor] + strength = edge_data.get("strength", 1) + + # 计算新的激活值 + new_activation = current_activation - (1 / strength) + + if new_activation > 0: + activation_values[neighbor] = new_activation + visited_nodes.add(neighbor) + nodes_to_process.append((neighbor, new_activation, current_depth + 1)) + logger.debug( + f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})" + ) # noqa: E501 + + # 更新激活映射 + for node, activation_value in activation_values.items(): + if activation_value > 0: + if node in activate_map: + activate_map[node] += activation_value + else: + activate_map[node] = activation_value + + # 输出激活映射 + # logger.info("激活映射统计:") + # for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True): + # logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}") + + # 基于激活值平方的独立概率选择 + remember_map = {} + # logger.info("基于激活值平方的归一化选择:") + + # 计算所有激活值的平方和 + total_squared_activation = sum(activation**2 for activation in activate_map.values()) + if total_squared_activation > 0: + # 计算归一化的激活值 + normalized_activations = { + node: (activation**2) / total_squared_activation for node, activation in activate_map.items() + } + + # 按归一化激活值排序并选择前max_memory_num个 + sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num] + + # 将选中的节点添加到remember_map + for node, normalized_activation in sorted_nodes: + remember_map[node] = activate_map[node] # 使用原始激活值 + logger.debug( + f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})" + ) + else: + logger.info("没有有效的激活值") + + # 从选中的节点中提取记忆 + all_memories = [] + # logger.info("开始从选中的节点中提取记忆:") + for node, activation in remember_map.items(): + logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):") + node_data = self.memory_graph.G.nodes[node] + memory_items = node_data.get("memory_items", []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + if memory_items: + logger.debug(f"节点包含 {len(memory_items)} 条记忆") + # 计算每条记忆与输入文本的相似度 + memory_similarities = [] + for memory in memory_items: + # 计算与输入文本的相似度 + memory_words = set(jieba.cut(memory)) + text_words = set(jieba.cut(text)) + all_words = memory_words | text_words + v1 = [1 if word in memory_words else 0 for word in all_words] + v2 = [1 if word in text_words else 0 for word in all_words] + similarity = cosine_similarity(v1, v2) + memory_similarities.append((memory, similarity)) + + # 按相似度排序 + memory_similarities.sort(key=lambda x: x[1], reverse=True) + # 获取最匹配的记忆 + top_memories = memory_similarities[:max_memory_length] + + # 添加到结果中 + for memory, similarity in top_memories: + all_memories.append((node, [memory], similarity)) + # logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})") + else: + logger.info("节点没有记忆") + + # 去重(基于记忆内容) + logger.debug("开始记忆去重:") + seen_memories = set() + unique_memories = [] + for topic, memory_items, activation_value in all_memories: + memory = memory_items[0] # 因为每个topic只有一条记忆 + if memory not in seen_memories: + seen_memories.add(memory) + unique_memories.append((topic, memory_items, activation_value)) + logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})") + else: + logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})") + + # 转换为(关键词, 记忆)格式 + result = [] + for topic, memory_items, _ in unique_memories: + memory = memory_items[0] # 因为每个topic只有一条记忆 + result.append((topic, memory)) + logger.info(f"选中记忆: {memory} (来自节点: {topic})") + + return result + + async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float: + """从文本中提取关键词并获取相关记忆。 + + Args: + text (str): 输入文本 + num (int, optional): 需要返回的记忆数量。默认为5。 + max_depth (int, optional): 记忆检索深度。默认为2。 + fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 + 如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。 + 如果为False,使用LLM提取关键词,速度较慢但更准确。 + + Returns: + float: 激活节点数与总节点数的比值 + """ + if not text: + return 0 + + if fast_retrieval: + # 使用jieba分词提取关键词 + words = jieba.cut(text) + # 过滤掉停用词和单字词 + keywords = [word for word in words if len(word) > 1] + # 去重 + keywords = list(set(keywords)) + # 限制关键词数量 + keywords = keywords[:5] + else: + # 使用LLM提取关键词 + topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 + # logger.info(f"提取关键词数量: {topic_num}") + topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, topic_num)) + + # 提取关键词 + keywords = re.findall(r"<([^>]+)>", topics_response[0]) + if not keywords: + keywords = [] + else: + keywords = [ + keyword.strip() + for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + if keyword.strip() + ] + + # logger.info(f"提取的关键词: {', '.join(keywords)}") + + # 过滤掉不存在于记忆图中的关键词 + valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] + if not valid_keywords: + logger.info("没有找到有效的关键词节点") + return 0 + + logger.info(f"有效的关键词: {', '.join(valid_keywords)}") + + # 从每个关键词获取记忆 + activate_map = {} # 存储每个词的累计激活值 + + # 对每个关键词进行扩散式检索 + for keyword in valid_keywords: + logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):") + # 初始化激活值 + activation_values = {keyword: 1.0} + # 记录已访问的节点 + visited_nodes = {keyword} + # 待处理的节点队列,每个元素是(节点, 激活值, 当前深度) + nodes_to_process = [(keyword, 1.0, 0)] + + while nodes_to_process: + current_node, current_activation, current_depth = nodes_to_process.pop(0) + + # 如果激活值小于0或超过最大深度,停止扩散 + if current_activation <= 0 or current_depth >= max_depth: + continue + + # 获取当前节点的所有邻居 + neighbors = list(self.memory_graph.G.neighbors(current_node)) + + for neighbor in neighbors: + if neighbor in visited_nodes: + continue + + # 获取连接强度 + edge_data = self.memory_graph.G[current_node][neighbor] + strength = edge_data.get("strength", 1) + + # 计算新的激活值 + new_activation = current_activation - (1 / strength) + + if new_activation > 0: + activation_values[neighbor] = new_activation + visited_nodes.add(neighbor) + nodes_to_process.append((neighbor, new_activation, current_depth + 1)) + # logger.debug( + # f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})") # noqa: E501 + + # 更新激活映射 + for node, activation_value in activation_values.items(): + if activation_value > 0: + if node in activate_map: + activate_map[node] += activation_value + else: + activate_map[node] = activation_value + + # 输出激活映射 + # logger.info("激活映射统计:") + # for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True): + # logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}") + + # 计算激活节点数与总节点数的比值 + total_activation = sum(activate_map.values()) + logger.info(f"总激活值: {total_activation:.2f}") + total_nodes = len(self.memory_graph.G.nodes()) + # activated_nodes = len(activate_map) + activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0 + activation_ratio = activation_ratio * 60 + logger.info(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}") + + return activation_ratio + + # 负责海马体与其他部分的交互 class EntorhinalCortex: - def __init__(self, hippocampus): + def __init__(self, hippocampus: Hippocampus): self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph self.config = hippocampus.config @@ -819,433 +1247,6 @@ class ParahippocampalGyrus: logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒") -# 海马体 -class Hippocampus: - def __init__(self): - self.memory_graph = Memory_graph() - self.llm_topic_judge = None - self.llm_summary_by_topic = None - self.entorhinal_cortex = None - self.parahippocampal_gyrus = None - self.config = None - - def initialize(self, global_config): - self.config = MemoryConfig.from_global_config(global_config) - # 初始化子组件 - self.entorhinal_cortex = EntorhinalCortex(self) - self.parahippocampal_gyrus = ParahippocampalGyrus(self) - # 从数据库加载记忆图 - self.entorhinal_cortex.sync_memory_from_db() - self.llm_topic_judge = LLM_request(self.config.llm_topic_judge, request_type="memory") - self.llm_summary_by_topic = LLM_request(self.config.llm_summary_by_topic, request_type="memory") - - def get_all_node_names(self) -> list: - """获取记忆图中所有节点的名字列表""" - return list(self.memory_graph.G.nodes()) - - def calculate_node_hash(self, concept, memory_items) -> int: - """计算节点的特征值""" - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - sorted_items = sorted(memory_items) - content = f"{concept}:{'|'.join(sorted_items)}" - return hash(content) - - def calculate_edge_hash(self, source, target) -> int: - """计算边的特征值""" - nodes = sorted([source, target]) - return hash(f"{nodes[0]}:{nodes[1]}") - - def find_topic_llm(self, text, topic_num): - prompt = ( - f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," - f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" - f"如果确定找不出主题或者没有明显主题,返回。" - ) - return prompt - - def topic_what(self, text, topic, time_info): - prompt = ( - f'这是一段文字,{time_info}:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,' - f"可以包含时间和人物,以及具体的观点。只输出这句话就好" - ) - return prompt - - def calculate_topic_num(self, text, compress_rate): - """计算文本的话题数量""" - information_content = calculate_information_content(text) - topic_by_length = text.count("\n") * compress_rate - topic_by_information_content = max(1, min(5, int((information_content - 3) * 2))) - topic_num = int((topic_by_length + topic_by_information_content) / 2) - logger.debug( - f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, " - f"topic_num: {topic_num}" - ) - return topic_num - - def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: - """从关键词获取相关记忆。 - - Args: - keyword (str): 关键词 - max_depth (int, optional): 记忆检索深度,默认为2。1表示只获取直接相关的记忆,2表示获取间接相关的记忆。 - - Returns: - list: 记忆列表,每个元素是一个元组 (topic, memory_items, similarity) - - topic: str, 记忆主题 - - memory_items: list, 该主题下的记忆项列表 - - similarity: float, 与关键词的相似度 - """ - if not keyword: - return [] - - # 获取所有节点 - all_nodes = list(self.memory_graph.G.nodes()) - memories = [] - - # 计算关键词的词集合 - keyword_words = set(jieba.cut(keyword)) - - # 遍历所有节点,计算相似度 - for node in all_nodes: - node_words = set(jieba.cut(node)) - all_words = keyword_words | node_words - v1 = [1 if word in keyword_words else 0 for word in all_words] - v2 = [1 if word in node_words else 0 for word in all_words] - similarity = cosine_similarity(v1, v2) - - # 如果相似度超过阈值,获取该节点的记忆 - if similarity >= 0.3: # 可以调整这个阈值 - node_data = self.memory_graph.G.nodes[node] - memory_items = node_data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - memories.append((node, memory_items, similarity)) - - # 按相似度降序排序 - memories.sort(key=lambda x: x[2], reverse=True) - return memories - - async def get_memory_from_text( - self, - text: str, - max_memory_num: int = 3, - max_memory_length: int = 2, - max_depth: int = 3, - fast_retrieval: bool = False, - ) -> list: - """从文本中提取关键词并获取相关记忆。 - - Args: - text (str): 输入文本 - num (int, optional): 需要返回的记忆数量。默认为5。 - max_depth (int, optional): 记忆检索深度。默认为2。 - fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 - 如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。 - 如果为False,使用LLM提取关键词,速度较慢但更准确。 - - Returns: - list: 记忆列表,每个元素是一个元组 (topic, memory_items, similarity) - - topic: str, 记忆主题 - - memory_items: list, 该主题下的记忆项列表 - - similarity: float, 与文本的相似度 - """ - if not text: - return [] - - if fast_retrieval: - # 使用jieba分词提取关键词 - words = jieba.cut(text) - # 过滤掉停用词和单字词 - keywords = [word for word in words if len(word) > 1] - # 去重 - keywords = list(set(keywords)) - # 限制关键词数量 - keywords = keywords[:5] - else: - # 使用LLM提取关键词 - topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 - # logger.info(f"提取关键词数量: {topic_num}") - topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, topic_num)) - - # 提取关键词 - keywords = re.findall(r"<([^>]+)>", topics_response[0]) - if not keywords: - keywords = [] - else: - keywords = [ - keyword.strip() - for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - if keyword.strip() - ] - - # logger.info(f"提取的关键词: {', '.join(keywords)}") - - # 过滤掉不存在于记忆图中的关键词 - valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] - if not valid_keywords: - logger.info("没有找到有效的关键词节点") - return [] - - logger.info(f"有效的关键词: {', '.join(valid_keywords)}") - - # 从每个关键词获取记忆 - all_memories = [] - activate_map = {} # 存储每个词的累计激活值 - - # 对每个关键词进行扩散式检索 - for keyword in valid_keywords: - logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):") - # 初始化激活值 - activation_values = {keyword: 1.0} - # 记录已访问的节点 - visited_nodes = {keyword} - # 待处理的节点队列,每个元素是(节点, 激活值, 当前深度) - nodes_to_process = [(keyword, 1.0, 0)] - - while nodes_to_process: - current_node, current_activation, current_depth = nodes_to_process.pop(0) - - # 如果激活值小于0或超过最大深度,停止扩散 - if current_activation <= 0 or current_depth >= max_depth: - continue - - # 获取当前节点的所有邻居 - neighbors = list(self.memory_graph.G.neighbors(current_node)) - - for neighbor in neighbors: - if neighbor in visited_nodes: - continue - - # 获取连接强度 - edge_data = self.memory_graph.G[current_node][neighbor] - strength = edge_data.get("strength", 1) - - # 计算新的激活值 - new_activation = current_activation - (1 / strength) - - if new_activation > 0: - activation_values[neighbor] = new_activation - visited_nodes.add(neighbor) - nodes_to_process.append((neighbor, new_activation, current_depth + 1)) - logger.debug( - f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})" - ) # noqa: E501 - - # 更新激活映射 - for node, activation_value in activation_values.items(): - if activation_value > 0: - if node in activate_map: - activate_map[node] += activation_value - else: - activate_map[node] = activation_value - - # 输出激活映射 - # logger.info("激活映射统计:") - # for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True): - # logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}") - - # 基于激活值平方的独立概率选择 - remember_map = {} - # logger.info("基于激活值平方的归一化选择:") - - # 计算所有激活值的平方和 - total_squared_activation = sum(activation**2 for activation in activate_map.values()) - if total_squared_activation > 0: - # 计算归一化的激活值 - normalized_activations = { - node: (activation**2) / total_squared_activation for node, activation in activate_map.items() - } - - # 按归一化激活值排序并选择前max_memory_num个 - sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num] - - # 将选中的节点添加到remember_map - for node, normalized_activation in sorted_nodes: - remember_map[node] = activate_map[node] # 使用原始激活值 - logger.debug( - f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})" - ) - else: - logger.info("没有有效的激活值") - - # 从选中的节点中提取记忆 - all_memories = [] - # logger.info("开始从选中的节点中提取记忆:") - for node, activation in remember_map.items(): - logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):") - node_data = self.memory_graph.G.nodes[node] - memory_items = node_data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - if memory_items: - logger.debug(f"节点包含 {len(memory_items)} 条记忆") - # 计算每条记忆与输入文本的相似度 - memory_similarities = [] - for memory in memory_items: - # 计算与输入文本的相似度 - memory_words = set(jieba.cut(memory)) - text_words = set(jieba.cut(text)) - all_words = memory_words | text_words - v1 = [1 if word in memory_words else 0 for word in all_words] - v2 = [1 if word in text_words else 0 for word in all_words] - similarity = cosine_similarity(v1, v2) - memory_similarities.append((memory, similarity)) - - # 按相似度排序 - memory_similarities.sort(key=lambda x: x[1], reverse=True) - # 获取最匹配的记忆 - top_memories = memory_similarities[:max_memory_length] - - # 添加到结果中 - for memory, similarity in top_memories: - all_memories.append((node, [memory], similarity)) - # logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})") - else: - logger.info("节点没有记忆") - - # 去重(基于记忆内容) - logger.debug("开始记忆去重:") - seen_memories = set() - unique_memories = [] - for topic, memory_items, activation_value in all_memories: - memory = memory_items[0] # 因为每个topic只有一条记忆 - if memory not in seen_memories: - seen_memories.add(memory) - unique_memories.append((topic, memory_items, activation_value)) - logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})") - else: - logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})") - - # 转换为(关键词, 记忆)格式 - result = [] - for topic, memory_items, _ in unique_memories: - memory = memory_items[0] # 因为每个topic只有一条记忆 - result.append((topic, memory)) - logger.info(f"选中记忆: {memory} (来自节点: {topic})") - - return result - - async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float: - """从文本中提取关键词并获取相关记忆。 - - Args: - text (str): 输入文本 - num (int, optional): 需要返回的记忆数量。默认为5。 - max_depth (int, optional): 记忆检索深度。默认为2。 - fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 - 如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。 - 如果为False,使用LLM提取关键词,速度较慢但更准确。 - - Returns: - float: 激活节点数与总节点数的比值 - """ - if not text: - return 0 - - if fast_retrieval: - # 使用jieba分词提取关键词 - words = jieba.cut(text) - # 过滤掉停用词和单字词 - keywords = [word for word in words if len(word) > 1] - # 去重 - keywords = list(set(keywords)) - # 限制关键词数量 - keywords = keywords[:5] - else: - # 使用LLM提取关键词 - topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 - # logger.info(f"提取关键词数量: {topic_num}") - topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, topic_num)) - - # 提取关键词 - keywords = re.findall(r"<([^>]+)>", topics_response[0]) - if not keywords: - keywords = [] - else: - keywords = [ - keyword.strip() - for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - if keyword.strip() - ] - - # logger.info(f"提取的关键词: {', '.join(keywords)}") - - # 过滤掉不存在于记忆图中的关键词 - valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] - if not valid_keywords: - logger.info("没有找到有效的关键词节点") - return 0 - - logger.info(f"有效的关键词: {', '.join(valid_keywords)}") - - # 从每个关键词获取记忆 - activate_map = {} # 存储每个词的累计激活值 - - # 对每个关键词进行扩散式检索 - for keyword in valid_keywords: - logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):") - # 初始化激活值 - activation_values = {keyword: 1.0} - # 记录已访问的节点 - visited_nodes = {keyword} - # 待处理的节点队列,每个元素是(节点, 激活值, 当前深度) - nodes_to_process = [(keyword, 1.0, 0)] - - while nodes_to_process: - current_node, current_activation, current_depth = nodes_to_process.pop(0) - - # 如果激活值小于0或超过最大深度,停止扩散 - if current_activation <= 0 or current_depth >= max_depth: - continue - - # 获取当前节点的所有邻居 - neighbors = list(self.memory_graph.G.neighbors(current_node)) - - for neighbor in neighbors: - if neighbor in visited_nodes: - continue - - # 获取连接强度 - edge_data = self.memory_graph.G[current_node][neighbor] - strength = edge_data.get("strength", 1) - - # 计算新的激活值 - new_activation = current_activation - (1 / strength) - - if new_activation > 0: - activation_values[neighbor] = new_activation - visited_nodes.add(neighbor) - nodes_to_process.append((neighbor, new_activation, current_depth + 1)) - # logger.debug( - # f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})") # noqa: E501 - - # 更新激活映射 - for node, activation_value in activation_values.items(): - if activation_value > 0: - if node in activate_map: - activate_map[node] += activation_value - else: - activate_map[node] = activation_value - - # 输出激活映射 - # logger.info("激活映射统计:") - # for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True): - # logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}") - - # 计算激活节点数与总节点数的比值 - total_activation = sum(activate_map.values()) - logger.info(f"总激活值: {total_activation:.2f}") - total_nodes = len(self.memory_graph.G.nodes()) - # activated_nodes = len(activate_map) - activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0 - activation_ratio = activation_ratio * 60 - logger.info(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}") - - return activation_ratio - class HippocampusManager: _instance = None diff --git a/src/plugins/moods/moods.py b/src/plugins/moods/moods.py index d564b48b6..9ce0fd93b 100644 --- a/src/plugins/moods/moods.py +++ b/src/plugins/moods/moods.py @@ -137,14 +137,14 @@ class MoodManager: personality = Individuality.get_instance().personality if personality: # 神经质:影响情绪变化速度 - neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.5 - agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.5 + neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.4 + agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.4 # 宜人性:影响情绪基准线 if personality.agreeableness < 0.2: - agreeableness_bias = (personality.agreeableness - 0.2) * 2 + agreeableness_bias = (personality.agreeableness - 0.2) * 0.5 elif personality.agreeableness > 0.8: - agreeableness_bias = (personality.agreeableness - 0.8) * 2 + agreeableness_bias = (personality.agreeableness - 0.8) * 0.5 else: agreeableness_bias = 0 diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 84a70cd98..e52b5e824 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -75,8 +75,8 @@ model_v3_probability = 0.3 # 麦麦回答时选择次要回复模型2 模型的 [heartflow] # 注意:可能会消耗大量token,请谨慎开启,仅会使用v3模型 sub_heart_flow_update_interval = 60 # 子心流更新频率,间隔 单位秒 -sub_heart_flow_freeze_time = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒 -sub_heart_flow_stop_time = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒 +sub_heart_flow_freeze_time = 100 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒 +sub_heart_flow_stop_time = 500 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒 heart_flow_update_interval = 300 # 心流更新频率,间隔 单位秒 From d2ec1701977875673a6c47a1db5e83c021c5bdef Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 10 Apr 2025 17:01:25 +0800 Subject: [PATCH 13/24] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E5=BD=BB=E5=BA=95?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=5Fexecute=5Frequest=E7=82=B8=E9=A3=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/heart_flow/heartflow.py | 3 +- src/plugins/chat/auto_speak.py | 6 +- .../reasoning_chat/reasoning_chat.py | 31 +- src/plugins/memory_system/Hippocampus.py | 639 +++++++++--------- 4 files changed, 346 insertions(+), 333 deletions(-) diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index f5b394f2e..3ea51917c 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -9,6 +9,7 @@ from src.common.logger import get_module_logger, LogConfig, HEARTFLOW_STYLE_CONF from src.individuality.individuality import Individuality import time import random +from typing import Dict, Any heartflow_config = LogConfig( # 使用海马体专用样式 @@ -39,7 +40,7 @@ class Heartflow: model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" ) - self._subheartflows = {} + self._subheartflows: Dict[Any, SubHeartflow] = {} self.active_subheartflows_nums = 0 async def _cleanup_inactive_subheartflows(self): diff --git a/src/plugins/chat/auto_speak.py b/src/plugins/chat/auto_speak.py index 62a5a20a5..ac76a2714 100644 --- a/src/plugins/chat/auto_speak.py +++ b/src/plugins/chat/auto_speak.py @@ -142,7 +142,11 @@ class AutoSpeakManager: message_manager.add_message(thinking_message) # 生成自主发言内容 - response, raw_content = await self.gpt.generate_response(message) + try: + response, raw_content = await self.gpt.generate_response(message) + except Exception as e: + logger.error(f"生成自主发言内容时发生错误: {e}") + return False if response: message_set = MessageSet(None, think_id) # 不需要chat_stream diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py index b9e94e4fe..c097427de 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py @@ -59,11 +59,7 @@ class ReasoningChat: return thinking_id - async def _send_response_messages(self, - message, - chat, - response_set:List[str], - thinking_id) -> MessageSending: + async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending: """发送回复消息""" container = message_manager.get_container(chat.stream_id) thinking_message = None @@ -240,19 +236,23 @@ class ReasoningChat: thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo) timer2 = time.time() timing_results["创建思考消息"] = timer2 - timer1 - + logger.debug(f"创建捕捉器,thinking_id:{thinking_id}") - + info_catcher = info_catcher_manager.get_info_catcher(thinking_id) info_catcher.catch_decide_to_response(message) # 生成回复 timer1 = time.time() - response_set = await self.gpt.generate_response(message,thinking_id) - timer2 = time.time() - timing_results["生成回复"] = timer2 - timer1 - - info_catcher.catch_after_generate_response(timing_results["生成回复"]) + try: + response_set = await self.gpt.generate_response(message, thinking_id) + timer2 = time.time() + timing_results["生成回复"] = timer2 - timer1 + + info_catcher.catch_after_generate_response(timing_results["生成回复"]) + except Exception as e: + logger.error(f"回复生成出现错误:str{e}") + response_set = None if not response_set: logger.info("为什么生成回复失败?") @@ -263,10 +263,9 @@ class ReasoningChat: first_bot_msg = await self._send_response_messages(message, chat, response_set, thinking_id) timer2 = time.time() timing_results["发送消息"] = timer2 - timer1 - - info_catcher.catch_after_response(timing_results["发送消息"],response_set,first_bot_msg) - - + + info_catcher.catch_after_response(timing_results["发送消息"], response_set, first_bot_msg) + info_catcher.done_catch() # 处理表情包 diff --git a/src/plugins/memory_system/Hippocampus.py b/src/plugins/memory_system/Hippocampus.py index 8e2cd21e7..407d20540 100644 --- a/src/plugins/memory_system/Hippocampus.py +++ b/src/plugins/memory_system/Hippocampus.py @@ -506,319 +506,6 @@ class EntorhinalCortex: logger.success(f"[数据库] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") -# 负责整合,遗忘,合并记忆 -class ParahippocampalGyrus: - def __init__(self, hippocampus): - self.hippocampus = hippocampus - self.memory_graph = hippocampus.memory_graph - self.config = hippocampus.config - - async def memory_compress(self, messages: list, compress_rate=0.1): - """压缩和总结消息内容,生成记忆主题和摘要。 - - Args: - messages (list): 消息列表,每个消息是一个字典,包含以下字段: - - time: float, 消息的时间戳 - - detailed_plain_text: str, 消息的详细文本内容 - compress_rate (float, optional): 压缩率,用于控制生成的主题数量。默认为0.1。 - - Returns: - tuple: (compressed_memory, similar_topics_dict) - - compressed_memory: set, 压缩后的记忆集合,每个元素是一个元组 (topic, summary) - - topic: str, 记忆主题 - - summary: str, 主题的摘要描述 - - similar_topics_dict: dict, 相似主题字典,key为主题,value为相似主题列表 - 每个相似主题是一个元组 (similar_topic, similarity) - - similar_topic: str, 相似的主题 - - similarity: float, 相似度分数(0-1之间) - - Process: - 1. 合并消息文本并生成时间信息 - 2. 使用LLM提取关键主题 - 3. 过滤掉包含禁用关键词的主题 - 4. 为每个主题生成摘要 - 5. 查找与现有记忆中的相似主题 - """ - if not messages: - return set(), {} - - # 合并消息文本,同时保留时间信息 - input_text = "" - time_info = "" - # 计算最早和最晚时间 - earliest_time = min(msg["time"] for msg in messages) - latest_time = max(msg["time"] for msg in messages) - - earliest_dt = datetime.datetime.fromtimestamp(earliest_time) - latest_dt = datetime.datetime.fromtimestamp(latest_time) - - # 如果是同一年 - if earliest_dt.year == latest_dt.year: - earliest_str = earliest_dt.strftime("%m-%d %H:%M:%S") - latest_str = latest_dt.strftime("%m-%d %H:%M:%S") - time_info += f"是在{earliest_dt.year}年,{earliest_str} 到 {latest_str} 的对话:\n" - else: - earliest_str = earliest_dt.strftime("%Y-%m-%d %H:%M:%S") - latest_str = latest_dt.strftime("%Y-%m-%d %H:%M:%S") - time_info += f"是从 {earliest_str} 到 {latest_str} 的对话:\n" - - for msg in messages: - input_text += f"{msg['detailed_plain_text']}\n" - - logger.debug(input_text) - - topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) - topics_response = await self.hippocampus.llm_topic_judge.generate_response( - self.hippocampus.find_topic_llm(input_text, topic_num) - ) - - # 使用正则表达式提取<>中的内容 - topics = re.findall(r"<([^>]+)>", topics_response[0]) - - # 如果没有找到<>包裹的内容,返回['none'] - if not topics: - topics = ["none"] - else: - # 处理提取出的话题 - topics = [ - topic.strip() - for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - if topic.strip() - ] - - # 过滤掉包含禁用关键词的topic - filtered_topics = [ - topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words) - ] - - logger.debug(f"过滤后话题: {filtered_topics}") - - # 创建所有话题的请求任务 - tasks = [] - for topic in filtered_topics: - topic_what_prompt = self.hippocampus.topic_what(input_text, topic, time_info) - task = self.hippocampus.llm_summary_by_topic.generate_response_async(topic_what_prompt) - tasks.append((topic.strip(), task)) - - # 等待所有任务完成 - compressed_memory = set() - similar_topics_dict = {} - - for topic, task in tasks: - response = await task - if response: - compressed_memory.add((topic, response[0])) - - existing_topics = list(self.memory_graph.G.nodes()) - similar_topics = [] - - for existing_topic in existing_topics: - topic_words = set(jieba.cut(topic)) - existing_words = set(jieba.cut(existing_topic)) - - all_words = topic_words | existing_words - v1 = [1 if word in topic_words else 0 for word in all_words] - v2 = [1 if word in existing_words else 0 for word in all_words] - - similarity = cosine_similarity(v1, v2) - - if similarity >= 0.7: - similar_topics.append((existing_topic, similarity)) - - similar_topics.sort(key=lambda x: x[1], reverse=True) - similar_topics = similar_topics[:3] - similar_topics_dict[topic] = similar_topics - - return compressed_memory, similar_topics_dict - - async def operation_build_memory(self): - logger.debug("------------------------------------开始构建记忆--------------------------------------") - start_time = time.time() - memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample() - all_added_nodes = [] - all_connected_nodes = [] - all_added_edges = [] - for i, messages in enumerate(memory_samples, 1): - all_topics = [] - progress = (i / len(memory_samples)) * 100 - bar_length = 30 - filled_length = int(bar_length * i // len(memory_samples)) - bar = "█" * filled_length + "-" * (bar_length - filled_length) - logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") - - compress_rate = self.config.memory_compress_rate - compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) - logger.debug(f"压缩后记忆数量: {compressed_memory},似曾相识的话题: {similar_topics_dict}") - - current_time = datetime.datetime.now().timestamp() - logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}") - all_added_nodes.extend(topic for topic, _ in compressed_memory) - - for topic, memory in compressed_memory: - self.memory_graph.add_dot(topic, memory) - all_topics.append(topic) - - if topic in similar_topics_dict: - similar_topics = similar_topics_dict[topic] - for similar_topic, similarity in similar_topics: - if topic != similar_topic: - strength = int(similarity * 10) - - logger.debug(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})") - all_added_edges.append(f"{topic}-{similar_topic}") - - all_connected_nodes.append(topic) - all_connected_nodes.append(similar_topic) - - self.memory_graph.G.add_edge( - topic, - similar_topic, - strength=strength, - created_time=current_time, - last_modified=current_time, - ) - - for i in range(len(all_topics)): - for j in range(i + 1, len(all_topics)): - logger.debug(f"连接同批次节点: {all_topics[i]} 和 {all_topics[j]}") - all_added_edges.append(f"{all_topics[i]}-{all_topics[j]}") - self.memory_graph.connect_dot(all_topics[i], all_topics[j]) - - logger.success(f"更新记忆: {', '.join(all_added_nodes)}") - logger.debug(f"强化连接: {', '.join(all_added_edges)}") - logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") - - await self.hippocampus.entorhinal_cortex.sync_memory_to_db() - - end_time = time.time() - logger.success(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------") - - async def operation_forget_topic(self, percentage=0.005): - start_time = time.time() - logger.info("[遗忘] 开始检查数据库...") - - # 验证百分比参数 - if not 0 <= percentage <= 1: - logger.warning(f"[遗忘] 无效的遗忘百分比: {percentage}, 使用默认值 0.005") - percentage = 0.005 - - all_nodes = list(self.memory_graph.G.nodes()) - all_edges = list(self.memory_graph.G.edges()) - - if not all_nodes and not all_edges: - logger.info("[遗忘] 记忆图为空,无需进行遗忘操作") - return - - # 确保至少检查1个节点和边,且不超过总数 - check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage))) - check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage))) - - # 只有在有足够的节点和边时才进行采样 - if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count: - try: - nodes_to_check = random.sample(all_nodes, check_nodes_count) - edges_to_check = random.sample(all_edges, check_edges_count) - except ValueError as e: - logger.error(f"[遗忘] 采样错误: {str(e)}") - return - else: - logger.info("[遗忘] 没有足够的节点或边进行遗忘操作") - return - - # 使用列表存储变化信息 - edge_changes = { - "weakened": [], # 存储减弱的边 - "removed": [], # 存储移除的边 - } - node_changes = { - "reduced": [], # 存储减少记忆的节点 - "removed": [], # 存储移除的节点 - } - - current_time = datetime.datetime.now().timestamp() - - logger.info("[遗忘] 开始检查连接...") - edge_check_start = time.time() - for source, target in edges_to_check: - edge_data = self.memory_graph.G[source][target] - last_modified = edge_data.get("last_modified") - - if current_time - last_modified > 3600 * self.config.memory_forget_time: - current_strength = edge_data.get("strength", 1) - new_strength = current_strength - 1 - - if new_strength <= 0: - self.memory_graph.G.remove_edge(source, target) - edge_changes["removed"].append(f"{source} -> {target}") - else: - edge_data["strength"] = new_strength - edge_data["last_modified"] = current_time - edge_changes["weakened"].append(f"{source}-{target} (强度: {current_strength} -> {new_strength})") - edge_check_end = time.time() - logger.info(f"[遗忘] 连接检查耗时: {edge_check_end - edge_check_start:.2f}秒") - - logger.info("[遗忘] 开始检查节点...") - node_check_start = time.time() - for node in nodes_to_check: - node_data = self.memory_graph.G.nodes[node] - last_modified = node_data.get("last_modified", current_time) - - if current_time - last_modified > 3600 * 24: - memory_items = node_data.get("memory_items", []) - if not isinstance(memory_items, list): - memory_items = [memory_items] if memory_items else [] - - if memory_items: - current_count = len(memory_items) - removed_item = random.choice(memory_items) - memory_items.remove(removed_item) - - if memory_items: - self.memory_graph.G.nodes[node]["memory_items"] = memory_items - self.memory_graph.G.nodes[node]["last_modified"] = current_time - node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})") - else: - self.memory_graph.G.remove_node(node) - node_changes["removed"].append(node) - node_check_end = time.time() - logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒") - - if any(edge_changes.values()) or any(node_changes.values()): - sync_start = time.time() - - await self.hippocampus.entorhinal_cortex.resync_memory_to_db() - - sync_end = time.time() - logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}秒") - - # 汇总输出所有变化 - logger.info("[遗忘] 遗忘操作统计:") - if edge_changes["weakened"]: - logger.info( - f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}" - ) - - if edge_changes["removed"]: - logger.info( - f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}" - ) - - if node_changes["reduced"]: - logger.info( - f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}" - ) - - if node_changes["removed"]: - logger.info( - f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}" - ) - else: - logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件") - - end_time = time.time() - logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒") - - # 海马体 class Hippocampus: def __init__(self): @@ -1247,6 +934,327 @@ class Hippocampus: return activation_ratio +# 负责整合,遗忘,合并记忆 +class ParahippocampalGyrus: + def __init__(self, hippocampus: Hippocampus): + self.hippocampus = hippocampus + self.memory_graph = hippocampus.memory_graph + self.config = hippocampus.config + + async def memory_compress(self, messages: list, compress_rate=0.1): + """压缩和总结消息内容,生成记忆主题和摘要。 + + Args: + messages (list): 消息列表,每个消息是一个字典,包含以下字段: + - time: float, 消息的时间戳 + - detailed_plain_text: str, 消息的详细文本内容 + compress_rate (float, optional): 压缩率,用于控制生成的主题数量。默认为0.1。 + + Returns: + tuple: (compressed_memory, similar_topics_dict) + - compressed_memory: set, 压缩后的记忆集合,每个元素是一个元组 (topic, summary) + - topic: str, 记忆主题 + - summary: str, 主题的摘要描述 + - similar_topics_dict: dict, 相似主题字典,key为主题,value为相似主题列表 + 每个相似主题是一个元组 (similar_topic, similarity) + - similar_topic: str, 相似的主题 + - similarity: float, 相似度分数(0-1之间) + + Process: + 1. 合并消息文本并生成时间信息 + 2. 使用LLM提取关键主题 + 3. 过滤掉包含禁用关键词的主题 + 4. 为每个主题生成摘要 + 5. 查找与现有记忆中的相似主题 + """ + if not messages: + return set(), {} + + # 合并消息文本,同时保留时间信息 + input_text = "" + time_info = "" + # 计算最早和最晚时间 + earliest_time = min(msg["time"] for msg in messages) + latest_time = max(msg["time"] for msg in messages) + + earliest_dt = datetime.datetime.fromtimestamp(earliest_time) + latest_dt = datetime.datetime.fromtimestamp(latest_time) + + # 如果是同一年 + if earliest_dt.year == latest_dt.year: + earliest_str = earliest_dt.strftime("%m-%d %H:%M:%S") + latest_str = latest_dt.strftime("%m-%d %H:%M:%S") + time_info += f"是在{earliest_dt.year}年,{earliest_str} 到 {latest_str} 的对话:\n" + else: + earliest_str = earliest_dt.strftime("%Y-%m-%d %H:%M:%S") + latest_str = latest_dt.strftime("%Y-%m-%d %H:%M:%S") + time_info += f"是从 {earliest_str} 到 {latest_str} 的对话:\n" + + for msg in messages: + input_text += f"{msg['detailed_plain_text']}\n" + + logger.debug(input_text) + + topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) + topics_response = await self.hippocampus.llm_topic_judge.generate_response( + self.hippocampus.find_topic_llm(input_text, topic_num) + ) + + # 使用正则表达式提取<>中的内容 + topics = re.findall(r"<([^>]+)>", topics_response[0]) + + # 如果没有找到<>包裹的内容,返回['none'] + if not topics: + topics = ["none"] + else: + # 处理提取出的话题 + topics = [ + topic.strip() + for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + if topic.strip() + ] + + # 过滤掉包含禁用关键词的topic + filtered_topics = [ + topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words) + ] + + logger.debug(f"过滤后话题: {filtered_topics}") + + # 创建所有话题的请求任务 + tasks = [] + for topic in filtered_topics: + topic_what_prompt = self.hippocampus.topic_what(input_text, topic, time_info) + try: + task = self.hippocampus.llm_summary_by_topic.generate_response_async(topic_what_prompt) + tasks.append((topic.strip(), task)) + except Exception as e: + logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}") + continue + + # 等待所有任务完成 + compressed_memory = set() + similar_topics_dict = {} + + for topic, task in tasks: + response = await task + if response: + compressed_memory.add((topic, response[0])) + + existing_topics = list(self.memory_graph.G.nodes()) + similar_topics = [] + + for existing_topic in existing_topics: + topic_words = set(jieba.cut(topic)) + existing_words = set(jieba.cut(existing_topic)) + + all_words = topic_words | existing_words + v1 = [1 if word in topic_words else 0 for word in all_words] + v2 = [1 if word in existing_words else 0 for word in all_words] + + similarity = cosine_similarity(v1, v2) + + if similarity >= 0.7: + similar_topics.append((existing_topic, similarity)) + + similar_topics.sort(key=lambda x: x[1], reverse=True) + similar_topics = similar_topics[:3] + similar_topics_dict[topic] = similar_topics + + return compressed_memory, similar_topics_dict + + async def operation_build_memory(self): + logger.debug("------------------------------------开始构建记忆--------------------------------------") + start_time = time.time() + memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample() + all_added_nodes = [] + all_connected_nodes = [] + all_added_edges = [] + for i, messages in enumerate(memory_samples, 1): + all_topics = [] + progress = (i / len(memory_samples)) * 100 + bar_length = 30 + filled_length = int(bar_length * i // len(memory_samples)) + bar = "█" * filled_length + "-" * (bar_length - filled_length) + logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") + + compress_rate = self.config.memory_compress_rate + try: + compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) + except Exception as e: + logger.error(f"压缩记忆时发生错误: {e}") + continue + logger.debug(f"压缩后记忆数量: {compressed_memory},似曾相识的话题: {similar_topics_dict}") + + current_time = datetime.datetime.now().timestamp() + logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}") + all_added_nodes.extend(topic for topic, _ in compressed_memory) + + for topic, memory in compressed_memory: + self.memory_graph.add_dot(topic, memory) + all_topics.append(topic) + + if topic in similar_topics_dict: + similar_topics = similar_topics_dict[topic] + for similar_topic, similarity in similar_topics: + if topic != similar_topic: + strength = int(similarity * 10) + + logger.debug(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})") + all_added_edges.append(f"{topic}-{similar_topic}") + + all_connected_nodes.append(topic) + all_connected_nodes.append(similar_topic) + + self.memory_graph.G.add_edge( + topic, + similar_topic, + strength=strength, + created_time=current_time, + last_modified=current_time, + ) + + for i in range(len(all_topics)): + for j in range(i + 1, len(all_topics)): + logger.debug(f"连接同批次节点: {all_topics[i]} 和 {all_topics[j]}") + all_added_edges.append(f"{all_topics[i]}-{all_topics[j]}") + self.memory_graph.connect_dot(all_topics[i], all_topics[j]) + + logger.success(f"更新记忆: {', '.join(all_added_nodes)}") + logger.debug(f"强化连接: {', '.join(all_added_edges)}") + logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") + + await self.hippocampus.entorhinal_cortex.sync_memory_to_db() + + end_time = time.time() + logger.success(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------") + + async def operation_forget_topic(self, percentage=0.005): + start_time = time.time() + logger.info("[遗忘] 开始检查数据库...") + + # 验证百分比参数 + if not 0 <= percentage <= 1: + logger.warning(f"[遗忘] 无效的遗忘百分比: {percentage}, 使用默认值 0.005") + percentage = 0.005 + + all_nodes = list(self.memory_graph.G.nodes()) + all_edges = list(self.memory_graph.G.edges()) + + if not all_nodes and not all_edges: + logger.info("[遗忘] 记忆图为空,无需进行遗忘操作") + return + + # 确保至少检查1个节点和边,且不超过总数 + check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage))) + check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage))) + + # 只有在有足够的节点和边时才进行采样 + if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count: + try: + nodes_to_check = random.sample(all_nodes, check_nodes_count) + edges_to_check = random.sample(all_edges, check_edges_count) + except ValueError as e: + logger.error(f"[遗忘] 采样错误: {str(e)}") + return + else: + logger.info("[遗忘] 没有足够的节点或边进行遗忘操作") + return + + # 使用列表存储变化信息 + edge_changes = { + "weakened": [], # 存储减弱的边 + "removed": [], # 存储移除的边 + } + node_changes = { + "reduced": [], # 存储减少记忆的节点 + "removed": [], # 存储移除的节点 + } + + current_time = datetime.datetime.now().timestamp() + + logger.info("[遗忘] 开始检查连接...") + edge_check_start = time.time() + for source, target in edges_to_check: + edge_data = self.memory_graph.G[source][target] + last_modified = edge_data.get("last_modified") + + if current_time - last_modified > 3600 * self.config.memory_forget_time: + current_strength = edge_data.get("strength", 1) + new_strength = current_strength - 1 + + if new_strength <= 0: + self.memory_graph.G.remove_edge(source, target) + edge_changes["removed"].append(f"{source} -> {target}") + else: + edge_data["strength"] = new_strength + edge_data["last_modified"] = current_time + edge_changes["weakened"].append(f"{source}-{target} (强度: {current_strength} -> {new_strength})") + edge_check_end = time.time() + logger.info(f"[遗忘] 连接检查耗时: {edge_check_end - edge_check_start:.2f}秒") + + logger.info("[遗忘] 开始检查节点...") + node_check_start = time.time() + for node in nodes_to_check: + node_data = self.memory_graph.G.nodes[node] + last_modified = node_data.get("last_modified", current_time) + + if current_time - last_modified > 3600 * 24: + memory_items = node_data.get("memory_items", []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + if memory_items: + current_count = len(memory_items) + removed_item = random.choice(memory_items) + memory_items.remove(removed_item) + + if memory_items: + self.memory_graph.G.nodes[node]["memory_items"] = memory_items + self.memory_graph.G.nodes[node]["last_modified"] = current_time + node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})") + else: + self.memory_graph.G.remove_node(node) + node_changes["removed"].append(node) + node_check_end = time.time() + logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒") + + if any(edge_changes.values()) or any(node_changes.values()): + sync_start = time.time() + + await self.hippocampus.entorhinal_cortex.resync_memory_to_db() + + sync_end = time.time() + logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}秒") + + # 汇总输出所有变化 + logger.info("[遗忘] 遗忘操作统计:") + if edge_changes["weakened"]: + logger.info( + f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}" + ) + + if edge_changes["removed"]: + logger.info( + f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}" + ) + + if node_changes["reduced"]: + logger.info( + f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}" + ) + + if node_changes["removed"]: + logger.info( + f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}" + ) + else: + logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件") + + end_time = time.time() + logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒") + + class HippocampusManager: _instance = None _hippocampus = None @@ -1317,12 +1325,13 @@ class HippocampusManager: if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") try: - response = await self._hippocampus.get_memory_from_text(text, max_memory_num, max_memory_length, max_depth, fast_retrieval) + response = await self._hippocampus.get_memory_from_text( + text, max_memory_num, max_memory_length, max_depth, fast_retrieval + ) except Exception as e: logger.error(f"文本激活记忆失败: {e}") response = [] return response - async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float: """从文本中获取激活值的公共接口""" From d23ab986adadfe75518bfafa052a92377864c3e0 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Thu, 10 Apr 2025 17:30:25 +0800 Subject: [PATCH 14/24] =?UTF-8?q?=E8=AE=A9eula=E5=92=8Cprivacy=E7=A1=AE?= =?UTF-8?q?=E8=AE=A4=E6=9B=B4=E6=98=BE=E7=9C=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 18 +++++++++++------- src/common/logger.py | 6 ++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/bot.py b/bot.py index 5b12b0389..c9568ecba 100644 --- a/bot.py +++ b/bot.py @@ -7,12 +7,16 @@ from pathlib import Path import time import platform from dotenv import load_dotenv -from src.common.logger import get_module_logger +from src.common.logger import get_module_logger, LogConfig, CONFIRM_STYLE_CONFIG from src.common.crash_logger import install_crash_handler from src.main import MainSystem logger = get_module_logger("main_bot") - +confirm_logger_config = LogConfig( + console_format=CONFIRM_STYLE_CONFIG["console_format"], + file_format=CONFIRM_STYLE_CONFIG["file_format"], +) +confirm_logger = get_module_logger("main_bot", config=confirm_logger_config) # 获取没有加载env时的环境变量 env_mask = {key: os.getenv(key) for key in os.environ} @@ -166,8 +170,8 @@ def check_eula(): # 如果EULA或隐私条款有更新,提示用户重新确认 if eula_updated or privacy_updated: - print("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议") - print( + confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议") + confirm_logger.critical( f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_new_hash}"和"PRIVACY_AGREE={privacy_new_hash}"继续运行' ) while True: @@ -176,14 +180,14 @@ def check_eula(): # print("确认成功,继续运行") # print(f"确认成功,继续运行{eula_updated} {privacy_updated}") if eula_updated: - print(f"更新EULA确认文件{eula_new_hash}") + logger.info(f"更新EULA确认文件{eula_new_hash}") eula_confirm_file.write_text(eula_new_hash, encoding="utf-8") if privacy_updated: - print(f"更新隐私条款确认文件{privacy_new_hash}") + logger.info(f"更新隐私条款确认文件{privacy_new_hash}") privacy_confirm_file.write_text(privacy_new_hash, encoding="utf-8") break else: - print('请输入"同意"或"confirmed"以继续运行') + confirm_logger.critical('请输入"同意"或"confirmed"以继续运行') return elif eula_confirmed and privacy_confirmed: return diff --git a/src/common/logger.py b/src/common/logger.py index 9e118622d..cded9467c 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -290,6 +290,12 @@ WILLING_STYLE_CONFIG = { }, } +CONFIRM_STYLE_CONFIG = { + "console_format": ( + "{message}" + ), # noqa: E501 + "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"), +} # 根据SIMPLE_OUTPUT选择配置 MEMORY_STYLE_CONFIG = MEMORY_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MEMORY_STYLE_CONFIG["advanced"] From f3d4d7f7f2382ce56be79bf63a7541efb1c4b8a9 Mon Sep 17 00:00:00 2001 From: zzzzz Date: Thu, 10 Apr 2025 21:32:31 +0800 Subject: [PATCH 15/24] fix #721 --- src/plugins/storage/storage.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/plugins/storage/storage.py b/src/plugins/storage/storage.py index c35f55be5..d07b02719 100644 --- a/src/plugins/storage/storage.py +++ b/src/plugins/storage/storage.py @@ -1,3 +1,4 @@ +import re from typing import Union from ...common.database import db @@ -7,19 +8,34 @@ from src.common.logger import get_module_logger logger = get_module_logger("message_storage") - class MessageStorage: async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: """存储消息到数据库""" try: + # 莫越权 救世啊 + pattern = r".*?|.*?|.*?" + + processed_plain_text = message.processed_plain_text + if processed_plain_text: + filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL) + else: + filtered_processed_plain_text = "" + + detailed_plain_text = message.detailed_plain_text + if detailed_plain_text: + filtered_detailed_plain_text = re.sub(pattern, "", detailed_plain_text, flags=re.DOTALL) + else: + filtered_detailed_plain_text = "" + message_data = { "message_id": message.message_info.message_id, "time": message.message_info.time, "chat_id": chat_stream.stream_id, "chat_info": chat_stream.to_dict(), "user_info": message.message_info.user_info.to_dict(), - "processed_plain_text": message.processed_plain_text, - "detailed_plain_text": message.detailed_plain_text, + # 使用过滤后的文本 + "processed_plain_text": filtered_processed_plain_text, + "detailed_plain_text": filtered_detailed_plain_text, "memorized_times": message.memorized_times, } db.messages.insert_one(message_data) From 110f94353fb2ceea8632a827e733a1d6480d26de Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Thu, 10 Apr 2025 22:13:17 +0800 Subject: [PATCH 16/24] =?UTF-8?q?fix=EF=BC=9A=E5=8A=A0=E5=85=A5=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E8=B0=83=E7=94=A8=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 538 -> 584 bytes src/heart_flow/observation.py | 8 +- src/heart_flow/sub_heartflow.py | 271 ++------------- src/heart_flow/tool_use.py | 561 ++++++++++++++++++++++++++++++ src/plugins/models/utils_model.py | 52 ++- src/tool_use/tool_use.py | 0 6 files changed, 627 insertions(+), 265 deletions(-) create mode 100644 src/heart_flow/tool_use.py create mode 100644 src/tool_use/tool_use.py diff --git a/requirements.txt b/requirements.txt index ada41d290306e10c34374c30323d519831de9444..0fcb31f83c499ae63c79092dcb349ca3f4112a35 100644 GIT binary patch delta 54 zcmbQma)M=p6q6n=0~bRsLo!1F1BjKzkjPNXPy!?i7&5^kr3__2Rx(3ALlHwB0{|JT B3D^Jt delta 7 OcmX@XGK*z{6cYdn?*caf diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index c54df2f92..818f1775d 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -47,8 +47,8 @@ class ChattingObservation(Observation): new_messages = list( db.messages.find({"chat_id": self.chat_id, "time": {"$gt": self.last_observe_time}}) .sort("time", 1) - .limit(20) - ) # 按时间正序排列,最多20条 + .limit(15) + ) # 按时间正序排列,最多15条 if not new_messages: return self.observe_info # 没有新消息,返回上次观察结果 @@ -63,8 +63,8 @@ class ChattingObservation(Observation): # 将新消息添加到talking_message,同时保持列表长度不超过20条 self.talking_message.extend(new_messages) - if len(self.talking_message) > 20: - self.talking_message = self.talking_message[-20:] # 只保留最新的20条 + if len(self.talking_message) > 15: + self.talking_message = self.talking_message[-15:] # 只保留最新的15条 self.translate_message_list_to_str() # 更新观察次数 diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index a6c6e047a..6d0138987 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -16,6 +16,8 @@ import random from src.plugins.chat.chat_stream import ChatStream from src.plugins.person_info.relationship_manager import relationship_manager from src.plugins.chat.utils import get_recent_group_speaker +import json +from src.heart_flow.tool_use import ToolUser subheartflow_config = LogConfig( # 使用海马体专用样式 @@ -47,6 +49,7 @@ class SubHeartflow: self.llm_model = LLM_request( model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow" ) + self.main_heartflow_info = "" @@ -63,6 +66,8 @@ class SubHeartflow: self.running_knowledges = [] self.bot_name = global_config.BOT_NICKNAME + + self.tool_user = ToolUser() def add_observation(self, observation: Observation): """添加一个新的observation对象到列表中,如果已存在相同id的observation则不添加""" @@ -115,6 +120,7 @@ class SubHeartflow: observation = self.observations[0] await observation.observe() + async def do_thinking_before_reply(self, message_txt:str, sender_name:str, chat_stream:ChatStream): current_thinking_info = self.current_mind mood_info = self.current_state.mood @@ -123,6 +129,19 @@ class SubHeartflow: chat_observe_info = observation.observe_info # print(f"chat_observe_info:{chat_observe_info}") + # 首先尝试使用工具获取更多信息 + tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream) + + # 如果工具被使用且获得了结果,将收集到的信息合并到思考中 + if tool_result.get("used_tools", False): + logger.info("使用工具收集了信息") + + # 如果有收集到的信息,将其添加到当前思考中 + if "collected_info" in tool_result: + collected_info = tool_result["collected_info"] + + + # 开始构建prompt prompt_personality = f"你的名字是{self.bot_name},你" # person @@ -158,38 +177,11 @@ class SubHeartflow: f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。" ) - # 调取记忆 - related_memory = await HippocampusManager.get_instance().get_memory_from_text( - text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False - ) - - if related_memory: - related_memory_info = "" - for memory in related_memory: - related_memory_info += memory[1] - else: - related_memory_info = "" - - related_info, grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4) - # print(related_info) - for _topic, results in grouped_results.items(): - for result in results: - # print(result) - self.running_knowledges.append(result) - - # print(f"相关记忆:{related_memory_info}") - - schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False) - prompt = "" - # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" + if tool_result.get("used_tools", False): + prompt += f"{collected_info}\n" prompt += f"{relation_prompt_all}\n" prompt += f"{prompt_personality}\n" - # prompt += f"你刚刚在做的事情是:{schedule_info}\n" - # if related_memory_info: - # prompt += f"你想起来你之前见过的回忆:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n" - # if related_info: - # prompt += f"你想起你知道:{related_info}\n" prompt += f"刚刚你的想法是{current_thinking_info}。如果有新的内容,记得转换话题\n" prompt += "-----------------------------------\n" prompt += f"现在你正在上网,和qq群里的网友们聊天,群里正在聊的话题是:{chat_observe_info}\n" @@ -211,7 +203,7 @@ class SubHeartflow: logger.info(f"prompt:\n{prompt}\n") logger.info(f"麦麦的思考前脑内状态:{self.current_mind}") - return self.current_mind ,self.past_mind + return self.current_mind, self.past_mind async def do_thinking_after_reply(self, reply_content, chat_talking_prompt): # print("麦麦回复之后脑袋转起来了") @@ -310,224 +302,5 @@ class SubHeartflow: self.past_mind.append(self.current_mind) self.current_mind = response - async def get_prompt_info(self, message: str, threshold: float): - start_time = time.time() - related_info = "" - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - - # 1. 先从LLM获取主题,类似于记忆系统的做法 - topics = [] - # try: - # # 先尝试使用记忆系统的方法获取主题 - # hippocampus = HippocampusManager.get_instance()._hippocampus - # topic_num = min(5, max(1, int(len(message) * 0.1))) - # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) - - # # 提取关键词 - # topics = re.findall(r"<([^>]+)>", topics_response[0]) - # if not topics: - # topics = [] - # else: - # topics = [ - # topic.strip() - # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - # if topic.strip() - # ] - - # logger.info(f"从LLM提取的主题: {', '.join(topics)}") - # except Exception as e: - # logger.error(f"从LLM提取主题失败: {str(e)}") - # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 - # words = jieba.cut(message) - # topics = [word for word in words if len(word) > 1][:5] - # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") - - # 如果无法提取到主题,直接使用整个消息 - if not topics: - logger.debug("未能提取到任何主题,使用整个消息进行查询") - embedding = await get_embedding(message, request_type="info_retrieval") - if not embedding: - logger.error("获取消息嵌入向量失败") - return "" - - related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒") - return related_info, {} - - # 2. 对每个主题进行知识库查询 - logger.info(f"开始处理{len(topics)}个主题的知识库查询") - - # 优化:批量获取嵌入向量,减少API调用 - embeddings = {} - topics_batch = [topic for topic in topics if len(topic) > 0] - if message: # 确保消息非空 - topics_batch.append(message) - - # 批量获取嵌入向量 - embed_start_time = time.time() - for text in topics_batch: - if not text or len(text.strip()) == 0: - continue - - try: - embedding = await get_embedding(text, request_type="info_retrieval") - if embedding: - embeddings[text] = embedding - else: - logger.warning(f"获取'{text}'的嵌入向量失败") - except Exception as e: - logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") - - logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒") - - if not embeddings: - logger.error("所有嵌入向量获取失败") - return "" - - # 3. 对每个主题进行知识库查询 - all_results = [] - query_start_time = time.time() - - # 首先添加原始消息的查询结果 - if message in embeddings: - original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) - if original_results: - for result in original_results: - result["topic"] = "原始消息" - all_results.extend(original_results) - logger.info(f"原始消息查询到{len(original_results)}条结果") - - # 然后添加每个主题的查询结果 - for topic in topics: - if not topic or topic not in embeddings: - continue - - try: - topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) - if topic_results: - # 添加主题标记 - for result in topic_results: - result["topic"] = topic - all_results.extend(topic_results) - logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") - except Exception as e: - logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") - - logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") - - # 4. 去重和过滤 - process_start_time = time.time() - unique_contents = set() - filtered_results = [] - for result in all_results: - content = result["content"] - if content not in unique_contents: - unique_contents.add(content) - filtered_results.append(result) - - # 5. 按相似度排序 - filtered_results.sort(key=lambda x: x["similarity"], reverse=True) - - # 6. 限制总数量(最多10条) - filtered_results = filtered_results[:10] - logger.info( - f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果" - ) - - # 7. 格式化输出 - if filtered_results: - format_start_time = time.time() - grouped_results = {} - for result in filtered_results: - topic = result["topic"] - if topic not in grouped_results: - grouped_results[topic] = [] - grouped_results[topic].append(result) - - # 按主题组织输出 - for topic, results in grouped_results.items(): - related_info += f"【主题: {topic}】\n" - for _i, result in enumerate(results, 1): - _similarity = result["similarity"] - content = result["content"].strip() - # 调试:为内容添加序号和相似度信息 - # related_info += f"{i}. [{similarity:.2f}] {content}\n" - related_info += f"{content}\n" - related_info += "\n" - - logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒") - - logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒") - return related_info, grouped_results - - def get_info_from_db( - self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False - ) -> Union[str, list]: - if not query_embedding: - return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] - - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") - - if not results: - return "" if not return_raw else [] - - if return_raw: - return results - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - # subheartflow = SubHeartflow() diff --git a/src/heart_flow/tool_use.py b/src/heart_flow/tool_use.py new file mode 100644 index 000000000..7471e7512 --- /dev/null +++ b/src/heart_flow/tool_use.py @@ -0,0 +1,561 @@ +from src.plugins.models.utils_model import LLM_request +from src.plugins.config.config import global_config +from src.plugins.chat.chat_stream import ChatStream +from src.plugins.memory_system.Hippocampus import HippocampusManager +from src.common.database import db +import time +import json +from src.common.logger import get_module_logger +from src.plugins.chat.utils import get_embedding +from typing import Union +logger = get_module_logger("tool_use") + + +class ToolUser: + def __init__(self): + self.llm_model_tool = LLM_request( + model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use" + ) + + async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + """构建工具使用的提示词 + + Args: + message_txt: 用户消息文本 + sender_name: 发送者名称 + chat_stream: 聊天流对象 + + Returns: + str: 构建好的提示词 + """ + from src.plugins.config.config import global_config + + new_messages = list( + db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}) + .sort("time", 1) + .limit(15) + ) + new_messages_str = "" + for msg in new_messages: + if "detailed_plain_text" in msg: + new_messages_str += f"{msg['detailed_plain_text']}" + + # 这些信息应该从调用者传入,而不是从self获取 + bot_name = global_config.BOT_NICKNAME + prompt = "" + prompt += "你正在思考如何回复群里的消息。\n" + prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" + prompt += f"注意你就是{bot_name},{bot_name}指的就是你。" + prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。" + + return prompt + + def _define_tools(self): + """定义可用的工具列表 + + Returns: + list: 工具定义列表 + """ + tools = [ + { + "type": "function", + "function": { + "name": "search_knowledge", + "description": "从知识库中搜索相关信息", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索查询关键词" + }, + "threshold": { + "type": "number", + "description": "相似度阈值,0.0到1.0之间" + } + }, + "required": ["query"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_memory", + "description": "从记忆系统中获取相关记忆", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "要查询的相关文本" + }, + "max_memory_num": { + "type": "integer", + "description": "最大返回记忆数量" + } + }, + "required": ["text"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_current_task", + "description": "获取当前正在做的事情/最近的任务", + "parameters": { + "type": "object", + "properties": { + "num": { + "type": "integer", + "description": "要获取的任务数量" + }, + "time_info": { + "type": "boolean", + "description": "是否包含时间信息" + } + }, + "required": [] + } + } + } + ] + return tools + + async def _execute_tool_call(self, tool_call, message_txt:str): + """执行特定的工具调用 + + Args: + tool_call: 工具调用对象 + message_txt: 原始消息文本 + + Returns: + dict: 工具调用结果 + """ + try: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + + if function_name == "search_knowledge": + return await self._execute_search_knowledge(tool_call, function_args, message_txt) + elif function_name == "get_memory": + return await self._execute_get_memory(tool_call, function_args, message_txt) + elif function_name == "get_current_task": + return await self._execute_get_current_task(tool_call, function_args) + + logger.warning(f"未知工具名称: {function_name}") + return None + except Exception as e: + logger.error(f"执行工具调用时发生错误: {str(e)}") + return None + + async def _execute_search_knowledge(self, tool_call, function_args, message_txt:str): + """执行知识库搜索工具 + + Args: + tool_call: 工具调用对象 + function_args: 工具参数 + message_txt: 原始消息文本 + + Returns: + dict: 工具调用结果 + """ + try: + query = function_args.get("query", message_txt) + threshold = function_args.get("threshold", 0.4) + + # 调用知识库搜索 + embedding = await get_embedding(query, request_type="info_retrieval") + if embedding: + knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) + return { + "tool_call_id": tool_call["id"], + "role": "tool", + "name": "search_knowledge", + "content": f"知识库搜索结果: {knowledge_info}" + } + return None + except Exception as e: + logger.error(f"知识库搜索工具执行失败: {str(e)}") + return None + + async def _execute_get_memory(self, tool_call, function_args, message_txt:str): + """执行记忆获取工具 + + Args: + tool_call: 工具调用对象 + function_args: 工具参数 + message_txt: 原始消息文本 + + Returns: + dict: 工具调用结果 + """ + try: + text = function_args.get("text", message_txt) + max_memory_num = function_args.get("max_memory_num", 2) + + # 调用记忆系统 + related_memory = await HippocampusManager.get_instance().get_memory_from_text( + text=text, + max_memory_num=max_memory_num, + max_memory_length=2, + max_depth=3, + fast_retrieval=False + ) + + memory_info = "" + if related_memory: + for memory in related_memory: + memory_info += memory[1] + "\n" + + return { + "tool_call_id": tool_call["id"], + "role": "tool", + "name": "get_memory", + "content": f"记忆系统结果: {memory_info if memory_info else '没有找到相关记忆'}" + } + except Exception as e: + logger.error(f"记忆获取工具执行失败: {str(e)}") + return None + + async def _execute_get_current_task(self, tool_call, function_args): + """执行获取当前任务工具 + + Args: + tool_call: 工具调用对象 + function_args: 工具参数 + + Returns: + dict: 工具调用结果 + """ + try: + from src.plugins.schedule.schedule_generator import bot_schedule + + # 获取参数,如果没有提供则使用默认值 + num = function_args.get("num", 1) + time_info = function_args.get("time_info", False) + + # 调用日程系统获取当前任务 + current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info) + + # 格式化返回结果 + if current_task: + task_info = current_task + else: + task_info = "当前没有正在进行的任务" + + return { + "tool_call_id": tool_call["id"], + "role": "tool", + "name": "get_current_task", + "content": f"当前任务信息: {task_info}" + } + except Exception as e: + logger.error(f"获取当前任务工具执行失败: {str(e)}") + return None + + async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + """使用工具辅助思考,判断是否需要额外信息 + + Args: + message_txt: 用户消息文本 + sender_name: 发送者名称 + chat_stream: 聊天流对象 + + Returns: + dict: 工具使用结果 + """ + try: + # 构建提示词 + prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream) + + # 定义可用工具 + tools = self._define_tools() + + # 使用llm_model_tool发送带工具定义的请求 + payload = { + "model": self.llm_model_tool.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": global_config.max_response_length, + "tools": tools, + "temperature": 0.2 + } + + logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}") + # 发送请求获取模型是否需要调用工具 + response = await self.llm_model_tool._execute_request( + endpoint="/chat/completions", + payload=payload, + prompt=prompt + ) + + # 根据返回值数量判断是否有工具调用 + if len(response) == 3: + content, reasoning_content, tool_calls = response + logger.info(f"工具思考: {tool_calls}") + + # 检查响应中工具调用是否有效 + if not tool_calls: + logger.info("模型返回了空的tool_calls列表") + return {"used_tools": False, "thinking": self.current_mind} + + logger.info(f"模型请求调用{len(tool_calls)}个工具") + tool_results = [] + collected_info = "" + + # 执行所有工具调用 + for tool_call in tool_calls: + result = await self._execute_tool_call(tool_call, message_txt) + if result: + tool_results.append(result) + # 将工具结果添加到收集的信息中 + collected_info += f"\n{result['name']}返回结果: {result['content']}\n" + + # 如果有工具结果,直接返回收集的信息 + if collected_info: + logger.info(f"工具调用收集到信息: {collected_info}") + return { + "used_tools": True, + "collected_info": collected_info, + "thinking": self.current_mind # 保持原始思考不变 + } + else: + # 没有工具调用 + content, reasoning_content = response + logger.info("模型没有请求调用任何工具") + + # 如果没有工具调用或处理失败,直接返回原始思考 + return { + "used_tools": False, + "thinking": self.current_mind + } + + except Exception as e: + logger.error(f"工具调用过程中出错: {str(e)}") + return { + "used_tools": False, + "error": str(e), + "thinking": self.current_mind + } + + + + async def get_prompt_info(self, message: str, threshold: float): + start_time = time.time() + related_info = "" + logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") + + # 1. 先从LLM获取主题,类似于记忆系统的做法 + topics = [] + # try: + # # 先尝试使用记忆系统的方法获取主题 + # hippocampus = HippocampusManager.get_instance()._hippocampus + # topic_num = min(5, max(1, int(len(message) * 0.1))) + # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) + + # # 提取关键词 + # topics = re.findall(r"<([^>]+)>", topics_response[0]) + # if not topics: + # topics = [] + # else: + # topics = [ + # topic.strip() + # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + # if topic.strip() + # ] + + # logger.info(f"从LLM提取的主题: {', '.join(topics)}") + # except Exception as e: + # logger.error(f"从LLM提取主题失败: {str(e)}") + # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 + # words = jieba.cut(message) + # topics = [word for word in words if len(word) > 1][:5] + # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") + + # 如果无法提取到主题,直接使用整个消息 + if not topics: + logger.debug("未能提取到任何主题,使用整个消息进行查询") + embedding = await get_embedding(message, request_type="info_retrieval") + if not embedding: + logger.error("获取消息嵌入向量失败") + return "" + + related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) + logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒") + return related_info, {} + + # 2. 对每个主题进行知识库查询 + logger.info(f"开始处理{len(topics)}个主题的知识库查询") + + # 优化:批量获取嵌入向量,减少API调用 + embeddings = {} + topics_batch = [topic for topic in topics if len(topic) > 0] + if message: # 确保消息非空 + topics_batch.append(message) + + # 批量获取嵌入向量 + embed_start_time = time.time() + for text in topics_batch: + if not text or len(text.strip()) == 0: + continue + + try: + embedding = await get_embedding(text, request_type="info_retrieval") + if embedding: + embeddings[text] = embedding + else: + logger.warning(f"获取'{text}'的嵌入向量失败") + except Exception as e: + logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") + + logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒") + + if not embeddings: + logger.error("所有嵌入向量获取失败") + return "" + + # 3. 对每个主题进行知识库查询 + all_results = [] + query_start_time = time.time() + + # 首先添加原始消息的查询结果 + if message in embeddings: + original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) + if original_results: + for result in original_results: + result["topic"] = "原始消息" + all_results.extend(original_results) + logger.info(f"原始消息查询到{len(original_results)}条结果") + + # 然后添加每个主题的查询结果 + for topic in topics: + if not topic or topic not in embeddings: + continue + + try: + topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) + if topic_results: + # 添加主题标记 + for result in topic_results: + result["topic"] = topic + all_results.extend(topic_results) + logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") + except Exception as e: + logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") + + logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") + + # 4. 去重和过滤 + process_start_time = time.time() + unique_contents = set() + filtered_results = [] + for result in all_results: + content = result["content"] + if content not in unique_contents: + unique_contents.add(content) + filtered_results.append(result) + + # 5. 按相似度排序 + filtered_results.sort(key=lambda x: x["similarity"], reverse=True) + + # 6. 限制总数量(最多10条) + filtered_results = filtered_results[:10] + logger.info( + f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果" + ) + + # 7. 格式化输出 + if filtered_results: + format_start_time = time.time() + grouped_results = {} + for result in filtered_results: + topic = result["topic"] + if topic not in grouped_results: + grouped_results[topic] = [] + grouped_results[topic].append(result) + + # 按主题组织输出 + for topic, results in grouped_results.items(): + related_info += f"【主题: {topic}】\n" + for _i, result in enumerate(results, 1): + _similarity = result["similarity"] + content = result["content"].strip() + # 调试:为内容添加序号和相似度信息 + # related_info += f"{i}. [{similarity:.2f}] {content}\n" + related_info += f"{content}\n" + related_info += "\n" + + logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒") + + logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒") + return related_info, grouped_results + + def get_info_from_db( + self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + ) -> Union[str, list]: + if not query_embedding: + return "" if not return_raw else [] + # 使用余弦相似度计算 + pipeline = [ + { + "$addFields": { + "dotProduct": { + "$reduce": { + "input": {"$range": [0, {"$size": "$embedding"}]}, + "initialValue": 0, + "in": { + "$add": [ + "$$value", + { + "$multiply": [ + {"$arrayElemAt": ["$embedding", "$$this"]}, + {"$arrayElemAt": [query_embedding, "$$this"]}, + ] + }, + ] + }, + } + }, + "magnitude1": { + "$sqrt": { + "$reduce": { + "input": "$embedding", + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + } + } + }, + "magnitude2": { + "$sqrt": { + "$reduce": { + "input": query_embedding, + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + } + } + }, + } + }, + {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, + { + "$match": { + "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 + } + }, + {"$sort": {"similarity": -1}}, + {"$limit": limit}, + {"$project": {"content": 1, "similarity": 1}}, + ] + + results = list(db.knowledges.aggregate(pipeline)) + logger.debug(f"知识库查询结果数量: {len(results)}") + + if not results: + return "" if not return_raw else [] + + if return_raw: + return results + else: + # 返回所有找到的内容,用换行分隔 + return "\n".join(str(result["content"]) for result in results) \ No newline at end of file diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 784bfa1db..1066453ff 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -342,6 +342,7 @@ class LLM_request: "message": { "content": accumulated_content, "reasoning_content": reasoning_content, + # 流式输出可能没有工具调用,此处不需要添加tool_calls字段 } } ], @@ -366,6 +367,7 @@ class LLM_request: "message": { "content": accumulated_content, "reasoning_content": reasoning_content, + # 流式输出可能没有工具调用,此处不需要添加tool_calls字段 } } ], @@ -384,7 +386,13 @@ class LLM_request: # 构造一个伪result以便调用自定义响应处理器或默认处理器 result = { "choices": [ - {"message": {"content": content, "reasoning_content": reasoning_content}} + { + "message": { + "content": content, + "reasoning_content": reasoning_content, + # 流式输出可能没有工具调用,此处不需要添加tool_calls字段 + } + } ], "usage": usage, } @@ -566,6 +574,9 @@ class LLM_request: reasoning_content = message.get("reasoning_content", "") if not reasoning_content: reasoning_content = reasoning + + # 提取工具调用信息 + tool_calls = message.get("tool_calls", None) # 记录token使用情况 usage = result.get("usage", {}) @@ -581,8 +592,12 @@ class LLM_request: request_type=request_type if request_type is not None else self.request_type, endpoint=endpoint, ) - - return content, reasoning_content + + # 只有当tool_calls存在且不为空时才返回 + if tool_calls: + return content, reasoning_content, tool_calls + else: + return content, reasoning_content return "没有返回结果", "" @@ -605,21 +620,33 @@ class LLM_request: return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} # 防止小朋友们截图自己的key - async def generate_response(self, prompt: str) -> Tuple[str, str, str]: + async def generate_response(self, prompt: str) -> Tuple: """根据输入的提示生成模型的异步响应""" - content, reasoning_content = await self._execute_request(endpoint="/chat/completions", prompt=prompt) - return content, reasoning_content, self.model_name + response = await self._execute_request(endpoint="/chat/completions", prompt=prompt) + # 根据返回值的长度决定怎么处理 + if len(response) == 3: + content, reasoning_content, tool_calls = response + return content, reasoning_content, self.model_name, tool_calls + else: + content, reasoning_content = response + return content, reasoning_content, self.model_name - async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]: + async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple: """根据输入的提示和图片生成模型的异步响应""" - content, reasoning_content = await self._execute_request( + response = await self._execute_request( endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format ) - return content, reasoning_content + # 根据返回值的长度决定怎么处理 + if len(response) == 3: + content, reasoning_content, tool_calls = response + return content, reasoning_content, tool_calls + else: + content, reasoning_content = response + return content, reasoning_content - async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple[str, str]]: + async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: """异步方式根据输入的提示生成模型的响应""" # 构建请求体 data = { @@ -630,10 +657,11 @@ class LLM_request: **kwargs, } - content, reasoning_content = await self._execute_request( + response = await self._execute_request( endpoint="/chat/completions", payload=data, prompt=prompt ) - return content, reasoning_content + # 原样返回响应,不做处理 + return response async def get_embedding(self, text: str) -> Union[list, None]: """异步方法:获取文本的embedding向量 diff --git a/src/tool_use/tool_use.py b/src/tool_use/tool_use.py new file mode 100644 index 000000000..e69de29bb From 54f7b73ec495a075d3dcd3d35ce57a48dbde02fd Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Thu, 10 Apr 2025 23:32:28 +0800 Subject: [PATCH 17/24] =?UTF-8?q?better=EF=BC=9A=E5=BF=83=E6=B5=81?= =?UTF-8?q?=E5=8D=87=E7=BA=A7=EF=BC=8C=E5=A4=A7=E5=A4=A7=E5=87=8F=E5=B0=91?= =?UTF-8?q?=E4=BA=86=E5=A4=8D=E8=AF=BB=E6=83=85=E5=86=B5=EF=BC=8C=E5=B9=B6?= =?UTF-8?q?=E4=B8=94=E7=81=B5=E6=B4=BB=E8=B0=83=E7=94=A8=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E6=9D=A5=E5=AE=9E=E7=8E=B0=E7=9F=A5=E8=AF=86=E5=92=8C=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E6=A3=80=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/do_tool/tool_can_use/README.md | 102 ++++ src/do_tool/tool_can_use/__init__.py | 11 + src/do_tool/tool_can_use/base_tool.py | 119 ++++ src/do_tool/tool_can_use/get_current_task.py | 63 +++ src/do_tool/tool_can_use/get_knowledge.py | 147 +++++ src/do_tool/tool_can_use/get_memory.py | 72 +++ src/do_tool/tool_use.py | 172 ++++++ src/heart_flow/observation.py | 65 ++- src/heart_flow/sub_heartflow.py | 14 +- src/heart_flow/tool_use.py | 561 ------------------- 10 files changed, 729 insertions(+), 597 deletions(-) create mode 100644 src/do_tool/tool_can_use/README.md create mode 100644 src/do_tool/tool_can_use/__init__.py create mode 100644 src/do_tool/tool_can_use/base_tool.py create mode 100644 src/do_tool/tool_can_use/get_current_task.py create mode 100644 src/do_tool/tool_can_use/get_knowledge.py create mode 100644 src/do_tool/tool_can_use/get_memory.py create mode 100644 src/do_tool/tool_use.py delete mode 100644 src/heart_flow/tool_use.py diff --git a/src/do_tool/tool_can_use/README.md b/src/do_tool/tool_can_use/README.md new file mode 100644 index 000000000..15c771887 --- /dev/null +++ b/src/do_tool/tool_can_use/README.md @@ -0,0 +1,102 @@ +# 工具系统使用指南 + +## 概述 + +`tool_can_use` 是一个插件式工具系统,允许轻松扩展和注册新工具。每个工具作为独立的文件存在于该目录下,系统会自动发现和注册这些工具。 + +## 工具结构 + +每个工具应该继承 `BaseTool` 基类并实现必要的属性和方法: + +```python +from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool + +class MyNewTool(BaseTool): + # 工具名称,必须唯一 + name = "my_new_tool" + + # 工具描述,告诉LLM这个工具的用途 + description = "这是一个新工具,用于..." + + # 工具参数定义,遵循JSONSchema格式 + parameters = { + "type": "object", + "properties": { + "param1": { + "type": "string", + "description": "参数1的描述" + }, + "param2": { + "type": "integer", + "description": "参数2的描述" + } + }, + "required": ["param1"] # 必需的参数列表 + } + + async def execute(self, function_args, message_txt=""): + """执行工具逻辑 + + Args: + function_args: 工具调用参数 + message_txt: 原始消息文本 + + Returns: + Dict: 包含执行结果的字典,必须包含name和content字段 + """ + # 实现工具逻辑 + result = f"工具执行结果: {function_args.get('param1')}" + + return { + "name": self.name, + "content": result + } + +# 注册工具 +register_tool(MyNewTool) +``` + +## 自动注册机制 + +工具系统通过以下步骤自动注册工具: + +1. 在`__init__.py`中,`discover_tools()`函数会自动遍历当前目录中的所有Python文件 +2. 对于每个文件,系统会寻找继承自`BaseTool`的类 +3. 这些类会被自动注册到工具注册表中 + +只要确保在每个工具文件的末尾调用`register_tool(YourToolClass)`,工具就会被自动注册。 + +## 添加新工具步骤 + +1. 在`tool_can_use`目录下创建新的Python文件(如`my_new_tool.py`) +2. 导入`BaseTool`和`register_tool` +3. 创建继承自`BaseTool`的工具类 +4. 实现必要的属性(`name`, `description`, `parameters`) +5. 实现`execute`方法 +6. 使用`register_tool`注册工具 + +## 与ToolUser整合 + +`ToolUser`类已经更新为使用这个新的工具系统,它会: + +1. 自动获取所有已注册工具的定义 +2. 基于工具名称找到对应的工具实例 +3. 调用工具的`execute`方法 + +## 使用示例 + +```python +from src.do_tool.tool_use import ToolUser + +# 创建工具用户 +tool_user = ToolUser() + +# 使用工具 +result = await tool_user.use_tool(message_txt="查询关于Python的知识", sender_name="用户", chat_stream=chat_stream) + +# 处理结果 +if result["used_tools"]: + print("工具使用结果:", result["collected_info"]) +else: + print("未使用工具") +``` \ No newline at end of file diff --git a/src/do_tool/tool_can_use/__init__.py b/src/do_tool/tool_can_use/__init__.py new file mode 100644 index 000000000..cc196d07a --- /dev/null +++ b/src/do_tool/tool_can_use/__init__.py @@ -0,0 +1,11 @@ +from src.do_tool.tool_can_use.base_tool import ( + BaseTool, + register_tool, + discover_tools, + get_all_tool_definitions, + get_tool_instance, + TOOL_REGISTRY +) + +# 自动发现并注册工具 +discover_tools() \ No newline at end of file diff --git a/src/do_tool/tool_can_use/base_tool.py b/src/do_tool/tool_can_use/base_tool.py new file mode 100644 index 000000000..03aac5e4c --- /dev/null +++ b/src/do_tool/tool_can_use/base_tool.py @@ -0,0 +1,119 @@ +from typing import Dict, List, Any, Optional, Union, Type +import inspect +import importlib +import pkgutil +import os +from src.common.logger import get_module_logger + +logger = get_module_logger("base_tool") + +# 工具注册表 +TOOL_REGISTRY = {} + +class BaseTool: + """所有工具的基类""" + # 工具名称,子类必须重写 + name = None + # 工具描述,子类必须重写 + description = None + # 工具参数定义,子类必须重写 + parameters = None + + @classmethod + def get_tool_definition(cls) -> Dict[str, Any]: + """获取工具定义,用于LLM工具调用 + + Returns: + Dict: 工具定义字典 + """ + if not cls.name or not cls.description or not cls.parameters: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") + + return { + "type": "function", + "function": { + "name": cls.name, + "description": cls.description, + "parameters": cls.parameters + } + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行工具函数 + + Args: + function_args: 工具调用参数 + message_txt: 原始消息文本 + + Returns: + Dict: 工具执行结果 + """ + raise NotImplementedError("子类必须实现execute方法") + + +def register_tool(tool_class: Type[BaseTool]): + """注册工具到全局注册表 + + Args: + tool_class: 工具类 + """ + if not issubclass(tool_class, BaseTool): + raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类") + + tool_name = tool_class.name + if not tool_name: + raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性") + + TOOL_REGISTRY[tool_name] = tool_class + logger.info(f"已注册工具: {tool_name}") + + +def discover_tools(): + """自动发现并注册tool_can_use目录下的所有工具""" + # 获取当前目录路径 + current_dir = os.path.dirname(os.path.abspath(__file__)) + package_name = os.path.basename(current_dir) + parent_dir = os.path.dirname(current_dir) + + # 导入当前包 + package = importlib.import_module(f"src.do_tool.{package_name}") + + # 遍历包中的所有模块 + for _, module_name, is_pkg in pkgutil.iter_modules([current_dir]): + # 跳过当前模块和__pycache__ + if module_name == "base_tool" or module_name.startswith("__"): + continue + + # 导入模块 + module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}") + + # 查找模块中的工具类 + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: + register_tool(obj) + + logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具") + + +def get_all_tool_definitions() -> List[Dict[str, Any]]: + """获取所有已注册工具的定义 + + Returns: + List[Dict]: 工具定义列表 + """ + return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()] + + +def get_tool_instance(tool_name: str) -> Optional[BaseTool]: + """获取指定名称的工具实例 + + Args: + tool_name: 工具名称 + + Returns: + Optional[BaseTool]: 工具实例,如果找不到则返回None + """ + tool_class = TOOL_REGISTRY.get(tool_name) + if not tool_class: + return None + return tool_class() \ No newline at end of file diff --git a/src/do_tool/tool_can_use/get_current_task.py b/src/do_tool/tool_can_use/get_current_task.py new file mode 100644 index 000000000..dd3402357 --- /dev/null +++ b/src/do_tool/tool_can_use/get_current_task.py @@ -0,0 +1,63 @@ +from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool +from src.plugins.schedule.schedule_generator import bot_schedule +from src.common.logger import get_module_logger +from typing import Dict, Any + +logger = get_module_logger("get_current_task_tool") + +class GetCurrentTaskTool(BaseTool): + """获取当前正在做的事情/最近的任务工具""" + name = "get_current_task" + description = "获取当前正在做的事情/最近的任务" + parameters = { + "type": "object", + "properties": { + "num": { + "type": "integer", + "description": "要获取的任务数量" + }, + "time_info": { + "type": "boolean", + "description": "是否包含时间信息" + } + }, + "required": [] + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行获取当前任务 + + Args: + function_args: 工具参数 + message_txt: 原始消息文本,此工具不使用 + + Returns: + Dict: 工具执行结果 + """ + try: + # 获取参数,如果没有提供则使用默认值 + num = function_args.get("num", 1) + time_info = function_args.get("time_info", False) + + # 调用日程系统获取当前任务 + current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info) + + # 格式化返回结果 + if current_task: + task_info = current_task + else: + task_info = "当前没有正在进行的任务" + + return { + "name": "get_current_task", + "content": f"当前任务信息: {task_info}" + } + except Exception as e: + logger.error(f"获取当前任务工具执行失败: {str(e)}") + return { + "name": "get_current_task", + "content": f"获取当前任务失败: {str(e)}" + } + +# 注册工具 +register_tool(GetCurrentTaskTool) \ No newline at end of file diff --git a/src/do_tool/tool_can_use/get_knowledge.py b/src/do_tool/tool_can_use/get_knowledge.py new file mode 100644 index 000000000..06ea7a91b --- /dev/null +++ b/src/do_tool/tool_can_use/get_knowledge.py @@ -0,0 +1,147 @@ +from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool +from src.plugins.chat.utils import get_embedding +from src.common.database import db +from src.common.logger import get_module_logger +from typing import Dict, Any, Union, List + +logger = get_module_logger("get_knowledge_tool") + +class SearchKnowledgeTool(BaseTool): + """从知识库中搜索相关信息的工具""" + name = "search_knowledge" + description = "从知识库中搜索相关信息" + parameters = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索查询关键词" + }, + "threshold": { + "type": "number", + "description": "相似度阈值,0.0到1.0之间" + } + }, + "required": ["query"] + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行知识库搜索 + + Args: + function_args: 工具参数 + message_txt: 原始消息文本 + + Returns: + Dict: 工具执行结果 + """ + try: + query = function_args.get("query", message_txt) + threshold = function_args.get("threshold", 0.4) + + # 调用知识库搜索 + embedding = await get_embedding(query, request_type="info_retrieval") + if embedding: + knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) + if knowledge_info: + content = f"你知道这些知识: {knowledge_info}" + else: + content = f"你不太了解有关{query}的知识" + return { + "name": "search_knowledge", + "content": content + } + return { + "name": "search_knowledge", + "content": f"无法获取关于'{query}'的嵌入向量" + } + except Exception as e: + logger.error(f"知识库搜索工具执行失败: {str(e)}") + return { + "name": "search_knowledge", + "content": f"知识库搜索失败: {str(e)}" + } + + def get_info_from_db( + self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + ) -> Union[str, list]: + """从数据库中获取相关信息 + + Args: + query_embedding: 查询的嵌入向量 + limit: 最大返回结果数 + threshold: 相似度阈值 + return_raw: 是否返回原始结果 + + Returns: + Union[str, list]: 格式化的信息字符串或原始结果列表 + """ + if not query_embedding: + return "" if not return_raw else [] + + # 使用余弦相似度计算 + pipeline = [ + { + "$addFields": { + "dotProduct": { + "$reduce": { + "input": {"$range": [0, {"$size": "$embedding"}]}, + "initialValue": 0, + "in": { + "$add": [ + "$$value", + { + "$multiply": [ + {"$arrayElemAt": ["$embedding", "$$this"]}, + {"$arrayElemAt": [query_embedding, "$$this"]}, + ] + }, + ] + }, + } + }, + "magnitude1": { + "$sqrt": { + "$reduce": { + "input": "$embedding", + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + } + } + }, + "magnitude2": { + "$sqrt": { + "$reduce": { + "input": query_embedding, + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + } + } + }, + } + }, + {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, + { + "$match": { + "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 + } + }, + {"$sort": {"similarity": -1}}, + {"$limit": limit}, + {"$project": {"content": 1, "similarity": 1}}, + ] + + results = list(db.knowledges.aggregate(pipeline)) + logger.debug(f"知识库查询结果数量: {len(results)}") + + if not results: + return "" if not return_raw else [] + + if return_raw: + return results + else: + # 返回所有找到的内容,用换行分隔 + return "\n".join(str(result["content"]) for result in results) + +# 注册工具 +register_tool(SearchKnowledgeTool) diff --git a/src/do_tool/tool_can_use/get_memory.py b/src/do_tool/tool_can_use/get_memory.py new file mode 100644 index 000000000..171e8486a --- /dev/null +++ b/src/do_tool/tool_can_use/get_memory.py @@ -0,0 +1,72 @@ +from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool +from src.plugins.memory_system.Hippocampus import HippocampusManager +from src.common.logger import get_module_logger +from typing import Dict, Any + +logger = get_module_logger("get_memory_tool") + +class GetMemoryTool(BaseTool): + """从记忆系统中获取相关记忆的工具""" + name = "get_memory" + description = "从记忆系统中获取相关记忆" + parameters = { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "要查询的相关文本" + }, + "max_memory_num": { + "type": "integer", + "description": "最大返回记忆数量" + } + }, + "required": ["text"] + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行记忆获取 + + Args: + function_args: 工具参数 + message_txt: 原始消息文本 + + Returns: + Dict: 工具执行结果 + """ + try: + text = function_args.get("text", message_txt) + max_memory_num = function_args.get("max_memory_num", 2) + + # 调用记忆系统 + related_memory = await HippocampusManager.get_instance().get_memory_from_text( + text=text, + max_memory_num=max_memory_num, + max_memory_length=2, + max_depth=3, + fast_retrieval=False + ) + + memory_info = "" + if related_memory: + for memory in related_memory: + memory_info += memory[1] + "\n" + + if memory_info: + content = f"你记得这些事情: {memory_info}" + else: + content = f"你不太记得有关{text}的记忆,你对此不太了解" + + return { + "name": "get_memory", + "content": content + } + except Exception as e: + logger.error(f"记忆获取工具执行失败: {str(e)}") + return { + "name": "get_memory", + "content": f"记忆获取失败: {str(e)}" + } + +# 注册工具 +register_tool(GetMemoryTool) \ No newline at end of file diff --git a/src/do_tool/tool_use.py b/src/do_tool/tool_use.py new file mode 100644 index 000000000..a2e23ab21 --- /dev/null +++ b/src/do_tool/tool_use.py @@ -0,0 +1,172 @@ +from src.plugins.models.utils_model import LLM_request +from src.plugins.config.config import global_config +from src.plugins.chat.chat_stream import ChatStream +from src.common.database import db +import time +import json +from src.common.logger import get_module_logger +from typing import Union +from src.do_tool.tool_can_use import get_all_tool_definitions, get_tool_instance + +logger = get_module_logger("tool_use") + + +class ToolUser: + def __init__(self): + self.llm_model_tool = LLM_request( + model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use" + ) + + async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + """构建工具使用的提示词 + + Args: + message_txt: 用户消息文本 + sender_name: 发送者名称 + chat_stream: 聊天流对象 + + Returns: + str: 构建好的提示词 + """ + new_messages = list( + db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}) + .sort("time", 1) + .limit(15) + ) + new_messages_str = "" + for msg in new_messages: + if "detailed_plain_text" in msg: + new_messages_str += f"{msg['detailed_plain_text']}" + + # 这些信息应该从调用者传入,而不是从self获取 + bot_name = global_config.BOT_NICKNAME + prompt = "" + prompt += "你正在思考如何回复群里的消息。\n" + prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" + prompt += f"注意你就是{bot_name},{bot_name}指的就是你。" + prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。" + + return prompt + + def _define_tools(self): + """获取所有已注册工具的定义 + + Returns: + list: 工具定义列表 + """ + return get_all_tool_definitions() + + async def _execute_tool_call(self, tool_call, message_txt:str): + """执行特定的工具调用 + + Args: + tool_call: 工具调用对象 + message_txt: 原始消息文本 + + Returns: + dict: 工具调用结果 + """ + try: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + + # 获取对应工具实例 + tool_instance = get_tool_instance(function_name) + if not tool_instance: + logger.warning(f"未知工具名称: {function_name}") + return None + + # 执行工具 + result = await tool_instance.execute(function_args, message_txt) + if result: + return { + "tool_call_id": tool_call["id"], + "role": "tool", + "name": function_name, + "content": result["content"] + } + return None + except Exception as e: + logger.error(f"执行工具调用时发生错误: {str(e)}") + return None + + async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + """使用工具辅助思考,判断是否需要额外信息 + + Args: + message_txt: 用户消息文本 + sender_name: 发送者名称 + chat_stream: 聊天流对象 + + Returns: + dict: 工具使用结果 + """ + try: + # 构建提示词 + prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream) + + # 定义可用工具 + tools = self._define_tools() + + # 使用llm_model_tool发送带工具定义的请求 + payload = { + "model": self.llm_model_tool.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": global_config.max_response_length, + "tools": tools, + "temperature": 0.2 + } + + logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}") + # 发送请求获取模型是否需要调用工具 + response = await self.llm_model_tool._execute_request( + endpoint="/chat/completions", + payload=payload, + prompt=prompt + ) + + # 根据返回值数量判断是否有工具调用 + if len(response) == 3: + content, reasoning_content, tool_calls = response + logger.info(f"工具思考: {tool_calls}") + + # 检查响应中工具调用是否有效 + if not tool_calls: + logger.info("模型返回了空的tool_calls列表") + return {"used_tools": False} + + logger.info(f"模型请求调用{len(tool_calls)}个工具") + tool_results = [] + collected_info = "" + + # 执行所有工具调用 + for tool_call in tool_calls: + result = await self._execute_tool_call(tool_call, message_txt) + if result: + tool_results.append(result) + # 将工具结果添加到收集的信息中 + collected_info += f"\n{result['name']}返回结果: {result['content']}\n" + + # 如果有工具结果,直接返回收集的信息 + if collected_info: + logger.info(f"工具调用收集到信息: {collected_info}") + return { + "used_tools": True, + "collected_info": collected_info, + } + else: + # 没有工具调用 + content, reasoning_content = response + logger.info("模型没有请求调用任何工具") + + # 如果没有工具调用或处理失败,直接返回原始思考 + return { + "used_tools": False, + } + + except Exception as e: + logger.error(f"工具调用过程中出错: {str(e)}") + return { + "used_tools": False, + "error": str(e), + } \ No newline at end of file diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index 818f1775d..f507374cb 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -68,20 +68,24 @@ class ChattingObservation(Observation): self.translate_message_list_to_str() # 更新观察次数 - self.observe_times += 1 + # self.observe_times += 1 self.last_observe_time = new_messages[-1]["time"] # 检查是否需要更新summary - current_time = int(datetime.now().timestamp()) - if current_time - self.last_summary_time >= 30: # 如果超过30秒,重置计数 - self.summary_count = 0 - self.last_summary_time = current_time + # current_time = int(datetime.now().timestamp()) + # if current_time - self.last_summary_time >= 30: # 如果超过30秒,重置计数 + # self.summary_count = 0 + # self.last_summary_time = current_time - if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次 - await self.update_talking_summary(new_messages_str) - self.summary_count += 1 + # if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次 + # await self.update_talking_summary(new_messages_str) + # print(f"更新聊天总结:{self.observe_info}11111111111111") + # self.summary_count += 1 + updated_observe_info = await self.update_talking_summary(new_messages_str) + print(f"更新聊天总结:{updated_observe_info}11111111111111") + self.observe_info = updated_observe_info - return self.observe_info + return updated_observe_info async def carefully_observe(self): # 查找新消息,限制最多40条 @@ -110,43 +114,46 @@ class ChattingObservation(Observation): self.observe_times += 1 self.last_observe_time = new_messages[-1]["time"] - await self.update_talking_summary(new_messages_str) - return self.observe_info + updated_observe_info = await self.update_talking_summary(new_messages_str) + self.observe_info = updated_observe_info + return updated_observe_info async def update_talking_summary(self, new_messages_str): # 基于已经有的talking_summary,和新的talking_message,生成一个summary # print(f"更新聊天总结:{self.talking_summary}") # 开始构建prompt - prompt_personality = "你" - # person - individuality = Individuality.get_instance() + # prompt_personality = "你" + # # person + # individuality = Individuality.get_instance() - personality_core = individuality.personality.personality_core - prompt_personality += personality_core + # personality_core = individuality.personality.personality_core + # prompt_personality += personality_core - personality_sides = individuality.personality.personality_sides - random.shuffle(personality_sides) - prompt_personality += f",{personality_sides[0]}" + # personality_sides = individuality.personality.personality_sides + # random.shuffle(personality_sides) + # prompt_personality += f",{personality_sides[0]}" - identity_detail = individuality.identity.identity_detail - random.shuffle(identity_detail) - prompt_personality += f",{identity_detail[0]}" + # identity_detail = individuality.identity.identity_detail + # random.shuffle(identity_detail) + # prompt_personality += f",{identity_detail[0]}" - personality_info = prompt_personality + # personality_info = prompt_personality prompt = "" - prompt += f"{personality_info},请注意识别你自己的聊天发言" - prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n" + # prompt += f"{personality_info}" + prompt += f"你的名字叫:{self.name}\n,标识'{self.name}'的都是你自己说的话" prompt += f"你正在参与一个qq群聊的讨论,你记得这个群之前在聊的内容是:{self.observe_info}\n" prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{new_messages_str}\n" - prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容, - 以及聊天中的一些重要信息,注意识别你自己的发言,记得不要分点,不要太长,精简的概括成一段文本\n""" + prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容,目前最新讨论的话题 + 以及聊天中的一些重要信息,记得不要分点,精简的概括成一段文本\n""" prompt += "总结概括:" try: - self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt) + updated_observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt) except Exception as e: print(f"获取总结失败: {e}") - self.observe_info = "" + updated_observe_info = "" + + return updated_observe_info # print(f"prompt:{prompt}") # print(f"self.observe_info:{self.observe_info}") diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index 6d0138987..80d7efb61 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -17,7 +17,7 @@ from src.plugins.chat.chat_stream import ChatStream from src.plugins.person_info.relationship_manager import relationship_manager from src.plugins.chat.utils import get_recent_group_speaker import json -from src.heart_flow.tool_use import ToolUser +from src.do_tool.tool_use import ToolUser subheartflow_config = LogConfig( # 使用海马体专用样式 @@ -133,14 +133,13 @@ class SubHeartflow: tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream) # 如果工具被使用且获得了结果,将收集到的信息合并到思考中 + collected_info = "" if tool_result.get("used_tools", False): logger.info("使用工具收集了信息") # 如果有收集到的信息,将其添加到当前思考中 if "collected_info" in tool_result: collected_info = tool_result["collected_info"] - - # 开始构建prompt prompt_personality = f"你的名字是{self.bot_name},你" @@ -178,6 +177,7 @@ class SubHeartflow: ) prompt = "" + # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" if tool_result.get("used_tools", False): prompt += f"{collected_info}\n" prompt += f"{relation_prompt_all}\n" @@ -187,10 +187,10 @@ class SubHeartflow: prompt += f"现在你正在上网,和qq群里的网友们聊天,群里正在聊的话题是:{chat_observe_info}\n" prompt += f"你现在{mood_info}\n" prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" - prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长," - prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。" - prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)," - prompt += f"记得结合上述的消息,生成符合内心想法的内心独白,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。" + prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白" + prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n" + prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,其他描述等)," + prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。" try: response, reasoning_content = await self.llm_model.generate_response_async(prompt) diff --git a/src/heart_flow/tool_use.py b/src/heart_flow/tool_use.py deleted file mode 100644 index 7471e7512..000000000 --- a/src/heart_flow/tool_use.py +++ /dev/null @@ -1,561 +0,0 @@ -from src.plugins.models.utils_model import LLM_request -from src.plugins.config.config import global_config -from src.plugins.chat.chat_stream import ChatStream -from src.plugins.memory_system.Hippocampus import HippocampusManager -from src.common.database import db -import time -import json -from src.common.logger import get_module_logger -from src.plugins.chat.utils import get_embedding -from typing import Union -logger = get_module_logger("tool_use") - - -class ToolUser: - def __init__(self): - self.llm_model_tool = LLM_request( - model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use" - ) - - async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream): - """构建工具使用的提示词 - - Args: - message_txt: 用户消息文本 - sender_name: 发送者名称 - chat_stream: 聊天流对象 - - Returns: - str: 构建好的提示词 - """ - from src.plugins.config.config import global_config - - new_messages = list( - db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}) - .sort("time", 1) - .limit(15) - ) - new_messages_str = "" - for msg in new_messages: - if "detailed_plain_text" in msg: - new_messages_str += f"{msg['detailed_plain_text']}" - - # 这些信息应该从调用者传入,而不是从self获取 - bot_name = global_config.BOT_NICKNAME - prompt = "" - prompt += "你正在思考如何回复群里的消息。\n" - prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" - prompt += f"注意你就是{bot_name},{bot_name}指的就是你。" - prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。" - - return prompt - - def _define_tools(self): - """定义可用的工具列表 - - Returns: - list: 工具定义列表 - """ - tools = [ - { - "type": "function", - "function": { - "name": "search_knowledge", - "description": "从知识库中搜索相关信息", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "搜索查询关键词" - }, - "threshold": { - "type": "number", - "description": "相似度阈值,0.0到1.0之间" - } - }, - "required": ["query"] - } - } - }, - { - "type": "function", - "function": { - "name": "get_memory", - "description": "从记忆系统中获取相关记忆", - "parameters": { - "type": "object", - "properties": { - "text": { - "type": "string", - "description": "要查询的相关文本" - }, - "max_memory_num": { - "type": "integer", - "description": "最大返回记忆数量" - } - }, - "required": ["text"] - } - } - }, - { - "type": "function", - "function": { - "name": "get_current_task", - "description": "获取当前正在做的事情/最近的任务", - "parameters": { - "type": "object", - "properties": { - "num": { - "type": "integer", - "description": "要获取的任务数量" - }, - "time_info": { - "type": "boolean", - "description": "是否包含时间信息" - } - }, - "required": [] - } - } - } - ] - return tools - - async def _execute_tool_call(self, tool_call, message_txt:str): - """执行特定的工具调用 - - Args: - tool_call: 工具调用对象 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - function_name = tool_call["function"]["name"] - function_args = json.loads(tool_call["function"]["arguments"]) - - if function_name == "search_knowledge": - return await self._execute_search_knowledge(tool_call, function_args, message_txt) - elif function_name == "get_memory": - return await self._execute_get_memory(tool_call, function_args, message_txt) - elif function_name == "get_current_task": - return await self._execute_get_current_task(tool_call, function_args) - - logger.warning(f"未知工具名称: {function_name}") - return None - except Exception as e: - logger.error(f"执行工具调用时发生错误: {str(e)}") - return None - - async def _execute_search_knowledge(self, tool_call, function_args, message_txt:str): - """执行知识库搜索工具 - - Args: - tool_call: 工具调用对象 - function_args: 工具参数 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - query = function_args.get("query", message_txt) - threshold = function_args.get("threshold", 0.4) - - # 调用知识库搜索 - embedding = await get_embedding(query, request_type="info_retrieval") - if embedding: - knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": "search_knowledge", - "content": f"知识库搜索结果: {knowledge_info}" - } - return None - except Exception as e: - logger.error(f"知识库搜索工具执行失败: {str(e)}") - return None - - async def _execute_get_memory(self, tool_call, function_args, message_txt:str): - """执行记忆获取工具 - - Args: - tool_call: 工具调用对象 - function_args: 工具参数 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - text = function_args.get("text", message_txt) - max_memory_num = function_args.get("max_memory_num", 2) - - # 调用记忆系统 - related_memory = await HippocampusManager.get_instance().get_memory_from_text( - text=text, - max_memory_num=max_memory_num, - max_memory_length=2, - max_depth=3, - fast_retrieval=False - ) - - memory_info = "" - if related_memory: - for memory in related_memory: - memory_info += memory[1] + "\n" - - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": "get_memory", - "content": f"记忆系统结果: {memory_info if memory_info else '没有找到相关记忆'}" - } - except Exception as e: - logger.error(f"记忆获取工具执行失败: {str(e)}") - return None - - async def _execute_get_current_task(self, tool_call, function_args): - """执行获取当前任务工具 - - Args: - tool_call: 工具调用对象 - function_args: 工具参数 - - Returns: - dict: 工具调用结果 - """ - try: - from src.plugins.schedule.schedule_generator import bot_schedule - - # 获取参数,如果没有提供则使用默认值 - num = function_args.get("num", 1) - time_info = function_args.get("time_info", False) - - # 调用日程系统获取当前任务 - current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info) - - # 格式化返回结果 - if current_task: - task_info = current_task - else: - task_info = "当前没有正在进行的任务" - - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": "get_current_task", - "content": f"当前任务信息: {task_info}" - } - except Exception as e: - logger.error(f"获取当前任务工具执行失败: {str(e)}") - return None - - async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream): - """使用工具辅助思考,判断是否需要额外信息 - - Args: - message_txt: 用户消息文本 - sender_name: 发送者名称 - chat_stream: 聊天流对象 - - Returns: - dict: 工具使用结果 - """ - try: - # 构建提示词 - prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream) - - # 定义可用工具 - tools = self._define_tools() - - # 使用llm_model_tool发送带工具定义的请求 - payload = { - "model": self.llm_model_tool.model_name, - "messages": [{"role": "user", "content": prompt}], - "max_tokens": global_config.max_response_length, - "tools": tools, - "temperature": 0.2 - } - - logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}") - # 发送请求获取模型是否需要调用工具 - response = await self.llm_model_tool._execute_request( - endpoint="/chat/completions", - payload=payload, - prompt=prompt - ) - - # 根据返回值数量判断是否有工具调用 - if len(response) == 3: - content, reasoning_content, tool_calls = response - logger.info(f"工具思考: {tool_calls}") - - # 检查响应中工具调用是否有效 - if not tool_calls: - logger.info("模型返回了空的tool_calls列表") - return {"used_tools": False, "thinking": self.current_mind} - - logger.info(f"模型请求调用{len(tool_calls)}个工具") - tool_results = [] - collected_info = "" - - # 执行所有工具调用 - for tool_call in tool_calls: - result = await self._execute_tool_call(tool_call, message_txt) - if result: - tool_results.append(result) - # 将工具结果添加到收集的信息中 - collected_info += f"\n{result['name']}返回结果: {result['content']}\n" - - # 如果有工具结果,直接返回收集的信息 - if collected_info: - logger.info(f"工具调用收集到信息: {collected_info}") - return { - "used_tools": True, - "collected_info": collected_info, - "thinking": self.current_mind # 保持原始思考不变 - } - else: - # 没有工具调用 - content, reasoning_content = response - logger.info("模型没有请求调用任何工具") - - # 如果没有工具调用或处理失败,直接返回原始思考 - return { - "used_tools": False, - "thinking": self.current_mind - } - - except Exception as e: - logger.error(f"工具调用过程中出错: {str(e)}") - return { - "used_tools": False, - "error": str(e), - "thinking": self.current_mind - } - - - - async def get_prompt_info(self, message: str, threshold: float): - start_time = time.time() - related_info = "" - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - - # 1. 先从LLM获取主题,类似于记忆系统的做法 - topics = [] - # try: - # # 先尝试使用记忆系统的方法获取主题 - # hippocampus = HippocampusManager.get_instance()._hippocampus - # topic_num = min(5, max(1, int(len(message) * 0.1))) - # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) - - # # 提取关键词 - # topics = re.findall(r"<([^>]+)>", topics_response[0]) - # if not topics: - # topics = [] - # else: - # topics = [ - # topic.strip() - # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - # if topic.strip() - # ] - - # logger.info(f"从LLM提取的主题: {', '.join(topics)}") - # except Exception as e: - # logger.error(f"从LLM提取主题失败: {str(e)}") - # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 - # words = jieba.cut(message) - # topics = [word for word in words if len(word) > 1][:5] - # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") - - # 如果无法提取到主题,直接使用整个消息 - if not topics: - logger.debug("未能提取到任何主题,使用整个消息进行查询") - embedding = await get_embedding(message, request_type="info_retrieval") - if not embedding: - logger.error("获取消息嵌入向量失败") - return "" - - related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒") - return related_info, {} - - # 2. 对每个主题进行知识库查询 - logger.info(f"开始处理{len(topics)}个主题的知识库查询") - - # 优化:批量获取嵌入向量,减少API调用 - embeddings = {} - topics_batch = [topic for topic in topics if len(topic) > 0] - if message: # 确保消息非空 - topics_batch.append(message) - - # 批量获取嵌入向量 - embed_start_time = time.time() - for text in topics_batch: - if not text or len(text.strip()) == 0: - continue - - try: - embedding = await get_embedding(text, request_type="info_retrieval") - if embedding: - embeddings[text] = embedding - else: - logger.warning(f"获取'{text}'的嵌入向量失败") - except Exception as e: - logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") - - logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒") - - if not embeddings: - logger.error("所有嵌入向量获取失败") - return "" - - # 3. 对每个主题进行知识库查询 - all_results = [] - query_start_time = time.time() - - # 首先添加原始消息的查询结果 - if message in embeddings: - original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) - if original_results: - for result in original_results: - result["topic"] = "原始消息" - all_results.extend(original_results) - logger.info(f"原始消息查询到{len(original_results)}条结果") - - # 然后添加每个主题的查询结果 - for topic in topics: - if not topic or topic not in embeddings: - continue - - try: - topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) - if topic_results: - # 添加主题标记 - for result in topic_results: - result["topic"] = topic - all_results.extend(topic_results) - logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") - except Exception as e: - logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") - - logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") - - # 4. 去重和过滤 - process_start_time = time.time() - unique_contents = set() - filtered_results = [] - for result in all_results: - content = result["content"] - if content not in unique_contents: - unique_contents.add(content) - filtered_results.append(result) - - # 5. 按相似度排序 - filtered_results.sort(key=lambda x: x["similarity"], reverse=True) - - # 6. 限制总数量(最多10条) - filtered_results = filtered_results[:10] - logger.info( - f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果" - ) - - # 7. 格式化输出 - if filtered_results: - format_start_time = time.time() - grouped_results = {} - for result in filtered_results: - topic = result["topic"] - if topic not in grouped_results: - grouped_results[topic] = [] - grouped_results[topic].append(result) - - # 按主题组织输出 - for topic, results in grouped_results.items(): - related_info += f"【主题: {topic}】\n" - for _i, result in enumerate(results, 1): - _similarity = result["similarity"] - content = result["content"].strip() - # 调试:为内容添加序号和相似度信息 - # related_info += f"{i}. [{similarity:.2f}] {content}\n" - related_info += f"{content}\n" - related_info += "\n" - - logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒") - - logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒") - return related_info, grouped_results - - def get_info_from_db( - self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False - ) -> Union[str, list]: - if not query_embedding: - return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] - - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") - - if not results: - return "" if not return_raw else [] - - if return_raw: - return results - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) \ No newline at end of file From 9aacbd55cbfd66d7807562040a6506dcb8b4948e Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Thu, 10 Apr 2025 23:44:34 +0800 Subject: [PATCH 18/24] Update sub_heartflow.py --- src/heart_flow/sub_heartflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index 80d7efb61..baa20c64f 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -189,7 +189,7 @@ class SubHeartflow: prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白" prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n" - prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,其他描述等)," + prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写" prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。" try: From 68a60f7e7124f8e231747e4e84f5cb871838df70 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 11 Apr 2025 10:22:49 +0800 Subject: [PATCH 19/24] =?UTF-8?q?fix:=20PFC=E4=B8=8D=E8=AF=BB=E5=8F=96?= =?UTF-8?q?=E8=81=8A=E5=A4=A9=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/PFC/action_planner.py | 14 ++- src/plugins/PFC/chat_observer.py | 110 ++++++------------ src/plugins/PFC/chat_states.py | 25 +++- src/plugins/PFC/conversation.py | 15 +-- src/plugins/PFC/notification_handler.py | 71 ----------- src/plugins/PFC/observation_info.py | 91 ++++++++------- src/plugins/PFC/pfc.py | 54 ++++++--- src/plugins/PFC/pfc_utils.py | 62 +++++++++- .../think_flow_chat/think_flow_generator.py | 4 +- 9 files changed, 222 insertions(+), 224 deletions(-) delete mode 100644 src/plugins/PFC/notification_handler.py diff --git a/src/plugins/PFC/action_planner.py b/src/plugins/PFC/action_planner.py index 43b0749a1..6eb67b7d4 100644 --- a/src/plugins/PFC/action_planner.py +++ b/src/plugins/PFC/action_planner.py @@ -45,7 +45,19 @@ class ActionPlanner: # 构建对话目标 if conversation_info.goal_list: - goal, reasoning = conversation_info.goal_list[-1] + last_goal = conversation_info.goal_list[-1] + print(last_goal) + # 处理字典或元组格式 + if isinstance(last_goal, tuple) and len(last_goal) == 2: + goal, reasoning = last_goal + elif isinstance(last_goal, dict) and 'goal' in last_goal and 'reasoning' in last_goal: + # 处理字典格式 + goal = last_goal.get('goal', "目前没有明确对话目标") + reasoning = last_goal.get('reasoning', "目前没有明确对话目标,最好思考一个对话目标") + else: + # 处理未知格式 + goal = "目前没有明确对话目标" + reasoning = "目前没有明确对话目标,最好思考一个对话目标" else: goal = "目前没有明确对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标" diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py index c96bc47b1..b9f704917 100644 --- a/src/plugins/PFC/chat_observer.py +++ b/src/plugins/PFC/chat_observer.py @@ -1,5 +1,6 @@ import time import asyncio +import traceback from typing import Optional, Dict, Any, List, Tuple from src.common.logger import get_module_logger from ..message.message_base import UserInfo @@ -44,18 +45,14 @@ class ChatObserver: self.stream_id = stream_id self.message_storage = message_storage or MongoDBMessageStorage() - self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 - self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 - self.last_check_time: float = time.time() # 上次查看聊天记录时间 + # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 + # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 + # self.last_check_time: float = time.time() # 上次查看聊天记录时间 self.last_message_read: Optional[str] = None # 最后读取的消息ID self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 - # 消息历史记录 - self.message_history: List[Dict[str, Any]] = [] # 所有消息历史 - self.last_message_id: Optional[str] = None # 最后一条消息的ID - self.message_count: int = 0 # 消息计数 # 运行状态 self._running: bool = False @@ -72,7 +69,7 @@ class ChatObserver: self.is_cold_chat_state: bool = False self.update_event = asyncio.Event() - self.update_interval = 5 # 更新间隔(秒) + self.update_interval = 2 # 更新间隔(秒) self.message_cache = [] self.update_running = False @@ -98,21 +95,17 @@ class ChatObserver: Args: message: 消息数据 """ - self.message_history.append(message) - self.last_message_id = message["message_id"] - self.last_message_time = message["time"] # 更新最后消息时间 - self.message_count += 1 + try: - # 更新说话时间 - user_info = UserInfo.from_dict(message.get("user_info", {})) - if user_info.user_id == global_config.BOT_QQ: - self.last_bot_speak_time = message["time"] - else: - self.last_user_speak_time = message["time"] - - # 发送新消息通知 - notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message) - await self.notification_manager.send_notification(notification) + # 发送新消息通知 + # logger.info(f"发送新ccchandleer消息通知: {message}") + notification = create_new_message_notification(sender="chat_observer", target="observation_info", message=message) + # logger.info(f"发送新消ddddd息通知: {notification}") + # print(self.notification_manager) + await self.notification_manager.send_notification(notification) + except Exception as e: + logger.error(f"添加消息到历史记录时出错: {e}") + print(traceback.format_exc()) # 检查并更新冷场状态 await self._check_cold_chat() @@ -156,9 +149,6 @@ class ChatObserver: Returns: bool: 是否有新消息 """ - if time_point is None: - logger.warning("time_point 为 None,返回 False") - return False if self.last_message_time is None: logger.debug("没有最后消息时间,返回 False") @@ -214,6 +204,8 @@ class ChatObserver: if new_messages: self.last_message_read = new_messages[-1]["message_id"] + + print(f"获取111111111122222222新消息: {new_messages}") return new_messages @@ -230,6 +222,8 @@ class ChatObserver: if new_messages: self.last_message_read = new_messages[-1]["message_id"] + + logger.debug(f"获取指定时间点111之前的消息: {new_messages}") return new_messages @@ -237,20 +231,24 @@ class ChatObserver: async def _update_loop(self): """更新循环""" - try: - start_time = time.time() - messages = await self._fetch_new_messages_before(start_time) - for message in messages: - await self._add_message_to_history(message) - except Exception as e: - logger.error(f"缓冲消息出错: {e}") + # try: + # start_time = time.time() + # messages = await self._fetch_new_messages_before(start_time) + # for message in messages: + # await self._add_message_to_history(message) + # logger.debug(f"缓冲消息: {messages}") + # except Exception as e: + # logger.error(f"缓冲消息出错: {e}") while self._running: try: # 等待事件或超时(1秒) try: + # print("等待事件") await asyncio.wait_for(self._update_event.wait(), timeout=1) + except asyncio.TimeoutError: + # print("超时") pass # 超时后也执行一次检查 self._update_event.clear() # 重置触发事件 @@ -355,51 +353,6 @@ class ChatObserver: return time_info - def start_periodic_update(self): - """启动观察器的定期更新""" - if not self.update_running: - self.update_running = True - asyncio.create_task(self._periodic_update()) - - async def _periodic_update(self): - """定期更新消息历史""" - try: - while self.update_running: - await self._update_message_history() - await asyncio.sleep(self.update_interval) - except Exception as e: - logger.error(f"定期更新消息历史时出错: {str(e)}") - - async def _update_message_history(self) -> bool: - """更新消息历史 - - Returns: - bool: 是否有新消息 - """ - try: - messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50) - - if not messages: - return False - - # 检查是否有新消息 - has_new_messages = False - if messages and ( - not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"] - ): - has_new_messages = True - - self.message_cache = messages - - if has_new_messages: - self.update_event.set() - self.update_event.clear() - return True - return False - - except Exception as e: - logger.error(f"更新消息历史时出错: {str(e)}") - return False def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: """获取缓存的消息历史 @@ -421,3 +374,6 @@ class ChatObserver: if not self.message_cache: return None return self.message_cache[0] + + def __str__(self): + return f"ChatObserver for {self.stream_id}" diff --git a/src/plugins/PFC/chat_states.py b/src/plugins/PFC/chat_states.py index b28ca69a6..373dfdb74 100644 --- a/src/plugins/PFC/chat_states.py +++ b/src/plugins/PFC/chat_states.py @@ -98,11 +98,17 @@ class NotificationManager: notification_type: 要处理的通知类型 handler: 处理器实例 """ + print(1145145511114445551111444) if target not in self._handlers: + print("没11有target") self._handlers[target] = {} if notification_type not in self._handlers[target]: + print("没11有notification_type") self._handlers[target][notification_type] = [] + print(self._handlers[target][notification_type]) + print(f"注册1111111111111111111111处理器: {target} {notification_type} {handler}") self._handlers[target][notification_type].append(handler) + print(self._handlers[target][notification_type]) def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): """注销通知处理器 @@ -126,6 +132,7 @@ class NotificationManager: async def send_notification(self, notification: Notification): """发送通知""" self._notification_history.append(notification) + # print("kaishichul-----------------------------------i") # 如果是状态通知,更新活跃状态 if isinstance(notification, StateNotification): @@ -133,12 +140,16 @@ class NotificationManager: self._active_states.add(notification.type) else: self._active_states.discard(notification.type) + # 调用目标接收者的处理器 target = notification.target if target in self._handlers: handlers = self._handlers[target].get(notification.type, []) + # print(1111111) + print(handlers) for handler in handlers: + print(f"调用处理器: {handler}") await handler.handle_notification(notification) def get_active_states(self) -> Set[NotificationType]: @@ -170,6 +181,13 @@ class NotificationManager: history = history[-limit:] return history + + def __str__(self): + str = "" + for target, handlers in self._handlers.items(): + for notification_type, handler_list in handlers.items(): + str += f"NotificationManager for {target} {notification_type} {handler_list}" + return str # 一些常用的通知创建函数 @@ -182,8 +200,9 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str, target=target, data={ "message_id": message.get("message_id"), - "content": message.get("content"), - "sender": message.get("sender"), + "processed_plain_text": message.get("processed_plain_text"), + "detailed_plain_text": message.get("detailed_plain_text"), + "user_info": message.get("user_info"), "time": message.get("time"), }, ) @@ -276,3 +295,5 @@ class ChatStateManager: current_time = datetime.now().timestamp() return (current_time - self.state_info.last_message_time) <= threshold + + diff --git a/src/plugins/PFC/conversation.py b/src/plugins/PFC/conversation.py index 40a729671..a5da3e48d 100644 --- a/src/plugins/PFC/conversation.py +++ b/src/plugins/PFC/conversation.py @@ -60,9 +60,10 @@ class Conversation: self.chat_observer = ChatObserver.get_instance(self.stream_id) self.chat_observer.start() self.observation_info = ObservationInfo() - self.observation_info.bind_to_chat_observer(self.stream_id) + self.observation_info.bind_to_chat_observer(self.chat_observer) + # print(self.chat_observer.get_cached_messages(limit=) - # 对话信息 + self.conversation_info = ConversationInfo() except Exception as e: logger.error(f"初始化对话实例:注册信息组件失败: {e}") @@ -140,6 +141,7 @@ class Conversation: if action == "direct_reply": self.state = ConversationState.GENERATING self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info) + print(f"生成回复: {self.generated_reply}") # # 检查回复是否合适 # is_suitable, reason, need_replan = await self.reply_generator.check_reply( @@ -148,6 +150,7 @@ class Conversation: # ) if self._check_new_messages_after_planning(): + logger.info("333333发现新消息,重新考虑行动") return None await self._send_reply() @@ -212,15 +215,9 @@ class Conversation: logger.warning("没有生成回复") return - messages = self.chat_observer.get_cached_messages(limit=1) - if not messages: - logger.warning("没有最近的消息可以回复") - return - - latest_message = self._convert_to_message(messages[0]) try: await self.direct_sender.send_message( - chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message + chat_stream=self.chat_stream, content=self.generated_reply ) self.chat_observer.trigger_update() # 触发立即更新 if not await self.chat_observer.wait_for_update(): diff --git a/src/plugins/PFC/notification_handler.py b/src/plugins/PFC/notification_handler.py deleted file mode 100644 index 1131d18bf..000000000 --- a/src/plugins/PFC/notification_handler.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import TYPE_CHECKING -from src.common.logger import get_module_logger -from .chat_states import NotificationHandler, Notification, NotificationType - -if TYPE_CHECKING: - from .conversation import Conversation - -logger = get_module_logger("notification_handler") - - -class PFCNotificationHandler(NotificationHandler): - """PFC通知处理器""" - - def __init__(self, conversation: "Conversation"): - """初始化PFC通知处理器 - - Args: - conversation: 对话实例 - """ - self.conversation = conversation - - async def handle_notification(self, notification: Notification): - """处理通知 - - Args: - notification: 通知对象 - """ - logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}") - - # 根据通知类型执行不同的处理 - if notification.type == NotificationType.NEW_MESSAGE: - # 新消息通知 - await self._handle_new_message(notification) - elif notification.type == NotificationType.COLD_CHAT: - # 冷聊天通知 - await self._handle_cold_chat(notification) - elif notification.type == NotificationType.COMMAND: - # 命令通知 - await self._handle_command(notification) - else: - logger.warning(f"未知的通知类型: {notification.type.name}") - - async def _handle_new_message(self, notification: Notification): - """处理新消息通知 - - Args: - notification: 通知对象 - """ - - # 更新决策信息 - observation_info = self.conversation.observation_info - observation_info.last_message_time = notification.data.get("time", 0) - observation_info.add_unprocessed_message(notification.data) - - # 手动触发观察器更新 - self.conversation.chat_observer.trigger_update() - - async def _handle_cold_chat(self, notification: Notification): - """处理冷聊天通知 - - Args: - notification: 通知对象 - """ - # 获取冷聊天信息 - cold_duration = notification.data.get("duration", 0) - - # 更新决策信息 - observation_info = self.conversation.observation_info - observation_info.conversation_cold_duration = cold_duration - - logger.info(f"对话已冷: {cold_duration}秒") diff --git a/src/plugins/PFC/observation_info.py b/src/plugins/PFC/observation_info.py index d0eee2236..947c3205d 100644 --- a/src/plugins/PFC/observation_info.py +++ b/src/plugins/PFC/observation_info.py @@ -6,7 +6,7 @@ import time from dataclasses import dataclass, field from src.common.logger import get_module_logger from .chat_observer import ChatObserver -from .chat_states import NotificationHandler +from .chat_states import NotificationHandler, NotificationType logger = get_module_logger("observation_info") @@ -22,63 +22,70 @@ class ObservationInfoHandler(NotificationHandler): """ self.observation_info = observation_info - async def handle_notification(self, notification: Dict[str, Any]): - """处理通知 - - Args: - notification: 通知数据 - """ - notification_type = notification.get("type") - data = notification.get("data", {}) - - if notification_type == "NEW_MESSAGE": + async def handle_notification(self, notification): + # 获取通知类型和数据 + notification_type = notification.type + data = notification.data + + if notification_type == NotificationType.NEW_MESSAGE: # 处理新消息通知 logger.debug(f"收到新消息通知data: {data}") - message = data.get("message", {}) + message_id = data.get("message_id") + processed_plain_text = data.get("processed_plain_text") + detailed_plain_text = data.get("detailed_plain_text") + user_info = data.get("user_info") + time_value = data.get("time") + + message = { + "message_id": message_id, + "processed_plain_text": processed_plain_text, + "detailed_plain_text": detailed_plain_text, + "user_info": user_info, + "time": time_value + } + self.observation_info.update_from_message(message) - # self.observation_info.has_unread_messages = True - # self.observation_info.new_unread_message.append(message.get("processed_plain_text", "")) - elif notification_type == "COLD_CHAT": + elif notification_type == NotificationType.COLD_CHAT: # 处理冷场通知 is_cold = data.get("is_cold", False) self.observation_info.update_cold_chat_status(is_cold, time.time()) - elif notification_type == "ACTIVE_CHAT": + elif notification_type == NotificationType.ACTIVE_CHAT: # 处理活跃通知 is_active = data.get("is_active", False) self.observation_info.is_cold = not is_active - elif notification_type == "BOT_SPEAKING": + elif notification_type == NotificationType.BOT_SPEAKING: # 处理机器人说话通知 self.observation_info.is_typing = False self.observation_info.last_bot_speak_time = time.time() - elif notification_type == "USER_SPEAKING": + elif notification_type == NotificationType.USER_SPEAKING: # 处理用户说话通知 self.observation_info.is_typing = False self.observation_info.last_user_speak_time = time.time() - elif notification_type == "MESSAGE_DELETED": + elif notification_type == NotificationType.MESSAGE_DELETED: # 处理消息删除通知 message_id = data.get("message_id") self.observation_info.unprocessed_messages = [ msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id ] - elif notification_type == "USER_JOINED": + elif notification_type == NotificationType.USER_JOINED: # 处理用户加入通知 user_id = data.get("user_id") if user_id: self.observation_info.active_users.add(user_id) - elif notification_type == "USER_LEFT": + elif notification_type == NotificationType.USER_LEFT: # 处理用户离开通知 user_id = data.get("user_id") if user_id: self.observation_info.active_users.discard(user_id) - elif notification_type == "ERROR": + elif notification_type == NotificationType.ERROR: # 处理错误通知 error_msg = data.get("error", "") logger.error(f"收到错误通知: {error_msg}") @@ -100,6 +107,7 @@ class ObservationInfo: last_message_content: str = "" last_message_sender: Optional[str] = None bot_id: Optional[str] = None + chat_history_count: int = 0 new_messages_count: int = 0 cold_chat_duration: float = 0.0 @@ -117,28 +125,37 @@ class ObservationInfo: self.chat_observer = None self.handler = ObservationInfoHandler(self) - def bind_to_chat_observer(self, stream_id: str): + def bind_to_chat_observer(self, chat_observer: ChatObserver): """绑定到指定的chat_observer Args: stream_id: 聊天流ID """ - self.chat_observer = ChatObserver.get_instance(stream_id) + self.chat_observer = chat_observer + print(f"1919810----------------------绑定-----------------------------") + print(self.chat_observer) + print(f"1919810--------------------绑定-----------------------------") + print(self.chat_observer.notification_manager) + print(f"1919810-------------------绑定-----------------------------") self.chat_observer.notification_manager.register_handler( - target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler + target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler ) self.chat_observer.notification_manager.register_handler( - target="observation_info", notification_type="COLD_CHAT", handler=self.handler + target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler ) + print("1919810------------------------绑定-----------------------------") + print(f"1919810--------------------绑定-----------------------------") + print(self.chat_observer.notification_manager) + print(f"1919810-------------------绑定-----------------------------") def unbind_from_chat_observer(self): """解除与chat_observer的绑定""" if self.chat_observer: self.chat_observer.notification_manager.unregister_handler( - target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler + target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler ) self.chat_observer.notification_manager.unregister_handler( - target="observation_info", notification_type="COLD_CHAT", handler=self.handler + target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler ) self.chat_observer = None @@ -148,8 +165,11 @@ class ObservationInfo: Args: message: 消息数据 """ + print("1919810-----------------------------------------------------") logger.debug(f"更新信息from_message: {message}") self.last_message_time = message["time"] + self.last_message_id = message["message_id"] + self.last_message_content = message.get("processed_plain_text", "") user_info = UserInfo.from_dict(message.get("user_info", {})) @@ -169,7 +189,6 @@ class ObservationInfo: def update_changed(self): """更新changed状态""" self.changed = True - # self.meta_plan_trigger = True def update_cold_chat_status(self, is_cold: bool, current_time: float): """更新冷场状态 @@ -223,17 +242,3 @@ class ObservationInfo: self.unprocessed_messages.clear() self.new_messages_count = 0 - def add_unprocessed_message(self, message: Dict[str, Any]): - """添加未处理的消息 - - Args: - message: 消息数据 - """ - # 防止重复添加同一消息 - message_id = message.get("message_id") - if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages): - self.unprocessed_messages.append(message) - self.new_messages_count += 1 - - # 同时更新其他消息相关信息 - self.update_from_message(message) diff --git a/src/plugins/PFC/pfc.py b/src/plugins/PFC/pfc.py index c88ed47d5..0a20812b9 100644 --- a/src/plugins/PFC/pfc.py +++ b/src/plugins/PFC/pfc.py @@ -99,19 +99,21 @@ class GoalAnalyzer: 3. 添加新目标 4. 删除不再相关的目标 -请以JSON格式输出当前的所有对话目标,包含以下字段: +请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段: 1. goal: 对话目标(简短的一句话) 2. reasoning: 对话原因,为什么设定这个目标(简要解释) 输出格式示例: -{{ -"goal": "回答用户关于Python编程的具体问题", -"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答" -}}, -{{ -"goal": "回答用户关于python安装的具体问题", -"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答" -}}""" +[ + {{ + "goal": "回答用户关于Python编程的具体问题", + "reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答" + }}, + {{ + "goal": "回答用户关于python安装的具体问题", + "reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答" + }} +]""" logger.debug(f"发送到LLM的提示词: {prompt}") try: @@ -120,13 +122,37 @@ class GoalAnalyzer: except Exception as e: logger.error(f"分析对话目标时出错: {str(e)}") content = "" - # 使用简化函数提取JSON内容 + + # 使用改进后的get_items_from_json函数处理JSON数组 success, result = get_items_from_json( - content, "goal", "reasoning", required_types={"goal": str, "reasoning": str} + content, "goal", "reasoning", + required_types={"goal": str, "reasoning": str}, + allow_array=True ) - # TODO - - conversation_info.goal_list.append(result) + + if success: + # 判断结果是单个字典还是字典列表 + if isinstance(result, list): + # 清空现有目标列表并添加新目标 + conversation_info.goal_list = [] + for item in result: + goal = item.get("goal", "") + reasoning = item.get("reasoning", "") + conversation_info.goal_list.append((goal, reasoning)) + + # 返回第一个目标作为当前主要目标(如果有) + if result: + first_goal = result[0] + return (first_goal.get("goal", ""), "", first_goal.get("reasoning", "")) + else: + # 单个目标的情况 + goal = result.get("goal", "") + reasoning = result.get("reasoning", "") + conversation_info.goal_list.append((goal, reasoning)) + return (goal, "", reasoning) + + # 如果解析失败,返回默认值 + return ("", "", "") async def _update_goals(self, new_goal: str, method: str, reasoning: str): """更新目标列表 diff --git a/src/plugins/PFC/pfc_utils.py b/src/plugins/PFC/pfc_utils.py index 633d9016e..f99b32a3d 100644 --- a/src/plugins/PFC/pfc_utils.py +++ b/src/plugins/PFC/pfc_utils.py @@ -1,6 +1,6 @@ import json import re -from typing import Dict, Any, Optional, Tuple +from typing import Dict, Any, Optional, Tuple, List, Union from src.common.logger import get_module_logger logger = get_module_logger("pfc_utils") @@ -11,7 +11,8 @@ def get_items_from_json( *items: str, default_values: Optional[Dict[str, Any]] = None, required_types: Optional[Dict[str, type]] = None, -) -> Tuple[bool, Dict[str, Any]]: + allow_array: bool = True, +) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: """从文本中提取JSON内容并获取指定字段 Args: @@ -19,18 +20,69 @@ def get_items_from_json( *items: 要提取的字段名 default_values: 字段的默认值,格式为 {字段名: 默认值} required_types: 字段的必需类型,格式为 {字段名: 类型} + allow_array: 是否允许解析JSON数组 Returns: - Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典) + Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表) """ content = content.strip() result = {} - + # 设置默认值 if default_values: result.update(default_values) - # 尝试解析JSON + # 首先尝试解析为JSON数组 + if allow_array: + try: + # 尝试找到文本中的JSON数组 + array_pattern = r"\[[\s\S]*\]" + array_match = re.search(array_pattern, content) + if array_match: + array_content = array_match.group() + json_array = json.loads(array_content) + + # 确认是数组类型 + if isinstance(json_array, list): + # 验证数组中的每个项目是否包含所有必需字段 + valid_items = [] + for item in json_array: + if not isinstance(item, dict): + continue + + # 检查是否有所有必需字段 + if all(field in item for field in items): + # 验证字段类型 + if required_types: + type_valid = True + for field, expected_type in required_types.items(): + if field in item and not isinstance(item[field], expected_type): + type_valid = False + break + + if not type_valid: + continue + + # 验证字符串字段不为空 + string_valid = True + for field in items: + if isinstance(item[field], str) and not item[field].strip(): + string_valid = False + break + + if not string_valid: + continue + + valid_items.append(item) + + if valid_items: + return True, valid_items + except json.JSONDecodeError: + logger.debug("JSON数组解析失败,尝试解析单个JSON对象") + except Exception as e: + logger.debug(f"尝试解析JSON数组时出错: {str(e)}") + + # 尝试解析JSON对象 try: json_data = json.loads(content) except json.JSONDecodeError: diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py index f422b8c99..9541eed17 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py @@ -26,11 +26,11 @@ logger = get_module_logger("llm_generator", config=llm_config) class ResponseGenerator: def __init__(self): self.model_normal = LLM_request( - model=global_config.llm_normal, temperature=0.3, max_tokens=256, request_type="response_heartflow" + model=global_config.llm_normal, temperature=0.15, max_tokens=256, request_type="response_heartflow" ) self.model_sum = LLM_request( - model=global_config.llm_summary_by_topic, temperature=0.7, max_tokens=2000, request_type="relation" + model=global_config.llm_summary_by_topic, temperature=0.6, max_tokens=2000, request_type="relation" ) self.current_model_type = "r1" # 默认使用 R1 self.current_model_name = "unknown model" From 157f2bc0447ce956417f69699f58ecc324e24c95 Mon Sep 17 00:00:00 2001 From: HexatomicRing <54496918+HexatomicRing@users.noreply.github.com> Date: Fri, 11 Apr 2025 10:35:34 +0800 Subject: [PATCH 20/24] =?UTF-8?q?=E9=98=B2=E6=AD=A2=E5=85=B3=E9=94=AE?= =?UTF-8?q?=E8=AF=8D=E5=92=8Cregex=E9=87=8D=E5=A4=8D=E5=8C=B9=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../reasoning_prompt_builder.py | 23 ++++++++++--------- .../think_flow_prompt_builder.py | 23 ++++++++++--------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py index 045045bae..75a876a9c 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py @@ -115,17 +115,18 @@ class PromptBuilder: f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" ) keywords_reaction_prompt += rule.get("reaction", "") + "," - for pattern in rule.get("regex", []): - result = pattern.search(message_txt) - if result: - reaction = rule.get('reaction', '') - for name, content in result.groupdict().items(): - reaction = reaction.replace(f'[{name}]', content) - logger.info( - f"匹配到以下正则表达式:{pattern},触发反应:{reaction}" - ) - keywords_reaction_prompt += reaction + "," - break + else: + for pattern in rule.get("regex", []): + result = pattern.search(message_txt) + if result: + reaction = rule.get('reaction', '') + for name, content in result.groupdict().items(): + reaction = reaction.replace(f'[{name}]', content) + logger.info( + f"匹配到以下正则表达式:{pattern},触发反应:{reaction}" + ) + keywords_reaction_prompt += reaction + "," + break # 中文高手(新加的好玩功能) prompt_ger = "" diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py index 8938e2e78..7ae7940bb 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py @@ -61,17 +61,18 @@ class PromptBuilder: f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" ) keywords_reaction_prompt += rule.get("reaction", "") + "," - for pattern in rule.get("regex", []): - result = pattern.search(message_txt) - if result: - reaction = rule.get('reaction', '') - for name, content in result.groupdict().items(): - reaction = reaction.replace(f'[{name}]', content) - logger.info( - f"匹配到以下正则表达式:{pattern},触发反应:{reaction}" - ) - keywords_reaction_prompt += reaction + "," - break + else: + for pattern in rule.get("regex", []): + result = pattern.search(message_txt) + if result: + reaction = rule.get('reaction', '') + for name, content in result.groupdict().items(): + reaction = reaction.replace(f'[{name}]', content) + logger.info( + f"匹配到以下正则表达式:{pattern},触发反应:{reaction}" + ) + keywords_reaction_prompt += reaction + "," + break # 中文高手(新加的好玩功能) prompt_ger = "" From 138fc11752d0bc5d4d6a7159e34c4677e5bb45a9 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 11 Apr 2025 10:48:14 +0800 Subject: [PATCH 21/24] =?UTF-8?q?fix:=E5=AD=A9=E5=AD=90=E4=BB=AC=EF=BC=8CP?= =?UTF-8?q?FC=E7=BB=88=E4=BA=8E=E5=A4=8D=E6=B4=BB=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/PFC/action_planner.py | 4 ++-- src/plugins/PFC/chat_observer.py | 28 +++++++++++----------------- src/plugins/PFC/message_storage.py | 15 ++++++--------- src/plugins/PFC/observation_info.py | 12 ++---------- 4 files changed, 21 insertions(+), 38 deletions(-) diff --git a/src/plugins/PFC/action_planner.py b/src/plugins/PFC/action_planner.py index 6eb67b7d4..372474ac0 100644 --- a/src/plugins/PFC/action_planner.py +++ b/src/plugins/PFC/action_planner.py @@ -66,14 +66,14 @@ class ActionPlanner: chat_history_list = observation_info.chat_history chat_history_text = "" for msg in chat_history_list: - chat_history_text += f"{msg}\n" + chat_history_text += f"{msg.get('detailed_plain_text', '')}\n" if observation_info.new_messages_count > 0: new_messages_list = observation_info.unprocessed_messages chat_history_text += f"有{observation_info.new_messages_count}条新消息:\n" for msg in new_messages_list: - chat_history_text += f"{msg}\n" + chat_history_text += f"{msg.get('detailed_plain_text', '')}\n" observation_info.clear_unprocessed_messages() diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py index b9f704917..a766e3b4a 100644 --- a/src/plugins/PFC/chat_observer.py +++ b/src/plugins/PFC/chat_observer.py @@ -18,38 +18,36 @@ class ChatObserver: _instances: Dict[str, "ChatObserver"] = {} @classmethod - def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver": + def get_instance(cls, stream_id: str) -> "ChatObserver": """获取或创建观察器实例 Args: stream_id: 聊天流ID - message_storage: 消息存储实现,如果为None则使用MongoDB实现 Returns: ChatObserver: 观察器实例 """ if stream_id not in cls._instances: - cls._instances[stream_id] = cls(stream_id, message_storage) + cls._instances[stream_id] = cls(stream_id) return cls._instances[stream_id] - def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None): + def __init__(self, stream_id: str): """初始化观察器 Args: stream_id: 聊天流ID - message_storage: 消息存储实现,如果为None则使用MongoDB实现 """ if stream_id in self._instances: raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.") self.stream_id = stream_id - self.message_storage = message_storage or MongoDBMessageStorage() + self.message_storage = MongoDBMessageStorage() # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 # self.last_check_time: float = time.time() # 上次查看聊天记录时间 - self.last_message_read: Optional[str] = None # 最后读取的消息ID - self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 + self.last_message_read: Optional[Dict[str, Any]] = None # 最后读取的消息ID + self.last_message_time: float = time.time() self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 @@ -133,12 +131,6 @@ class ChatObserver: notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold) await self.notification_manager.send_notification(notification) - async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: - """获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象""" - messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read) - for message in messages: - await self._add_message_to_history(message) - return messages, self.message_history def new_message_after(self, time_point: float) -> bool: """判断是否在指定时间点后有新消息 @@ -200,12 +192,13 @@ class ChatObserver: Returns: List[Dict[str, Any]]: 新消息列表 """ - new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read) + new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_time) if new_messages: - self.last_message_read = new_messages[-1]["message_id"] + self.last_message_read = new_messages[-1] + self.last_message_time = new_messages[-1]["time"] - print(f"获取111111111122222222新消息: {new_messages}") + print(f"获取数据库中找到的新消息: {new_messages}") return new_messages @@ -267,6 +260,7 @@ class ChatObserver: except Exception as e: logger.error(f"更新循环出错: {e}") + logger.error(traceback.format_exc()) self._update_complete.set() # 即使出错也要设置完成事件 def trigger_update(self): diff --git a/src/plugins/PFC/message_storage.py b/src/plugins/PFC/message_storage.py index 88f409641..75bab6edd 100644 --- a/src/plugins/PFC/message_storage.py +++ b/src/plugins/PFC/message_storage.py @@ -1,18 +1,18 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional from src.common.database import db - +import time class MessageStorage(ABC): """消息存储接口""" @abstractmethod - async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: + async def get_messages_after(self, chat_id: str, message: Dict[str, Any]) -> List[Dict[str, Any]]: """获取指定消息ID之后的所有消息 Args: chat_id: 聊天ID - message_id: 消息ID,如果为None则获取所有消息 + message: 消息 Returns: List[Dict[str, Any]]: 消息列表 @@ -53,14 +53,11 @@ class MongoDBMessageStorage(MessageStorage): def __init__(self): self.db = db - async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: + async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: query = {"chat_id": chat_id} + print(f"storage_check_message: {message_time}") - if message_id: - # 获取ID大于message_id的消息 - last_message = self.db.messages.find_one({"message_id": message_id}) - if last_message: - query["time"] = {"$gt": last_message["time"]} + query["time"] = {"$gt": message_time} return list(self.db.messages.find(query).sort("time", 1)) diff --git a/src/plugins/PFC/observation_info.py b/src/plugins/PFC/observation_info.py index 947c3205d..01f619dc3 100644 --- a/src/plugins/PFC/observation_info.py +++ b/src/plugins/PFC/observation_info.py @@ -132,11 +132,6 @@ class ObservationInfo: stream_id: 聊天流ID """ self.chat_observer = chat_observer - print(f"1919810----------------------绑定-----------------------------") - print(self.chat_observer) - print(f"1919810--------------------绑定-----------------------------") - print(self.chat_observer.notification_manager) - print(f"1919810-------------------绑定-----------------------------") self.chat_observer.notification_manager.register_handler( target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler ) @@ -144,9 +139,6 @@ class ObservationInfo: target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler ) print("1919810------------------------绑定-----------------------------") - print(f"1919810--------------------绑定-----------------------------") - print(self.chat_observer.notification_manager) - print(f"1919810-------------------绑定-----------------------------") def unbind_from_chat_observer(self): """解除与chat_observer的绑定""" @@ -235,10 +227,10 @@ class ObservationInfo: """清空未处理消息列表""" # 将未处理消息添加到历史记录中 for message in self.unprocessed_messages: - if "processed_plain_text" in message: - self.chat_history.append(message["processed_plain_text"]) + self.chat_history.append(message) # 清空未处理消息列表 self.has_unread_messages = False self.unprocessed_messages.clear() + self.chat_history_count = len(self.chat_history) self.new_messages_count = 0 From 27c10ff29d91a518226c3538febd13453fa744be Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 11 Apr 2025 10:55:45 +0800 Subject: [PATCH 22/24] fix: Ruff --- src/do_tool/tool_can_use/__init__.py | 11 ++++++++++- src/do_tool/tool_can_use/base_tool.py | 12 ++++-------- src/do_tool/tool_can_use/get_knowledge.py | 2 +- src/do_tool/tool_use.py | 1 - src/heart_flow/observation.py | 2 -- src/heart_flow/sub_heartflow.py | 11 +++++------ .../think_flow_chat/think_flow_prompt_builder.py | 4 ++-- 7 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/do_tool/tool_can_use/__init__.py b/src/do_tool/tool_can_use/__init__.py index cc196d07a..3189d2897 100644 --- a/src/do_tool/tool_can_use/__init__.py +++ b/src/do_tool/tool_can_use/__init__.py @@ -7,5 +7,14 @@ from src.do_tool.tool_can_use.base_tool import ( TOOL_REGISTRY ) +__all__ = [ + 'BaseTool', + 'register_tool', + 'discover_tools', + 'get_all_tool_definitions', + 'get_tool_instance', + 'TOOL_REGISTRY' +] + # 自动发现并注册工具 -discover_tools() \ No newline at end of file +discover_tools() \ No newline at end of file diff --git a/src/do_tool/tool_can_use/base_tool.py b/src/do_tool/tool_can_use/base_tool.py index 03aac5e4c..c8c80ebe8 100644 --- a/src/do_tool/tool_can_use/base_tool.py +++ b/src/do_tool/tool_can_use/base_tool.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Any, Optional, Union, Type +from typing import Dict, List, Any, Optional, Type import inspect import importlib import pkgutil @@ -73,13 +73,9 @@ def discover_tools(): # 获取当前目录路径 current_dir = os.path.dirname(os.path.abspath(__file__)) package_name = os.path.basename(current_dir) - parent_dir = os.path.dirname(current_dir) - - # 导入当前包 - package = importlib.import_module(f"src.do_tool.{package_name}") # 遍历包中的所有模块 - for _, module_name, is_pkg in pkgutil.iter_modules([current_dir]): + for _, module_name, _ in pkgutil.iter_modules([current_dir]): # 跳过当前模块和__pycache__ if module_name == "base_tool" or module_name.startswith("__"): continue @@ -88,7 +84,7 @@ def discover_tools(): module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}") # 查找模块中的工具类 - for name, obj in inspect.getmembers(module): + for _, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: register_tool(obj) @@ -116,4 +112,4 @@ def get_tool_instance(tool_name: str) -> Optional[BaseTool]: tool_class = TOOL_REGISTRY.get(tool_name) if not tool_class: return None - return tool_class() \ No newline at end of file + return tool_class() \ No newline at end of file diff --git a/src/do_tool/tool_can_use/get_knowledge.py b/src/do_tool/tool_can_use/get_knowledge.py index 06ea7a91b..fa17dfbf6 100644 --- a/src/do_tool/tool_can_use/get_knowledge.py +++ b/src/do_tool/tool_can_use/get_knowledge.py @@ -2,7 +2,7 @@ from src.do_tool.tool_can_use.base_tool import BaseTool, register_tool from src.plugins.chat.utils import get_embedding from src.common.database import db from src.common.logger import get_module_logger -from typing import Dict, Any, Union, List +from typing import Dict, Any, Union logger = get_module_logger("get_knowledge_tool") diff --git a/src/do_tool/tool_use.py b/src/do_tool/tool_use.py index a2e23ab21..95118f79f 100644 --- a/src/do_tool/tool_use.py +++ b/src/do_tool/tool_use.py @@ -5,7 +5,6 @@ from src.common.database import db import time import json from src.common.logger import get_module_logger -from typing import Union from src.do_tool.tool_can_use import get_all_tool_definitions, get_tool_instance logger = get_module_logger("tool_use") diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index f507374cb..55ab9db11 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -4,8 +4,6 @@ from datetime import datetime from src.plugins.models.utils_model import LLM_request from src.plugins.config.config import global_config from src.common.database import db -from src.individuality.individuality import Individuality -import random # 所有观察的基类 diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index baa20c64f..9cf2e2ea2 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -5,18 +5,17 @@ from src.plugins.models.utils_model import LLM_request from src.plugins.config.config import global_config import re import time -from src.plugins.schedule.schedule_generator import bot_schedule -from src.plugins.memory_system.Hippocampus import HippocampusManager +# from src.plugins.schedule.schedule_generator import bot_schedule +# from src.plugins.memory_system.Hippocampus import HippocampusManager from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402 -from src.plugins.chat.utils import get_embedding -from src.common.database import db -from typing import Union +# from src.plugins.chat.utils import get_embedding +# from src.common.database import db +# from typing import Union from src.individuality.individuality import Individuality import random from src.plugins.chat.chat_stream import ChatStream from src.plugins.person_info.relationship_manager import relationship_manager from src.plugins.chat.utils import get_recent_group_speaker -import json from src.do_tool.tool_use import ToolUser subheartflow_config = LogConfig( diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py index 0e00eea95..43b0db219 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py @@ -108,7 +108,7 @@ class PromptBuilder: individuality = Individuality.get_instance() prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1) - prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1) + # prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1) # 日程构建 @@ -166,7 +166,7 @@ class PromptBuilder: ) -> tuple[str, str]: individuality = Individuality.get_instance() - prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1) + # prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1) prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1) From b11ebbc8323ea2155376aff74f9a9afc2ad3705f Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 11 Apr 2025 10:57:49 +0800 Subject: [PATCH 23/24] fix: Ruff x2 --- src/plugins/PFC/chat_observer.py | 4 ++-- src/plugins/PFC/message_storage.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py index a766e3b4a..0af11e135 100644 --- a/src/plugins/PFC/chat_observer.py +++ b/src/plugins/PFC/chat_observer.py @@ -1,12 +1,12 @@ import time import asyncio import traceback -from typing import Optional, Dict, Any, List, Tuple +from typing import Optional, Dict, Any, List from src.common.logger import get_module_logger from ..message.message_base import UserInfo from ..config.config import global_config from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification -from .message_storage import MessageStorage, MongoDBMessageStorage +from .message_storage import MongoDBMessageStorage logger = get_module_logger("chat_observer") diff --git a/src/plugins/PFC/message_storage.py b/src/plugins/PFC/message_storage.py index 75bab6edd..afd233347 100644 --- a/src/plugins/PFC/message_storage.py +++ b/src/plugins/PFC/message_storage.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any from src.common.database import db -import time class MessageStorage(ABC): """消息存储接口""" From 1ac9a66cee430d86c0a7e67c97e1d20c35f1a5dc Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Fri, 11 Apr 2025 13:10:15 +0800 Subject: [PATCH 24/24] =?UTF-8?q?=E4=B8=8D=E5=B0=8F=E5=BF=83=E7=82=B8?= =?UTF-8?q?=E4=BA=86logger?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot.py b/bot.py index c9568ecba..653efd45d 100644 --- a/bot.py +++ b/bot.py @@ -16,7 +16,7 @@ confirm_logger_config = LogConfig( console_format=CONFIRM_STYLE_CONFIG["console_format"], file_format=CONFIRM_STYLE_CONFIG["file_format"], ) -confirm_logger = get_module_logger("main_bot", config=confirm_logger_config) +confirm_logger = get_module_logger("confirm", config=confirm_logger_config) # 获取没有加载env时的环境变量 env_mask = {key: os.getenv(key) for key in os.environ}