diff --git a/src/common/logger.py b/src/common/logger.py index 7ef539fc3..0a8839d2f 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -283,17 +283,13 @@ WILLING_STYLE_CONFIG = { "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), }, "simple": { - "console_format": ( - "{time:MM-DD HH:mm} | 意愿 | {message}" - ), # noqa: E501 + "console_format": ("{time:MM-DD HH:mm} | 意愿 | {message}"), # noqa: E501 "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), }, } CONFIRM_STYLE_CONFIG = { - "console_format": ( - "{message}" - ), # noqa: E501 + "console_format": ("{message}"), # noqa: E501 "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"), } diff --git a/src/do_tool/tool_can_use/__init__.py b/src/do_tool/tool_can_use/__init__.py index 3189d2897..a7ea17ab7 100644 --- a/src/do_tool/tool_can_use/__init__.py +++ b/src/do_tool/tool_can_use/__init__.py @@ -4,17 +4,17 @@ from src.do_tool.tool_can_use.base_tool import ( discover_tools, get_all_tool_definitions, get_tool_instance, - TOOL_REGISTRY + TOOL_REGISTRY, ) __all__ = [ - 'BaseTool', - 'register_tool', - 'discover_tools', - 'get_all_tool_definitions', - 'get_tool_instance', - 'TOOL_REGISTRY' + "BaseTool", + "register_tool", + "discover_tools", + "get_all_tool_definitions", + "get_tool_instance", + "TOOL_REGISTRY", ] # 自动发现并注册工具 -discover_tools() \ No newline at end of file +discover_tools() diff --git a/src/do_tool/tool_can_use/base_tool.py b/src/do_tool/tool_can_use/base_tool.py index c8c80ebe8..b1edf8055 100644 --- a/src/do_tool/tool_can_use/base_tool.py +++ b/src/do_tool/tool_can_use/base_tool.py @@ -10,41 +10,39 @@ logger = get_module_logger("base_tool") # 工具注册表 TOOL_REGISTRY = {} + class BaseTool: """所有工具的基类""" + # 工具名称,子类必须重写 name = None # 工具描述,子类必须重写 description = None # 工具参数定义,子类必须重写 parameters = None - + @classmethod def get_tool_definition(cls) -> Dict[str, Any]: """获取工具定义,用于LLM工具调用 - + Returns: Dict: 工具定义字典 """ if not cls.name or not cls.description or not cls.parameters: raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") - + return { "type": "function", - "function": { - "name": cls.name, - "description": cls.description, - "parameters": cls.parameters - } + "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters}, } - + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: """执行工具函数 - + Args: function_args: 工具调用参数 message_txt: 原始消息文本 - + Returns: Dict: 工具执行结果 """ @@ -53,17 +51,17 @@ class BaseTool: def register_tool(tool_class: Type[BaseTool]): """注册工具到全局注册表 - + Args: tool_class: 工具类 """ if not issubclass(tool_class, BaseTool): raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类") - + tool_name = tool_class.name if not tool_name: raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性") - + TOOL_REGISTRY[tool_name] = tool_class logger.info(f"已注册工具: {tool_name}") @@ -73,27 +71,27 @@ def discover_tools(): # 获取当前目录路径 current_dir = os.path.dirname(os.path.abspath(__file__)) package_name = os.path.basename(current_dir) - + # 遍历包中的所有模块 for _, module_name, _ in pkgutil.iter_modules([current_dir]): # 跳过当前模块和__pycache__ if module_name == "base_tool" or module_name.startswith("__"): continue - + # 导入模块 module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}") - + # 查找模块中的工具类 for _, obj in inspect.getmembers(module): if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: register_tool(obj) - + logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具") def get_all_tool_definitions() -> List[Dict[str, Any]]: """获取所有已注册工具的定义 - + Returns: List[Dict]: 工具定义列表 """ @@ -102,14 +100,14 @@ def get_all_tool_definitions() -> List[Dict[str, Any]]: def get_tool_instance(tool_name: str) -> Optional[BaseTool]: """获取指定名称的工具实例 - + Args: tool_name: 工具名称 - + Returns: Optional[BaseTool]: 工具实例,如果找不到则返回None """ tool_class = TOOL_REGISTRY.get(tool_name) if not tool_class: return None - return tool_class() \ No newline at end of file + return tool_class() diff --git a/src/do_tool/tool_can_use/fibonacci_sequence_tool.py b/src/do_tool/tool_can_use/fibonacci_sequence_tool.py index 31ca4d0a7..4609b18a0 100644 --- a/src/do_tool/tool_can_use/fibonacci_sequence_tool.py +++ b/src/do_tool/tool_can_use/fibonacci_sequence_tool.py @@ -4,29 +4,25 @@ from typing import Dict, Any logger = get_module_logger("fibonacci_sequence_tool") + class FibonacciSequenceTool(BaseTool): """生成斐波那契数列的工具""" + name = "fibonacci_sequence" description = "生成指定长度的斐波那契数列" parameters = { "type": "object", - "properties": { - "n": { - "type": "integer", - "description": "斐波那契数列的长度", - "minimum": 1 - } - }, - "required": ["n"] + "properties": {"n": {"type": "integer", "description": "斐波那契数列的长度", "minimum": 1}}, + "required": ["n"], } - + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: """执行工具功能 - + Args: function_args: 工具参数 message_txt: 原始消息文本 - + Returns: Dict: 工具执行结果 """ @@ -34,23 +30,18 @@ class FibonacciSequenceTool(BaseTool): n = function_args.get("n") if n <= 0: raise ValueError("参数n必须大于0") - + sequence = [] a, b = 0, 1 for _ in range(n): sequence.append(a) a, b = b, a + b - - return { - "name": self.name, - "content": sequence - } + + return {"name": self.name, "content": sequence} except Exception as e: logger.error(f"fibonacci_sequence工具执行失败: {str(e)}") - return { - "name": self.name, - "content": f"执行失败: {str(e)}" - } + return {"name": self.name, "content": f"执行失败: {str(e)}"} + # 注册工具 -register_tool(FibonacciSequenceTool) \ No newline at end of file +register_tool(FibonacciSequenceTool) diff --git a/src/do_tool/tool_can_use/generate_buddha_emoji_tool.py b/src/do_tool/tool_can_use/generate_buddha_emoji_tool.py index 559b6eadd..e704b6015 100644 --- a/src/do_tool/tool_can_use/generate_buddha_emoji_tool.py +++ b/src/do_tool/tool_can_use/generate_buddha_emoji_tool.py @@ -4,8 +4,10 @@ from typing import Dict, Any logger = get_module_logger("generate_buddha_emoji_tool") + class GenerateBuddhaEmojiTool(BaseTool): """生成佛祖颜文字的工具类""" + name = "generate_buddha_emoji" description = "生成一个佛祖的颜文字表情" parameters = { @@ -13,32 +15,27 @@ class GenerateBuddhaEmojiTool(BaseTool): "properties": { # 无参数 }, - "required": [] + "required": [], } - + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: """执行工具功能,生成佛祖颜文字 - + Args: function_args: 工具参数 message_txt: 原始消息文本 - + Returns: Dict: 工具执行结果 """ try: buddha_emoji = "这是一个佛祖emoji:༼ つ ◕_◕ ༽つ" - - return { - "name": self.name, - "content": buddha_emoji - } + + return {"name": self.name, "content": buddha_emoji} except Exception as e: logger.error(f"generate_buddha_emoji工具执行失败: {str(e)}") - return { - "name": self.name, - "content": f"执行失败: {str(e)}" - } + return {"name": self.name, "content": f"执行失败: {str(e)}"} + # 注册工具 -register_tool(GenerateBuddhaEmojiTool) \ No newline at end of file +register_tool(GenerateBuddhaEmojiTool) diff --git a/src/do_tool/tool_can_use/generate_cmd_tutorial_tool.py b/src/do_tool/tool_can_use/generate_cmd_tutorial_tool.py index 6a790adb6..3a9f9bba1 100644 --- a/src/do_tool/tool_can_use/generate_cmd_tutorial_tool.py +++ b/src/do_tool/tool_can_use/generate_cmd_tutorial_tool.py @@ -4,23 +4,21 @@ from typing import Dict, Any logger = get_module_logger("generate_cmd_tutorial_tool") + class GenerateCmdTutorialTool(BaseTool): """生成Windows CMD基本操作教程的工具""" + name = "generate_cmd_tutorial" description = "生成关于Windows命令提示符(CMD)的基本操作教程,包括常用命令和使用方法" - parameters = { - "type": "object", - "properties": {}, - "required": [] - } - + parameters = {"type": "object", "properties": {}, "required": []} + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: """执行工具功能 - + Args: function_args: 工具参数 message_txt: 原始消息文本 - + Returns: Dict: 工具执行结果 """ @@ -57,17 +55,12 @@ class GenerateCmdTutorialTool(BaseTool): 注意:使用命令时要小心,特别是删除操作。 """ - - return { - "name": self.name, - "content": tutorial_content - } + + return {"name": self.name, "content": tutorial_content} except Exception as e: logger.error(f"generate_cmd_tutorial工具执行失败: {str(e)}") - return { - "name": self.name, - "content": f"执行失败: {str(e)}" - } + return {"name": self.name, "content": f"执行失败: {str(e)}"} + # 注册工具 -register_tool(GenerateCmdTutorialTool) \ No newline at end of file +register_tool(GenerateCmdTutorialTool) diff --git a/src/do_tool/tool_can_use/get_current_task.py b/src/do_tool/tool_can_use/get_current_task.py index dd3402357..1975c40b0 100644 --- a/src/do_tool/tool_can_use/get_current_task.py +++ b/src/do_tool/tool_can_use/get_current_task.py @@ -5,32 +5,28 @@ from typing import Dict, Any logger = get_module_logger("get_current_task_tool") + class GetCurrentTaskTool(BaseTool): """获取当前正在做的事情/最近的任务工具""" + name = "get_current_task" description = "获取当前正在做的事情/最近的任务" parameters = { "type": "object", "properties": { - "num": { - "type": "integer", - "description": "要获取的任务数量" - }, - "time_info": { - "type": "boolean", - "description": "是否包含时间信息" - } + "num": {"type": "integer", "description": "要获取的任务数量"}, + "time_info": {"type": "boolean", "description": "是否包含时间信息"}, }, - "required": [] + "required": [], } - + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: """执行获取当前任务 - + Args: function_args: 工具参数 message_txt: 原始消息文本,此工具不使用 - + Returns: Dict: 工具执行结果 """ @@ -38,26 +34,21 @@ class GetCurrentTaskTool(BaseTool): # 获取参数,如果没有提供则使用默认值 num = function_args.get("num", 1) time_info = function_args.get("time_info", False) - + # 调用日程系统获取当前任务 current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info) - + # 格式化返回结果 if current_task: task_info = current_task else: task_info = "当前没有正在进行的任务" - - return { - "name": "get_current_task", - "content": f"当前任务信息: {task_info}" - } + + return {"name": "get_current_task", "content": f"当前任务信息: {task_info}"} except Exception as e: logger.error(f"获取当前任务工具执行失败: {str(e)}") - return { - "name": "get_current_task", - "content": f"获取当前任务失败: {str(e)}" - } + return {"name": "get_current_task", "content": f"获取当前任务失败: {str(e)}"} + # 注册工具 -register_tool(GetCurrentTaskTool) \ No newline at end of file +register_tool(GetCurrentTaskTool) diff --git a/src/do_tool/tool_can_use/get_knowledge.py b/src/do_tool/tool_can_use/get_knowledge.py index fa17dfbf6..0b492f11a 100644 --- a/src/do_tool/tool_can_use/get_knowledge.py +++ b/src/do_tool/tool_can_use/get_knowledge.py @@ -6,39 +6,35 @@ from typing import Dict, Any, Union logger = get_module_logger("get_knowledge_tool") + class SearchKnowledgeTool(BaseTool): """从知识库中搜索相关信息的工具""" + name = "search_knowledge" description = "从知识库中搜索相关信息" parameters = { "type": "object", "properties": { - "query": { - "type": "string", - "description": "搜索查询关键词" - }, - "threshold": { - "type": "number", - "description": "相似度阈值,0.0到1.0之间" - } + "query": {"type": "string", "description": "搜索查询关键词"}, + "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"}, }, - "required": ["query"] + "required": ["query"], } - + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: """执行知识库搜索 - + Args: function_args: 工具参数 message_txt: 原始消息文本 - + Returns: Dict: 工具执行结果 """ try: query = function_args.get("query", message_txt) threshold = function_args.get("threshold", 0.4) - + # 调用知识库搜索 embedding = await get_embedding(query, request_type="info_retrieval") if embedding: @@ -47,38 +43,29 @@ class SearchKnowledgeTool(BaseTool): content = f"你知道这些知识: {knowledge_info}" else: content = f"你不太了解有关{query}的知识" - return { - "name": "search_knowledge", - "content": content - } - return { - "name": "search_knowledge", - "content": f"无法获取关于'{query}'的嵌入向量" - } + return {"name": "search_knowledge", "content": content} + return {"name": "search_knowledge", "content": f"无法获取关于'{query}'的嵌入向量"} except Exception as e: logger.error(f"知识库搜索工具执行失败: {str(e)}") - return { - "name": "search_knowledge", - "content": f"知识库搜索失败: {str(e)}" - } - + return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"} + def get_info_from_db( self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False ) -> Union[str, list]: """从数据库中获取相关信息 - + Args: query_embedding: 查询的嵌入向量 limit: 最大返回结果数 threshold: 相似度阈值 return_raw: 是否返回原始结果 - + Returns: Union[str, list]: 格式化的信息字符串或原始结果列表 """ if not query_embedding: return "" if not return_raw else [] - + # 使用余弦相似度计算 pipeline = [ { @@ -143,5 +130,6 @@ class SearchKnowledgeTool(BaseTool): # 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in results) + # 注册工具 register_tool(SearchKnowledgeTool) diff --git a/src/do_tool/tool_can_use/get_memory.py b/src/do_tool/tool_can_use/get_memory.py index 171e8486a..16af4c644 100644 --- a/src/do_tool/tool_can_use/get_memory.py +++ b/src/do_tool/tool_can_use/get_memory.py @@ -5,68 +5,55 @@ from typing import Dict, Any logger = get_module_logger("get_memory_tool") + class GetMemoryTool(BaseTool): """从记忆系统中获取相关记忆的工具""" + name = "get_memory" description = "从记忆系统中获取相关记忆" parameters = { "type": "object", "properties": { - "text": { - "type": "string", - "description": "要查询的相关文本" - }, - "max_memory_num": { - "type": "integer", - "description": "最大返回记忆数量" - } + "text": {"type": "string", "description": "要查询的相关文本"}, + "max_memory_num": {"type": "integer", "description": "最大返回记忆数量"}, }, - "required": ["text"] + "required": ["text"], } - + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: """执行记忆获取 - + Args: function_args: 工具参数 message_txt: 原始消息文本 - + Returns: Dict: 工具执行结果 """ try: text = function_args.get("text", message_txt) max_memory_num = function_args.get("max_memory_num", 2) - + # 调用记忆系统 related_memory = await HippocampusManager.get_instance().get_memory_from_text( - text=text, - max_memory_num=max_memory_num, - max_memory_length=2, - max_depth=3, - fast_retrieval=False + text=text, max_memory_num=max_memory_num, max_memory_length=2, max_depth=3, fast_retrieval=False ) - + memory_info = "" if related_memory: for memory in related_memory: memory_info += memory[1] + "\n" - + if memory_info: content = f"你记得这些事情: {memory_info}" else: content = f"你不太记得有关{text}的记忆,你对此不太了解" - - return { - "name": "get_memory", - "content": content - } + + return {"name": "get_memory", "content": content} except Exception as e: logger.error(f"记忆获取工具执行失败: {str(e)}") - return { - "name": "get_memory", - "content": f"记忆获取失败: {str(e)}" - } + return {"name": "get_memory", "content": f"记忆获取失败: {str(e)}"} + # 注册工具 -register_tool(GetMemoryTool) \ No newline at end of file +register_tool(GetMemoryTool) diff --git a/src/do_tool/tool_use.py b/src/do_tool/tool_use.py index 95118f79f..51bc37568 100644 --- a/src/do_tool/tool_use.py +++ b/src/do_tool/tool_use.py @@ -16,21 +16,19 @@ class ToolUser: model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use" ) - async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + async def _build_tool_prompt(self, message_txt: str, sender_name: str, chat_stream: ChatStream): """构建工具使用的提示词 - + Args: message_txt: 用户消息文本 sender_name: 发送者名称 chat_stream: 聊天流对象 - + Returns: str: 构建好的提示词 """ new_messages = list( - db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}) - .sort("time", 1) - .limit(15) + db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}).sort("time", 1).limit(15) ) new_messages_str = "" for msg in new_messages: @@ -44,37 +42,37 @@ class ToolUser: prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" prompt += f"注意你就是{bot_name},{bot_name}指的就是你。" prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。" - + return prompt - + def _define_tools(self): """获取所有已注册工具的定义 - + Returns: list: 工具定义列表 """ return get_all_tool_definitions() - - async def _execute_tool_call(self, tool_call, message_txt:str): + + async def _execute_tool_call(self, tool_call, message_txt: str): """执行特定的工具调用 - + Args: tool_call: 工具调用对象 message_txt: 原始消息文本 - + Returns: dict: 工具调用结果 """ try: function_name = tool_call["function"]["name"] function_args = json.loads(tool_call["function"]["arguments"]) - + # 获取对应工具实例 tool_instance = get_tool_instance(function_name) if not tool_instance: logger.warning(f"未知工具名称: {function_name}") return None - + # 执行工具 result = await tool_instance.execute(function_args, message_txt) if result: @@ -82,62 +80,60 @@ class ToolUser: "tool_call_id": tool_call["id"], "role": "tool", "name": function_name, - "content": result["content"] + "content": result["content"], } return None except Exception as e: logger.error(f"执行工具调用时发生错误: {str(e)}") return None - - async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + + async def use_tool(self, message_txt: str, sender_name: str, chat_stream: ChatStream): """使用工具辅助思考,判断是否需要额外信息 - + Args: message_txt: 用户消息文本 sender_name: 发送者名称 chat_stream: 聊天流对象 - + Returns: dict: 工具使用结果 """ try: # 构建提示词 prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream) - + # 定义可用工具 tools = self._define_tools() - + # 使用llm_model_tool发送带工具定义的请求 payload = { "model": self.llm_model_tool.model_name, "messages": [{"role": "user", "content": prompt}], "max_tokens": global_config.max_response_length, "tools": tools, - "temperature": 0.2 + "temperature": 0.2, } - + logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}") # 发送请求获取模型是否需要调用工具 response = await self.llm_model_tool._execute_request( - endpoint="/chat/completions", - payload=payload, - prompt=prompt + endpoint="/chat/completions", payload=payload, prompt=prompt ) - + # 根据返回值数量判断是否有工具调用 if len(response) == 3: content, reasoning_content, tool_calls = response logger.info(f"工具思考: {tool_calls}") - + # 检查响应中工具调用是否有效 if not tool_calls: logger.info("模型返回了空的tool_calls列表") return {"used_tools": False} - + logger.info(f"模型请求调用{len(tool_calls)}个工具") tool_results = [] collected_info = "" - + # 执行所有工具调用 for tool_call in tool_calls: result = await self._execute_tool_call(tool_call, message_txt) @@ -145,7 +141,7 @@ class ToolUser: tool_results.append(result) # 将工具结果添加到收集的信息中 collected_info += f"\n{result['name']}返回结果: {result['content']}\n" - + # 如果有工具结果,直接返回收集的信息 if collected_info: logger.info(f"工具调用收集到信息: {collected_info}") @@ -157,15 +153,15 @@ class ToolUser: # 没有工具调用 content, reasoning_content = response logger.info("模型没有请求调用任何工具") - + # 如果没有工具调用或处理失败,直接返回原始思考 return { "used_tools": False, } - + except Exception as e: logger.error(f"工具调用过程中出错: {str(e)}") return { "used_tools": False, "error": str(e), - } \ No newline at end of file + } diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index 74250a708..d6116d0d5 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -43,12 +43,11 @@ def init_prompt(): class CurrentState: def __init__(self): - self.current_state_info = "" self.mood_manager = MoodManager() self.mood = self.mood_manager.get_prompt() - + self.attendance_factor = 0 self.engagement_factor = 0 @@ -66,9 +65,6 @@ class Heartflow: ) self._subheartflows: Dict[Any, SubHeartflow] = {} - - - async def _cleanup_inactive_subheartflows(self): """定期清理不活跃的子心流""" @@ -90,7 +86,7 @@ class Heartflow: logger.info(f"已清理不活跃的子心流: {subheartflow_id}") await asyncio.sleep(30) # 每分钟检查一次 - + async def _sub_heartflow_update(self): while True: # 检查是否存在子心流 @@ -103,13 +99,12 @@ class Heartflow: await asyncio.sleep(global_config.heart_flow_update_interval) # 5分钟思考一次 async def heartflow_start_working(self): - # 启动清理任务 asyncio.create_task(self._cleanup_inactive_subheartflows()) # 启动子心流更新任务 asyncio.create_task(self._sub_heartflow_update()) - + async def _update_current_state(self): print("TODO") diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index 55ab9db11..aef23f964 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -150,7 +150,7 @@ class ChattingObservation(Observation): except Exception as e: print(f"获取总结失败: {e}") updated_observe_info = "" - + return updated_observe_info # print(f"prompt:{prompt}") # print(f"self.observe_info:{self.observe_info}") diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index 18f256d1d..ce1dd10a1 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -5,9 +5,11 @@ from src.plugins.models.utils_model import LLM_request from src.plugins.config.config import global_config import re import time + # from src.plugins.schedule.schedule_generator import bot_schedule # from src.plugins.memory_system.Hippocampus import HippocampusManager from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402 + # from src.plugins.chat.utils import get_embedding # from src.common.database import db # from typing import Union @@ -17,7 +19,7 @@ from src.plugins.chat.chat_stream import ChatStream from src.plugins.person_info.relationship_manager import relationship_manager from src.plugins.chat.utils import get_recent_group_speaker from src.do_tool.tool_use import ToolUser -from ..plugins.utils.prompt_builder import Prompt,global_prompt_manager +from ..plugins.utils.prompt_builder import Prompt, global_prompt_manager subheartflow_config = LogConfig( # 使用海马体专用样式 @@ -26,6 +28,7 @@ subheartflow_config = LogConfig( ) logger = get_module_logger("subheartflow", config=subheartflow_config) + def init_prompt(): prompt = "" # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" @@ -41,7 +44,7 @@ def init_prompt(): prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n" prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写" prompt += "记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{bot_name},{bot_name}指的就是你。" - Prompt(prompt,"sub_heartflow_prompt_before") + Prompt(prompt, "sub_heartflow_prompt_before") prompt = "" # prompt += f"你现在正在做的事情是:{schedule_info}\n" prompt += "{prompt_personality}\n" @@ -52,8 +55,7 @@ def init_prompt(): prompt += "你现在{mood_info}" prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白" prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:" - Prompt(prompt,'sub_heartflow_prompt_after') - + Prompt(prompt, "sub_heartflow_prompt_after") class CurrentState: @@ -78,7 +80,6 @@ class SubHeartflow: self.llm_model = LLM_request( model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow" ) - self.main_heartflow_info = "" @@ -93,9 +94,9 @@ class SubHeartflow: self.observations: list[Observation] = [] self.running_knowledges = [] - + self.bot_name = global_config.BOT_NICKNAME - + self.tool_user = ToolUser() def add_observation(self, observation: Observation): @@ -145,12 +146,12 @@ class SubHeartflow: ): # 5分钟无回复/不在场,销毁 logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...") break # 退出循环,销毁自己 + async def do_observe(self): observation = self.observations[0] await observation.observe() - - async def do_thinking_before_reply(self, message_txt:str, sender_name:str, chat_stream:ChatStream): + async def do_thinking_before_reply(self, message_txt: str, sender_name: str, chat_stream: ChatStream): current_thinking_info = self.current_mind mood_info = self.current_state.mood # mood_info = "你很生气,很愤怒" @@ -160,12 +161,12 @@ class SubHeartflow: # 首先尝试使用工具获取更多信息 tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream) - + # 如果工具被使用且获得了结果,将收集到的信息合并到思考中 collected_info = "" if tool_result.get("used_tools", False): logger.info("使用工具收集了信息") - + # 如果有收集到的信息,将其添加到当前思考中 if "collected_info" in tool_result: collected_info = tool_result["collected_info"] @@ -185,7 +186,7 @@ class SubHeartflow: identity_detail = individuality.identity.identity_detail random.shuffle(identity_detail) prompt_personality += f",{identity_detail[0]}" - + # 关系 who_chat_in_group = [ (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname) @@ -204,9 +205,9 @@ class SubHeartflow: # f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录," # f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。" # ) - relation_prompt_all = (await global_prompt_manager.get_prompt_async('relationship_prompt')).format( - relation_prompt,sender_name - ) + relation_prompt_all = (await global_prompt_manager.get_prompt_async("relationship_prompt")).format( + relation_prompt, sender_name + ) # prompt = "" # # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" @@ -224,9 +225,16 @@ class SubHeartflow: # prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写" # prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。" - prompt= (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format( - collected_info,relation_prompt_all,prompt_personality,current_thinking_info,chat_observe_info,mood_info,sender_name, - message_txt,self.bot_name + prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format( + collected_info, + relation_prompt_all, + prompt_personality, + current_thinking_info, + chat_observe_info, + mood_info, + sender_name, + message_txt, + self.bot_name, ) try: @@ -281,10 +289,10 @@ class SubHeartflow: # prompt += f"你现在{mood_info}" # prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白" # prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:" - prompt=(await global_prompt_manager.get_prompt_async('sub_heartflow_prompt_after')).format( - prompt_personality,chat_observe_info,current_thinking_info,message_new_info,reply_info,mood_info + prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_after")).format( + prompt_personality, chat_observe_info, current_thinking_info, message_new_info, reply_info, mood_info ) - + try: response, reasoning_content = await self.llm_model.generate_response_async(prompt) except Exception as e: @@ -343,5 +351,6 @@ class SubHeartflow: self.past_mind.append(self.current_mind) self.current_mind = response + init_prompt() # subheartflow = SubHeartflow() diff --git a/src/plugins/PFC/action_planner.py b/src/plugins/PFC/action_planner.py index 53b95118b..61afc1bd3 100644 --- a/src/plugins/PFC/action_planner.py +++ b/src/plugins/PFC/action_planner.py @@ -53,13 +53,13 @@ class ActionPlanner: goal = goal_reason[0] reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" elif isinstance(goal_reason, dict): - goal = goal_reason.get('goal') - reasoning = goal_reason.get('reasoning', "没有明确原因") + goal = goal_reason.get("goal") + reasoning = goal_reason.get("reasoning", "没有明确原因") else: # 如果是其他类型,尝试转为字符串 goal = str(goal_reason) reasoning = "没有明确原因" - + goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goals_str += goal_str else: @@ -68,7 +68,11 @@ class ActionPlanner: goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" # 获取聊天历史记录 - chat_history_list = observation_info.chat_history[-20:] if len(observation_info.chat_history) >= 20 else observation_info.chat_history + chat_history_list = ( + observation_info.chat_history[-20:] + if len(observation_info.chat_history) >= 20 + else observation_info.chat_history + ) chat_history_text = "" for msg in chat_history_list: chat_history_text += f"{msg.get('detailed_plain_text', '')}\n" @@ -85,15 +89,21 @@ class ActionPlanner: personality_text = f"你的名字是{self.name},{self.personality_info}" # 构建action历史文本 - action_history_list = conversation_info.done_action[-10:] if len(conversation_info.done_action) >= 10 else conversation_info.done_action + action_history_list = ( + conversation_info.done_action[-10:] + if len(conversation_info.done_action) >= 10 + else conversation_info.done_action + ) action_history_text = "你之前做的事情是:" for action in action_history_list: if isinstance(action, dict): - action_type = action.get('action') - action_reason = action.get('reason') - action_status = action.get('status') + action_type = action.get("action") + action_reason = action.get("reason") + action_status = action.get("status") if action_status == "recall": - action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" + action_history_text += ( + f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" + ) elif action_status == "done": action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" elif isinstance(action, tuple): @@ -102,7 +112,9 @@ class ActionPlanner: action_reason = action[1] if len(action) > 1 else "未知原因" action_status = action[2] if len(action) > 2 else "done" if action_status == "recall": - action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" + action_history_text += ( + f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" + ) elif action_status == "done": action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" @@ -147,7 +159,14 @@ end_conversation: 结束对话,长时间没回复或者当你觉得谈话暂 reason = result["reason"] # 验证action类型 - if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal", "end_conversation"]: + if action not in [ + "direct_reply", + "fetch_knowledge", + "wait", + "listening", + "rethink_goal", + "end_conversation", + ]: logger.warning(f"未知的行动类型: {action},默认使用listening") action = "listening" diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py index 844f346f3..cc59d8247 100644 --- a/src/plugins/PFC/chat_observer.py +++ b/src/plugins/PFC/chat_observer.py @@ -1,12 +1,12 @@ import time import asyncio import traceback -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List from src.common.logger import get_module_logger from ..message.message_base import UserInfo from ..config.config import global_config from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification -from .message_storage import MongoDBMessageStorage +from .message_storage import MongoDBMessageStorage logger = get_module_logger("chat_observer") @@ -51,7 +51,6 @@ class ChatObserver: self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 - # 运行状态 self._running: bool = False self._task: Optional[asyncio.Task] = None @@ -94,10 +93,11 @@ class ChatObserver: message: 消息数据 """ try: - # 发送新消息通知 # logger.info(f"发送新ccchandleer消息通知: {message}") - notification = create_new_message_notification(sender="chat_observer", target="observation_info", message=message) + notification = create_new_message_notification( + sender="chat_observer", target="observation_info", message=message + ) # logger.info(f"发送新消ddddd息通知: {notification}") # print(self.notification_manager) await self.notification_manager.send_notification(notification) @@ -131,7 +131,6 @@ class ChatObserver: notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold) await self.notification_manager.send_notification(notification) - def new_message_after(self, time_point: float) -> bool: """判断是否在指定时间点后有新消息 @@ -197,7 +196,7 @@ class ChatObserver: if new_messages: self.last_message_read = new_messages[-1] self.last_message_time = new_messages[-1]["time"] - + # print(f"获取数据库中找到的新消息: {new_messages}") return new_messages @@ -215,7 +214,7 @@ class ChatObserver: if new_messages: self.last_message_read = new_messages[-1]["message_id"] - + logger.debug(f"获取指定时间点111之前的消息: {new_messages}") return new_messages @@ -239,7 +238,7 @@ class ChatObserver: try: # print("等待事件") await asyncio.wait_for(self._update_event.wait(), timeout=1) - + except asyncio.TimeoutError: # print("超时") pass # 超时后也执行一次检查 @@ -347,7 +346,6 @@ class ChatObserver: return time_info - def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: """获取缓存的消息历史 @@ -368,6 +366,6 @@ class ChatObserver: if not self.message_cache: return None return self.message_cache[0] - + def __str__(self): return f"ChatObserver for {self.stream_id}" diff --git a/src/plugins/PFC/chat_states.py b/src/plugins/PFC/chat_states.py index 373dfdb74..0253ea6dd 100644 --- a/src/plugins/PFC/chat_states.py +++ b/src/plugins/PFC/chat_states.py @@ -140,7 +140,6 @@ class NotificationManager: self._active_states.add(notification.type) else: self._active_states.discard(notification.type) - # 调用目标接收者的处理器 target = notification.target @@ -181,7 +180,7 @@ class NotificationManager: history = history[-limit:] return history - + def __str__(self): str = "" for target, handlers in self._handlers.items(): @@ -295,5 +294,3 @@ class ChatStateManager: current_time = datetime.now().timestamp() return (current_time - self.state_info.last_message_time) <= threshold - - diff --git a/src/plugins/PFC/conversation.py b/src/plugins/PFC/conversation.py index 599b1c453..7fcff895b 100644 --- a/src/plugins/PFC/conversation.py +++ b/src/plugins/PFC/conversation.py @@ -65,7 +65,6 @@ class Conversation: self.observation_info.bind_to_chat_observer(self.chat_observer) # print(self.chat_observer.get_cached_messages(limit=) - self.conversation_info = ConversationInfo() except Exception as e: logger.error(f"初始化对话实例:注册信息组件失败: {e}") @@ -96,7 +95,7 @@ class Conversation: # 执行行动 await self._handle_action(action, reason, self.observation_info, self.conversation_info) - + for goal in self.conversation_info.goal_list: # 检查goal是否为元组类型,如果是元组则使用索引访问,如果是字典则使用get方法 if isinstance(goal, tuple): @@ -151,7 +150,7 @@ class Conversation: if action == "direct_reply": self.waiter.wait_accumulated_time = 0 - + self.state = ConversationState.GENERATING self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info) print(f"生成回复: {self.generated_reply}") @@ -174,7 +173,6 @@ class Conversation: await self._send_reply() - conversation_info.done_action[-1].update( { "status": "done", @@ -184,7 +182,7 @@ class Conversation: elif action == "fetch_knowledge": self.waiter.wait_accumulated_time = 0 - + self.state = ConversationState.FETCHING knowledge = "TODO:知识" topic = "TODO:关键词" @@ -199,7 +197,7 @@ class Conversation: elif action == "rethink_goal": self.waiter.wait_accumulated_time = 0 - + self.state = ConversationState.RETHINKING await self.goal_analyzer.analyze_goal(conversation_info, observation_info) @@ -208,7 +206,6 @@ class Conversation: logger.info("倾听对方发言...") await self.waiter.wait_listening(conversation_info) - elif action == "end_conversation": self.should_continue = False logger.info("决定结束对话...") @@ -239,9 +236,7 @@ class Conversation: return try: - await self.direct_sender.send_message( - chat_stream=self.chat_stream, content=self.generated_reply - ) + await self.direct_sender.send_message(chat_stream=self.chat_stream, content=self.generated_reply) self.chat_observer.trigger_update() # 触发立即更新 if not await self.chat_observer.wait_for_update(): logger.warning("等待消息更新超时") diff --git a/src/plugins/PFC/message_storage.py b/src/plugins/PFC/message_storage.py index fbab0b2b6..55bccb14e 100644 --- a/src/plugins/PFC/message_storage.py +++ b/src/plugins/PFC/message_storage.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any from src.common.database import db + class MessageStorage(ABC): """消息存储接口""" diff --git a/src/plugins/PFC/observation_info.py b/src/plugins/PFC/observation_info.py index a8b804449..08ff3c046 100644 --- a/src/plugins/PFC/observation_info.py +++ b/src/plugins/PFC/observation_info.py @@ -26,24 +26,24 @@ class ObservationInfoHandler(NotificationHandler): # 获取通知类型和数据 notification_type = notification.type data = notification.data - + if notification_type == NotificationType.NEW_MESSAGE: # 处理新消息通知 logger.debug(f"收到新消息通知data: {data}") message_id = data.get("message_id") processed_plain_text = data.get("processed_plain_text") - detailed_plain_text = data.get("detailed_plain_text") + detailed_plain_text = data.get("detailed_plain_text") user_info = data.get("user_info") time_value = data.get("time") - + message = { "message_id": message_id, "processed_plain_text": processed_plain_text, "detailed_plain_text": detailed_plain_text, "user_info": user_info, - "time": time_value + "time": time_value, } - + self.observation_info.update_from_message(message) elif notification_type == NotificationType.COLD_CHAT: @@ -161,7 +161,7 @@ class ObservationInfo: # logger.debug(f"更新信息from_message: {message}") self.last_message_time = message["time"] self.last_message_id = message["message_id"] - + self.last_message_content = message.get("processed_plain_text", "") user_info = UserInfo.from_dict(message.get("user_info", {})) @@ -233,4 +233,3 @@ class ObservationInfo: self.unprocessed_messages.clear() self.chat_history_count = len(self.chat_history) self.new_messages_count = 0 - diff --git a/src/plugins/PFC/pfc.py b/src/plugins/PFC/pfc.py index f3c2aa344..727a8f1ba 100644 --- a/src/plugins/PFC/pfc.py +++ b/src/plugins/PFC/pfc.py @@ -1,6 +1,7 @@ # Programmable Friendly Conversationalist # Prefrontal cortex import datetime + # import asyncio from typing import List, Optional, Tuple, TYPE_CHECKING from src.common.logger import get_module_logger @@ -63,13 +64,13 @@ class GoalAnalyzer: goal = goal_reason[0] reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" elif isinstance(goal_reason, dict): - goal = goal_reason.get('goal') - reasoning = goal_reason.get('reasoning', "没有明确原因") + goal = goal_reason.get("goal") + reasoning = goal_reason.get("reasoning", "没有明确原因") else: # 如果是其他类型,尝试转为字符串 goal = str(goal_reason) reasoning = "没有明确原因" - + goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goals_str += goal_str else: @@ -140,14 +141,12 @@ class GoalAnalyzer: except Exception as e: logger.error(f"分析对话目标时出错: {str(e)}") content = "" - + # 使用改进后的get_items_from_json函数处理JSON数组 success, result = get_items_from_json( - content, "goal", "reasoning", - required_types={"goal": str, "reasoning": str}, - allow_array=True + content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}, allow_array=True ) - + if success: # 判断结果是单个字典还是字典列表 if isinstance(result, list): @@ -157,7 +156,7 @@ class GoalAnalyzer: goal = item.get("goal", "") reasoning = item.get("reasoning", "") conversation_info.goal_list.append((goal, reasoning)) - + # 返回第一个目标作为当前主要目标(如果有) if result: first_goal = result[0] @@ -168,7 +167,7 @@ class GoalAnalyzer: reasoning = result.get("reasoning", "") conversation_info.goal_list.append((goal, reasoning)) return (goal, "", reasoning) - + # 如果解析失败,返回默认值 return ("", "", "") @@ -293,7 +292,6 @@ class GoalAnalyzer: return False, False, f"分析出错: {str(e)}" - class DirectMessageSender: """直接发送消息到平台的发送器""" diff --git a/src/plugins/PFC/pfc_utils.py b/src/plugins/PFC/pfc_utils.py index f99b32a3d..eae36e125 100644 --- a/src/plugins/PFC/pfc_utils.py +++ b/src/plugins/PFC/pfc_utils.py @@ -27,7 +27,7 @@ def get_items_from_json( """ content = content.strip() result = {} - + # 设置默认值 if default_values: result.update(default_values) @@ -41,7 +41,7 @@ def get_items_from_json( if array_match: array_content = array_match.group() json_array = json.loads(array_content) - + # 确认是数组类型 if isinstance(json_array, list): # 验证数组中的每个项目是否包含所有必需字段 @@ -49,7 +49,7 @@ def get_items_from_json( for item in json_array: if not isinstance(item, dict): continue - + # 检查是否有所有必需字段 if all(field in item for field in items): # 验证字段类型 @@ -59,22 +59,22 @@ def get_items_from_json( if field in item and not isinstance(item[field], expected_type): type_valid = False break - + if not type_valid: continue - + # 验证字符串字段不为空 string_valid = True for field in items: if isinstance(item[field], str) and not item[field].strip(): string_valid = False break - + if not string_valid: continue - + valid_items.append(item) - + if valid_items: return True, valid_items except json.JSONDecodeError: diff --git a/src/plugins/PFC/reply_generator.py b/src/plugins/PFC/reply_generator.py index 11edf25a4..e65b64014 100644 --- a/src/plugins/PFC/reply_generator.py +++ b/src/plugins/PFC/reply_generator.py @@ -49,22 +49,26 @@ class ReplyGenerator: goal = goal_reason[0] reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" elif isinstance(goal_reason, dict): - goal = goal_reason.get('goal') - reasoning = goal_reason.get('reasoning', "没有明确原因") + goal = goal_reason.get("goal") + reasoning = goal_reason.get("reasoning", "没有明确原因") else: # 如果是其他类型,尝试转为字符串 goal = str(goal_reason) reasoning = "没有明确原因" - + goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goals_str += goal_str else: goal = "目前没有明确对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标" goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" - + # 获取聊天历史记录 - chat_history_list = observation_info.chat_history[-20:] if len(observation_info.chat_history) >= 20 else observation_info.chat_history + chat_history_list = ( + observation_info.chat_history[-20:] + if len(observation_info.chat_history) >= 20 + else observation_info.chat_history + ) chat_history_text = "" for msg in chat_history_list: chat_history_text += f"{msg.get('detailed_plain_text', '')}\n" @@ -81,15 +85,21 @@ class ReplyGenerator: personality_text = f"你的名字是{self.name},{self.personality_info}" # 构建action历史文本 - action_history_list = conversation_info.done_action[-10:] if len(conversation_info.done_action) >= 10 else conversation_info.done_action + action_history_list = ( + conversation_info.done_action[-10:] + if len(conversation_info.done_action) >= 10 + else conversation_info.done_action + ) action_history_text = "你之前做的事情是:" for action in action_history_list: if isinstance(action, dict): - action_type = action.get('action') - action_reason = action.get('reason') - action_status = action.get('status') + action_type = action.get("action") + action_reason = action.get("reason") + action_status = action.get("status") if action_status == "recall": - action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" + action_history_text += ( + f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" + ) elif action_status == "done": action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" elif isinstance(action, tuple): @@ -98,7 +108,9 @@ class ReplyGenerator: action_reason = action[1] if len(action) > 1 else "未知原因" action_status = action[2] if len(action) > 2 else "done" if action_status == "recall": - action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" + action_history_text += ( + f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" + ) elif action_status == "done": action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" diff --git a/src/plugins/PFC/waiter.py b/src/plugins/PFC/waiter.py index 6c55c243e..042ad80cd 100644 --- a/src/plugins/PFC/waiter.py +++ b/src/plugins/PFC/waiter.py @@ -16,7 +16,7 @@ class Waiter: self.chat_observer = ChatObserver.get_instance(stream_id) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.name = global_config.BOT_NICKNAME - + self.wait_accumulated_time = 0 async def wait(self, conversation_info: ConversationInfo) -> bool: @@ -38,20 +38,20 @@ class Waiter: # 检查是否超时 if time.time() - wait_start_time > 300: self.wait_accumulated_time += 300 - + logger.info("等待超过300秒,结束对话") wait_goal = { - "goal": f"你等待了{self.wait_accumulated_time/60}分钟,思考接下来要做什么", - "reason": "对方很久没有回复你的消息了" + "goal": f"你等待了{self.wait_accumulated_time / 60}分钟,思考接下来要做什么", + "reason": "对方很久没有回复你的消息了", } conversation_info.goal_list.append(wait_goal) print(f"添加目标: {wait_goal}") - + return True await asyncio.sleep(1) logger.info("等待中...") - + async def wait_listening(self, conversation_info: ConversationInfo) -> bool: """等待倾听 @@ -73,14 +73,13 @@ class Waiter: self.wait_accumulated_time += 300 logger.info("等待超过300秒,结束对话") wait_goal = { - "goal": f"你等待了{self.wait_accumulated_time/60}分钟,思考接下来要做什么", - "reason": "对方话说一半消失了,很久没有回复" + "goal": f"你等待了{self.wait_accumulated_time / 60}分钟,思考接下来要做什么", + "reason": "对方话说一半消失了,很久没有回复", } conversation_info.goal_list.append(wait_goal) print(f"添加目标: {wait_goal}") - + return True await asyncio.sleep(1) logger.info("等待中...") - diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 9000f4b24..c2126eee2 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -8,7 +8,7 @@ from ..chat_module.only_process.only_message_process import MessageProcessor from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from ..chat_module.think_flow_chat.think_flow_chat import ThinkFlowChat from ..chat_module.reasoning_chat.reasoning_chat import ReasoningChat -from ..utils.prompt_builder import Prompt,global_prompt_manager +from ..utils.prompt_builder import Prompt, global_prompt_manager import traceback # 定义日志配置 @@ -89,17 +89,17 @@ class ChatBot: if userinfo.user_id in global_config.ban_user_id: logger.debug(f"用户{userinfo.user_id}被禁止回复") return - + if message.message_info.template_info and not message.message_info.template_info.template_default: - template_group_name=message.message_info.template_info.template_name - template_items=message.message_info.template_info.template_items + template_group_name = message.message_info.template_info.template_name + template_items = message.message_info.template_info.template_items async with global_prompt_manager.async_message_scope(template_group_name): - if isinstance(template_items,dict): + if isinstance(template_items, dict): for k in template_items.keys(): - await Prompt.create_async(template_items[k],k) + await Prompt.create_async(template_items[k], k) print(f"注册{template_items[k]},{k}") else: - template_group_name=None + template_group_name = None async def preprocess(): if global_config.enable_pfc_chatting: diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index b7986ae3e..b07c33c39 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -87,7 +87,6 @@ async def get_embedding(text, request_type="embedding"): return embedding - async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list: """从数据库获取群组最近的消息记录 diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py index 8b81ca4b2..83abe71cf 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py @@ -38,7 +38,7 @@ class ResponseGenerator: self.current_model_type = "r1" # 默认使用 R1 self.current_model_name = "unknown model" - async def generate_response(self, message: MessageThinking,thinking_id:str) -> Optional[Union[str, List[str]]]: + async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]: """根据当前模型类型选择对应的生成函数""" # 从global_config中获取模型概率值并选择模型 if random.random() < global_config.MODEL_R1_PROBABILITY: @@ -52,7 +52,7 @@ class ResponseGenerator: f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" ) # noqa: E501 - model_response = await self._generate_response_with_model(message, current_model,thinking_id) + model_response = await self._generate_response_with_model(message, current_model, thinking_id) # print(f"raw_content: {model_response}") @@ -65,11 +65,11 @@ class ResponseGenerator: logger.info(f"{self.current_model_type}思考,失败") return None - async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request,thinking_id:str): + async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request, thinking_id: str): sender_name = "" - + info_catcher = info_catcher_manager.get_info_catcher(thinking_id) - + if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: sender_name = ( f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]" @@ -94,14 +94,11 @@ class ResponseGenerator: try: content, reasoning_content, self.current_model_name = await model.generate_response(prompt) - + info_catcher.catch_after_llm_generated( - prompt=prompt, - response=content, - reasoning_content=reasoning_content, - model_name=self.current_model_name) - - + prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=self.current_model_name + ) + except Exception: logger.exception("生成回复时出错") return None @@ -118,7 +115,6 @@ class ResponseGenerator: return content - # def _save_to_db( # self, # message: MessageRecv, diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py index 2ab34db11..15f6424c1 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py @@ -144,12 +144,10 @@ class PromptBuilder: for pattern in rule.get("regex", []): result = pattern.search(message_txt) if result: - reaction = rule.get('reaction', '') + reaction = rule.get("reaction", "") for name, content in result.groupdict().items(): - reaction = reaction.replace(f'[{name}]', content) - logger.info( - f"匹配到以下正则表达式:{pattern},触发反应:{reaction}" - ) + reaction = reaction.replace(f"[{name}]", content) + logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}") keywords_reaction_prompt += reaction + "," break diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py index 964aca55c..34c9860f0 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_chat.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_chat.py @@ -59,11 +59,7 @@ class ThinkFlowChat: return thinking_id - async def _send_response_messages(self, - message, - chat, - response_set:List[str], - thinking_id) -> MessageSending: + async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending: """发送回复消息""" container = message_manager.get_container(chat.stream_id) thinking_message = None @@ -260,8 +256,6 @@ class ThinkFlowChat: if random() < reply_probability: try: do_reply = True - - # 回复前处理 await willing_manager.before_generate_reply_handle(message.message_info.message_id) @@ -274,9 +268,9 @@ class ThinkFlowChat: timing_results["创建思考消息"] = timer2 - timer1 except Exception as e: logger.error(f"心流创建思考消息失败: {e}") - + logger.debug(f"创建捕捉器,thinking_id:{thinking_id}") - + info_catcher = info_catcher_manager.get_info_catcher(thinking_id) info_catcher.catch_decide_to_response(message) @@ -288,32 +282,32 @@ class ThinkFlowChat: timing_results["观察"] = timer2 - timer1 except Exception as e: logger.error(f"心流观察失败: {e}") - + info_catcher.catch_after_observe(timing_results["观察"]) # 思考前脑内状态 try: timer1 = time.time() - current_mind,past_mind = await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply( - message_txt = message.processed_plain_text, - sender_name = message.message_info.user_info.user_nickname, - chat_stream = chat + current_mind, past_mind = await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply( + message_txt=message.processed_plain_text, + sender_name=message.message_info.user_info.user_nickname, + chat_stream=chat, ) timer2 = time.time() timing_results["思考前脑内状态"] = timer2 - timer1 except Exception as e: logger.error(f"心流思考前脑内状态失败: {e}") - - info_catcher.catch_afer_shf_step(timing_results["思考前脑内状态"],past_mind,current_mind) + + info_catcher.catch_afer_shf_step(timing_results["思考前脑内状态"], past_mind, current_mind) # 生成回复 timer1 = time.time() - response_set = await self.gpt.generate_response(message,thinking_id) + response_set = await self.gpt.generate_response(message, thinking_id) timer2 = time.time() timing_results["生成回复"] = timer2 - timer1 info_catcher.catch_after_generate_response(timing_results["生成回复"]) - + if not response_set: logger.info("回复生成失败,返回为空") return @@ -326,11 +320,9 @@ class ThinkFlowChat: timing_results["发送消息"] = timer2 - timer1 except Exception as e: logger.error(f"心流发送消息失败: {e}") - - - info_catcher.catch_after_response(timing_results["发送消息"],response_set,first_bot_msg) - - + + info_catcher.catch_after_response(timing_results["发送消息"], response_set, first_bot_msg) + info_catcher.done_catch() # 处理表情包 diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py index 164e8ab7c..df55ad80b 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_generator.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_generator.py @@ -35,44 +35,51 @@ class ResponseGenerator: self.current_model_type = "r1" # 默认使用 R1 self.current_model_name = "unknown model" - async def generate_response(self, message: MessageRecv,thinking_id:str) -> Optional[List[str]]: + async def generate_response(self, message: MessageRecv, thinking_id: str) -> Optional[List[str]]: """根据当前模型类型选择对应的生成函数""" - logger.info( f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" ) - + arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier() - + time1 = time.time() - + checked = False if random.random() > 0: checked = False current_model = self.model_normal - current_model.temperature = 0.3 * arousal_multiplier #激活度越高,温度越高 - model_response = await self._generate_response_with_model(message, current_model,thinking_id,mode="normal") - + current_model.temperature = 0.3 * arousal_multiplier # 激活度越高,温度越高 + model_response = await self._generate_response_with_model( + message, current_model, thinking_id, mode="normal" + ) + model_checked_response = model_response else: checked = True current_model = self.model_normal - current_model.temperature = 0.3 * arousal_multiplier #激活度越高,温度越高 + current_model.temperature = 0.3 * arousal_multiplier # 激活度越高,温度越高 print(f"生成{message.processed_plain_text}回复温度是:{current_model.temperature}") - model_response = await self._generate_response_with_model(message, current_model,thinking_id,mode="simple") - + model_response = await self._generate_response_with_model( + message, current_model, thinking_id, mode="simple" + ) + current_model.temperature = 0.3 - model_checked_response = await self._check_response_with_model(message, model_response, current_model,thinking_id) + model_checked_response = await self._check_response_with_model( + message, model_response, current_model, thinking_id + ) time2 = time.time() if model_response: if checked: - logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},思忖后,回复是:{model_checked_response},生成回复时间: {time2 - time1}秒") + logger.info( + f"{global_config.BOT_NICKNAME}的回复是:{model_response},思忖后,回复是:{model_checked_response},生成回复时间: {time2 - time1}秒" + ) else: logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},生成回复时间: {time2 - time1}秒") - + model_processed_response = await self._process_response(model_checked_response) return model_processed_response @@ -80,11 +87,13 @@ class ResponseGenerator: logger.info(f"{self.current_model_type}思考,失败") return None - async def _generate_response_with_model(self, message: MessageRecv, model: LLM_request,thinking_id:str,mode:str = "normal") -> str: + async def _generate_response_with_model( + self, message: MessageRecv, model: LLM_request, thinking_id: str, mode: str = "normal" + ) -> str: sender_name = "" - + info_catcher = info_catcher_manager.get_info_catcher(thinking_id) - + if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: sender_name = ( f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]" @@ -116,25 +125,22 @@ class ResponseGenerator: try: content, reasoning_content, self.current_model_name = await model.generate_response(prompt) - - + info_catcher.catch_after_llm_generated( - prompt=prompt, - response=content, - reasoning_content=reasoning_content, - model_name=self.current_model_name) - + prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=self.current_model_name + ) + except Exception: logger.exception("生成回复时出错") return None - return content - - async def _check_response_with_model(self, message: MessageRecv, content:str, model: LLM_request,thinking_id:str) -> str: - + + async def _check_response_with_model( + self, message: MessageRecv, content: str, model: LLM_request, thinking_id: str + ) -> str: _info_catcher = info_catcher_manager.get_info_catcher(thinking_id) - + sender_name = "" if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: sender_name = ( @@ -145,8 +151,7 @@ class ResponseGenerator: sender_name = f"({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}" else: sender_name = f"用户({message.chat_stream.user_info.user_id})" - - + # 构建prompt timer1 = time.time() prompt = await prompt_builder._build_prompt_check_response( @@ -154,7 +159,7 @@ class ResponseGenerator: message_txt=message.processed_plain_text, sender_name=sender_name, stream_id=message.chat_stream.stream_id, - content=content + content=content, ) timer2 = time.time() logger.info(f"构建check_prompt: {prompt}") @@ -162,19 +167,17 @@ class ResponseGenerator: try: checked_content, reasoning_content, self.current_model_name = await model.generate_response(prompt) - - + # info_catcher.catch_after_llm_generated( # prompt=prompt, # response=content, # reasoning_content=reasoning_content, # model_name=self.current_model_name) - + except Exception: logger.exception("检查回复时出错") return None - return checked_content async def _get_emotion_tags(self, content: str, processed_plain_text: str): diff --git a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py index 6ebbd43ae..cfc419738 100644 --- a/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py +++ b/src/plugins/chat_module/think_flow_chat/think_flow_prompt_builder.py @@ -110,12 +110,10 @@ class PromptBuilder: for pattern in rule.get("regex", []): result = pattern.search(message_txt) if result: - reaction = rule.get('reaction', '') + reaction = rule.get("reaction", "") for name, content in result.groupdict().items(): - reaction = reaction.replace(f'[{name}]', content) - logger.info( - f"匹配到以下正则表达式:{pattern},触发反应:{reaction}" - ) + reaction = reaction.replace(f"[{name}]", content) + logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}") keywords_reaction_prompt += reaction + "," break diff --git a/src/plugins/memory_system/Hippocampus.py b/src/plugins/memory_system/Hippocampus.py index 0a738b312..4e52afeca 100644 --- a/src/plugins/memory_system/Hippocampus.py +++ b/src/plugins/memory_system/Hippocampus.py @@ -225,6 +225,7 @@ class Memory_graph: return None + # 海马体 class Hippocampus: def __init__(self): @@ -653,7 +654,6 @@ class Hippocampus: return activation_ratio - # 负责海马体与其他部分的交互 class EntorhinalCortex: def __init__(self, hippocampus: Hippocampus): diff --git a/src/plugins/memory_system/debug_memory.py b/src/plugins/memory_system/debug_memory.py index 3b98d0d42..eff5d7d0d 100644 --- a/src/plugins/memory_system/debug_memory.py +++ b/src/plugins/memory_system/debug_memory.py @@ -27,7 +27,6 @@ async def test_memory_system(): # 测试记忆检索 test_text = "千石可乐在群里聊天" - # test_text = '''千石可乐:分不清AI的陪伴和人类的陪伴,是这样吗?''' print(f"开始测试记忆检索,测试文本: {test_text}\n") memories = await hippocampus_manager.get_memory_from_text( diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 1066453ff..a472b5bf7 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -574,7 +574,7 @@ class LLM_request: reasoning_content = message.get("reasoning_content", "") if not reasoning_content: reasoning_content = reasoning - + # 提取工具调用信息 tool_calls = message.get("tool_calls", None) @@ -592,7 +592,7 @@ class LLM_request: request_type=request_type if request_type is not None else self.request_type, endpoint=endpoint, ) - + # 只有当tool_calls存在且不为空时才返回 if tool_calls: return content, reasoning_content, tool_calls @@ -657,9 +657,7 @@ class LLM_request: **kwargs, } - response = await self._execute_request( - endpoint="/chat/completions", payload=data, prompt=prompt - ) + response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt) # 原样返回响应,不做处理 return response diff --git a/src/plugins/moods/moods.py b/src/plugins/moods/moods.py index 9ce0fd93b..9aae3c7c2 100644 --- a/src/plugins/moods/moods.py +++ b/src/plugins/moods/moods.py @@ -238,14 +238,14 @@ class MoodManager: base_prompt += "情绪比较平静。" return base_prompt - + def get_arousal_multiplier(self) -> float: """根据当前情绪状态返回唤醒度乘数""" if self.current_mood.arousal > 0.4: - multiplier = 1 + min(0.15,(self.current_mood.arousal - 0.4)/3) + multiplier = 1 + min(0.15, (self.current_mood.arousal - 0.4) / 3) return multiplier elif self.current_mood.arousal < -0.4: - multiplier = 1 - min(0.15,((0 - self.current_mood.arousal) - 0.4)/3) + multiplier = 1 - min(0.15, ((0 - self.current_mood.arousal) - 0.4) / 3) return multiplier return 1.0 diff --git a/src/plugins/respon_info_catcher/info_catcher.py b/src/plugins/respon_info_catcher/info_catcher.py index 4e9943b8c..3fe5ab645 100644 --- a/src/plugins/respon_info_catcher/info_catcher.py +++ b/src/plugins/respon_info_catcher/info_catcher.py @@ -1,28 +1,29 @@ from src.plugins.config.config import global_config -from src.plugins.chat.message import MessageRecv,MessageSending,Message +from src.plugins.chat.message import MessageRecv, MessageSending, Message from src.common.database import db import time import traceback from typing import List + class InfoCatcher: def __init__(self): - self.chat_history = [] # 聊天历史,长度为三倍使用的上下文 + self.chat_history = [] # 聊天历史,长度为三倍使用的上下文 self.context_length = global_config.MAX_CONTEXT_SIZE - self.chat_history_in_thinking = [] # 思考期间的聊天内容 - self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文 - + self.chat_history_in_thinking = [] # 思考期间的聊天内容 + self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文 + self.chat_id = "" self.response_mode = global_config.response_mode self.trigger_response_text = "" self.response_text = "" - + self.trigger_response_time = 0 self.trigger_response_message = None - + self.response_time = 0 self.response_messages = [] - + # 使用字典来存储 heartflow 模式的数据 self.heartflow_data = { "heart_flow_prompt": "", @@ -32,17 +33,12 @@ class InfoCatcher: "sub_heartflow_model": "", "prompt": "", "response": "", - "model": "" + "model": "", } - + # 使用字典来存储 reasoning 模式的数据 - self.reasoning_data = { - "thinking_log": "", - "prompt": "", - "response": "", - "model": "" - } - + self.reasoning_data = {"thinking_log": "", "prompt": "", "response": "", "model": ""} + # 耗时 self.timing_results = { "interested_rate_time": 0, @@ -50,24 +46,24 @@ class InfoCatcher: "sub_heartflow_step_time": 0, "make_response_time": 0, } - - def catch_decide_to_response(self,message:MessageRecv): + + def catch_decide_to_response(self, message: MessageRecv): # 搜集决定回复时的信息 self.trigger_response_message = message self.trigger_response_text = message.detailed_plain_text - + self.trigger_response_time = time.time() - + self.chat_id = message.chat_stream.stream_id - + self.chat_history = self.get_message_from_db_before_msg(message) - - def catch_after_observe(self,obs_duration:float):#这里可以有更多信息 + + def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息 self.timing_results["sub_heartflow_observe_time"] = obs_duration # def catch_shf - - def catch_afer_shf_step(self,step_duration:float,past_mind:str,current_mind:str): + + def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str): self.timing_results["sub_heartflow_step_time"] = step_duration if len(past_mind) > 1: self.heartflow_data["sub_heartflow_before"] = past_mind[-1] @@ -75,11 +71,8 @@ class InfoCatcher: else: self.heartflow_data["sub_heartflow_before"] = past_mind[-1] self.heartflow_data["sub_heartflow_now"] = current_mind - - def catch_after_llm_generated(self,prompt:str, - response:str, - reasoning_content:str = "", - model_name:str = ""): + + def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""): if self.response_mode == "heart_flow": self.heartflow_data["prompt"] = prompt self.heartflow_data["response"] = response @@ -89,41 +82,38 @@ class InfoCatcher: self.reasoning_data["prompt"] = prompt self.reasoning_data["response"] = response self.reasoning_data["model"] = model_name - + self.response_text = response - - def catch_after_generate_response(self,response_duration:float): + + def catch_after_generate_response(self, response_duration: float): self.timing_results["make_response_time"] = response_duration - - - - def catch_after_response(self,response_duration:float, - response_message:List[str], - first_bot_msg:MessageSending): + + def catch_after_response( + self, response_duration: float, response_message: List[str], first_bot_msg: MessageSending + ): self.timing_results["make_response_time"] = response_duration self.response_time = time.time() for msg in response_message: self.response_messages.append(msg) - - self.chat_history_in_thinking = self.get_message_from_db_between_msgs(self.trigger_response_message,first_bot_msg) - + + self.chat_history_in_thinking = self.get_message_from_db_between_msgs( + self.trigger_response_message, first_bot_msg + ) + def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message): try: # 从数据库中获取消息的时间戳 time_start = message_start.message_info.time time_end = message_end.message_info.time chat_id = message_start.chat_stream.stream_id - + print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") - + # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 messages_between = db.messages.find( - { - "chat_id": chat_id, - "time": {"$gt": time_start, "$lt": time_end} - } + {"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}} ).sort("time", -1) - + result = list(messages_between) print(f"查询结果数量: {len(result)}") if result: @@ -133,21 +123,23 @@ class InfoCatcher: except Exception as e: print(f"获取消息时出错: {str(e)}") return [] - + def get_message_from_db_before_msg(self, message: MessageRecv): # 从数据库中获取消息 message_id = message.message_info.message_id chat_id = message.chat_stream.stream_id - + # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 - messages_before = db.messages.find( - {"chat_id": chat_id, "message_id": {"$lt": message_id}} - ).sort("time", -1).limit(self.context_length*3) #获取更多历史信息 - + messages_before = ( + db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) + .sort("time", -1) + .limit(self.context_length * 3) + ) # 获取更多历史信息 + return list(messages_before) - + def message_list_to_dict(self, message_list): - #存储简化的聊天记录 + # 存储简化的聊天记录 result = [] for message in message_list: if not isinstance(message, dict): @@ -160,7 +152,7 @@ class InfoCatcher: "processed_plain_text": message["processed_plain_text"], } result.append(lite_message) - + return result def message_to_dict(self, message): @@ -176,12 +168,12 @@ class InfoCatcher: "processed_plain_text": message.processed_plain_text, # "detailed_plain_text": message.detailed_plain_text } - + def done_catch(self): """将收集到的信息存储到数据库的 thinking_log 集合中""" try: # 将消息对象转换为可序列化的字典 - + thinking_log_data = { "chat_id": self.chat_id, "response_mode": self.response_mode, @@ -198,7 +190,7 @@ class InfoCatcher: "timing_results": self.timing_results, "chat_history": self.message_list_to_dict(self.chat_history), "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking), - "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response) + "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response), } # 根据不同的响应模式添加相应的数据 @@ -209,20 +201,22 @@ class InfoCatcher: # 将数据插入到 thinking_log 集合中 db.thinking_log.insert_one(thinking_log_data) - + return True except Exception as e: print(f"存储思考日志时出错: {str(e)}") print(traceback.format_exc()) return False + class InfoCatcherManager: def __init__(self): self.info_catchers = {} - def get_info_catcher(self,thinking_id:str) -> InfoCatcher: + def get_info_catcher(self, thinking_id: str) -> InfoCatcher: if thinking_id not in self.info_catchers: self.info_catchers[thinking_id] = InfoCatcher() return self.info_catchers[thinking_id] -info_catcher_manager = InfoCatcherManager() \ No newline at end of file + +info_catcher_manager = InfoCatcherManager() diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index c1b5fdec6..f75065cf8 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -32,7 +32,7 @@ class ScheduleGenerator: # 使用离线LLM模型 self.llm_scheduler_all = LLM_request( model=global_config.llm_reasoning, - temperature=global_config.SCHEDULE_TEMPERATURE+0.3, + temperature=global_config.SCHEDULE_TEMPERATURE + 0.3, max_tokens=7000, request_type="schedule", ) diff --git a/src/plugins/storage/storage.py b/src/plugins/storage/storage.py index d07b02719..577b40340 100644 --- a/src/plugins/storage/storage.py +++ b/src/plugins/storage/storage.py @@ -8,6 +8,7 @@ from src.common.logger import get_module_logger logger = get_module_logger("message_storage") + class MessageStorage: async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: """存储消息到数据库""" diff --git a/src/plugins/utils/prompt_builder.py b/src/plugins/utils/prompt_builder.py index 60c6a70c2..abf5fe392 100644 --- a/src/plugins/utils/prompt_builder.py +++ b/src/plugins/utils/prompt_builder.py @@ -5,6 +5,7 @@ from typing import Dict, Any, Optional, List, Union from contextlib import asynccontextmanager import asyncio + class PromptContext: def __init__(self): self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} @@ -129,7 +130,9 @@ class Prompt(str): return obj @classmethod - async def create_async(cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs): + async def create_async( + cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs + ): """异步创建Prompt实例""" prompt = cls(fstr, name, args, **kwargs) if global_prompt_manager._context._current_context: diff --git a/src/plugins/willing/mode_classical.py b/src/plugins/willing/mode_classical.py index 74f24350f..294539d08 100644 --- a/src/plugins/willing/mode_classical.py +++ b/src/plugins/willing/mode_classical.py @@ -1,6 +1,7 @@ import asyncio from .willing_manager import BaseWillingManager + class ClassicalWillingManager(BaseWillingManager): def __init__(self): super().__init__() @@ -41,17 +42,22 @@ class ClassicalWillingManager(BaseWillingManager): self.chat_reply_willing[chat_id] = min(current_willing, 3.0) - reply_probability = min(max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1) + reply_probability = min( + max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1 + ) # 检查群组权限(如果是群聊) - if willing_info.group_info and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups: + if ( + willing_info.group_info + and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups + ): reply_probability = reply_probability / self.global_config.down_frequency_rate if is_emoji_not_reply: reply_probability = 0 return reply_probability - + async def before_generate_reply_handle(self, message_id): chat_id = self.ongoing_messages[message_id].chat_id current_willing = self.chat_reply_willing.get(chat_id, 0) @@ -71,8 +77,6 @@ class ClassicalWillingManager(BaseWillingManager): async def get_variable_parameters(self): return await super().get_variable_parameters() - + async def set_variable_parameters(self, parameters): return await super().set_variable_parameters(parameters) - - diff --git a/src/plugins/willing/mode_custom.py b/src/plugins/willing/mode_custom.py index 786c779b4..c3a5c3078 100644 --- a/src/plugins/willing/mode_custom.py +++ b/src/plugins/willing/mode_custom.py @@ -4,4 +4,3 @@ from .willing_manager import BaseWillingManager class CustomWillingManager(BaseWillingManager): def __init__(self): super().__init__() - diff --git a/src/plugins/willing/mode_dynamic.py b/src/plugins/willing/mode_dynamic.py index 523c05244..0487a1a98 100644 --- a/src/plugins/willing/mode_dynamic.py +++ b/src/plugins/willing/mode_dynamic.py @@ -20,7 +20,6 @@ class DynamicWillingManager(BaseWillingManager): self._decay_task = None self._mode_switch_task = None - async def async_task_starter(self): if self._decay_task is None: self._decay_task = asyncio.create_task(self._decay_reply_willing()) @@ -84,7 +83,9 @@ class DynamicWillingManager(BaseWillingManager): self.chat_high_willing_mode[chat_id] = True self.chat_reply_willing[chat_id] = 1.0 # 设置为较高回复意愿 self.chat_high_willing_duration[chat_id] = random.randint(180, 240) # 3-4分钟 - self.logger.debug(f"聊天流 {chat_id} 切换到高回复意愿期,持续 {self.chat_high_willing_duration[chat_id]} 秒") + self.logger.debug( + f"聊天流 {chat_id} 切换到高回复意愿期,持续 {self.chat_high_willing_duration[chat_id]} 秒" + ) self.chat_last_mode_change[chat_id] = time.time() self.chat_msg_count[chat_id] = 0 # 重置消息计数 @@ -148,7 +149,9 @@ class DynamicWillingManager(BaseWillingManager): # 根据话题兴趣度适当调整 if willing_info.interested_rate > 0.5: - current_willing += (willing_info.interested_rate - 0.5) * 0.5 * self.global_config.response_interested_rate_amplifier + current_willing += ( + (willing_info.interested_rate - 0.5) * 0.5 * self.global_config.response_interested_rate_amplifier + ) # 根据当前模式计算回复概率 base_probability = 0.0 @@ -228,12 +231,12 @@ class DynamicWillingManager(BaseWillingManager): async def bombing_buffer_message_handle(self, message_id): return await super().bombing_buffer_message_handle(message_id) - + async def after_generate_reply_handle(self, message_id): return await super().after_generate_reply_handle(message_id) async def get_variable_parameters(self): return await super().get_variable_parameters() - + async def set_variable_parameters(self, parameters): - return await super().set_variable_parameters(parameters) \ No newline at end of file + return await super().set_variable_parameters(parameters) diff --git a/src/plugins/willing/mode_mxp.py b/src/plugins/willing/mode_mxp.py index 25627d707..b4fc1448c 100644 --- a/src/plugins/willing/mode_mxp.py +++ b/src/plugins/willing/mode_mxp.py @@ -17,19 +17,22 @@ Mxp 模式:梦溪畔独家赞助 中策是发issue 下下策是询问一个菜鸟(@梦溪畔) """ + from .willing_manager import BaseWillingManager from typing import Dict import asyncio import time import math + class MxpWillingManager(BaseWillingManager): """Mxp意愿管理器""" + def __init__(self): super().__init__() self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值} self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间 - self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息 + self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息 self.temporary_willing: float = 0 # 临时意愿值 # 可变参数 @@ -39,8 +42,8 @@ class MxpWillingManager(BaseWillingManager): self.basic_maximum_willing = 0.5 # 基础最大意愿值 self.mention_willing_gain = 0.6 # 提及意愿增益 self.interest_willing_gain = 0.3 # 兴趣意愿增益 - self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚 - self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数 + self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚 + self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数 self.single_chat_gain = 0.12 # 单聊增益 async def async_task_starter(self) -> None: @@ -73,9 +76,13 @@ class MxpWillingManager(BaseWillingManager): w_info = self.ongoing_messages[message_id] if w_info.is_mentioned_bot: self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += 0.2 - if w_info.chat_id in self.last_response_person and self.last_response_person[w_info.chat_id][0] == w_info.person_id: - self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] +=\ - self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1) + if ( + w_info.chat_id in self.last_response_person + and self.last_response_person[w_info.chat_id][0] == w_info.person_id + ): + self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * ( + 2 * self.last_response_person[w_info.chat_id][1] + 1 + ) now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0]) if now_chat_new_person[0] != w_info.person_id: self.last_response_person[w_info.chat_id] = [w_info.person_id, 0] @@ -98,7 +105,10 @@ class MxpWillingManager(BaseWillingManager): rel_level = self._get_relationship_level_num(rel_value) current_willing += rel_level * 0.1 - if w_info.chat_id in self.last_response_person and self.last_response_person[w_info.chat_id][0] == w_info.person_id: + if ( + w_info.chat_id in self.last_response_person + and self.last_response_person[w_info.chat_id][0] == w_info.person_id + ): current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1) chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id] @@ -141,16 +151,22 @@ class MxpWillingManager(BaseWillingManager): self.logger.debug(f"聊天流{chat_id}不存在,错误") continue basic_willing = self.chat_reply_willing[chat_id] - person_willing[person_id] = basic_willing + (willing - basic_willing) * self.intention_decay_rate + person_willing[person_id] = ( + basic_willing + (willing - basic_willing) * self.intention_decay_rate + ) def setup(self, message, chat, is_mentioned_bot, interested_rate): super().setup(message, chat, is_mentioned_bot, interested_rate) - self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(chat.stream_id, self.basic_maximum_willing) + self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get( + chat.stream_id, self.basic_maximum_willing + ) self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {}) - self.chat_person_reply_willing[chat.stream_id][self.ongoing_messages[message.message_info.message_id].person_id] = \ - self.chat_person_reply_willing[chat.stream_id].get(self.ongoing_messages[message.message_info.message_id].person_id, - self.chat_reply_willing[chat.stream_id]) + self.chat_person_reply_willing[chat.stream_id][ + self.ongoing_messages[message.message_info.message_id].person_id + ] = self.chat_person_reply_willing[chat.stream_id].get( + self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id] + ) if chat.stream_id not in self.chat_new_message_time: self.chat_new_message_time[chat.stream_id] = [] @@ -166,7 +182,7 @@ class MxpWillingManager(BaseWillingManager): else: probability = math.atan(willing * 4) / math.pi * 2 return probability - + async def _chat_new_message_to_change_basic_willing(self): """聊天流新消息改变基础意愿""" while True: @@ -174,10 +190,11 @@ class MxpWillingManager(BaseWillingManager): await asyncio.sleep(update_time) async with self.lock: for chat_id, message_times in self.chat_new_message_time.items(): - # 清理过期消息 current_time = time.time() - message_times = [msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time] + message_times = [ + msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time + ] self.chat_new_message_time[chat_id] = message_times if len(message_times) < self.number_of_message_storage: @@ -185,7 +202,9 @@ class MxpWillingManager(BaseWillingManager): update_time = 20 elif len(message_times) == self.number_of_message_storage: time_interval = current_time - message_times[0] - basic_willing = self.basic_maximum_willing * math.sqrt(time_interval / self.message_expiration_time) + basic_willing = self.basic_maximum_willing * math.sqrt( + time_interval / self.message_expiration_time + ) self.chat_reply_willing[chat_id] = basic_willing update_time = 17 * math.sqrt(time_interval / self.message_expiration_time) + 3 else: @@ -203,7 +222,7 @@ class MxpWillingManager(BaseWillingManager): "interest_willing_gain": "兴趣意愿增益", "emoji_response_penalty": "表情包回复惩罚", "down_frequency_rate": "降低回复频率的群组惩罚系数", - "single_chat_gain": "单聊增益(不仅是私聊)" + "single_chat_gain": "单聊增益(不仅是私聊)", } async def set_variable_parameters(self, parameters: Dict[str, any]): @@ -215,7 +234,7 @@ class MxpWillingManager(BaseWillingManager): self.logger.debug(f"参数 {key} 已更新为 {value}") else: self.logger.debug(f"尝试设置未知参数 {key}") - + def _get_relationship_level_num(self, relationship_value) -> int: """关系等级计算""" if -1000 <= relationship_value < -227: @@ -235,4 +254,4 @@ class MxpWillingManager(BaseWillingManager): return level_num - 2 async def get_willing(self, chat_id): - return self.temporary_willing \ No newline at end of file + return self.temporary_willing diff --git a/src/plugins/willing/willing_manager.py b/src/plugins/willing/willing_manager.py index 07e02a29b..ada995120 100644 --- a/src/plugins/willing/willing_manager.py +++ b/src/plugins/willing/willing_manager.py @@ -1,4 +1,3 @@ - from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger from dataclasses import dataclass from ..config.config import global_config, BotConfig @@ -38,10 +37,11 @@ willing_config = LogConfig( ) logger = get_module_logger("willing", config=willing_config) + @dataclass class WillingInfo: """此类保存意愿模块常用的参数 - + Attributes: message (MessageRecv): 原始消息对象 chat (ChatStream): 聊天流对象 @@ -53,6 +53,7 @@ class WillingInfo: is_emoji (bool): 是否为表情包 interested_rate (float): 兴趣度 """ + message: MessageRecv chat: ChatStream person_info_manager: PersonInfoManager @@ -60,22 +61,21 @@ class WillingInfo: person_id: str group_info: Optional[GroupInfo] is_mentioned_bot: bool - is_emoji: bool + is_emoji: bool interested_rate: float # current_mood: float 当前心情? + class BaseWillingManager(ABC): """回复意愿管理基类""" - + @classmethod - def create(cls, manager_type: str) -> 'BaseWillingManager': + def create(cls, manager_type: str) -> "BaseWillingManager": try: module = importlib.import_module(f".mode_{manager_type}", __package__) manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager") if not issubclass(manager_class, cls): - raise TypeError( - f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}" - ) + raise TypeError(f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}") else: logger.info(f"成功载入willing模式:{manager_type}") return manager_class() @@ -85,7 +85,7 @@ class BaseWillingManager(ABC): logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~") logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}。") return manager_class() - + def __init__(self): self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id) self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id) @@ -136,17 +136,17 @@ class BaseWillingManager(ABC): async def get_reply_probability(self, message_id: str): """抽象方法:获取回复概率""" raise NotImplementedError - + @abstractmethod async def bombing_buffer_message_handle(self, message_id: str): """抽象方法:炸飞消息处理""" pass - + async def get_willing(self, chat_id: str): """获取指定聊天流的回复意愿""" async with self.lock: return self.chat_reply_willing.get(chat_id, 0) - + async def set_willing(self, chat_id: str, willing: float): """设置指定聊天流的回复意愿""" async with self.lock: @@ -173,5 +173,6 @@ def init_willing_manager() -> BaseWillingManager: mode = global_config.willing_mode.lower() return BaseWillingManager.create(mode) + # 全局willing_manager对象 willing_manager = init_willing_manager()