From 307d5a73a69f83b854b7bdc87433c2e50a36df9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 19:42:53 +0800 Subject: [PATCH 01/57] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E5=AF=BC=E5=85=A5=E8=B7=AF=E5=BE=84=EF=BC=8C?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=E6=97=A7=E7=9A=84=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E5=B9=B6=E6=B7=BB=E5=8A=A0=E6=96=B0=E7=9A=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 2 +- src/chat/focus_chat/heartflow_prompt_builder.py | 2 +- src/chat/memory_system/Hippocampus.py | 2 +- src/chat/memory_system/manually_alter_memory.py | 2 +- src/chat/message_receive/chat_stream.py | 2 +- src/chat/message_receive/storage.py | 2 +- src/chat/models/utils_model.py | 2 +- src/chat/person_info/person_info.py | 2 +- src/chat/utils/info_catcher.py | 2 +- src/chat/utils/statistic.py | 2 +- src/chat/utils/utils.py | 2 +- src/chat/utils/utils_image.py | 2 +- src/chat/zhishi/knowledge_library.py | 2 +- src/common/{ => database}/database.py | 0 src/common/database/database_model.py | 2 ++ src/common/message_repository.py | 2 +- src/experimental/PFC/message_storage.py | 2 +- src/tools/tool_can_use/get_knowledge.py | 2 +- 18 files changed, 18 insertions(+), 16 deletions(-) rename src/common/{ => database}/database.py (100%) create mode 100644 src/common/database/database_model.py diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 5d800866f..076dbf5a4 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -10,7 +10,7 @@ from PIL import Image import io import re -from ...common.database import db +from ...common.database.database import db from ...config.config import global_config from ..utils.utils_image import image_path_to_base64, image_manager from ..models.utils_model import LLMRequest diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 55fb79b46..d8babe2e5 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -7,7 +7,7 @@ from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time from typing import Union, Optional, Dict, Any -from src.common.database import db +from common.database.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 70eb679c9..e64475126 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -10,7 +10,7 @@ import jieba import networkx as nx import numpy as np from collections import Counter -from ...common.database import db +from ...common.database.database import db from ...chat.models.utils_model import LLMRequest from src.common.logger_manager import get_logger from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 diff --git a/src/chat/memory_system/manually_alter_memory.py b/src/chat/memory_system/manually_alter_memory.py index ce5abbba7..9bbf59f5b 100644 --- a/src/chat/memory_system/manually_alter_memory.py +++ b/src/chat/memory_system/manually_alter_memory.py @@ -34,7 +34,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) from src.common.logger import get_module_logger # noqa E402 -from src.common.database import db # noqa E402 +from common.database.database import db # noqa E402 logger = get_module_logger("mem_alter") console = Console() diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 53ebd5026..7f41ac96b 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -5,7 +5,7 @@ import copy from typing import Dict, Optional -from ...common.database import db +from ...common.database.database import db from maim_message import GroupInfo, UserInfo from src.common.logger_manager import get_logger diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index cae029a11..eb6ea73df 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,7 +1,7 @@ import re from typing import Union -from ...common.database import db +from ...common.database.database import db from .message import MessageSending, MessageRecv from .chat_stream import ChatStream from src.common.logger import get_module_logger diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py index e662a8e33..9ca4e56d0 100644 --- a/src/chat/models/utils_model.py +++ b/src/chat/models/utils_model.py @@ -12,7 +12,7 @@ import base64 from PIL import Image import io import os -from ...common.database import db +from ...common.database.database import db from ...config.config import global_config from rich.traceback import install diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index 605b86b23..00cbe86f1 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -1,5 +1,5 @@ from src.common.logger_manager import get_logger -from ...common.database import db +from ...common.database.database import db import copy import hashlib from typing import Any, Callable, Dict diff --git a/src/chat/utils/info_catcher.py b/src/chat/utils/info_catcher.py index 174bb5b49..b7f59c661 100644 --- a/src/chat/utils/info_catcher.py +++ b/src/chat/utils/info_catcher.py @@ -1,6 +1,6 @@ from src.config.config import global_config from src.chat.message_receive.message import MessageRecv, MessageSending, Message -from src.common.database import db +from common.database.database import db import time import traceback from typing import List diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 3f9832926..4bcf6fea0 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Tuple, List from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask -from ...common.database import db +from ...common.database.database import db from src.manager.local_store_manager import local_storage logger = get_module_logger("maibot_statistic") diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 8fe8334b8..f78a0c114 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -13,7 +13,7 @@ from src.manager.mood_manager import mood_manager from ..message_receive.message import MessageRecv from ..models.utils_model import LLMRequest from .typo_generator import ChineseTypoGenerator -from ...common.database import db +from ...common.database.database import db from ...config.config import global_config logger = get_module_logger("chat_utils") diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 455038246..6fbafc905 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -8,7 +8,7 @@ import io import numpy as np -from ...common.database import db +from ...common.database.database import db from ...config.config import global_config from ..models.utils_model import LLMRequest diff --git a/src/chat/zhishi/knowledge_library.py b/src/chat/zhishi/knowledge_library.py index 6fa1d3e1a..0068a153c 100644 --- a/src/chat/zhishi/knowledge_library.py +++ b/src/chat/zhishi/knowledge_library.py @@ -16,7 +16,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) # 现在可以导入src模块 -from src.common.database import db # noqa E402 +from common.database.database import db # noqa E402 # 加载根目录下的env.edv文件 diff --git a/src/common/database.py b/src/common/database/database.py similarity index 100% rename from src/common/database.py rename to src/common/database/database.py diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py new file mode 100644 index 000000000..45cecfed6 --- /dev/null +++ b/src/common/database/database_model.py @@ -0,0 +1,2 @@ +from peewee import * + diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 03f192cea..03eaba332 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,4 +1,4 @@ -from src.common.database import db +from common.database.database import db from src.common.logger import get_module_logger import traceback from typing import List, Any, Optional diff --git a/src/experimental/PFC/message_storage.py b/src/experimental/PFC/message_storage.py index cd6a01e34..24866e38c 100644 --- a/src/experimental/PFC/message_storage.py +++ b/src/experimental/PFC/message_storage.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any -from src.common.database import db +from common.database.database import db class MessageStorage(ABC): diff --git a/src/tools/tool_can_use/get_knowledge.py b/src/tools/tool_can_use/get_knowledge.py index 65acd55c0..2a4922f9f 100644 --- a/src/tools/tool_can_use/get_knowledge.py +++ b/src/tools/tool_can_use/get_knowledge.py @@ -1,6 +1,6 @@ from src.tools.tool_can_use.base_tool import BaseTool from src.chat.utils.utils import get_embedding -from src.common.database import db +from common.database.database import db from src.common.logger_manager import get_logger from typing import Any, Union From 88ab2bcaf49add0a52d468be452f251a135ca9ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 20:16:46 +0800 Subject: [PATCH 02/57] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=A8=A1=E5=9E=8B=EF=BC=8C=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=9F=BA=E7=A1=80=E6=A8=A1=E5=9E=8B=E5=92=8C=E5=A4=9A=E4=B8=AA?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E8=A1=A8=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/database/database_model.py | 145 +++++++++++++++++++++++++- 1 file changed, 144 insertions(+), 1 deletion(-) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 45cecfed6..bb00abbaa 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,2 +1,145 @@ -from peewee import * +from peewee import Model, DoubleField, IntegerField, SqliteDatabase, BooleanField, TextField, FloatField + +# 请在此处定义您的数据库实例。 +# 您需要取消注释并配置适合您的数据库的部分。 +# 例如,对于 SQLite: +db = SqliteDatabase('my_application.db') +# +# 对于 PostgreSQL: +# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password', +# host='localhost', port=5432) +# +# 对于 MySQL: +# db = MySQLDatabase('your_db_name', user='your_user', password='your_password', +# host='localhost', port=3306) + +# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。 +# 这允许您在一个地方为所有模型指定数据库。 +class BaseModel(Model): + class Meta: + # 将下面的 'db' 替换为您实际的数据库实例变量名。 + database = db # 例如: database = my_actual_db_instance + pass # 在用户定义数据库实例之前,此处为占位符 + +class ChatStreams(BaseModel): + """ + 用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。 + """ + # stream_id: "a544edeb1a9b73e3e1d77dff36e41264" + # 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。 + stream_id = TextField(unique=True, index=True) + + # create_time: 1746096761.4490178 (时间戳,精确到小数点后7位) + # DoubleField 用于存储浮点数,适合此类时间戳。 + create_time = DoubleField() + + # group_info 字段: + # platform: "qq" + # group_id: "941657197" + # group_name: "测试" + group_platform = TextField() + group_id = TextField() + group_name = TextField() + + # last_active_time: 1746623771.4825106 (时间戳,精确到小数点后7位) + last_active_time = DoubleField() + + # platform: "qq" (顶层平台字段) + platform = TextField() + + # user_info 字段: + # platform: "qq" + # user_id: "1787882683" + # user_nickname: "墨梓柒(IceSakurary)" + # user_cardname: "" + user_platform = TextField() + user_id = TextField() + user_nickname = TextField() + # user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。 + user_cardname = TextField(null=True) + + class Meta: + # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 + # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它, + # 请取消注释并在下面设置数据库实例: + # database = db + table_name = 'chat_streams' # 可选:明确指定数据库中的表名 + +class LLMUsage(BaseModel): + """ + 用于存储 API 使用日志数据的模型。 + """ + model_name = TextField() + user_id = TextField() + request_type = TextField() + endpoint = TextField() + prompt_tokens = IntegerField() + completion_tokens = IntegerField() + total_tokens = IntegerField() + cost = DoubleField() + status = TextField() + # timestamp: "$date": "2025-05-01T18:52:50.870Z" (存储为字符串) + timestamp = TextField() + + class Meta: + # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 + # database = db + table_name = 'llm_usage' + +class Emoji(BaseModel): + """表情包""" + + full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名) + format = TextField() # 图片格式 + hash = TextField(index=True) # 表情包的哈希值 + description = TextField() # 表情包的描述 + query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数) + is_registered = BooleanField(default=False) # 是否已注册 + is_banned = BooleanField(default=False) # 是否被禁止注册 + # emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化 + emotion = TextField(null=True) + record_time = FloatField() # 记录时间(被创建的时间) + register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间) + usage_count = IntegerField(default=0) # 使用次数(被使用的次数) + last_used_time = FloatField(null=True) # 上次使用时间 + + class Meta: + # database = db # 继承自 BaseModel + table_name = 'emoji' + +class Messages(BaseModel): + """ + 用于存储消息数据的模型。 + """ + message_id = IntegerField(index=True) # 消息 ID + time = DoubleField() # 消息时间戳 + + chat_id = TextField(index=True) # 对应的 ChatStreams stream_id + + # 从 chat_info 扁平化而来的字段 + chat_info_stream_id = TextField() + chat_info_platform = TextField() + chat_info_user_platform = TextField() + chat_info_user_id = TextField() + chat_info_user_nickname = TextField() + chat_info_user_cardname = TextField(null=True) + chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在 + chat_info_group_id = TextField(null=True) + chat_info_group_name = TextField(null=True) + chat_info_create_time = DoubleField() + chat_info_last_active_time = DoubleField() + + # 从顶层 user_info 扁平化而来的字段 (消息发送者信息) + user_platform = TextField() + user_id = TextField() + user_nickname = TextField() + user_cardname = TextField(null=True) + + processed_plain_text = TextField(null=True) # 处理后的纯文本消息 + detailed_plain_text = TextField(null=True) # 详细的纯文本消息 + memorized_times = IntegerField(default=0) # 被记忆的次数 + + class Meta: + # database = db # 继承自 BaseModel + table_name = 'messages' From b66534120f416bf166df3febf1aef5ac40666a6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 20:24:45 +0800 Subject: [PATCH 03/57] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=9B=BE?= =?UTF-8?q?=E5=83=8F=E5=92=8C=E5=9C=A8=E7=BA=BF=E6=97=B6=E9=95=BF=E8=AE=B0?= =?UTF-8?q?=E5=BD=95=E6=A8=A1=E5=9E=8B=EF=BC=8C=E6=89=A9=E5=B1=95=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/database/database_model.py | 39 +++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index bb00abbaa..b3bf3f629 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -143,3 +143,42 @@ class Messages(BaseModel): # database = db # 继承自 BaseModel table_name = 'messages' +class Images(BaseModel): + """ + 用于存储图像信息的模型。 + """ + hash = TextField(index=True) # 图像的哈希值 + description = TextField(null=True) # 图像的描述 + path = TextField(unique=True) # 图像文件的路径 + timestamp = FloatField() # 时间戳 + type = TextField() # 图像类型,例如 "emoji" + + class Meta: + # database = db # 继承自 BaseModel + table_name = 'images' + +class ImageDescriptions(BaseModel): + """ + 用于存储图像描述信息的模型。 + """ + type = TextField() # 类型,例如 "emoji" + hash = TextField(index=True) # 图像的哈希值 + description = TextField() # 图像的描述 + timestamp = FloatField() # 时间戳 + + class Meta: + # database = db # 继承自 BaseModel + table_name = 'image_descriptions' + +class OnlineTime(BaseModel): + """ + 用于存储在线时长记录的模型。 + """ + # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) + timestamp = TextField() + duration = IntegerField() # 时长,单位分钟 + + class Meta: + # database = db # 继承自 BaseModel + table_name = 'online_time' + From df897a0f4220d384d0508363d49c844e981f134f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 20:40:57 +0800 Subject: [PATCH 04/57] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=A8=A1=E5=9E=8B=EF=BC=8C=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E4=B8=AA=E4=BA=BA=E4=BF=A1=E6=81=AF=E5=AD=98=E5=82=A8=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../focus_chat/heartflow_prompt_builder.py | 4 ++-- src/common/database/database_model.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index d8babe2e5..1acef540e 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -6,14 +6,14 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time -from typing import Union, Optional, Dict, Any +from typing import Union, Optional from common.database.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.focus_chat.expressors.exprssion_learner import expression_learner -import traceback +# import traceback import random diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index b3bf3f629..c1135a33d 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -182,3 +182,23 @@ class OnlineTime(BaseModel): # database = db # 继承自 BaseModel table_name = 'online_time' +class PersonInfo(BaseModel): + """ + 用于存储个人信息数据的模型。 + """ + person_id = TextField(unique=True, index=True) # 个人唯一ID + person_name = TextField() # 个人名称 + name_reason = TextField(null=True) # 名称设定的原因 + platform = TextField() # 平台 + user_id = TextField(index=True) # 用户ID + nickname = TextField() # 用户昵称 + relationship_value = IntegerField(default=0) # 关系值 + konw_time = FloatField() # 认识时间 (时间戳) + msg_interval = IntegerField() # 消息间隔 + # msg_interval_list: 存储为 JSON 字符串的列表 + msg_interval_list = TextField(null=True) + + class Meta: + # database = db # 继承自 BaseModel + table_name = 'person_info' + From b84cc9240aab39e9ff5862f3ac8e773dd6b7b703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 22:53:21 +0800 Subject: [PATCH 05/57] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E4=BA=A4=E4=BA=92=E4=BB=A5=E4=BD=BF=E7=94=A8=20Peewee?= =?UTF-8?q?=20ORM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 更新数据库连接和模型定义,以便使用 Peewee for SQLite。 - 在消息存储和检索功能中,用 Peewee ORM 查询替换 MongoDB 查询。 - 为 Messages、ThinkingLog 和 OnlineTime 引入了新的模型,以方便结构化数据存储。 - 增强了数据库操作的错误处理和日志记录。 - 删除了过时的 MongoDB 集合管理代码。 - 通过利用 Peewee 内置的查询和数据操作方法来提升性能。 --- src/chat/emoji_system/emoji_manager.py | 188 ++++--- .../focus_chat/heartflow_prompt_builder.py | 157 +++--- src/chat/memory_system/Hippocampus.py | 2 +- src/chat/message_receive/chat_stream.py | 126 ++++- src/chat/models/utils_model.py | 41 +- src/chat/person_info/person_info.py | 531 ++++++++++-------- src/chat/utils/info_catcher.py | 160 +++--- src/chat/utils/statistic.py | 66 +-- src/chat/utils/utils_image.py | 175 +++--- src/common/database/database.py | 14 +- src/common/database/database_model.py | 54 +- src/common/message_repository.py | 92 ++- src/experimental/PFC/chat_observer.py | 4 +- src/experimental/PFC/message_storage.py | 37 +- src/tools/tool_can_use/get_knowledge.py | 110 ++-- 15 files changed, 999 insertions(+), 758 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 076dbf5a4..68fa5de44 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -10,7 +10,10 @@ from PIL import Image import io import re -from ...common.database.database import db +# from gradio_client import file + +from ...common.database.database_model import Emoji +from ...common.database.database import db as peewee_db from ...config.config import global_config from ..utils.utils_image import image_path_to_base64, image_manager from ..models.utils_model import LLMRequest @@ -143,37 +146,28 @@ class MaiEmoji: # --- 数据库操作 --- try: # 准备数据库记录 for emoji collection - emoji_record = { - "filename": self.filename, - "path": self.path, # 存储目录路径 - "full_path": self.full_path, # 存储完整文件路径 - "embedding": self.embedding, - "description": self.description, - "emotion": self.emotion, - "hash": self.hash, - "format": self.format, - "timestamp": int(self.register_time), - "usage_count": self.usage_count, - "last_used_time": self.last_used_time, - } - - # 使用upsert确保记录存在或被更新 - db["emoji"].update_one({"hash": self.hash}, {"$set": emoji_record}, upsert=True) + emotion_str = ",".join(self.emotion) if self.emotion else "" + Emoji.create(hash=self.hash, + full_path=self.full_path, + format=self.format, + description=self.description, + emotion=emotion_str, # Store as comma-separated string + query_count=0, # Default value + is_registered=True, + is_banned=False, # Default value + record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time + register_time=self.register_time, + usage_count=self.usage_count, + last_used_time=self.last_used_time, + ) + logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") return True except Exception as db_error: logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}") - # 数据库保存失败,是否需要将文件移回?为了简化,暂时只记录错误 - # 可以考虑在这里尝试删除已移动的文件,避免残留 - try: - if os.path.exists(self.full_path): # full_path 此时是目标路径 - os.remove(self.full_path) - logger.warning(f"[回滚] 已删除移动失败后残留的文件: {self.full_path}") - except Exception as remove_error: - logger.error(f"[错误] 回滚删除文件失败: {remove_error}") return False except Exception as e: @@ -201,10 +195,14 @@ class MaiEmoji: # 文件删除失败,但仍然尝试删除数据库记录 # 2. 删除数据库记录 - result = db.emoji.delete_one({"hash": self.hash}) - deleted_in_db = result.deleted_count > 0 + try: + will_delete_emoji = Emoji.get(Emoji.hash == self.hash) + result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. + except Emoji.DoesNotExist: + logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") + result = 0 # Indicate no DB record was deleted - if deleted_in_db: + if result > 0: logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") # 3. 标记对象已被删除 self.is_deleted = True @@ -246,44 +244,43 @@ def _emoji_objects_to_readable_list(emoji_objects): def _to_emoji_objects(data): emoji_objects = [] load_errors = 0 - emoji_data_list = list(data) + # data is now an iterable of Peewee Emoji model instances + emoji_data_list = list(data) - for emoji_data in emoji_data_list: - full_path = emoji_data.get("full_path") + for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance + full_path = emoji_data.full_path if not full_path: - logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: {emoji_data.get('_id')}") + logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}") load_errors += 1 - continue # 跳过缺少 full_path 的记录 + continue try: - # 使用 full_path 初始化 MaiEmoji 对象 emoji = MaiEmoji(full_path=full_path) - # 设置从数据库加载的属性 - emoji.hash = emoji_data.get("hash", "") - # 如果 hash 为空,也跳过?取决于业务逻辑 + emoji.hash = emoji_data.hash if not emoji.hash: logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") load_errors += 1 continue - emoji.description = emoji_data.get("description", "") - emoji.emotion = emoji_data.get("emotion", []) - emoji.usage_count = emoji_data.get("usage_count", 0) - # 优先使用 last_used_time,否则用 timestamp,最后用当前时间 - last_used = emoji_data.get("last_used_time") - timestamp = emoji_data.get("timestamp") - emoji.last_used_time = ( - last_used if last_used is not None else (timestamp if timestamp is not None else time.time()) - ) - emoji.register_time = timestamp if timestamp is not None else time.time() - emoji.format = emoji_data.get("format", "") # 加载格式 + emoji.description = emoji_data.description + # Deserialize emotion string from DB to list + emoji.emotion = emoji_data.emotion.split(',') if emoji_data.emotion else [] + emoji.usage_count = emoji_data.usage_count + + db_last_used_time = emoji_data.last_used_time + db_register_time = emoji_data.register_time - # 不需要再手动设置 path 和 filename,__init__ 会自动处理 + # If last_used_time from DB is None, use MaiEmoji's initialized register_time or current time + emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time + # If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time()) + emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time + + emoji.format = emoji_data.format emoji_objects.append(emoji) - except ValueError as ve: # 捕获 __init__ 可能的错误 + except ValueError as ve: logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}") load_errors += 1 except Exception as e: @@ -385,12 +382,13 @@ class EmojiManager: """初始化数据库连接和表情目录""" if not self._initialized: try: - self._ensure_emoji_collection() + # Ensure Peewee database connection is up and tables are created + if not peewee_db.is_closed(): + peewee_db.connect(reuse_if_open=True) + Emoji.create_table(safe=True) # Ensures table exists + _ensure_emoji_dir() self._initialized = True - # 更新表情包数量 - # 启动时执行一次完整性检查 - # await self.check_emoji_file_integrity() except Exception as e: logger.exception(f"初始化表情管理器失败: {e}") @@ -401,33 +399,15 @@ class EmojiManager: if not self._initialized: raise RuntimeError("EmojiManager not initialized") - @staticmethod - def _ensure_emoji_collection(): - """确保emoji集合存在并创建索引 - - 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。 - - 索引的作用是加快数据库查询速度: - - embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包 - - tags字段的普通索引: 加快按标签搜索表情包的速度 - - filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度 - - 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 - """ - if "emoji" not in db.list_collection_names(): - db.create_collection("emoji") - db.emoji.create_index([("embedding", "2dsphere")]) - db.emoji.create_index([("filename", 1)], unique=True) - def record_usage(self, emoji_hash: str): """记录表情使用次数""" try: - db.emoji.update_one({"hash": emoji_hash}, {"$inc": {"usage_count": 1}}) - for emoji in self.emoji_objects: - if emoji.hash == emoji_hash: - emoji.usage_count += 1 - break - + emoji_update = Emoji.get(Emoji.hash == emoji_hash) + emoji_update.usage_count += 1 + emoji_update.last_used_time = time.time() # Update last used time + emoji_update.save() # Persist changes to DB + except Emoji.DoesNotExist: + logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") @@ -657,9 +637,10 @@ class EmojiManager: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: self._ensure_db() - logger.info("[数据库] 开始加载所有表情包记录...") + logger.info("[数据库] 开始加载所有表情包记录 (Peewee)...") - emoji_objects, load_errors = _to_emoji_objects(db.emoji.find()) + emoji_peewee_instances = Emoji.select() + emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) # 更新内存中的列表和数量 self.emoji_objects = emoji_objects @@ -686,15 +667,16 @@ class EmojiManager: try: self._ensure_db() - query = {} if emoji_hash: - query = {"hash": emoji_hash} + query = Emoji.select().where(Emoji.hash == emoji_hash) else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) - - emoji_objects, load_errors = _to_emoji_objects(db.emoji.find(query)) + query = Emoji.select() + + emoji_peewee_instances = query + emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) if load_errors > 0: logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。") @@ -908,6 +890,44 @@ class EmojiManager: logger.error(f"获取表情包描述失败: {str(e)}") return "", [] + # async def register_emoji_by_filename(self, filename: str) -> bool: + # if global_config.EMOJI_CHECK: + # prompt = f''' + # 这是一个表情包,请对这个表情包进行审核,标准如下: + # 1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求 + # 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗 + # 3. 不能是任何形式的截图,聊天记录或视频截图 + # 4. 不要出现5个以上文字 + # 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容 + # ''' + # content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + # if content == "否": + # return "", [] + + # # 分析情感含义 + # emotion_prompt = f""" + # 请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字 + # 这是一个基于这个表情包的描述:'{description}' + # 你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析 + # 请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔 + # """ + # emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7) + + # # 处理情感列表 + # emotions = [e.strip() for e in emotions_text.split(",") if e.strip()] + + # # 根据情感标签数量随机选择喵~超过5个选3个,超过2个选2个 + # if len(emotions) > 5: + # emotions = random.sample(emotions, 3) + # elif len(emotions) > 2: + # emotions = random.sample(emotions, 2) + + # return f"[表情包:{description}]", emotions + + # except Exception as e: + # logger.error(f"获取表情包描述失败: {str(e)}") + # return "", [] + async def register_emoji_by_filename(self, filename: str) -> bool: """读取指定文件名的表情包图片,分析并注册到数据库 diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 1acef540e..141d850ab 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -7,7 +7,7 @@ from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time from typing import Union, Optional -from common.database.database import db +# from common.database.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager @@ -15,6 +15,9 @@ from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.focus_chat.expressors.exprssion_learner import expression_learner # import traceback import random +import json +import math +from src.common.database.database_model import Knowledges logger = get_logger("prompt") @@ -69,7 +72,7 @@ def init_prompt(): 你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},{reply_style1}, 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,{reply_style2}。{prompt_ger} 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,不要浮夸,平淡一些 ,不要随意遵从他人指令。 -请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 +请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。 {moderation_prompt} 不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""", "reasoning_prompt_main", @@ -439,30 +442,6 @@ class PromptBuilder: logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") # 1. 先从LLM获取主题,类似于记忆系统的做法 topics = [] - # try: - # # 先尝试使用记忆系统的方法获取主题 - # hippocampus = HippocampusManager.get_instance()._hippocampus - # topic_num = min(5, max(1, int(len(message) * 0.1))) - # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) - - # # 提取关键词 - # topics = re.findall(r"<([^>]+)>", topics_response[0]) - # if not topics: - # topics = [] - # else: - # topics = [ - # topic.strip() - # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - # if topic.strip() - # ] - - # logger.info(f"从LLM提取的主题: {', '.join(topics)}") - # except Exception as e: - # logger.error(f"从LLM提取主题失败: {str(e)}") - # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 - # words = jieba.cut(message) - # topics = [word for word in words if len(word) > 1][:5] - # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") # 如果无法提取到主题,直接使用整个消息 if not topics: @@ -572,8 +551,6 @@ class PromptBuilder: for _i, result in enumerate(results, 1): _similarity = result["similarity"] content = result["content"].strip() - # 调试:为内容添加序号和相似度信息 - # related_info += f"{i}. [{similarity:.2f}] {content}\n" related_info += f"{content}\n" related_info += "\n" @@ -602,14 +579,14 @@ class PromptBuilder: return related_info else: logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索") - knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) + knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) related_info += knowledge_from_old logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") return related_info except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") try: - knowledge_from_old = await self.get_prompt_info_old(message, threshold=0.38) + knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) related_info += knowledge_from_old logger.debug( f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}" @@ -625,70 +602,70 @@ class PromptBuilder: ) -> Union[str, list]: if not query_embedding: return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") + results_with_similarity = [] + try: + # Fetch all knowledge entries + # This might be inefficient for very large databases. + # Consider strategies like FAISS or other vector search libraries if performance becomes an issue. + all_knowledges = Knowledges.select() - if not results: + if not all_knowledges: + return "" if not return_raw else [] + + query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding)) + if query_embedding_magnitude == 0: # Avoid division by zero + return "" if not return_raw else [] + + for knowledge_item in all_knowledges: + try: + db_embedding_str = knowledge_item.embedding + db_embedding = json.loads(db_embedding_str) + + if len(db_embedding) != len(query_embedding): + logger.warning(f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping.") + continue + + # Calculate Cosine Similarity + dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding)) + db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding)) + + if db_embedding_magnitude == 0: # Avoid division by zero + similarity = 0.0 + else: + similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude) + + if similarity >= threshold: + results_with_similarity.append({ + "content": knowledge_item.content, + "similarity": similarity + }) + except json.JSONDecodeError: + logger.error(f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}") + except Exception as e: + logger.error(f"Error processing knowledge item: {e}") + + + # Sort by similarity in descending order + results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True) + + # Limit results + limited_results = results_with_similarity[:limit] + + logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}") + + if not limited_results: + return "" if not return_raw else [] + + if return_raw: + return limited_results + else: + return "\n".join(str(result["content"]) for result in limited_results) + + except Exception as e: + logger.error(f"Error querying Knowledges with Peewee: {e}") return "" if not return_raw else [] - if return_raw: - return results - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - def weighted_sample_no_replacement(items, weights, k) -> list: """ diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index e64475126..78616d824 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -10,7 +10,7 @@ import jieba import networkx as nx import numpy as np from collections import Counter -from ...common.database.database import db +from ...common.database.database import memory_db as db from ...chat.models.utils_model import LLMRequest from src.common.logger_manager import get_logger from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 7f41ac96b..723d6da47 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -6,6 +6,7 @@ from typing import Dict, Optional from ...common.database.database import db +from ...common.database.database_model import ChatStreams # 新增导入 from maim_message import GroupInfo, UserInfo from src.common.logger_manager import get_logger @@ -82,7 +83,13 @@ class ChatManager: def __init__(self): if not self._initialized: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream - self._ensure_collection() + try: + db.connect(reuse_if_open=True) + # 确保 ChatStreams 表存在 + db.create_tables([ChatStreams], safe=True) + except Exception as e: + logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") + self._initialized = True # 在事件循环中启动初始化 # asyncio.create_task(self._initialize()) @@ -107,15 +114,6 @@ class ChatManager: except Exception as e: logger.error(f"聊天流自动保存失败: {str(e)}") - @staticmethod - def _ensure_collection(): - """确保数据库集合存在并创建索引""" - if "chat_streams" not in db.list_collection_names(): - db.create_collection("chat_streams") - # 创建索引 - db.chat_streams.create_index([("stream_id", 1)], unique=True) - db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]) - @staticmethod def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: """生成聊天流唯一ID""" @@ -151,16 +149,43 @@ class ChatManager: stream = self.streams[stream_id] # 更新用户信息和群组信息 stream.update_active_time() - stream = copy.deepcopy(stream) + stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存 stream.user_info = user_info if group_info: stream.group_info = group_info return stream # 检查数据库中是否存在 - data = db.chat_streams.find_one({"stream_id": stream_id}) - if data: - stream = ChatStream.from_dict(data) + def _db_find_stream_sync(s_id: str): + return ChatStreams.get_or_none(ChatStreams.stream_id == s_id) + + model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) + + if model_instance: + # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息 + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + stream = ChatStream.from_dict(data_for_from_dict) # 更新用户信息和群组信息 stream.user_info = user_info if group_info: @@ -175,7 +200,7 @@ class ChatManager: group_info=group_info, ) except Exception as e: - logger.error(f"创建聊天流失败: {e}") + logger.error(f"获取或创建聊天流失败: {e}", exc_info=True) raise e # 保存到内存和数据库 @@ -205,15 +230,38 @@ class ChatManager: elif stream.user_info and stream.user_info.user_nickname: return f"{stream.user_info.user_nickname}的私聊" else: - # 如果没有群名或用户昵称,返回 None 或其他默认值 return None @staticmethod async def _save_stream(stream: ChatStream): """保存聊天流到数据库""" if not stream.saved: - db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True) - stream.saved = True + stream_data_dict = stream.to_dict() + + def _db_save_stream_sync(s_data_dict: dict): + user_info_d = s_data_dict.get("user_info") + group_info_d = s_data_dict.get("group_info") + + fields_to_save = { + "platform": s_data_dict["platform"], + "create_time": s_data_dict["create_time"], + "last_active_time": s_data_dict["last_active_time"], + "user_platform": user_info_d["platform"] if user_info_d else "", + "user_id": user_info_d["user_id"] if user_info_d else "", + "user_nickname": user_info_d["user_nickname"] if user_info_d else "", + "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, + "group_platform": group_info_d["platform"] if group_info_d else "", + "group_id": group_info_d["group_id"] if group_info_d else "", + "group_name": group_info_d["group_name"] if group_info_d else "", + } + + ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute() + + try: + await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) + stream.saved = True + except Exception as e: + logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) async def _save_all_streams(self): """保存所有聊天流""" @@ -222,10 +270,44 @@ class ChatManager: async def load_all_streams(self): """从数据库加载所有聊天流""" - all_streams = db.chat_streams.find({}) - for data in all_streams: - stream = ChatStream.from_dict(data) - self.streams[stream.stream_id] = stream + + def _db_load_all_streams_sync(): + loaded_streams_data = [] + for model_instance in ChatStreams.select(): + user_info_data = { + "platform": model_instance.user_platform, + "user_id": model_instance.user_id, + "user_nickname": model_instance.user_nickname, + "user_cardname": model_instance.user_cardname or "", + } + group_info_data = None + if model_instance.group_id: + group_info_data = { + "platform": model_instance.group_platform, + "group_id": model_instance.group_id, + "group_name": model_instance.group_name, + } + + data_for_from_dict = { + "stream_id": model_instance.stream_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.create_time, + "last_active_time": model_instance.last_active_time, + } + loaded_streams_data.append(data_for_from_dict) + return loaded_streams_data + + try: + all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync) + self.streams.clear() + for data in all_streams_data_list: + stream = ChatStream.from_dict(data) + stream.saved = True + self.streams[stream.stream_id] = stream + except Exception as e: + logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True) # 创建全局单例 diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py index 9ca4e56d0..986036e86 100644 --- a/src/chat/models/utils_model.py +++ b/src/chat/models/utils_model.py @@ -12,7 +12,8 @@ import base64 from PIL import Image import io import os -from ...common.database.database import db +from src.common.database.database import db # 确保 db 被导入用于 create_tables +from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 from ...config.config import global_config from rich.traceback import install @@ -85,8 +86,6 @@ async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," f"{image_base64[:10]}...{image_base64[-10:]}" ) - # if isinstance(content, str) and len(content) > 100: - # payload["messages"][0]["content"] = content[:100] return payload @@ -134,13 +133,11 @@ class LLMRequest: def _init_database(): """初始化数据库集合""" try: - # 创建llm_usage集合的索引 - db.llm_usage.create_index([("timestamp", 1)]) - db.llm_usage.create_index([("model_name", 1)]) - db.llm_usage.create_index([("user_id", 1)]) - db.llm_usage.create_index([("request_type", 1)]) + # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 + db.create_tables([LLMUsage], safe=True) + logger.info("LLMUsage 表已初始化/确保存在。") except Exception as e: - logger.error(f"创建数据库索引失败: {str(e)}") + logger.error(f"创建 LLMUsage 表失败: {str(e)}") def _record_usage( self, @@ -165,19 +162,19 @@ class LLMRequest: request_type = self.request_type try: - usage_data = { - "model_name": self.model_name, - "user_id": user_id, - "request_type": request_type, - "endpoint": endpoint, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "cost": self._calculate_cost(prompt_tokens, completion_tokens), - "status": "success", - "timestamp": datetime.now(), - } - db.llm_usage.insert_one(usage_data) + # 使用 Peewee 模型创建记录 + LLMUsage.create( + model_name=self.model_name, + user_id=user_id, + request_type=request_type, + endpoint=endpoint, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cost=self._calculate_cost(prompt_tokens, completion_tokens), + status="success", + timestamp=datetime.now(), # Peewee 会处理 DateTimeField + ) logger.trace( f"Token使用情况 - 模型: {self.model_name}, " f"用户: {user_id}, 类型: {request_type}, " diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index 00cbe86f1..cd9034d6f 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -1,5 +1,6 @@ from src.common.logger_manager import get_logger from ...common.database.database import db +from ...common.database.database_model import PersonInfo # 新增导入 import copy import hashlib from typing import Any, Callable, Dict @@ -16,7 +17,7 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt from pathlib import Path import pandas as pd -import json +import json # 新增导入 import re @@ -43,17 +44,13 @@ person_info_default = { "platform": None, "user_id": None, "nickname": None, - # "age" : 0, "relationship_value": 0, - # "saved" : True, - # "impression" : None, - # "gender" : Unkown, "konw_time": 0, "msg_interval": 2000, - "msg_interval_list": [], - "user_cardname": None, # 添加群名片 - "user_avatar": None, # 添加头像信息(例如URL或标识符) -} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 + "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField + "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中 + "user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中 +} class PersonInfoManager: @@ -64,21 +61,26 @@ class PersonInfoManager: max_tokens=256, request_type="qv_name", ) - if "person_info" not in db.list_collection_names(): - db.create_collection("person_info") - db.person_info.create_index("person_id", unique=True) + try: + db.connect(reuse_if_open=True) + db.create_tables([PersonInfo], safe=True) + except Exception as e: + logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") # 初始化时读取所有person_name - cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0}) - for doc in cursor: - if doc.get("person_name"): - self.person_name_list[doc["person_id"]] = doc["person_name"] - logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") + try: + for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where( + PersonInfo.person_name.is_null(False) + ): + if record.person_name: + self.person_name_list[record.person_id] = record.person_name + logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)") + except Exception as e: + logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") @staticmethod def get_person_id(platform: str, user_id: int): """获取唯一id""" - # 如果platform中存在-,就截取-后面的部分 if "-" in platform: platform = platform.split("-")[1] @@ -86,13 +88,17 @@ class PersonInfoManager: key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() - def is_person_known(self, platform: str, user_id: int): + async def is_person_known(self, platform: str, user_id: int): """判断是否认识某人""" person_id = self.get_person_id(platform, user_id) - document = db.person_info.find_one({"person_id": person_id}) - if document: - return True - else: + + def _db_check_known_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None + + try: + return await asyncio.to_thread(_db_check_known_sync, person_id) + except Exception as e: + logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}") return False @staticmethod @@ -103,73 +109,111 @@ class PersonInfoManager: return _person_info_default = copy.deepcopy(person_info_default) - _person_info_default["person_id"] = person_id + model_fields = PersonInfo._meta.fields.keys() + + final_data = {"person_id": person_id} if data: - for key in _person_info_default: - if key != "person_id" and key in data: - _person_info_default[key] = data[key] + for key, value in data.items(): + if key in model_fields: + final_data[key] = value - db.person_info.insert_one(_person_info_default) + for key, default_value in _person_info_default.items(): + if key in model_fields and key not in final_data: + final_data[key] = default_value + + if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list): + final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"]) + elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields: + final_data["msg_interval_list"] = json.dumps([]) + + def _db_create_sync(p_data: dict): + try: + PersonInfo.create(**p_data) + return True + except Exception as e: + logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}") + return False + + await asyncio.to_thread(_db_create_sync, final_data) async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None): """更新某一个字段,会补全""" - if field_name not in person_info_default.keys(): - logger.debug(f"更新'{field_name}'失败,未定义的字段") + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。") + return + logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。") return - document = db.person_info.find_one({"person_id": person_id}) + def _db_update_sync(p_id: str, f_name: str, val): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + if f_name == "msg_interval_list" and isinstance(val, list): + setattr(record, f_name, json.dumps(val)) + else: + setattr(record, f_name, val) + record.save() + return True, False + return False, True - if document: - db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}}) - else: - data[field_name] = value - logger.debug(f"更新时{person_id}不存在,已新建") - await self.create_person_info(person_id, data) + found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value) + + if needs_creation: + logger.debug(f"更新时 {person_id} 不存在,将新建。") + creation_data = data if data is not None else {} + creation_data[field_name] = value + if "platform" not in creation_data or "user_id" not in creation_data: + logger.warning(f"为 {person_id} 创建记录时,platform/user_id 可能缺失。") + + await self.create_person_info(person_id, creation_data) @staticmethod async def has_one_field(person_id: str, field_name: str): """判断是否存在某一个字段""" - document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) - if document: - return True - else: + if field_name not in PersonInfo._meta.fields: + logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。") + return False + + def _db_has_field_sync(p_id: str, f_name: str): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + return True + return False + + try: + return await asyncio.to_thread(_db_has_field_sync, person_id, field_name) + except Exception as e: + logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}") return False @staticmethod def _extract_json_from_text(text: str) -> dict: """从文本中提取JSON数据的高容错方法""" try: - # 尝试直接解析 parsed_json = json.loads(text) - # 如果解析结果是列表,尝试取第一个元素 if isinstance(parsed_json, list): - if parsed_json: # 检查列表是否为空 + if parsed_json: parsed_json = parsed_json[0] - else: # 如果列表为空,重置为 None,走后续逻辑 + else: parsed_json = None - # 确保解析结果是字典 if isinstance(parsed_json, dict): return parsed_json except json.JSONDecodeError: - # 解析失败,继续尝试其他方法 pass except Exception as e: logger.warning(f"尝试直接解析JSON时发生意外错误: {e}") - pass # 继续尝试其他方法 + pass - # 如果直接解析失败或结果不是字典 try: - # 尝试找到JSON对象格式的部分 json_pattern = r"\{[^{}]*\}" matches = re.findall(json_pattern, text) if matches: parsed_obj = json.loads(matches[0]) - if isinstance(parsed_obj, dict): # 确保是字典 + if isinstance(parsed_obj, dict): return parsed_obj - # 如果上面都失败了,尝试提取键值对 nickname_pattern = r'"nickname"[:\s]+"([^"]+)"' reason_pattern = r'"reason"[:\s]+"([^"]+)"' @@ -184,7 +228,6 @@ class PersonInfoManager: except Exception as e: logger.error(f"后备JSON提取失败: {str(e)}") - # 如果所有方法都失败了,返回默认字典 logger.warning(f"无法从文本中提取有效的JSON字典: {text}") return {"nickname": "", "reason": ""} @@ -199,9 +242,11 @@ class PersonInfoManager: old_name = await self.get_value(person_id, "person_name") old_reason = await self.get_value(person_id, "name_reason") - max_retries = 5 # 最大重试次数 + max_retries = 5 current_try = 0 - existing_names = "" + existing_names_str = "" + current_name_set = set(self.person_name_list.values()) + while current_try < max_retries: individuality = Individuality.get_instance() prompt_personality = individuality.get_prompt(x_person=2, level=1) @@ -216,45 +261,55 @@ class PersonInfoManager: qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}," qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸" + qv_name_prompt += "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改" + + if existing_names_str: + qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n" + + if len(current_name_set) < 50 and current_name_set: + qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n" - qv_name_prompt += ( - "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改" - ) - if existing_names: - qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n" qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:" qv_name_prompt += """{ "nickname": "昵称", "reason": "理由" }""" - # logger.debug(f"取名提示词:{qv_name_prompt}") response = await self.qv_name_llm.generate_response(qv_name_prompt) logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") result = self._extract_json_from_text(response[0]) - if not result["nickname"]: - logger.error("生成的昵称为空,重试中...") + if not result or not result.get("nickname"): + logger.error("生成的昵称为空或结果格式不正确,重试中...") current_try += 1 continue - # 检查生成的昵称是否已存在 - if result["nickname"] not in self.person_name_list.values(): - # 更新数据库和内存中的列表 - await self.update_one_field(person_id, "person_name", result["nickname"]) - # await self.update_one_field(person_id, "nickname", user_nickname) - # await self.update_one_field(person_id, "avatar", user_avatar) - await self.update_one_field(person_id, "name_reason", result["reason"]) + generated_nickname = result["nickname"] - self.person_name_list[person_id] = result["nickname"] - # logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}") + is_duplicate = False + if generated_nickname in current_name_set: + is_duplicate = True + else: + def _db_check_name_exists_sync(name_to_check): + return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists() + + if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): + is_duplicate = True + current_name_set.add(generated_nickname) + + if not is_duplicate: + await self.update_one_field(person_id, "person_name", generated_nickname) + await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由")) + + self.person_name_list[person_id] = generated_nickname return result else: - existing_names += f"{result['nickname']}、" + if existing_names_str: + existing_names_str += "、" + existing_names_str += generated_nickname + logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...") + current_try += 1 - logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...") - current_try += 1 - - logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称") + logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}") return None @staticmethod @@ -264,30 +319,56 @@ class PersonInfoManager: logger.debug("删除失败:person_id 不能为空") return - result = db.person_info.delete_one({"person_id": person_id}) - if result.deleted_count > 0: - logger.debug(f"删除成功:person_id={person_id}") + def _db_delete_sync(p_id: str): + try: + query = PersonInfo.delete().where(PersonInfo.person_id == p_id) + deleted_count = query.execute() + return deleted_count + except Exception as e: + logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}") + return 0 + + deleted_count = await asyncio.to_thread(_db_delete_sync, person_id) + + if deleted_count > 0: + logger.debug(f"删除成功:person_id={person_id} (Peewee)") else: - logger.debug(f"删除失败:未找到 person_id={person_id}") + logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)") @staticmethod async def get_value(person_id: str, field_name: str): """获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值""" if not person_id: logger.debug("get_value获取失败:person_id不能为空") + return person_info_default.get(field_name) + + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + logger.trace(f"字段'{field_name}'不在Peewee模型中,但存在于默认配置中。返回配置默认值。") + return copy.deepcopy(person_info_default[field_name]) + logger.debug(f"get_value获取失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") return None - if field_name not in person_info_default: - logger.debug(f"get_value获取失败:字段'{field_name}'未定义") + def _db_get_value_sync(p_id: str, f_name: str): + record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) + if record: + val = getattr(record, f_name) + if f_name == "msg_interval_list" and isinstance(val, str): + try: + return json.loads(val) + except json.JSONDecodeError: + logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}") + return copy.deepcopy(person_info_default.get(f_name, [])) + return val return None - document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) + value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name) - if document and field_name in document: - return document[field_name] + if value is not None: + return value else: - default_value = copy.deepcopy(person_info_default[field_name]) - logger.trace(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}") + default_value = copy.deepcopy(person_info_default.get(field_name)) + logger.trace(f"获取{person_id}的{field_name}失败或值为None,已返回默认值{default_value} (Peewee)") return default_value @staticmethod @@ -297,93 +378,82 @@ class PersonInfoManager: logger.debug("get_values获取失败:person_id不能为空") return {} - # 检查所有字段是否有效 - for field in field_names: - if field not in person_info_default: - logger.debug(f"get_values获取失败:字段'{field}'未定义") - return {} - - # 构建查询投影(所有字段都有效才会执行到这里) - projection = {field: 1 for field in field_names} - - document = db.person_info.find_one({"person_id": person_id}, projection) - result = {} - for field in field_names: - result[field] = copy.deepcopy( - document.get(field, person_info_default[field]) if document else person_info_default[field] - ) + + def _db_get_record_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) + + record = await asyncio.to_thread(_db_get_record_sync, person_id) + + for field_name in field_names: + if field_name not in PersonInfo._meta.fields: + if field_name in person_info_default: + result[field_name] = copy.deepcopy(person_info_default[field_name]) + logger.trace(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。") + else: + logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") + result[field_name] = None + continue + + if record: + value = getattr(record, field_name) + if field_name == "msg_interval_list" and isinstance(value, str): + try: + result[field_name] = json.loads(value) + except json.JSONDecodeError: + logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}") + result[field_name] = copy.deepcopy(person_info_default.get(field_name, [])) + elif value is not None: + result[field_name] = value + else: + result[field_name] = copy.deepcopy(person_info_default.get(field_name)) + else: + result[field_name] = copy.deepcopy(person_info_default.get(field_name)) return result @staticmethod async def del_all_undefined_field(): - """删除所有项里的未定义字段""" - # 获取所有已定义的字段名 - defined_fields = set(person_info_default.keys()) - - try: - # 遍历集合中的所有文档 - for document in db.person_info.find({}): - # 找出文档中未定义的字段 - undefined_fields = set(document.keys()) - defined_fields - {"_id"} - - if undefined_fields: - # 构建更新操作,使用$unset删除未定义字段 - update_result = db.person_info.update_one( - {"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}} - ) - - if update_result.modified_count > 0: - logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}") - - return - - except Exception as e: - logger.error(f"清理未定义字段时出错: {e}") - return + """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。""" + logger.info("del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。") + return @staticmethod async def get_specific_value_list( field_name: str, - way: Callable[[Any], bool], # 接受任意类型值 + way: Callable[[Any], bool], ) -> Dict[str, Any]: """ 获取满足条件的字段值字典 - - Args: - field_name: 目标字段名 - way: 判断函数 (value: Any) -> bool - - Returns: - {person_id: value} | {} - - Example: - # 查找所有nickname包含"admin"的用户 - result = manager.specific_value_list( - "nickname", - lambda x: "admin" in x.lower() - ) """ - if field_name not in person_info_default: - logger.error(f"字段检查失败:'{field_name}'未定义") + if field_name not in PersonInfo._meta.fields: + logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义") return {} + def _db_get_specific_sync(f_name: str): + found_results = {} + try: + for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)): + value = getattr(record, f_name) + if f_name == "msg_interval_list" and isinstance(value, str): + try: + processed_value = json.loads(value) + except json.JSONDecodeError: + logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}") + continue + else: + processed_value = value + + if way(processed_value): + found_results[record.person_id] = processed_value + except Exception as e_query: + logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True) + return found_results + try: - result = {} - for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}): - try: - value = doc[field_name] - if way(value): - result[doc["person_id"]] = value - except (KeyError, TypeError, ValueError) as e: - logger.debug(f"记录{doc.get('person_id')}处理失败: {str(e)}") - continue - - return result - + return await asyncio.to_thread(_db_get_specific_sync, field_name) except Exception as e: - logger.error(f"数据库查询失败: {str(e)}", exc_info=True) + logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True) return {} async def personal_habit_deduction(self): @@ -391,35 +461,31 @@ class PersonInfoManager: try: while 1: await asyncio.sleep(600) - current_time = datetime.datetime.now() - logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") + current_time_dt = datetime.datetime.now() + logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}") - # "msg_interval"推断 - msg_interval_map = False - msg_interval_lists = await self.get_specific_value_list( + msg_interval_map_generated = False + msg_interval_lists_map = await self.get_specific_value_list( "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100 ) - for person_id, msg_interval_list_ in msg_interval_lists.items(): + + for person_id, actual_msg_interval_list in msg_interval_lists_map.items(): await asyncio.sleep(0.3) try: time_interval = [] - for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]): + for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]): delta = t2 - t1 if delta > 0: time_interval.append(delta) time_interval = [t for t in time_interval if 200 <= t <= 8000] - # --- 修改后的逻辑 --- - # 数据量检查 (至少需要 30 条有效间隔,并且足够进行头尾截断) - if len(time_interval) >= 30 + 10: # 至少30条有效+头尾各5条 - time_interval.sort() - # 画图(log) - 这部分保留 - msg_interval_map = True + if len(time_interval) >= 30 + 10: + time_interval.sort() + msg_interval_map_generated = True log_dir = Path("logs/person_info") log_dir.mkdir(parents=True, exist_ok=True) plt.figure(figsize=(10, 6)) - # 使用截断前的数据画图,更能反映原始分布 time_series_original = pd.Series(time_interval) plt.hist( time_series_original, @@ -441,34 +507,27 @@ class PersonInfoManager: img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" plt.savefig(img_path) plt.close() - # 画图结束 - # 去掉头尾各 5 个数据点 trimmed_interval = time_interval[5:-5] - - # 计算截断后数据的 37% 分位数 - if trimmed_interval: # 确保截断后列表不为空 - msg_interval = int(round(np.percentile(trimmed_interval, 37))) - # 更新数据库 - await self.update_one_field(person_id, "msg_interval", msg_interval) - logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval}") + if trimmed_interval: + msg_interval_val = int(round(np.percentile(trimmed_interval, 37))) + await self.update_one_field(person_id, "msg_interval", msg_interval_val) + logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}") else: logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval") else: logger.trace( f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)" ) - # --- 修改结束 --- - except Exception as e: - logger.trace(f"用户{person_id}消息间隔计算失败: {type(e).__name__}: {str(e)}") + except Exception as e_inner: + logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}") continue - # 其他... - - if msg_interval_map: + if msg_interval_map_generated: logger.trace("已保存分布图到: logs/person_info") - current_time = datetime.datetime.now() - logger.trace(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") + + current_time_dt_end = datetime.datetime.now() + logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}") await asyncio.sleep(86400) except Exception as e: @@ -481,41 +540,27 @@ class PersonInfoManager: """ 根据 platform 和 user_id 获取 person_id。 如果对应的用户不存在,则使用提供的可选信息创建新用户。 - - Args: - platform: 平台标识 - user_id: 用户在该平台上的ID - nickname: 用户的昵称 (可选,用于创建新用户) - user_cardname: 用户的群名片 (可选,用于创建新用户) - user_avatar: 用户的头像信息 (可选,用于创建新用户) - - Returns: - 对应的 person_id。 """ person_id = self.get_person_id(platform, user_id) - # 检查用户是否已存在 - # 使用静态方法 get_person_id,因此可以直接调用 db - document = db.person_info.find_one({"person_id": person_id}) + def _db_check_exists_sync(p_id: str): + return PersonInfo.get_or_none(PersonInfo.person_id == p_id) - if document is None: - logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。") + record = await asyncio.to_thread(_db_check_exists_sync, person_id) + + if record is None: + logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") initial_data = { "platform": platform, - "user_id": user_id, + "user_id": str(user_id), "nickname": nickname, - "konw_time": int(datetime.datetime.now().timestamp()), # 添加初次认识时间 - # 注意:这里没有添加 user_cardname 和 user_avatar,因为它们不在 person_info_default 中 - # 如果需要存储它们,需要先在 person_info_default 中定义 + "konw_time": int(datetime.datetime.now().timestamp()), } - # 过滤掉值为 None 的初始数据 - initial_data = {k: v for k, v in initial_data.items() if v is not None} + model_fields = PersonInfo._meta.fields.keys() + filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} - # 注意:create_person_info 是静态方法 - await PersonInfoManager.create_person_info(person_id, data=initial_data) - # 创建后,可以考虑立即为其取名,但这可能会增加延迟 - # await self.qv_person_name(person_id, nickname, user_cardname, user_avatar) - logger.debug(f"已为 {person_id} 创建新记录,初始数据: {initial_data}") + await self.create_person_info(person_id, data=filtered_initial_data) + logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") return person_id @@ -525,35 +570,49 @@ class PersonInfoManager: logger.debug("get_person_info_by_name 获取失败:person_name 不能为空") return None - # 优先从内存缓存查找 person_id found_person_id = None - for pid, name in self.person_name_list.items(): - if name == person_name: + for pid, name_in_cache in self.person_name_list.items(): + if name_in_cache == person_name: found_person_id = pid - break # 找到第一个匹配就停止 + break if not found_person_id: - # 如果内存没有,尝试数据库查询(可能内存未及时更新或启动时未加载) - document = db.person_info.find_one({"person_name": person_name}) - if document: - found_person_id = document.get("person_id") - else: - logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户") - return None # 数据库也找不到 + def _db_find_by_name_sync(p_name_to_find: str): + return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find) - # 根据找到的 person_id 获取所需信息 - if found_person_id: - required_fields = ["person_id", "platform", "user_id", "nickname", "user_cardname", "user_avatar"] - person_data = await self.get_values(found_person_id, required_fields) - if person_data: # 确保 get_values 成功返回 - return person_data + record = await asyncio.to_thread(_db_find_by_name_sync, person_name) + if record: + found_person_id = record.person_id + if found_person_id not in self.person_name_list or self.person_name_list[found_person_id] != person_name: + self.person_name_list[found_person_id] = person_name else: - logger.warning(f"找到了 person_id '{found_person_id}' 但获取详细信息失败") + logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)") return None - else: - # 这理论上不应该发生,因为上面已经处理了找不到的情况 - logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id") - return None + + if found_person_id: + required_fields = [ + "person_id", + "platform", + "user_id", + "nickname", + "user_cardname", + "user_avatar", + "person_name", + "name_reason", + ] + valid_fields_to_get = [f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default] + + person_data = await self.get_values(found_person_id, valid_fields_to_get) + + if person_data: + final_result = {key: person_data.get(key) for key in required_fields} + return final_result + else: + logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)") + return None + + logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)") + return None person_info_manager = PersonInfoManager() diff --git a/src/chat/utils/info_catcher.py b/src/chat/utils/info_catcher.py index b7f59c661..fb8182973 100644 --- a/src/chat/utils/info_catcher.py +++ b/src/chat/utils/info_catcher.py @@ -1,9 +1,10 @@ from src.config.config import global_config from src.chat.message_receive.message import MessageRecv, MessageSending, Message -from common.database.database import db +from src.common.database.database_model import Messages, ThinkingLog import time import traceback from typing import List +import json class InfoCatcher: @@ -60,8 +61,6 @@ class InfoCatcher: def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息 self.timing_results["sub_heartflow_observe_time"] = obs_duration - # def catch_shf - def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str): self.timing_results["sub_heartflow_step_time"] = step_duration if len(past_mind) > 1: @@ -72,25 +71,10 @@ class InfoCatcher: self.heartflow_data["sub_heartflow_now"] = current_mind def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""): - # if self.response_mode == "heart_flow": # 条件判断不需要了喵~ - # self.heartflow_data["prompt"] = prompt - # self.heartflow_data["response"] = response - # self.heartflow_data["model"] = model_name - # elif self.response_mode == "reasoning": # 条件判断不需要了喵~ - # self.reasoning_data["thinking_log"] = reasoning_content - # self.reasoning_data["prompt"] = prompt - # self.reasoning_data["response"] = response - # self.reasoning_data["model"] = model_name - - # 直接记录信息喵~ self.reasoning_data["thinking_log"] = reasoning_content self.reasoning_data["prompt"] = prompt self.reasoning_data["response"] = response self.reasoning_data["model"] = model_name - # 如果 heartflow 数据也需要通用字段,可以取消下面的注释喵~ - # self.heartflow_data["prompt"] = prompt - # self.heartflow_data["response"] = response - # self.heartflow_data["model"] = model_name self.response_text = response @@ -102,6 +86,7 @@ class InfoCatcher: ): self.timing_results["make_response_time"] = response_duration self.response_time = time.time() + self.response_messages = [] for msg in response_message: self.response_messages.append(msg) @@ -112,107 +97,110 @@ class InfoCatcher: @staticmethod def get_message_from_db_between_msgs(message_start: Message, message_end: Message): try: - # 从数据库中获取消息的时间戳 time_start = message_start.message_info.time time_end = message_end.message_info.time chat_id = message_start.chat_stream.stream_id print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") - # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 - messages_between = db.messages.find( - {"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}} - ).sort("time", -1) + messages_between_query = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.time > time_start) & + (Messages.time < time_end) + ).order_by(Messages.time.desc()) - result = list(messages_between) + result = list(messages_between_query) print(f"查询结果数量: {len(result)}") if result: - print(f"第一条消息时间: {result[0]['time']}") - print(f"最后一条消息时间: {result[-1]['time']}") + print(f"第一条消息时间: {result[0].time}") + print(f"最后一条消息时间: {result[-1].time}") return result except Exception as e: print(f"获取消息时出错: {str(e)}") + print(traceback.format_exc()) return [] def get_message_from_db_before_msg(self, message: MessageRecv): - # 从数据库中获取消息 - message_id = message.message_info.message_id - chat_id = message.chat_stream.stream_id + message_id_val = message.message_info.message_id + chat_id_val = message.chat_stream.stream_id - # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 - messages_before = ( - db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) - .sort("time", -1) - .limit(self.context_length * 3) - ) # 获取更多历史信息 + messages_before_query = Messages.select().where( + (Messages.chat_id == chat_id_val) & + (Messages.message_id < message_id_val) + ).order_by(Messages.time.desc()).limit(self.context_length * 3) - return list(messages_before) + return list(messages_before_query) def message_list_to_dict(self, message_list): - # 存储简化的聊天记录 result = [] - for message in message_list: - if not isinstance(message, dict): - message = self.message_to_dict(message) - # print(message) + for msg_item in message_list: + processed_msg_item = msg_item + if not isinstance(msg_item, dict): + processed_msg_item = self.message_to_dict(msg_item) + + if not processed_msg_item: + continue lite_message = { - "time": message["time"], - "user_nickname": message["user_info"]["user_nickname"], - "processed_plain_text": message["processed_plain_text"], + "time": processed_msg_item.get("time"), + "user_nickname": processed_msg_item.get("user_nickname"), + "processed_plain_text": processed_msg_item.get("processed_plain_text"), } result.append(lite_message) - return result @staticmethod - def message_to_dict(message): - if not message: + def message_to_dict(msg_obj): + if not msg_obj: return None - if isinstance(message, dict): - return message - return { - # "message_id": message.message_info.message_id, - "time": message.message_info.time, - "user_id": message.message_info.user_info.user_id, - "user_nickname": message.message_info.user_info.user_nickname, - "processed_plain_text": message.processed_plain_text, - # "detailed_plain_text": message.detailed_plain_text - } + if isinstance(msg_obj, dict): + return msg_obj + + if isinstance(msg_obj, Messages): + return { + "time": msg_obj.time, + "user_id": msg_obj.user_id, + "user_nickname": msg_obj.user_nickname, + "processed_plain_text": msg_obj.processed_plain_text, + } + + if hasattr(msg_obj, 'message_info') and hasattr(msg_obj.message_info, 'user_info'): + return { + "time": msg_obj.message_info.time, + "user_id": msg_obj.message_info.user_info.user_id, + "user_nickname": msg_obj.message_info.user_info.user_nickname, + "processed_plain_text": msg_obj.processed_plain_text, + } + + print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}") + return {} def done_catch(self): - """将收集到的信息存储到数据库的 thinking_log 集合中喵~""" + """将收集到的信息存储到数据库的 thinking_log 表中喵~""" try: - # 将消息对象转换为可序列化的字典喵~ - - thinking_log_data = { - "chat_id": self.chat_id, - "trigger_text": self.trigger_response_text, - "response_text": self.response_text, - "trigger_info": { - "time": self.trigger_response_time, - "message": self.message_to_dict(self.trigger_response_message), - }, - "response_info": { - "time": self.response_time, - "message": self.response_messages, - }, - "timing_results": self.timing_results, - "chat_history": self.message_list_to_dict(self.chat_history), - "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking), - "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response), - "heartflow_data": self.heartflow_data, - "reasoning_data": self.reasoning_data, + trigger_info_dict = self.message_to_dict(self.trigger_response_message) + response_info_dict = { + "time": self.response_time, + "message": self.response_messages, } + chat_history_list = self.message_list_to_dict(self.chat_history) + chat_history_in_thinking_list = self.message_list_to_dict(self.chat_history_in_thinking) + chat_history_after_response_list = self.message_list_to_dict(self.chat_history_after_response) - # 根据不同的响应模式添加相应的数据喵~ # 现在直接都加上去好了喵~ - # if self.response_mode == "heart_flow": - # thinking_log_data["mode_specific_data"] = self.heartflow_data - # elif self.response_mode == "reasoning": - # thinking_log_data["mode_specific_data"] = self.reasoning_data - - # 将数据插入到 thinking_log 集合中喵~ - db.thinking_log.insert_one(thinking_log_data) + log_entry = ThinkingLog( + chat_id=self.chat_id, + trigger_text=self.trigger_response_text, + response_text=self.response_text, + trigger_info_json=json.dumps(trigger_info_dict) if trigger_info_dict else None, + response_info_json=json.dumps(response_info_dict), + timing_results_json=json.dumps(self.timing_results), + chat_history_json=json.dumps(chat_history_list), + chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list), + chat_history_after_response_json=json.dumps(chat_history_after_response_list), + heartflow_data_json=json.dumps(self.heartflow_data), + reasoning_data_json=json.dumps(self.reasoning_data) + ) + log_entry.save() return True except Exception as e: diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 4bcf6fea0..9a0131f74 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -5,7 +5,8 @@ from typing import Any, Dict, Tuple, List from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask -from ...common.database.database import db +from ...common.database.database import db # This db is the Peewee database instance +from ...common.database.database_model import OnlineTime # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_module_logger("maibot_statistic") @@ -39,7 +40,7 @@ class OnlineTimeRecordTask(AsyncTask): def __init__(self): super().__init__(task_name="Online Time Record Task", run_interval=60) - self.record_id: str | None = None + self.record_id: int | None = None # Changed to int for Peewee's default ID """记录ID""" self._init_database() # 初始化数据库 @@ -47,53 +48,46 @@ class OnlineTimeRecordTask(AsyncTask): @staticmethod def _init_database(): """初始化数据库""" - if "online_time" not in db.list_collection_names(): - # 初始化数据库(在线时长) - db.create_collection("online_time") - # 创建索引 - if ("end_timestamp", 1) not in db.online_time.list_indexes(): - db.online_time.create_index([("end_timestamp", 1)]) + with db.atomic(): # Use atomic operations for schema changes + OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model async def run(self): try: + current_time = datetime.now() + extended_end_time = current_time + timedelta(minutes=1) + if self.record_id: # 如果有记录,则更新结束时间 - db.online_time.update_one( - {"_id": self.record_id}, - { - "$set": { - "end_timestamp": datetime.now() + timedelta(minutes=1), - } - }, - ) - else: + query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) + updated_rows = query.execute() + if updated_rows == 0: + # Record might have been deleted or ID is stale, try to find/create + self.record_id = None # Reset record_id to trigger find/create logic below + + if not self.record_id: # Check again if record_id was reset or initially None # 如果没有记录,检查一分钟以内是否已有记录 - current_time = datetime.now() - if recent_record := db.online_time.find_one( - {"end_timestamp": {"$gte": current_time - timedelta(minutes=1)}} - ): + # Look for a record whose end_timestamp is recent enough to be considered ongoing + recent_record = OnlineTime.select().where( + OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)) + ).order_by(OnlineTime.end_timestamp.desc()).first() + + if recent_record: # 如果有记录,则更新结束时间 - self.record_id = recent_record["_id"] - db.online_time.update_one( - {"_id": self.record_id}, - { - "$set": { - "end_timestamp": current_time + timedelta(minutes=1), - } - }, - ) + self.record_id = recent_record.id + recent_record.end_timestamp = extended_end_time + recent_record.save() else: # 若没有记录,则插入新的在线时间记录 - self.record_id = db.online_time.insert_one( - { - "start_timestamp": current_time, - "end_timestamp": current_time + timedelta(minutes=1), - } - ).inserted_id + new_record = OnlineTime.create( + start_timestamp=current_time, + end_timestamp=extended_end_time, + ) + self.record_id = new_record.id except Exception as e: logger.error(f"在线时间记录失败,错误信息:{e}") + def _format_online_time(online_seconds: int) -> str: """ 格式化在线时间 diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 6fbafc905..11e7bed06 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -9,6 +9,7 @@ import numpy as np from ...common.database.database import db +from ...common.database.database_model import Images, ImageDescriptions from ...config.config import global_config from ..models.utils_model import LLMRequest @@ -32,40 +33,21 @@ class ImageManager: def __init__(self): if not self._initialized: - self._ensure_image_collection() - self._ensure_description_collection() self._ensure_image_dir() - self._initialized = True self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image") + + try: + db.connect(reuse_if_open=True) + db.create_tables([Images, ImageDescriptions], safe=True) + except Exception as e: + logger.error(f"数据库连接或表创建失败: {e}") + + self._initialized = True def _ensure_image_dir(self): """确保图像存储目录存在""" os.makedirs(self.IMAGE_DIR, exist_ok=True) - @staticmethod - def _ensure_image_collection(): - """确保images集合存在并创建索引""" - if "images" not in db.list_collection_names(): - db.create_collection("images") - - # 删除旧索引 - db.images.drop_indexes() - # 创建新的复合索引 - db.images.create_index([("hash", 1), ("type", 1)], unique=True) - db.images.create_index([("url", 1)]) - db.images.create_index([("path", 1)]) - - @staticmethod - def _ensure_description_collection(): - """确保image_descriptions集合存在并创建索引""" - if "image_descriptions" not in db.list_collection_names(): - db.create_collection("image_descriptions") - - # 删除旧索引 - db.image_descriptions.drop_indexes() - # 创建新的复合索引 - db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True) - @staticmethod def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 @@ -77,8 +59,15 @@ class ImageManager: Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) - return result["description"] if result else None + try: + record = ImageDescriptions.get_or_none( + (ImageDescriptions.hash == image_hash) & + (ImageDescriptions.type == description_type) + ) + return record.description if record else None + except Exception as e: + logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}") + return None @staticmethod def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None: @@ -90,20 +79,22 @@ class ImageManager: description_type: 描述类型 ('emoji' 或 'image') """ try: - db.image_descriptions.update_one( - {"hash": image_hash, "type": description_type}, - { - "$set": { - "description": description, - "timestamp": int(time.time()), - "hash": image_hash, # 确保hash字段存在 - "type": description_type, # 确保type字段存在 - } - }, - upsert=True, + current_timestamp = time.time() + defaults = { + 'description': description, + 'timestamp': current_timestamp + } + desc_obj, created = ImageDescriptions.get_or_create( + hash=image_hash, + type=description_type, + defaults=defaults ) + if not created: # 如果记录已存在,则更新 + desc_obj.description = description + desc_obj.timestamp = current_timestamp + desc_obj.save() except Exception as e: - logger.error(f"保存描述到数据库失败: {str(e)}") + logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") async def get_emoji_description(self, image_base64: str) -> str: """获取表情包描述,带查重和保存功能""" @@ -116,18 +107,25 @@ class ImageManager: # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: - # logger.debug(f"缓存表情包描述: {cached_description}") return f"[表情包,含义看起来是:{cached_description}]" # 调用AI获取描述 if image_format == "gif" or image_format == "GIF": - image_base64 = self.transform_gif(image_base64) + image_base64_processed = self.transform_gif(image_base64) + if image_base64_processed is None: + logger.warning("GIF转换失败,无法获取描述") + return "[表情包(GIF处理失败)]" prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些" - description, _ = await self._llm.generate_response_for_image(prompt, image_base64, "jpg") + description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg") else: prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些" description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + + if description is None: + logger.warning("AI未能生成表情包描述") + return "[表情包(描述生成失败)]" + # 再次检查缓存,防止并发写入时重复生成 cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") @@ -136,31 +134,37 @@ class ImageManager: # 根据配置决定是否保存图片 if global_config.save_emoji: # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")): - os.makedirs(os.path.join(self.IMAGE_DIR, "emoji")) - file_path = os.path.join(self.IMAGE_DIR, "emoji", filename) + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + emoji_dir = os.path.join(self.IMAGE_DIR, "emoji") + os.makedirs(emoji_dir, exist_ok=True) + file_path = os.path.join(emoji_dir, filename) try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 - image_doc = { - "hash": image_hash, - "path": file_path, - "type": "emoji", - "description": description, - "timestamp": timestamp, - } - db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) - logger.trace(f"保存表情包: {file_path}") + # 保存到数据库 (Images表) + try: + img_obj = Images.get((Images.hash == image_hash) & (Images.type == "emoji")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + hash=image_hash, + path=file_path, + type="emoji", + description=description, + timestamp=current_timestamp, + ) + logger.trace(f"保存表情包元数据: {file_path}") except Exception as e: - logger.error(f"保存表情包文件失败: {str(e)}") + logger.error(f"保存表情包文件或元数据失败: {str(e)}") - # 保存描述到数据库 + # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "emoji") return f"[表情包:{description}]" @@ -187,7 +191,12 @@ class ImageManager: "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多100个字。" ) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) + + if description is None: + logger.warning("AI未能生成图片描述") + return "[图片(描述生成失败)]" + # 再次检查缓存 cached_description = self._get_description_from_db(image_hash, "image") if cached_description: logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}") @@ -195,38 +204,40 @@ class ImageManager: logger.debug(f"描述是{description}") - if description is None: - logger.warning("AI未能生成图片描述") - return "[图片]" - # 根据配置决定是否保存图片 if global_config.save_pic: # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")): - os.makedirs(os.path.join(self.IMAGE_DIR, "image")) - file_path = os.path.join(self.IMAGE_DIR, "image", filename) + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + image_dir = os.path.join(self.IMAGE_DIR, "image") + os.makedirs(image_dir, exist_ok=True) + file_path = os.path.join(image_dir, filename) try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - # 保存到数据库 - image_doc = { - "hash": image_hash, - "path": file_path, - "type": "image", - "description": description, - "timestamp": timestamp, - } - db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) - logger.trace(f"保存图片: {file_path}") + # 保存到数据库 (Images表) + try: + img_obj = Images.get((Images.hash == image_hash) & (Images.type == "image")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + hash=image_hash, + path=file_path, + type="image", + description=description, + timestamp=current_timestamp, + ) + logger.trace(f"保存图片元数据: {file_path}") except Exception as e: - logger.error(f"保存图片文件失败: {str(e)}") + logger.error(f"保存图片文件或元数据失败: {str(e)}") - # 保存描述到数据库 + # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "image") return f"[图片:{description}]" diff --git a/src/common/database/database.py b/src/common/database/database.py index 752f746db..a2dab739d 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,5 +1,6 @@ import os from pymongo import MongoClient +from peewee import SqliteDatabase from pymongo.database import Database from rich.traceback import install @@ -57,4 +58,15 @@ class DBWrapper: # 全局数据库访问点 -db: Database = DBWrapper() +memory_db: Database = DBWrapper() + +# 定义数据库文件路径 +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +_DB_DIR = os.path.join(ROOT_PATH, "data") +_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db") + +# 确保数据库目录存在 +os.makedirs(_DB_DIR, exist_ok=True) + +# 全局 Peewee SQLite 数据库访问点 +db = SqliteDatabase(_DB_FILE) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index c1135a33d..b46cace9f 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,9 +1,10 @@ -from peewee import Model, DoubleField, IntegerField, SqliteDatabase, BooleanField, TextField, FloatField - +from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField +from .database import db +import datetime # 请在此处定义您的数据库实例。 # 您需要取消注释并配置适合您的数据库的部分。 # 例如,对于 SQLite: -db = SqliteDatabase('my_application.db') +# db = SqliteDatabase('MaiBot.db') # # 对于 PostgreSQL: # db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password', @@ -69,17 +70,16 @@ class LLMUsage(BaseModel): """ 用于存储 API 使用日志数据的模型。 """ - model_name = TextField() - user_id = TextField() - request_type = TextField() + model_name = TextField(index=True) # 添加索引 + user_id = TextField(index=True) # 添加索引 + request_type = TextField(index=True) # 添加索引 endpoint = TextField() prompt_tokens = IntegerField() completion_tokens = IntegerField() total_tokens = IntegerField() cost = DoubleField() status = TextField() - # timestamp: "$date": "2025-05-01T18:52:50.870Z" (存储为字符串) - timestamp = TextField() + timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 class Meta: # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 @@ -177,6 +177,8 @@ class OnlineTime(BaseModel): # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) timestamp = TextField() duration = IntegerField() # 时长,单位分钟 + start_timestamp = DateTimeField(default=datetime.datetime.now) + end_timestamp = DateTimeField(index=True) class Meta: # database = db # 继承自 BaseModel @@ -202,3 +204,39 @@ class PersonInfo(BaseModel): # database = db # 继承自 BaseModel table_name = 'person_info' +class Knowledges(BaseModel): + """ + 用于存储知识库条目的模型。 + """ + content = TextField() # 知识内容的文本 + embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 + # 可以添加其他元数据字段,如 source, create_time 等 + + class Meta: + # database = db # 继承自 BaseModel + table_name = 'knowledges' + + +class ThinkingLog(BaseModel): + chat_id = TextField(index=True) + trigger_text = TextField(null=True) + response_text = TextField(null=True) + + # Store complex dicts/lists as JSON strings + trigger_info_json = TextField(null=True) + response_info_json = TextField(null=True) + timing_results_json = TextField(null=True) + chat_history_json = TextField(null=True) + chat_history_in_thinking_json = TextField(null=True) + chat_history_after_response_json = TextField(null=True) + heartflow_data_json = TextField(null=True) + reasoning_data_json = TextField(null=True) + + # Add a timestamp for the log entry itself + # Ensure you have: from peewee import DateTimeField + # And: import datetime + created_at = DateTimeField(default=datetime.datetime.now) + + class Meta: + table_name = 'thinking_logs' + diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 03eaba332..7d987ace9 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,11 +1,19 @@ -from common.database.database import db +from src.common.database.database_model import Messages # 更改导入 from src.common.logger import get_module_logger import traceback from typing import List, Any, Optional +from peewee import Model # 添加 Peewee Model 导入 logger = get_module_logger(__name__) +def _model_to_dict(model_instance: Model) -> dict[str, Any]: + """ + 将 Peewee 模型实例转换为字典。 + """ + return model_instance.__data__ + + def find_messages( message_filter: dict[str, Any], sort: Optional[List[tuple[str, int]]] = None, @@ -16,39 +24,72 @@ def find_messages( 根据提供的过滤器、排序和限制条件查找消息。 Args: - message_filter: MongoDB 查询过滤器。 - sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值。 + sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。 limit: 返回的最大文档数,0表示不限制。 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。 Returns: - 消息文档列表,如果出错则返回空列表。 + 消息字典列表,如果出错则返回空列表。 """ try: - query = db.messages.find(message_filter) + query = Messages.select() + + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + conditions.append(getattr(Messages, key) == value) + else: + logger.warning( + f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。" + ) + if conditions: + # 使用 *conditions 将所有条件以 AND 连接 + query = query.where(*conditions) if limit > 0: if limit_mode == "earliest": # 获取时间最早的 limit 条记录,已经是正序 - query = query.sort([("time", 1)]).limit(limit) - results = list(query) + query = query.order_by(Messages.time.asc()).limit(limit) + peewee_results = list(query) else: # 默认为 'latest' # 获取时间最晚的 limit 条记录 - query = query.sort([("time", -1)]).limit(limit) - latest_results = list(query) + query = query.order_by(Messages.time.desc()).limit(limit) + latest_results_peewee = list(query) # 将结果按时间正序排列 - # 假设消息文档中总是有 'time' 字段且可排序 - results = sorted(latest_results, key=lambda msg: msg.get("time")) + peewee_results = sorted( + latest_results_peewee, key=lambda msg: msg.time + ) else: # limit 为 0 时,应用传入的 sort 参数 if sort: - query = query.sort(sort) - results = list(query) + peewee_sort_terms = [] + for field_name, direction in sort: + if hasattr(Messages, field_name): + field = getattr(Messages, field_name) + if direction == 1: # ASC + peewee_sort_terms.append(field.asc()) + elif direction == -1: # DESC + peewee_sort_terms.append(field.desc()) + else: + logger.warning( + f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。" + ) + else: + logger.warning( + f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。" + ) + if peewee_sort_terms: + query = query.order_by(*peewee_sort_terms) + peewee_results = list(query) + results = [_model_to_dict(msg) for msg in peewee_results] return results except Exception as e: log_message = ( - f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + traceback.format_exc() ) logger.error(log_message) @@ -60,18 +101,35 @@ def count_messages(message_filter: dict[str, Any]) -> int: 根据提供的过滤器计算消息数量。 Args: - message_filter: MongoDB 查询过滤器。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值。 Returns: 符合条件的消息数量,如果出错则返回 0。 """ try: - count = db.messages.count_documents(message_filter) + query = Messages.select() + + # 应用过滤器 + if message_filter: + conditions = [] + for key, value in message_filter.items(): + if hasattr(Messages, key): + conditions.append(getattr(Messages, key) == value) + else: + logger.warning( + f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。" + ) + if conditions: + query = query.where(*conditions) + + count = query.count() return count except Exception as e: - log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() + log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() logger.error(log_message) return 0 # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 +# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。 +# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。 diff --git a/src/experimental/PFC/chat_observer.py b/src/experimental/PFC/chat_observer.py index 704eeb330..e9e64053f 100644 --- a/src/experimental/PFC/chat_observer.py +++ b/src/experimental/PFC/chat_observer.py @@ -10,7 +10,7 @@ from src.experimental.PFC.chat_states import ( create_new_message_notification, create_cold_chat_notification, ) -from src.experimental.PFC.message_storage import MongoDBMessageStorage +from src.experimental.PFC.message_storage import PeeweeMessageStorage from rich.traceback import install install(extra_lines=3) @@ -53,7 +53,7 @@ class ChatObserver: self.stream_id = stream_id self.private_name = private_name - self.message_storage = MongoDBMessageStorage() + self.message_storage = PeeweeMessageStorage() # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 diff --git a/src/experimental/PFC/message_storage.py b/src/experimental/PFC/message_storage.py index 24866e38c..6e109fac3 100644 --- a/src/experimental/PFC/message_storage.py +++ b/src/experimental/PFC/message_storage.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any -from common.database.database import db +# from src.common.database.database import db # Peewee db 导入 +from src.common.database.database_model import Messages # Peewee Messages 模型导入 +from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典 class MessageStorage(ABC): @@ -47,28 +49,35 @@ class MessageStorage(ABC): pass -class MongoDBMessageStorage(MessageStorage): - """MongoDB消息存储实现""" +class PeeweeMessageStorage(MessageStorage): + """Peewee消息存储实现""" async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: - query = {"chat_id": chat_id, "time": {"$gt": message_time}} - # print(f"storage_check_message: {message_time}") + query = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.time > message_time) + ).order_by(Messages.time.asc()) - return list(db.messages.find(query).sort("time", 1)) + # print(f"storage_check_message: {message_time}") + messages_models = list(query) + return [model_to_dict(msg) for msg in messages_models] async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: - query = {"chat_id": chat_id, "time": {"$lt": time_point}} - - messages = list(db.messages.find(query).sort("time", -1).limit(limit)) + query = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.time < time_point) + ).order_by(Messages.time.desc()).limit(limit) + messages_models = list(query) # 将消息按时间正序排列 - messages.reverse() - return messages + messages_models.reverse() + return [model_to_dict(msg) for msg in messages_models] async def has_new_messages(self, chat_id: str, after_time: float) -> bool: - query = {"chat_id": chat_id, "time": {"$gt": after_time}} - - return db.messages.find_one(query) is not None + return Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.time > after_time) + ).exists() # # 创建一个内存消息存储实现,用于测试 diff --git a/src/tools/tool_can_use/get_knowledge.py b/src/tools/tool_can_use/get_knowledge.py index 2a4922f9f..4ff62b7c2 100644 --- a/src/tools/tool_can_use/get_knowledge.py +++ b/src/tools/tool_can_use/get_knowledge.py @@ -1,8 +1,10 @@ from src.tools.tool_can_use.base_tool import BaseTool from src.chat.utils.utils import get_embedding -from common.database.database import db +from src.common.database.database_model import Knowledges # Updated import from src.common.logger_manager import get_logger -from typing import Any, Union +from typing import Any, Union, List # Added List +import json # Added for parsing embedding +import math # Added for cosine similarity logger = get_logger("get_knowledge_tool") @@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool): Returns: dict: 工具执行结果 """ + query = "" # Initialize query to ensure it's defined in except block try: query = function_args.get("query") threshold = function_args.get("threshold", 0.4) @@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool): logger.error(f"知识库搜索工具执行失败: {str(e)}") return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} + @staticmethod + def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: + """计算两个向量之间的余弦相似度""" + dot_product = sum(p * q for p, q in zip(vec1, vec2)) + magnitude1 = math.sqrt(sum(p * p for p in vec1)) + magnitude2 = math.sqrt(sum(q * q for q in vec2)) + if magnitude1 == 0 or magnitude2 == 0: + return 0.0 + return dot_product / (magnitude1 * magnitude2) + @staticmethod def get_info_from_db( - query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False ) -> Union[str, list]: """从数据库中获取相关信息 @@ -66,66 +79,49 @@ class SearchKnowledgeTool(BaseTool): if not query_embedding: return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] + similar_items = [] + try: + all_knowledges = Knowledges.select() + for item in all_knowledges: + try: + item_embedding_str = item.embedding + if not item_embedding_str: + logger.warning(f"Knowledge item ID {item.id} has empty embedding string.") + continue + item_embedding = json.loads(item_embedding_str) + if not isinstance(item_embedding, list) or not all(isinstance(x, (int, float)) for x in item_embedding): + logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.") + continue + except json.JSONDecodeError: + logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}") + continue + except AttributeError: + logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.") + continue - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") + similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding) + + if similarity >= threshold: + similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item}) + + # 按相似度降序排序 + similar_items.sort(key=lambda x: x["similarity"], reverse=True) + + # 应用限制 + results = similar_items[:limit] + logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}") + + except Exception as e: + logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}") + return "" if not return_raw else [] if not results: return "" if not return_raw else [] if return_raw: - return results + # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理 + # 这里返回包含内容和相似度的字典列表 + return [{"content": r["content"], "similarity": r["similarity"]} for r in results] else: # 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in results) From 2051b011b12090ae9765b865a0f30131be608134 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 23:04:22 +0800 Subject: [PATCH 06/57] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=A1=A8=E5=88=9B=E5=BB=BA=E5=92=8C=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E5=8A=9F=E8=83=BD=EF=BC=8C=E7=A1=AE=E4=BF=9D?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=A1=A8=E5=AD=98=E5=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/utils/statistic.py | 145 ++++++++++++-------------- src/common/database/database_model.py | 60 +++++++++++ 2 files changed, 128 insertions(+), 77 deletions(-) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 9a0131f74..88329c3f4 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -6,7 +6,7 @@ from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask from ...common.database.database import db # This db is the Peewee database instance -from ...common.database.database_model import OnlineTime # Import the Peewee model +from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_module_logger("maibot_statistic") @@ -195,35 +195,28 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + # 排序-按照时间段开始时间降序排列(最晚的时间段在前) + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 总LLM请求数 TOTAL_REQ_CNT: 0, - # 请求次数统计 REQ_CNT_BY_TYPE: defaultdict(int), REQ_CNT_BY_USER: defaultdict(int), REQ_CNT_BY_MODEL: defaultdict(int), - # 输入Token数 IN_TOK_BY_TYPE: defaultdict(int), IN_TOK_BY_USER: defaultdict(int), IN_TOK_BY_MODEL: defaultdict(int), - # 输出Token数 OUT_TOK_BY_TYPE: defaultdict(int), OUT_TOK_BY_USER: defaultdict(int), OUT_TOK_BY_MODEL: defaultdict(int), - # 总Token数 TOTAL_TOK_BY_TYPE: defaultdict(int), TOTAL_TOK_BY_USER: defaultdict(int), TOTAL_TOK_BY_MODEL: defaultdict(int), - # 总开销 TOTAL_COST: 0.0, - # 请求开销统计 COST_BY_TYPE: defaultdict(float), COST_BY_USER: defaultdict(float), COST_BY_MODEL: defaultdict(float), @@ -232,26 +225,26 @@ class StatisticOutputTask(AsyncTask): } # 以最早的时间戳为起始时间获取记录 - for record in db.llm_usage.find({"timestamp": {"$gte": collect_period[-1][1]}}): - record_timestamp = record.get("timestamp") + # Assuming LLMUsage.timestamp is a DateTimeField + query_start_time = collect_period[-1][1] + for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): + record_timestamp = record.timestamp # This is already a datetime object for idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_REQ_CNT] += 1 - request_type = record.get("request_type", "unknown") # 请求类型 - user_id = str(record.get("user_id", "unknown")) # 用户ID - model_name = record.get("model_name", "unknown") # 模型名称 + request_type = record.request_type or "unknown" + user_id = record.user_id or "unknown" # user_id is TextField, already string + model_name = record.model_name or "unknown" stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 stats[period_key][REQ_CNT_BY_USER][user_id] += 1 stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 - prompt_tokens = record.get("prompt_tokens", 0) # 输入Token数 - completion_tokens = record.get("completion_tokens", 0) # 输出Token数 - total_tokens = prompt_tokens + completion_tokens # Token总数 = 输入Token数 + 输出Token数 + prompt_tokens = record.prompt_tokens or 0 + completion_tokens = record.completion_tokens or 0 + total_tokens = prompt_tokens + completion_tokens stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens @@ -265,13 +258,12 @@ class StatisticOutputTask(AsyncTask): stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens - cost = record.get("cost", 0.0) + cost = record.cost or 0.0 stats[period_key][TOTAL_COST] += cost stats[period_key][COST_BY_TYPE][request_type] += cost stats[period_key][COST_BY_USER][user_id] += cost stats[period_key][COST_BY_MODEL][model_name] += cost - break # 取消更早时间段的判断 - + break return stats @staticmethod @@ -281,39 +273,38 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 在线时间统计 ONLINE_TIME: 0.0, } for period_key, _ in collect_period } - # 统计在线时间 - for record in db.online_time.find({"end_timestamp": {"$gte": collect_period[-1][1]}}): - end_timestamp: datetime = record.get("end_timestamp") - for idx, (_, period_start) in enumerate(collect_period): - if end_timestamp >= period_start: - # 由于end_timestamp会超前标记时间,所以我们需要判断是否晚于当前时间,如果是,则使用当前时间作为结束时间 - end_timestamp = min(end_timestamp, now) - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 - for period_key, _period_start in collect_period[idx:]: - start_timestamp: datetime = record.get("start_timestamp") - if start_timestamp < _period_start: - # 如果开始时间在查询边界之前,则使用开始时间 - stats[period_key][ONLINE_TIME] += (end_timestamp - _period_start).total_seconds() - else: - # 否则,使用开始时间 - stats[period_key][ONLINE_TIME] += (end_timestamp - start_timestamp).total_seconds() - break # 取消更早时间段的判断 + query_start_time = collect_period[-1][1] + # Assuming OnlineTime.end_timestamp is a DateTimeField + for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): + # record.end_timestamp and record.start_timestamp are datetime objects + record_end_timestamp = record.end_timestamp + record_start_timestamp = record.start_timestamp + for idx, (_, period_boundary_start) in enumerate(collect_period): + if record_end_timestamp >= period_boundary_start: + # Calculate effective end time for this record in relation to 'now' + effective_end_time = min(record_end_timestamp, now) + + for period_key, current_period_start_time in collect_period[idx:]: + # Determine the portion of the record that falls within this specific statistical period + overlap_start = max(record_start_timestamp, current_period_start_time) + overlap_end = effective_end_time # Already capped by 'now' and record's own end + + if overlap_end > overlap_start: + stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() + break return stats def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: @@ -322,55 +313,55 @@ class StatisticOutputTask(AsyncTask): :param collect_period: 统计时间段 """ - if len(collect_period) <= 0: + if not collect_period: return {} - else: - # 排序-按照时间段开始时间降序排列(最晚的时间段在前) - collect_period.sort(key=lambda x: x[1], reverse=True) + + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { period_key: { - # 消息统计 TOTAL_MSG_CNT: 0, MSG_CNT_BY_CHAT: defaultdict(int), } for period_key, _ in collect_period } - # 统计消息量 - for message in db.messages.find({"time": {"$gte": collect_period[-1][1].timestamp()}}): - chat_info = message.get("chat_info", None) # 聊天信息 - user_info = message.get("user_info", None) # 用户信息(消息发送人) - message_time = message.get("time", 0) # 消息时间 + query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) + for message in Messages.select().where(Messages.time >= query_start_timestamp): + message_time_ts = message.time # This is a float timestamp - group_info = chat_info.get("group_info") if chat_info else None # 尝试获取群聊信息 - if group_info is not None: - # 若有群聊信息 - chat_id = f"g{group_info.get('group_id')}" - chat_name = group_info.get("group_name", f"群{group_info.get('group_id')}") - elif user_info: - # 若没有群聊信息,则尝试获取用户信息 - chat_id = f"u{user_info['user_id']}" - chat_name = user_info["user_nickname"] + chat_id = None + chat_name = None + + # Logic based on Peewee model structure, aiming to replicate original intent + if message.chat_info_group_id: + chat_id = f"g{message.chat_info_group_id}" + chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}" + elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat + # This uses the message SENDER's ID as per original logic's fallback + chat_id = f"u{message.user_id}" # SENDER's user_id + chat_name = message.user_nickname # SENDER's nickname else: - continue # 如果没有群组信息也没有用户信息,则跳过 + # If neither group_id nor sender_id is available for chat identification + logger.warning(f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats.") + continue + + if not chat_id: # Should not happen if above logic is correct + continue + # Update name_mapping if chat_id in self.name_mapping: - if chat_name != self.name_mapping[chat_id][0] and message_time > self.name_mapping[chat_id][1]: - # 如果用户名称不同,且新消息时间晚于之前记录的时间,则更新用户名称 - self.name_mapping[chat_id] = (chat_name, message_time) + if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]: + self.name_mapping[chat_id] = (chat_name, message_time_ts) else: - self.name_mapping[chat_id] = (chat_name, message_time) + self.name_mapping[chat_id] = (chat_name, message_time_ts) - for idx, (_, period_start) in enumerate(collect_period): - if message_time >= period_start.timestamp(): - # 如果记录时间在当前时间段内,则它一定在更早的时间段内 - # 因此,我们可以直接跳过更早的时间段的判断,直接更新当前以及更早时间段的统计数据 + for idx, (_, period_start_dt) in enumerate(collect_period): + if message_time_ts >= period_start_dt.timestamp(): for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 break - return stats def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index b46cace9f..89e047414 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -240,3 +240,63 @@ class ThinkingLog(BaseModel): class Meta: table_name = 'thinking_logs' +def create_tables(): + """ + 创建所有在模型中定义的数据库表。 + """ + with db: + db.create_tables([ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog + ]) + +def initialize_database(): + """ + 检查所有定义的表是否存在,如果不存在则创建它们。 + """ + models = [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog + ] + + needs_creation = False + try: + with db: # 管理 table_exists 检查的连接 + for model in models: + if not db.table_exists(model): + print(f"表 '{model._meta.table_name}' 未找到。") + needs_creation = True + break # 一个表丢失,无需进一步检查。 + except Exception as e: + print(f"检查表是否存在时出错: {e}") + # 如果检查失败(例如数据库不可用),则退出 + return + + if needs_creation: + print("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...") + try: + create_tables() # 此函数有其自己的 'with db:' 上下文管理。 + print("数据库表创建过程完成。") + except Exception as e: + print(f"创建表期间出错: {e}") + else: + print("所有数据库表均已存在。") + +# 模块加载时调用初始化函数 +initialize_database() From 0352340b72dd5d621811bcb881921fb883d84d45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 23:09:14 +0800 Subject: [PATCH 07/57] fix: Ruff --- src/chat/focus_chat/planners/action_factory.py | 6 +----- src/chat/focus_chat/planners/actions/base_action.py | 4 ++-- src/chat/focus_chat/planners/actions/reply_action.py | 4 ++-- src/chat/focus_chat/planners/planner.py | 2 +- src/chat/utils/chat_message_builder.py | 4 ++-- 5 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/chat/focus_chat/planners/action_factory.py b/src/chat/focus_chat/planners/action_factory.py index 257156a25..2aedecb22 100644 --- a/src/chat/focus_chat/planners/action_factory.py +++ b/src/chat/focus_chat/planners/action_factory.py @@ -1,6 +1,4 @@ -from typing import Dict, List, Optional, Callable, Coroutine, Type, Any, Union -import os -import importlib +from typing import Dict, List, Optional, Callable, Coroutine, Type, Any from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY, _DEFAULT_ACTIONS from src.chat.heart_flow.observation.observation import Observation from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor @@ -9,8 +7,6 @@ from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.common.logger_manager import get_logger # 导入动作类,确保装饰器被执行 -from src.chat.focus_chat.planners.actions.reply_action import ReplyAction -from src.chat.focus_chat.planners.actions.no_reply_action import NoReplyAction logger = get_logger("action_factory") diff --git a/src/chat/focus_chat/planners/actions/base_action.py b/src/chat/focus_chat/planners/actions/base_action.py index 7c77c300c..629bdbd43 100644 --- a/src/chat/focus_chat/planners/actions/base_action.py +++ b/src/chat/focus_chat/planners/actions/base_action.py @@ -25,8 +25,8 @@ def register_action(cls): logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description") return cls - action_name = getattr(cls, "action_name") - action_description = getattr(cls, "action_description") + action_name = getattr(cls, "action_name") #noqa + action_description = getattr(cls, "action_description") #noqa is_default = getattr(cls, "default", False) if not action_name or not action_description: diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 7b2e88fa0..80624d487 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -2,9 +2,9 @@ # -*- coding: utf-8 -*- from src.common.logger_manager import get_logger -from src.chat.utils.timer_calculator import Timer +# from src.chat.utils.timer_calculator import Timer from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action -from typing import Tuple, List, Optional +from typing import Tuple, List from src.chat.heart_flow.observation.observation import Observation from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index bb87e1da7..cc79635d8 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -4,7 +4,7 @@ from typing import List, Dict, Any, Optional from rich.traceback import install from src.chat.models.utils_model import LLMRequest from src.config.config import global_config -from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder +# from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.obs_info import ObsInfo from src.chat.focus_chat.info.cycle_info import CycleInfo diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 15b1e4fc6..5a442615d 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: def reply_replacer(match): # aaa = match.group(1) bbb = match.group(2) - anon_reply = get_anon_name(platform, bbb) + anon_reply = get_anon_name(platform, bbb) #noqa return f"回复 {anon_reply}" content = re.sub(reply_pattern, reply_replacer, content, count=1) @@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: def at_replacer(match): # aaa = match.group(1) bbb = match.group(2) - anon_at = get_anon_name(platform, bbb) + anon_at = get_anon_name(platform, bbb) #noqa return f"@{anon_at}" content = re.sub(at_pattern, at_replacer, content) From 17d19e7cacc6b37538e867ed32026384104bd0fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 14 May 2025 23:11:19 +0800 Subject: [PATCH 08/57] fix: Ruff x2 --- src/chat/focus_chat/planners/action_factory.py | 2 +- src/chat/focus_chat/planners/planner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/chat/focus_chat/planners/action_factory.py b/src/chat/focus_chat/planners/action_factory.py index 2aedecb22..bf77140cb 100644 --- a/src/chat/focus_chat/planners/action_factory.py +++ b/src/chat/focus_chat/planners/action_factory.py @@ -1,5 +1,5 @@ from typing import Dict, List, Optional, Callable, Coroutine, Type, Any -from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY, _DEFAULT_ACTIONS +from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY from src.chat.heart_flow.observation.observation import Observation from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index cc79635d8..05c9276c6 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -103,7 +103,7 @@ class ActionPlanner: cycle_info = info.get_observe_info() elif isinstance(info, StructuredInfo): logger.debug(f"{self.log_prefix} 结构化信息: {info}") - structured_info = info.get_data() + _structured_info = info.get_data() current_available_actions = self.action_manager.get_using_actions() From fb6094d269e3b1d096261de8d6aec75ca765cdec Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 14 May 2025 15:11:33 +0000 Subject: [PATCH 09/57] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 55 +++++----- src/chat/focus_chat/heartFC_chat.py | 3 - .../focus_chat/heartflow_prompt_builder.py | 26 +++-- .../focus_chat/planners/action_factory.py | 77 ++++++------- .../planners/actions/base_action.py | 33 +++--- .../planners/actions/no_reply_action.py | 4 +- .../planners/actions/reply_action.py | 17 +-- src/chat/focus_chat/planners/planner.py | 30 +++-- src/chat/person_info/person_info.py | 23 +++- src/chat/utils/chat_message_builder.py | 4 +- src/chat/utils/info_catcher.py | 30 ++--- src/chat/utils/statistic.py | 56 +++++----- src/chat/utils/utils_image.py | 20 ++-- src/common/database/database_model.py | 103 +++++++++++------- src/common/message_repository.py | 22 +--- src/experimental/PFC/message_storage.py | 25 +++-- src/tools/tool_can_use/get_knowledge.py | 4 +- 17 files changed, 278 insertions(+), 254 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 68fa5de44..77835d1fb 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -148,20 +148,21 @@ class MaiEmoji: # 准备数据库记录 for emoji collection emotion_str = ",".join(self.emotion) if self.emotion else "" - Emoji.create(hash=self.hash, - full_path=self.full_path, - format=self.format, - description=self.description, - emotion=emotion_str, # Store as comma-separated string - query_count=0, # Default value - is_registered=True, - is_banned=False, # Default value - record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time - register_time=self.register_time, - usage_count=self.usage_count, - last_used_time=self.last_used_time, - ) - + Emoji.create( + hash=self.hash, + full_path=self.full_path, + format=self.format, + description=self.description, + emotion=emotion_str, # Store as comma-separated string + query_count=0, # Default value + is_registered=True, + is_banned=False, # Default value + record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time + register_time=self.register_time, + usage_count=self.usage_count, + last_used_time=self.last_used_time, + ) + logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})") return True @@ -197,10 +198,10 @@ class MaiEmoji: # 2. 删除数据库记录 try: will_delete_emoji = Emoji.get(Emoji.hash == self.hash) - result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. + result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. except Emoji.DoesNotExist: logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") - result = 0 # Indicate no DB record was deleted + result = 0 # Indicate no DB record was deleted if result > 0: logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})") @@ -245,12 +246,14 @@ def _to_emoji_objects(data): emoji_objects = [] load_errors = 0 # data is now an iterable of Peewee Emoji model instances - emoji_data_list = list(data) + emoji_data_list = list(data) - for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance + for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance full_path = emoji_data.full_path if not full_path: - logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}") + logger.warning( + f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}" + ) load_errors += 1 continue @@ -265,9 +268,9 @@ def _to_emoji_objects(data): emoji.description = emoji_data.description # Deserialize emotion string from DB to list - emoji.emotion = emoji_data.emotion.split(',') if emoji_data.emotion else [] + emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else [] emoji.usage_count = emoji_data.usage_count - + db_last_used_time = emoji_data.last_used_time db_register_time = emoji_data.register_time @@ -275,7 +278,7 @@ def _to_emoji_objects(data): emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time # If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time()) emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time - + emoji.format = emoji_data.format emoji_objects.append(emoji) @@ -385,7 +388,7 @@ class EmojiManager: # Ensure Peewee database connection is up and tables are created if not peewee_db.is_closed(): peewee_db.connect(reuse_if_open=True) - Emoji.create_table(safe=True) # Ensures table exists + Emoji.create_table(safe=True) # Ensures table exists _ensure_emoji_dir() self._initialized = True @@ -404,8 +407,8 @@ class EmojiManager: try: emoji_update = Emoji.get(Emoji.hash == emoji_hash) emoji_update.usage_count += 1 - emoji_update.last_used_time = time.time() # Update last used time - emoji_update.save() # Persist changes to DB + emoji_update.last_used_time = time.time() # Update last used time + emoji_update.save() # Persist changes to DB except Emoji.DoesNotExist: logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包") except Exception as e: @@ -674,7 +677,7 @@ class EmojiManager: "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" ) query = Emoji.select() - + emoji_peewee_instances = query emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances) diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 4a28652d1..ff4f7fdb0 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -91,7 +91,6 @@ class HeartFChatting: self.action_manager = ActionManager() self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager) - # --- 处理器列表 --- self.processors: List[BaseProcessor] = [] self._register_default_processors() @@ -526,5 +525,3 @@ class HeartFChatting: if last_n is not None: history = history[-last_n:] return [cycle.to_dict() for cycle in history] - - diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 141d850ab..945f587bb 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -7,12 +7,14 @@ from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time from typing import Union, Optional + # from common.database.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.focus_chat.expressors.exprssion_learner import expression_learner + # import traceback import random import json @@ -614,7 +616,7 @@ class PromptBuilder: return "" if not return_raw else [] query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding)) - if query_embedding_magnitude == 0: # Avoid division by zero + if query_embedding_magnitude == 0: # Avoid division by zero return "" if not return_raw else [] for knowledge_item in all_knowledges: @@ -623,35 +625,35 @@ class PromptBuilder: db_embedding = json.loads(db_embedding_str) if len(db_embedding) != len(query_embedding): - logger.warning(f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping.") + logger.warning( + f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping." + ) continue - + # Calculate Cosine Similarity dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding)) db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding)) - if db_embedding_magnitude == 0: # Avoid division by zero + if db_embedding_magnitude == 0: # Avoid division by zero similarity = 0.0 else: similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude) - + if similarity >= threshold: - results_with_similarity.append({ - "content": knowledge_item.content, - "similarity": similarity - }) + results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity}) except json.JSONDecodeError: - logger.error(f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}") + logger.error( + f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}" + ) except Exception as e: logger.error(f"Error processing knowledge item: {e}") - # Sort by similarity in descending order results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True) # Limit results limited_results = results_with_similarity[:limit] - + logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}") if not limited_results: diff --git a/src/chat/focus_chat/planners/action_factory.py b/src/chat/focus_chat/planners/action_factory.py index bf77140cb..bca49c496 100644 --- a/src/chat/focus_chat/planners/action_factory.py +++ b/src/chat/focus_chat/planners/action_factory.py @@ -27,20 +27,19 @@ class ActionManager: self._using_actions: Dict[str, ActionInfo] = {} # 临时备份原始使用中的动作 self._original_actions_backup: Optional[Dict[str, ActionInfo]] = None - + # 默认动作集,仅作为快照,用于恢复默认 self._default_actions: Dict[str, ActionInfo] = {} - + # 加载所有已注册动作 self._load_registered_actions() - + # 初始化时将默认动作加载到使用中的动作 self._using_actions = self._default_actions.copy() - + # logger.info(f"当前可用动作: {list(self._using_actions.keys())}") # for action_name, action_info in self._using_actions.items(): - # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") - + # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") def _load_registered_actions(self) -> None: """ @@ -50,35 +49,35 @@ class ActionManager: # 从_ACTION_REGISTRY获取所有已注册动作 for action_name, action_class in _ACTION_REGISTRY.items(): # 获取动作相关信息 - action_description:str = getattr(action_class, "action_description", "") - action_parameters:dict[str:str] = getattr(action_class, "action_parameters", {}) - action_require:list[str] = getattr(action_class, "action_require", []) - is_default:bool = getattr(action_class, "default", False) - + action_description: str = getattr(action_class, "action_description", "") + action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {}) + action_require: list[str] = getattr(action_class, "action_require", []) + is_default: bool = getattr(action_class, "default", False) + if action_name and action_description: # 创建动作信息字典 action_info = { "description": action_description, "parameters": action_parameters, - "require": action_require + "require": action_require, } - + # 注册2 print("注册2") print(action_info) - + # 添加到所有已注册的动作 self._registered_actions[action_name] = action_info - + # 添加到默认动作(如果是默认动作) if is_default: self._default_actions[action_name] = action_info - + logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") logger.info(f"默认动作: {list(self._default_actions.keys())}") # for action_name, action_info in self._default_actions.items(): - # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") - + # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") + except Exception as e: logger.error(f"加载已注册动作失败: {e}") @@ -125,7 +124,7 @@ class ActionManager: if action_name not in self._using_actions: logger.warning(f"当前不可用的动作类型: {action_name}") return None - + handler_class = _ACTION_REGISTRY.get(action_name) if not handler_class: logger.warning(f"未注册的动作类型: {action_name}") @@ -149,7 +148,7 @@ class ActionManager: expressor=expressor, chat_stream=chat_stream, ) - + return instance except Exception as e: @@ -163,7 +162,7 @@ class ActionManager: def get_default_actions(self) -> Dict[str, ActionInfo]: """获取默认动作集""" return self._default_actions.copy() - + def get_using_actions(self) -> Dict[str, ActionInfo]: """获取当前正在使用的动作集""" return self._using_actions.copy() @@ -171,21 +170,21 @@ class ActionManager: def add_action_to_using(self, action_name: str) -> bool: """ 添加已注册的动作到当前使用的动作集 - + Args: action_name: 动作名称 - + Returns: bool: 添加是否成功 """ if action_name not in self._registered_actions: logger.warning(f"添加失败: 动作 {action_name} 未注册") return False - + if action_name in self._using_actions: logger.info(f"动作 {action_name} 已经在使用中") return True - + self._using_actions[action_name] = self._registered_actions[action_name] logger.info(f"添加动作 {action_name} 到使用集") return True @@ -193,17 +192,17 @@ class ActionManager: def remove_action_from_using(self, action_name: str) -> bool: """ 从当前使用的动作集中移除指定动作 - + Args: action_name: 动作名称 - + Returns: bool: 移除是否成功 """ if action_name not in self._using_actions: logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中") return False - + del self._using_actions[action_name] logger.info(f"已从使用集中移除动作 {action_name}") return True @@ -211,30 +210,26 @@ class ActionManager: def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: """ 添加新的动作到注册集 - + Args: action_name: 动作名称 description: 动作描述 parameters: 动作参数定义,默认为空字典 require: 动作依赖项,默认为空列表 - + Returns: bool: 添加是否成功 """ if action_name in self._registered_actions: return False - + if parameters is None: parameters = {} if require is None: require = [] - - action_info = { - "description": description, - "parameters": parameters, - "require": require - } - + + action_info = {"description": description, "parameters": parameters, "require": require} + self._registered_actions[action_name] = action_info return True @@ -260,7 +255,7 @@ class ActionManager: if self._original_actions_backup is not None: self._using_actions = self._original_actions_backup.copy() self._original_actions_backup = None - + def restore_default_actions(self) -> None: """恢复默认动作集到使用集""" self._using_actions = self._default_actions.copy() @@ -269,10 +264,10 @@ class ActionManager: def get_action(self, action_name: str) -> Optional[Type[BaseAction]]: """ 获取指定动作的处理器类 - + Args: action_name: 动作名称 - + Returns: Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None """ diff --git a/src/chat/focus_chat/planners/actions/base_action.py b/src/chat/focus_chat/planners/actions/base_action.py index 629bdbd43..d9e619d13 100644 --- a/src/chat/focus_chat/planners/actions/base_action.py +++ b/src/chat/focus_chat/planners/actions/base_action.py @@ -12,7 +12,7 @@ _DEFAULT_ACTIONS: Dict[str, str] = {} def register_action(cls): """ 动作注册装饰器 - + 用法: @register_action class MyAction(BaseAction): @@ -24,22 +24,22 @@ def register_action(cls): if not hasattr(cls, "action_name") or not hasattr(cls, "action_description"): logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description") return cls - - action_name = getattr(cls, "action_name") #noqa - action_description = getattr(cls, "action_description") #noqa + + action_name = getattr(cls, "action_name") # noqa + action_description = getattr(cls, "action_description") # noqa is_default = getattr(cls, "default", False) - + if not action_name or not action_description: logger.error(f"动作类 {cls.__name__} 的 action_name 或 action_description 为空") return cls - + # 将动作类注册到全局注册表 _ACTION_REGISTRY[action_name] = cls - + # 如果是默认动作,添加到默认动作集 if is_default: _DEFAULT_ACTIONS[action_name] = action_description - + logger.info(f"已注册动作: {action_name} -> {cls.__name__},默认: {is_default}") return cls @@ -60,15 +60,14 @@ class BaseAction(ABC): cycle_timers: 计时器字典 thinking_id: 思考ID """ - #每个动作必须实现 - self.action_name:str = "base_action" - self.action_description:str = "基础动作" - self.action_parameters:dict = {} - self.action_require:list[str] = [] - - self.default:bool = False - - + # 每个动作必须实现 + self.action_name: str = "base_action" + self.action_description: str = "基础动作" + self.action_parameters: dict = {} + self.action_require: list[str] = [] + + self.default: bool = False + self.action_data = action_data self.reasoning = reasoning self.cycle_timers = cycle_timers diff --git a/src/chat/focus_chat/planners/actions/no_reply_action.py b/src/chat/focus_chat/planners/actions/no_reply_action.py index a29812c7a..71f1cb3f3 100644 --- a/src/chat/focus_chat/planners/actions/no_reply_action.py +++ b/src/chat/focus_chat/planners/actions/no_reply_action.py @@ -29,7 +29,7 @@ class NoReplyAction(BaseAction): action_require = [ "话题无关/无聊/不感兴趣/不懂", "最后一条消息是你自己发的且无人回应你", - "你发送了太多消息,且无人回复" + "你发送了太多消息,且无人回复", ] default = True @@ -46,7 +46,7 @@ class NoReplyAction(BaseAction): total_no_reply_count: int = 0, total_waiting_time: float = 0.0, shutting_down: bool = False, - **kwargs + **kwargs, ): """初始化不回复动作处理器 diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 80624d487..3f8ca49a0 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- from src.common.logger_manager import get_logger + # from src.chat.utils.timer_calculator import Timer from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action from typing import Tuple, List @@ -22,14 +23,14 @@ class ReplyAction(BaseAction): 处理构建和发送消息回复的动作。 """ - action_name:str = "reply" - action_description:str = "表达想法,可以只包含文本、表情或两者都有" - action_parameters:dict[str:str] = { + action_name: str = "reply" + action_description: str = "表达想法,可以只包含文本、表情或两者都有" + action_parameters: dict[str:str] = { "text": "你想要表达的内容(可选)", "emojis": "描述当前使用表情包的场景(可选)", "target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)", } - action_require:list[str] = [ + action_require: list[str] = [ "有实质性内容需要表达", "有人提到你,但你还没有回应他", "在合适的时候添加表情(不要总是添加)", @@ -38,7 +39,7 @@ class ReplyAction(BaseAction): "一次只回复一个人,一次只回复一个话题,突出重点", "如果是自己发的消息想继续,需自然衔接", "避免重复或评价自己的发言,不要和自己聊天", - "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。" + "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。", ] default = True @@ -54,7 +55,7 @@ class ReplyAction(BaseAction): chat_stream: ChatStream, current_cycle: CycleDetail, log_prefix: str, - **kwargs + **kwargs, ): """初始化回复动作处理器 @@ -89,9 +90,9 @@ class ReplyAction(BaseAction): reasoning=self.reasoning, reply_data=self.action_data, cycle_timers=self.cycle_timers, - thinking_id=self.thinking_id + thinking_id=self.thinking_id, ) - + async def _handle_reply( self, reasoning: str, reply_data: dict, cycle_timers: dict, thinking_id: str ) -> tuple[bool, str]: diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index 05c9276c6..c87732d23 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -4,6 +4,7 @@ from typing import List, Dict, Any, Optional from rich.traceback import install from src.chat.models.utils_model import LLMRequest from src.config.config import global_config + # from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.obs_info import ObsInfo @@ -15,10 +16,12 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.individuality.individuality import Individuality from src.chat.focus_chat.planners.action_factory import ActionManager from src.chat.focus_chat.planners.action_factory import ActionInfo + logger = get_logger("planner") install(extra_lines=3) + def init_prompt(): Prompt( """你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话: @@ -44,8 +47,9 @@ def init_prompt(): }} 请输出你的决策 JSON:""", -"planner_prompt",) - + "planner_prompt", + ) + Prompt( """ action_name: {action_name} @@ -57,7 +61,7 @@ action_name: {action_name} """, "action_prompt", ) - + class ActionPlanner: def __init__(self, log_prefix: str, action_manager: ActionManager): @@ -68,7 +72,7 @@ class ActionPlanner: max_tokens=1000, request_type="action_planning", # 用于动作规划 ) - + self.action_manager = action_manager async def plan(self, all_plan_info: List[InfoBase], cycle_timers: dict) -> Dict[str, Any]: @@ -106,7 +110,7 @@ class ActionPlanner: _structured_info = info.get_data() current_available_actions = self.action_manager.get_using_actions() - + # --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- prompt = await self.build_planner_prompt( is_group_chat=is_group_chat, # <-- Pass HFC state @@ -197,7 +201,6 @@ class ActionPlanner: # 返回结果字典 return plan_result - async def build_planner_prompt( self, is_group_chat: bool, # Now passed as argument @@ -218,7 +221,6 @@ class ActionPlanner: ) chat_context_description = f"你正在和 {chat_target_name} 私聊" - chat_content_block = "" if observed_messages_str: chat_content_block = f"聊天记录:\n{observed_messages_str}" @@ -234,7 +236,6 @@ class ActionPlanner: individuality = Individuality.get_instance() personality_block = individuality.get_prompt(x_person=2, level=2) - action_options_block = "" for using_actions_name, using_actions_info in current_available_actions.items(): # print(using_actions_name) @@ -242,29 +243,26 @@ class ActionPlanner: # print(using_actions_info["parameters"]) # print(using_actions_info["require"]) # print(using_actions_info["description"]) - + using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt") - + param_text = "" for param_name, param_description in using_actions_info["parameters"].items(): param_text += f"{param_name}: {param_description}\n" - + require_text = "" for require_item in using_actions_info["require"]: require_text += f"- {require_item}\n" - + using_action_prompt = using_action_prompt.format( action_name=using_actions_name, action_description=using_actions_info["description"], action_parameters=param_text, action_require=require_text, ) - + action_options_block += using_action_prompt - - - planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") prompt = planner_prompt_template.format( bot_name=global_config.BOT_NICKNAME, diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index cd9034d6f..1ad2f358f 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -261,7 +261,9 @@ class PersonInfoManager: qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}," qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸" - qv_name_prompt += "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改" + qv_name_prompt += ( + "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改" + ) if existing_names_str: qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n" @@ -289,6 +291,7 @@ class PersonInfoManager: if generated_nickname in current_name_set: is_duplicate = True else: + def _db_check_name_exists_sync(name_to_check): return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists() @@ -415,7 +418,9 @@ class PersonInfoManager: @staticmethod async def del_all_undefined_field(): """删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。""" - logger.info("del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。") + logger.info( + "del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。" + ) return @staticmethod @@ -512,7 +517,9 @@ class PersonInfoManager: if trimmed_interval: msg_interval_val = int(round(np.percentile(trimmed_interval, 37))) await self.update_one_field(person_id, "msg_interval", msg_interval_val) - logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}") + logger.trace( + f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}" + ) else: logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval") else: @@ -577,13 +584,17 @@ class PersonInfoManager: break if not found_person_id: + def _db_find_by_name_sync(p_name_to_find: str): return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find) record = await asyncio.to_thread(_db_find_by_name_sync, person_name) if record: found_person_id = record.person_id - if found_person_id not in self.person_name_list or self.person_name_list[found_person_id] != person_name: + if ( + found_person_id not in self.person_name_list + or self.person_name_list[found_person_id] != person_name + ): self.person_name_list[found_person_id] = person_name else: logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)") @@ -600,7 +611,9 @@ class PersonInfoManager: "person_name", "name_reason", ] - valid_fields_to_get = [f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default] + valid_fields_to_get = [ + f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default + ] person_data = await self.get_values(found_person_id, valid_fields_to_get) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 5a442615d..42fed2f80 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: def reply_replacer(match): # aaa = match.group(1) bbb = match.group(2) - anon_reply = get_anon_name(platform, bbb) #noqa + anon_reply = get_anon_name(platform, bbb) # noqa return f"回复 {anon_reply}" content = re.sub(reply_pattern, reply_replacer, content, count=1) @@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: def at_replacer(match): # aaa = match.group(1) bbb = match.group(2) - anon_at = get_anon_name(platform, bbb) #noqa + anon_at = get_anon_name(platform, bbb) # noqa return f"@{anon_at}" content = re.sub(at_pattern, at_replacer, content) diff --git a/src/chat/utils/info_catcher.py b/src/chat/utils/info_catcher.py index fb8182973..07caa06c0 100644 --- a/src/chat/utils/info_catcher.py +++ b/src/chat/utils/info_catcher.py @@ -103,11 +103,11 @@ class InfoCatcher: print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") - messages_between_query = Messages.select().where( - (Messages.chat_id == chat_id) & - (Messages.time > time_start) & - (Messages.time < time_end) - ).order_by(Messages.time.desc()) + messages_between_query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end)) + .order_by(Messages.time.desc()) + ) result = list(messages_between_query) print(f"查询结果数量: {len(result)}") @@ -124,10 +124,12 @@ class InfoCatcher: message_id_val = message.message_info.message_id chat_id_val = message.chat_stream.stream_id - messages_before_query = Messages.select().where( - (Messages.chat_id == chat_id_val) & - (Messages.message_id < message_id_val) - ).order_by(Messages.time.desc()).limit(self.context_length * 3) + messages_before_query = ( + Messages.select() + .where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val)) + .order_by(Messages.time.desc()) + .limit(self.context_length * 3) + ) return list(messages_before_query) @@ -137,7 +139,7 @@ class InfoCatcher: processed_msg_item = msg_item if not isinstance(msg_item, dict): processed_msg_item = self.message_to_dict(msg_item) - + if not processed_msg_item: continue @@ -163,15 +165,15 @@ class InfoCatcher: "user_nickname": msg_obj.user_nickname, "processed_plain_text": msg_obj.processed_plain_text, } - - if hasattr(msg_obj, 'message_info') and hasattr(msg_obj.message_info, 'user_info'): + + if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"): return { "time": msg_obj.message_info.time, "user_id": msg_obj.message_info.user_info.user_id, "user_nickname": msg_obj.message_info.user_info.user_nickname, "processed_plain_text": msg_obj.processed_plain_text, } - + print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}") return {} @@ -198,7 +200,7 @@ class InfoCatcher: chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list), chat_history_after_response_json=json.dumps(chat_history_after_response_list), heartflow_data_json=json.dumps(self.heartflow_data), - reasoning_data_json=json.dumps(self.reasoning_data) + reasoning_data_json=json.dumps(self.reasoning_data), ) log_entry.save() diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 88329c3f4..cb202c520 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -5,8 +5,8 @@ from typing import Any, Dict, Tuple, List from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask -from ...common.database.database import db # This db is the Peewee database instance -from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model +from ...common.database.database import db # This db is the Peewee database instance +from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model from src.manager.local_store_manager import local_storage logger = get_module_logger("maibot_statistic") @@ -48,8 +48,8 @@ class OnlineTimeRecordTask(AsyncTask): @staticmethod def _init_database(): """初始化数据库""" - with db.atomic(): # Use atomic operations for schema changes - OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model + with db.atomic(): # Use atomic operations for schema changes + OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model async def run(self): try: @@ -62,14 +62,17 @@ class OnlineTimeRecordTask(AsyncTask): updated_rows = query.execute() if updated_rows == 0: # Record might have been deleted or ID is stale, try to find/create - self.record_id = None # Reset record_id to trigger find/create logic below - - if not self.record_id: # Check again if record_id was reset or initially None + self.record_id = None # Reset record_id to trigger find/create logic below + + if not self.record_id: # Check again if record_id was reset or initially None # 如果没有记录,检查一分钟以内是否已有记录 # Look for a record whose end_timestamp is recent enough to be considered ongoing - recent_record = OnlineTime.select().where( - OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)) - ).order_by(OnlineTime.end_timestamp.desc()).first() + recent_record = ( + OnlineTime.select() + .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) + .order_by(OnlineTime.end_timestamp.desc()) + .first() + ) if recent_record: # 如果有记录,则更新结束时间 @@ -87,7 +90,6 @@ class OnlineTimeRecordTask(AsyncTask): logger.error(f"在线时间记录失败,错误信息:{e}") - def _format_online_time(online_seconds: int) -> str: """ 格式化在线时间 @@ -197,7 +199,7 @@ class StatisticOutputTask(AsyncTask): """ if not collect_period: return {} - + # 排序-按照时间段开始时间降序排列(最晚的时间段在前) collect_period.sort(key=lambda x: x[1], reverse=True) @@ -228,14 +230,14 @@ class StatisticOutputTask(AsyncTask): # Assuming LLMUsage.timestamp is a DateTimeField query_start_time = collect_period[-1][1] for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): - record_timestamp = record.timestamp # This is already a datetime object + record_timestamp = record.timestamp # This is already a datetime object for idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: for period_key, _ in collect_period[idx:]: stats[period_key][TOTAL_REQ_CNT] += 1 request_type = record.request_type or "unknown" - user_id = record.user_id or "unknown" # user_id is TextField, already string + user_id = record.user_id or "unknown" # user_id is TextField, already string model_name = record.model_name or "unknown" stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 @@ -275,7 +277,7 @@ class StatisticOutputTask(AsyncTask): """ if not collect_period: return {} - + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { @@ -300,7 +302,7 @@ class StatisticOutputTask(AsyncTask): for period_key, current_period_start_time in collect_period[idx:]: # Determine the portion of the record that falls within this specific statistical period overlap_start = max(record_start_timestamp, current_period_start_time) - overlap_end = effective_end_time # Already capped by 'now' and record's own end + overlap_end = effective_end_time # Already capped by 'now' and record's own end if overlap_end > overlap_start: stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() @@ -315,7 +317,7 @@ class StatisticOutputTask(AsyncTask): """ if not collect_period: return {} - + collect_period.sort(key=lambda x: x[1], reverse=True) stats = { @@ -326,9 +328,9 @@ class StatisticOutputTask(AsyncTask): for period_key, _ in collect_period } - query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) + query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) for message in Messages.select().where(Messages.time >= query_start_timestamp): - message_time_ts = message.time # This is a float timestamp + message_time_ts = message.time # This is a float timestamp chat_id = None chat_name = None @@ -337,16 +339,18 @@ class StatisticOutputTask(AsyncTask): if message.chat_info_group_id: chat_id = f"g{message.chat_info_group_id}" chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}" - elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat - # This uses the message SENDER's ID as per original logic's fallback - chat_id = f"u{message.user_id}" # SENDER's user_id - chat_name = message.user_nickname # SENDER's nickname + elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat + # This uses the message SENDER's ID as per original logic's fallback + chat_id = f"u{message.user_id}" # SENDER's user_id + chat_name = message.user_nickname # SENDER's nickname else: # If neither group_id nor sender_id is available for chat identification - logger.warning(f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats.") + logger.warning( + f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats." + ) continue - - if not chat_id: # Should not happen if above logic is correct + + if not chat_id: # Should not happen if above logic is correct continue # Update name_mapping diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 11e7bed06..ee5846031 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -35,13 +35,13 @@ class ImageManager: if not self._initialized: self._ensure_image_dir() self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image") - + try: db.connect(reuse_if_open=True) db.create_tables([Images, ImageDescriptions], safe=True) except Exception as e: logger.error(f"数据库连接或表创建失败: {e}") - + self._initialized = True def _ensure_image_dir(self): @@ -61,8 +61,7 @@ class ImageManager: """ try: record = ImageDescriptions.get_or_none( - (ImageDescriptions.hash == image_hash) & - (ImageDescriptions.type == description_type) + (ImageDescriptions.hash == image_hash) & (ImageDescriptions.type == description_type) ) return record.description if record else None except Exception as e: @@ -80,14 +79,9 @@ class ImageManager: """ try: current_timestamp = time.time() - defaults = { - 'description': description, - 'timestamp': current_timestamp - } + defaults = {"description": description, "timestamp": current_timestamp} desc_obj, created = ImageDescriptions.get_or_create( - hash=image_hash, - type=description_type, - defaults=defaults + hash=image_hash, type=description_type, defaults=defaults ) if not created: # 如果记录已存在,则更新 desc_obj.description = description @@ -120,7 +114,7 @@ class ImageManager: else: prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些" description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) - + if description is None: logger.warning("AI未能生成表情包描述") return "[表情包(描述生成失败)]" @@ -191,7 +185,7 @@ class ImageManager: "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多100个字。" ) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) - + if description is None: logger.warning("AI未能生成图片描述") return "[图片(描述生成失败)]" diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 89e047414..2f1406bbd 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -14,18 +14,21 @@ import datetime # db = MySQLDatabase('your_db_name', user='your_user', password='your_password', # host='localhost', port=3306) + # 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。 # 这允许您在一个地方为所有模型指定数据库。 class BaseModel(Model): class Meta: # 将下面的 'db' 替换为您实际的数据库实例变量名。 database = db # 例如: database = my_actual_db_instance - pass # 在用户定义数据库实例之前,此处为占位符 + pass # 在用户定义数据库实例之前,此处为占位符 + class ChatStreams(BaseModel): """ 用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。 """ + # stream_id: "a544edeb1a9b73e3e1d77dff36e41264" # 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。 stream_id = TextField(unique=True, index=True) @@ -63,28 +66,31 @@ class ChatStreams(BaseModel): # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # 如果不使用带有数据库实例的 BaseModel,或者想覆盖它, # 请取消注释并在下面设置数据库实例: - # database = db - table_name = 'chat_streams' # 可选:明确指定数据库中的表名 + # database = db + table_name = "chat_streams" # 可选:明确指定数据库中的表名 + class LLMUsage(BaseModel): """ 用于存储 API 使用日志数据的模型。 """ - model_name = TextField(index=True) # 添加索引 - user_id = TextField(index=True) # 添加索引 - request_type = TextField(index=True) # 添加索引 + + model_name = TextField(index=True) # 添加索引 + user_id = TextField(index=True) # 添加索引 + request_type = TextField(index=True) # 添加索引 endpoint = TextField() prompt_tokens = IntegerField() completion_tokens = IntegerField() total_tokens = IntegerField() cost = DoubleField() status = TextField() - timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 + timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 class Meta: # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 - # database = db - table_name = 'llm_usage' + # database = db + table_name = "llm_usage" + class Emoji(BaseModel): """表情包""" @@ -105,16 +111,18 @@ class Emoji(BaseModel): class Meta: # database = db # 继承自 BaseModel - table_name = 'emoji' + table_name = "emoji" + class Messages(BaseModel): """ 用于存储消息数据的模型。 """ - message_id = IntegerField(index=True) # 消息 ID - time = DoubleField() # 消息时间戳 - chat_id = TextField(index=True) # 对应的 ChatStreams stream_id + message_id = IntegerField(index=True) # 消息 ID + time = DoubleField() # 消息时间戳 + + chat_id = TextField(index=True) # 对应的 ChatStreams stream_id # 从 chat_info 扁平化而来的字段 chat_info_stream_id = TextField() @@ -123,7 +131,7 @@ class Messages(BaseModel): chat_info_user_id = TextField() chat_info_user_nickname = TextField() chat_info_user_cardname = TextField(null=True) - chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在 + chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在 chat_info_group_id = TextField(null=True) chat_info_group_name = TextField(null=True) chat_info_create_time = DoubleField() @@ -135,18 +143,20 @@ class Messages(BaseModel): user_nickname = TextField() user_cardname = TextField(null=True) - processed_plain_text = TextField(null=True) # 处理后的纯文本消息 - detailed_plain_text = TextField(null=True) # 详细的纯文本消息 - memorized_times = IntegerField(default=0) # 被记忆的次数 + processed_plain_text = TextField(null=True) # 处理后的纯文本消息 + detailed_plain_text = TextField(null=True) # 详细的纯文本消息 + memorized_times = IntegerField(default=0) # 被记忆的次数 class Meta: # database = db # 继承自 BaseModel - table_name = 'messages' + table_name = "messages" + class Images(BaseModel): """ 用于存储图像信息的模型。 """ + hash = TextField(index=True) # 图像的哈希值 description = TextField(null=True) # 图像的描述 path = TextField(unique=True) # 图像文件的路径 @@ -155,12 +165,14 @@ class Images(BaseModel): class Meta: # database = db # 继承自 BaseModel - table_name = 'images' + table_name = "images" + class ImageDescriptions(BaseModel): """ 用于存储图像描述信息的模型。 """ + type = TextField() # 类型,例如 "emoji" hash = TextField(index=True) # 图像的哈希值 description = TextField() # 图像的描述 @@ -168,12 +180,14 @@ class ImageDescriptions(BaseModel): class Meta: # database = db # 继承自 BaseModel - table_name = 'image_descriptions' + table_name = "image_descriptions" + class OnlineTime(BaseModel): """ 用于存储在线时长记录的模型。 """ + # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) timestamp = TextField() duration = IntegerField() # 时长,单位分钟 @@ -182,12 +196,14 @@ class OnlineTime(BaseModel): class Meta: # database = db # 继承自 BaseModel - table_name = 'online_time' + table_name = "online_time" + class PersonInfo(BaseModel): """ 用于存储个人信息数据的模型。 """ + person_id = TextField(unique=True, index=True) # 个人唯一ID person_name = TextField() # 个人名称 name_reason = TextField(null=True) # 名称设定的原因 @@ -202,26 +218,28 @@ class PersonInfo(BaseModel): class Meta: # database = db # 继承自 BaseModel - table_name = 'person_info' + table_name = "person_info" + class Knowledges(BaseModel): """ 用于存储知识库条目的模型。 """ + content = TextField() # 知识内容的文本 embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 # 可以添加其他元数据字段,如 source, create_time 等 class Meta: # database = db # 继承自 BaseModel - table_name = 'knowledges' + table_name = "knowledges" class ThinkingLog(BaseModel): chat_id = TextField(index=True) trigger_text = TextField(null=True) response_text = TextField(null=True) - + # Store complex dicts/lists as JSON strings trigger_info_json = TextField(null=True) response_info_json = TextField(null=True) @@ -235,28 +253,32 @@ class ThinkingLog(BaseModel): # Add a timestamp for the log entry itself # Ensure you have: from peewee import DateTimeField # And: import datetime - created_at = DateTimeField(default=datetime.datetime.now) + created_at = DateTimeField(default=datetime.datetime.now) class Meta: - table_name = 'thinking_logs' + table_name = "thinking_logs" + def create_tables(): """ 创建所有在模型中定义的数据库表。 """ with db: - db.create_tables([ - ChatStreams, - LLMUsage, - Emoji, - Messages, - Images, - ImageDescriptions, - OnlineTime, - PersonInfo, - Knowledges, - ThinkingLog - ]) + db.create_tables( + [ + ChatStreams, + LLMUsage, + Emoji, + Messages, + Images, + ImageDescriptions, + OnlineTime, + PersonInfo, + Knowledges, + ThinkingLog, + ] + ) + def initialize_database(): """ @@ -272,9 +294,9 @@ def initialize_database(): OnlineTime, PersonInfo, Knowledges, - ThinkingLog + ThinkingLog, ] - + needs_creation = False try: with db: # 管理 table_exists 检查的连接 @@ -298,5 +320,6 @@ def initialize_database(): else: print("所有数据库表均已存在。") + # 模块加载时调用初始化函数 initialize_database() diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 7d987ace9..fab9ab8b0 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,4 +1,4 @@ -from src.common.database.database_model import Messages # 更改导入 +from src.common.database.database_model import Messages # 更改导入 from src.common.logger import get_module_logger import traceback from typing import List, Any, Optional @@ -42,9 +42,7 @@ def find_messages( if hasattr(Messages, key): conditions.append(getattr(Messages, key) == value) else: - logger.warning( - f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。" - ) + logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") if conditions: # 使用 *conditions 将所有条件以 AND 连接 query = query.where(*conditions) @@ -59,9 +57,7 @@ def find_messages( query = query.order_by(Messages.time.desc()).limit(limit) latest_results_peewee = list(query) # 将结果按时间正序排列 - peewee_results = sorted( - latest_results_peewee, key=lambda msg: msg.time - ) + peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time) else: # limit 为 0 时,应用传入的 sort 参数 if sort: @@ -74,13 +70,9 @@ def find_messages( elif direction == -1: # DESC peewee_sort_terms.append(field.desc()) else: - logger.warning( - f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。" - ) + logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。") else: - logger.warning( - f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。" - ) + logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") if peewee_sort_terms: query = query.order_by(*peewee_sort_terms) peewee_results = list(query) @@ -116,9 +108,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: if hasattr(Messages, key): conditions.append(getattr(Messages, key) == value) else: - logger.warning( - f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。" - ) + logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") if conditions: query = query.where(*conditions) diff --git a/src/experimental/PFC/message_storage.py b/src/experimental/PFC/message_storage.py index 6e109fac3..e2e1dd052 100644 --- a/src/experimental/PFC/message_storage.py +++ b/src/experimental/PFC/message_storage.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any + # from src.common.database.database import db # Peewee db 导入 from src.common.database.database_model import Messages # Peewee Messages 模型导入 from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典 @@ -53,20 +54,23 @@ class PeeweeMessageStorage(MessageStorage): """Peewee消息存储实现""" async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: - query = Messages.select().where( - (Messages.chat_id == chat_id) & - (Messages.time > message_time) - ).order_by(Messages.time.asc()) + query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time > message_time)) + .order_by(Messages.time.asc()) + ) # print(f"storage_check_message: {message_time}") messages_models = list(query) return [model_to_dict(msg) for msg in messages_models] async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: - query = Messages.select().where( - (Messages.chat_id == chat_id) & - (Messages.time < time_point) - ).order_by(Messages.time.desc()).limit(limit) + query = ( + Messages.select() + .where((Messages.chat_id == chat_id) & (Messages.time < time_point)) + .order_by(Messages.time.desc()) + .limit(limit) + ) messages_models = list(query) # 将消息按时间正序排列 @@ -74,10 +78,7 @@ class PeeweeMessageStorage(MessageStorage): return [model_to_dict(msg) for msg in messages_models] async def has_new_messages(self, chat_id: str, after_time: float) -> bool: - return Messages.select().where( - (Messages.chat_id == chat_id) & - (Messages.time > after_time) - ).exists() + return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists() # # 创建一个内存消息存储实现,用于测试 diff --git a/src/tools/tool_can_use/get_knowledge.py b/src/tools/tool_can_use/get_knowledge.py index 4ff62b7c2..fd37f11e7 100644 --- a/src/tools/tool_can_use/get_knowledge.py +++ b/src/tools/tool_can_use/get_knowledge.py @@ -89,7 +89,9 @@ class SearchKnowledgeTool(BaseTool): logger.warning(f"Knowledge item ID {item.id} has empty embedding string.") continue item_embedding = json.loads(item_embedding_str) - if not isinstance(item_embedding, list) or not all(isinstance(x, (int, float)) for x in item_embedding): + if not isinstance(item_embedding, list) or not all( + isinstance(x, (int, float)) for x in item_embedding + ): logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.") continue except json.JSONDecodeError: From a242a4cb7ade3079aafac45e35b2eb230218bafb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Thu, 15 May 2025 10:09:18 +0800 Subject: [PATCH 10/57] Update src/common/database/database_model.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- src/common/database/database_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 2f1406bbd..a671aa58a 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -211,7 +211,7 @@ class PersonInfo(BaseModel): user_id = TextField(index=True) # 用户ID nickname = TextField() # 用户昵称 relationship_value = IntegerField(default=0) # 关系值 - konw_time = FloatField() # 认识时间 (时间戳) + know_time = FloatField() # 认识时间 (时间戳) msg_interval = IntegerField() # 消息间隔 # msg_interval_list: 存储为 JSON 字符串的列表 msg_interval_list = TextField(null=True) From 8775c6645410a32db9c71bc0809c044db84bbcc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Thu, 15 May 2025 10:10:12 +0800 Subject: [PATCH 11/57] Update src/chat/focus_chat/heartflow_prompt_builder.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- src/chat/focus_chat/heartflow_prompt_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 945f587bb..edf671232 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -613,7 +613,7 @@ class PromptBuilder: all_knowledges = Knowledges.select() if not all_knowledges: - return "" if not return_raw else [] + return [] if return_raw else "" query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding)) if query_embedding_magnitude == 0: # Avoid division by zero From a88b04ab0b3b18c2b89b2aabf5e2bd049c80fa46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Thu, 15 May 2025 10:10:30 +0800 Subject: [PATCH 12/57] Update src/common/message_repository.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- src/common/message_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/message_repository.py b/src/common/message_repository.py index fab9ab8b0..522e36ccf 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -115,7 +115,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: count = query.count() return count except Exception as e: - log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc() + log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" logger.error(log_message) return 0 From 224c1e3fb798e971d82760654db2f93f2b1b3ec0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Thu, 15 May 2025 10:10:46 +0800 Subject: [PATCH 13/57] Update src/common/message_repository.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- src/common/message_repository.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 522e36ccf..5bf77c1a9 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -77,8 +77,7 @@ def find_messages( query = query.order_by(*peewee_sort_terms) peewee_results = list(query) - results = [_model_to_dict(msg) for msg in peewee_results] - return results + return [_model_to_dict(msg) for msg in peewee_results] except Exception as e: log_message = ( f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" From 2be0130d23412e6b78f84350b47016e8a700a68b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Thu, 15 May 2025 10:24:25 +0800 Subject: [PATCH 14/57] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=AD=98=E5=82=A8=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BD=BF=E7=94=A8?= =?UTF-8?q?Peewee=E6=A8=A1=E5=9E=8B=E5=AD=98=E5=82=A8=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=92=8C=E6=92=A4=E5=9B=9E=E6=B6=88=E6=81=AF=EF=BC=8C=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=97=B6=E9=97=B4=E6=88=B3=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/storage.py | 76 ++++++++++++++++++--------- src/chat/utils/statistic.py | 3 ++ src/common/database/database_model.py | 31 ++++++++--- 3 files changed, 77 insertions(+), 33 deletions(-) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index eb6ea73df..094b640d7 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,9 +1,10 @@ import re from typing import Union -from ...common.database.database import db +# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance from .message import MessageSending, MessageRecv from .chat_stream import ChatStream +from ...common.database.database_model import Messages, RecalledMessages # Import Peewee models from src.common.logger import get_module_logger logger = get_module_logger("message_storage") @@ -29,42 +30,65 @@ class MessageStorage: else: filtered_detailed_plain_text = "" - message_data = { - "message_id": message.message_info.message_id, - "time": message.message_info.time, - "chat_id": chat_stream.stream_id, - "chat_info": chat_stream.to_dict(), - "user_info": message.message_info.user_info.to_dict(), - # 使用过滤后的文本 - "processed_plain_text": filtered_processed_plain_text, - "detailed_plain_text": filtered_detailed_plain_text, - "memorized_times": message.memorized_times, - } - db.messages.insert_one(message_data) + chat_info_dict = chat_stream.to_dict() + user_info_dict = message.message_info.user_info.to_dict() + + # Ensure message_id is an int if the model field is IntegerField + try: + msg_id = int(message.message_info.message_id) + except ValueError: + logger.error(f"Message ID {message.message_info.message_id} is not a valid integer. Storing as 0 or consider changing model field type.") + msg_id = 0 # Or handle as appropriate, e.g. skip storing, or change model field to TextField + + Messages.create( + message_id=msg_id, + time=float(message.message_info.time), + chat_id=chat_stream.stream_id, + # Flattened chat_info + chat_info_stream_id=chat_info_dict.get("stream_id"), + chat_info_platform=chat_info_dict.get("platform"), + chat_info_user_platform=chat_info_dict.get("user_info", {}).get("platform"), + chat_info_user_id=chat_info_dict.get("user_info", {}).get("user_id"), + chat_info_user_nickname=chat_info_dict.get("user_info", {}).get("user_nickname"), + chat_info_user_cardname=chat_info_dict.get("user_info", {}).get("user_cardname"), + chat_info_group_platform=chat_info_dict.get("group_info", {}).get("platform"), + chat_info_group_id=chat_info_dict.get("group_info", {}).get("group_id"), + chat_info_group_name=chat_info_dict.get("group_info", {}).get("group_name"), + chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)), + chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)), + # Flattened user_info (message sender) + user_platform=user_info_dict.get("platform"), + user_id=user_info_dict.get("user_id"), + user_nickname=user_info_dict.get("user_nickname"), + user_cardname=user_info_dict.get("user_cardname"), + # Text content + processed_plain_text=filtered_processed_plain_text, + detailed_plain_text=filtered_detailed_plain_text, + memorized_times=message.memorized_times, + ) except Exception: logger.exception("存储消息失败") @staticmethod async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None: """存储撤回消息到数据库""" - if "recalled_messages" not in db.list_collection_names(): - db.create_collection("recalled_messages") - else: - try: - message_data = { - "message_id": message_id, - "time": time, - "stream_id": chat_stream.stream_id, - } - db.recalled_messages.insert_one(message_data) - except Exception: - logger.exception("存储撤回消息失败") + # Table creation is handled by initialize_database in database_model.py + try: + RecalledMessages.create( + message_id=message_id, + time=float(time), # Assuming time is a string representing a float timestamp + stream_id=chat_stream.stream_id, + ) + except Exception: + logger.exception("存储撤回消息失败") @staticmethod async def remove_recalled_message(time: str) -> None: """删除撤回消息""" try: - db.recalled_messages.delete_many({"time": {"$lt": time - 300}}) + # Assuming input 'time' is a string timestamp that can be converted to float + current_time_float = float(time) + RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() except Exception: logger.exception("删除撤回消息失败") diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index cb202c520..a657ae85b 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -2,6 +2,7 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Any, Dict, Tuple, List + from src.common.logger import get_module_logger from src.manager.async_task_manager import AsyncTask @@ -82,8 +83,10 @@ class OnlineTimeRecordTask(AsyncTask): else: # 若没有记录,则插入新的在线时间记录 new_record = OnlineTime.create( + timestamp=current_time.timestamp(), # 添加此行 start_timestamp=current_time, end_timestamp=extended_end_time, + duration=5, # 初始时长为5分钟 ) self.record_id = new_record.id except Exception as e: diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index a671aa58a..b959c4e51 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -1,6 +1,8 @@ from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField from .database import db import datetime +from ..logger_manager import get_logger +logger = get_logger("database_model") # 请在此处定义您的数据库实例。 # 您需要取消注释并配置适合您的数据库的部分。 # 例如,对于 SQLite: @@ -189,7 +191,7 @@ class OnlineTime(BaseModel): """ # timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串) - timestamp = TextField() + timestamp = TextField(default=datetime.datetime.now) # 时间戳 duration = IntegerField() # 时长,单位分钟 start_timestamp = DateTimeField(default=datetime.datetime.now) end_timestamp = DateTimeField(index=True) @@ -259,6 +261,19 @@ class ThinkingLog(BaseModel): table_name = "thinking_logs" +class RecalledMessages(BaseModel): + """ + 用于存储撤回消息记录的模型。 + """ + + message_id = TextField(index=True) # 被撤回的消息 ID + time = DoubleField() # 撤回操作发生的时间戳 + stream_id = TextField() # 对应的 ChatStreams stream_id + + class Meta: + table_name = "recalled_messages" + + def create_tables(): """ 创建所有在模型中定义的数据库表。 @@ -276,6 +291,7 @@ def create_tables(): PersonInfo, Knowledges, ThinkingLog, + RecalledMessages, # 添加新模型 ] ) @@ -295,6 +311,7 @@ def initialize_database(): PersonInfo, Knowledges, ThinkingLog, + RecalledMessages, # 添加新模型 ] needs_creation = False @@ -302,23 +319,23 @@ def initialize_database(): with db: # 管理 table_exists 检查的连接 for model in models: if not db.table_exists(model): - print(f"表 '{model._meta.table_name}' 未找到。") + logger.warning(f"表 '{model._meta.table_name}' 未找到。") needs_creation = True break # 一个表丢失,无需进一步检查。 except Exception as e: - print(f"检查表是否存在时出错: {e}") + logger.exception(f"检查表是否存在时出错: {e}") # 如果检查失败(例如数据库不可用),则退出 return if needs_creation: - print("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...") + logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...") try: create_tables() # 此函数有其自己的 'with db:' 上下文管理。 - print("数据库表创建过程完成。") + logger.info("数据库表创建过程完成。") except Exception as e: - print(f"创建表期间出错: {e}") + logger.exception(f"创建表期间出错: {e}") else: - print("所有数据库表均已存在。") + logger.info("所有数据库表均已存在。") # 模块加载时调用初始化函数 From 4c2cfd5c739f48d635d326bfa028b7593cec44e3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 15 May 2025 02:24:38 +0000 Subject: [PATCH 15/57] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/storage.py | 4 +++- src/common/database/database_model.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 094b640d7..e81913549 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -37,7 +37,9 @@ class MessageStorage: try: msg_id = int(message.message_info.message_id) except ValueError: - logger.error(f"Message ID {message.message_info.message_id} is not a valid integer. Storing as 0 or consider changing model field type.") + logger.error( + f"Message ID {message.message_info.message_id} is not a valid integer. Storing as 0 or consider changing model field type." + ) msg_id = 0 # Or handle as appropriate, e.g. skip storing, or change model field to TextField Messages.create( diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index b959c4e51..d77184b74 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -2,6 +2,7 @@ from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, Fl from .database import db import datetime from ..logger_manager import get_logger + logger = get_logger("database_model") # 请在此处定义您的数据库实例。 # 您需要取消注释并配置适合您的数据库的部分。 From a18524ce61faed9257fe9d3b0e7de3c8d98bc002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Thu, 15 May 2025 19:03:47 +0800 Subject: [PATCH 16/57] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=AD=98=E5=82=A8=E5=92=8C=E6=9F=A5=E8=AF=A2=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E6=94=AF=E6=8C=81MongoDB=E9=A3=8E=E6=A0=BC=E7=9A=84?= =?UTF-8?q?=E6=93=8D=E4=BD=9C=E7=AC=A6=EF=BC=8C=E4=BF=AE=E6=94=B9=E6=B6=88?= =?UTF-8?q?=E6=81=AFID=E5=AD=97=E6=AE=B5=E7=B1=BB=E5=9E=8B=E4=B8=BATextFie?= =?UTF-8?q?ld?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/storage.py | 29 +++++++------- src/chat/person_info/person_info.py | 12 +++--- src/common/database/database_model.py | 4 +- src/common/message_repository.py | 57 ++++++++++++++++++++++++--- 4 files changed, 74 insertions(+), 28 deletions(-) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index e81913549..d0041cd51 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -33,14 +33,13 @@ class MessageStorage: chat_info_dict = chat_stream.to_dict() user_info_dict = message.message_info.user_info.to_dict() - # Ensure message_id is an int if the model field is IntegerField - try: - msg_id = int(message.message_info.message_id) - except ValueError: - logger.error( - f"Message ID {message.message_info.message_id} is not a valid integer. Storing as 0 or consider changing model field type." - ) - msg_id = 0 # Or handle as appropriate, e.g. skip storing, or change model field to TextField + # message_id 现在是 TextField,直接使用字符串值 + msg_id = message.message_info.message_id + + # 安全地获取 group_info, 如果为 None 则视为空字典 + group_info_from_chat = chat_info_dict.get("group_info") or {} + # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) + user_info_from_chat = chat_info_dict.get("user_info") or {} Messages.create( message_id=msg_id, @@ -49,13 +48,13 @@ class MessageStorage: # Flattened chat_info chat_info_stream_id=chat_info_dict.get("stream_id"), chat_info_platform=chat_info_dict.get("platform"), - chat_info_user_platform=chat_info_dict.get("user_info", {}).get("platform"), - chat_info_user_id=chat_info_dict.get("user_info", {}).get("user_id"), - chat_info_user_nickname=chat_info_dict.get("user_info", {}).get("user_nickname"), - chat_info_user_cardname=chat_info_dict.get("user_info", {}).get("user_cardname"), - chat_info_group_platform=chat_info_dict.get("group_info", {}).get("platform"), - chat_info_group_id=chat_info_dict.get("group_info", {}).get("group_id"), - chat_info_group_name=chat_info_dict.get("group_info", {}).get("group_name"), + chat_info_user_platform=user_info_from_chat.get("platform"), + chat_info_user_id=user_info_from_chat.get("user_id"), + chat_info_user_nickname=user_info_from_chat.get("user_nickname"), + chat_info_user_cardname=user_info_from_chat.get("user_cardname"), + chat_info_group_platform=group_info_from_chat.get("platform"), + chat_info_group_id=group_info_from_chat.get("group_id"), + chat_info_group_name=group_info_from_chat.get("group_name"), chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)), chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)), # Flattened user_info (message sender) diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index 1ad2f358f..9d0841c0e 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -39,13 +39,13 @@ logger = get_logger("person_info") person_info_default = { "person_id": None, - "person_name": None, + "person_name": None, # 模型中已设为 null=True,此默认值OK "name_reason": None, - "platform": None, - "user_id": None, - "nickname": None, + "platform": "unknown", # 提供非None的默认值 + "user_id": "unknown", # 提供非None的默认值 + "nickname": "Unknown", # 提供非None的默认值 "relationship_value": 0, - "konw_time": 0, + "know_time": 0, # 修正拼写:konw_time -> know_time "msg_interval": 2000, "msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField "user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中 @@ -561,7 +561,7 @@ class PersonInfoManager: "platform": platform, "user_id": str(user_id), "nickname": nickname, - "konw_time": int(datetime.datetime.now().timestamp()), + "know_time": int(datetime.datetime.now().timestamp()), # 修正拼写:konw_time -> know_time } model_fields = PersonInfo._meta.fields.keys() filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index d77184b74..35f464b5f 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -122,7 +122,7 @@ class Messages(BaseModel): 用于存储消息数据的模型。 """ - message_id = IntegerField(index=True) # 消息 ID + message_id = TextField(index=True) # 消息 ID (更改自 IntegerField) time = DoubleField() # 消息时间戳 chat_id = TextField(index=True) # 对应的 ChatStreams stream_id @@ -208,7 +208,7 @@ class PersonInfo(BaseModel): """ person_id = TextField(unique=True, index=True) # 个人唯一ID - person_name = TextField() # 个人名称 + person_name = TextField(null=True) # 个人名称 (允许为空) name_reason = TextField(null=True) # 名称设定的原因 platform = TextField() # 平台 user_id = TextField(index=True) # 用户ID diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 5bf77c1a9..b1fb5dc46 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -24,7 +24,7 @@ def find_messages( 根据提供的过滤器、排序和限制条件查找消息。 Args: - message_filter: 查询过滤器字典,键为模型字段名,值为期望值。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。 limit: 返回的最大文档数,0表示不限制。 limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。 @@ -40,11 +40,34 @@ def find_messages( conditions = [] for key, value in message_filter.items(): if hasattr(Messages, key): - conditions.append(getattr(Messages, key) == value) + field = getattr(Messages, key) + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning( + f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" + ) + else: + # 直接相等比较 + conditions.append(field == value) else: logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") if conditions: - # 使用 *conditions 将所有条件以 AND 连接 query = query.where(*conditions) if limit > 0: @@ -92,7 +115,7 @@ def count_messages(message_filter: dict[str, Any]) -> int: 根据提供的过滤器计算消息数量。 Args: - message_filter: 查询过滤器字典,键为模型字段名,值为期望值。 + message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}). Returns: 符合条件的消息数量,如果出错则返回 0。 @@ -105,7 +128,31 @@ def count_messages(message_filter: dict[str, Any]) -> int: conditions = [] for key, value in message_filter.items(): if hasattr(Messages, key): - conditions.append(getattr(Messages, key) == value) + field = getattr(Messages, key) + if isinstance(value, dict): + # 处理 MongoDB 风格的操作符 + for op, op_value in value.items(): + if op == "$gt": + conditions.append(field > op_value) + elif op == "$lt": + conditions.append(field < op_value) + elif op == "$gte": + conditions.append(field >= op_value) + elif op == "$lte": + conditions.append(field <= op_value) + elif op == "$ne": + conditions.append(field != op_value) + elif op == "$in": + conditions.append(field.in_(op_value)) + elif op == "$nin": + conditions.append(field.not_in(op_value)) + else: + logger.warning( + f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" + ) + else: + # 直接相等比较 + conditions.append(field == value) else: logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") if conditions: From 9965997139f31bdf539de5a4d5c172b1a9a1229b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 15 May 2025 11:04:04 +0000 Subject: [PATCH 17/57] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/message_repository.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/common/message_repository.py b/src/common/message_repository.py index b1fb5dc46..ee69b22b0 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -59,9 +59,7 @@ def find_messages( elif op == "$nin": conditions.append(field.not_in(op_value)) else: - logger.warning( - f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" - ) + logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") else: # 直接相等比较 conditions.append(field == value) From cda9879bb2afb1e6f7e9d5fbed2f3993dc53a826 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 16 May 2025 00:43:46 +0800 Subject: [PATCH 18/57] =?UTF-8?q?Feat=EF=BC=9A=E6=B7=BB=E5=8A=A0=E5=AF=B9A?= =?UTF-8?q?ction=E6=8F=92=E4=BB=B6=E7=9A=84=E6=94=AF=E6=8C=81=EF=BC=8C?= =?UTF-8?q?=E7=8E=B0=E5=9C=A8=E5=8F=AF=E4=BB=A5=E7=BC=96=E5=86=99=E6=8F=92?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | Bin 800 -> 824 bytes .../expressors/default_expressor.py | 221 +++++++++++++++++- src/chat/focus_chat/heartFC_chat.py | 57 ++--- src/chat/focus_chat/heartFC_sender.py | 1 + .../focus_chat/heartflow_prompt_builder.py | 205 +--------------- src/chat/focus_chat/info/info_base.py | 9 + .../info_processors/chattinginfo_processor.py | 2 + .../info_processors/mind_processor.py | 204 ++-------------- .../info_processors/processor_utils.py | 56 ----- .../info_processors/tool_processor.py | 2 +- .../{action_factory.py => action_manager.py} | 160 ++++++++----- .../focus_chat/planners/actions/__init__.py | 5 + .../planners/actions/base_action.py | 33 ++- .../planners/actions/no_reply_action.py | 40 +--- .../planners/actions/plugin_action.py | 215 +++++++++++++++++ .../planners/actions/reply_action.py | 36 +-- src/chat/focus_chat/planners/planner.py | 71 +++--- .../observation/chatting_observation.py | 6 +- .../observation/hfcloop_observation.py | 21 +- .../heart_flow/observation/observation.py | 1 - src/chat/person_info/person_info.py | 9 + src/plugins.md | 101 ++++++++ src/plugins/__init__.py | 1 + src/plugins/test_plugin/__init__.py | 4 + src/plugins/test_plugin/actions/__init__.py | 6 + .../test_plugin/actions/mute_action.py | 48 ++++ .../test_plugin/actions/online_action.py | 44 ++++ .../test_plugin/actions/test_action.py | 38 +++ 28 files changed, 934 insertions(+), 662 deletions(-) delete mode 100644 src/chat/focus_chat/info_processors/processor_utils.py rename src/chat/focus_chat/planners/{action_factory.py => action_manager.py} (75%) create mode 100644 src/chat/focus_chat/planners/actions/__init__.py create mode 100644 src/chat/focus_chat/planners/actions/plugin_action.py create mode 100644 src/plugins.md create mode 100644 src/plugins/__init__.py create mode 100644 src/plugins/test_plugin/__init__.py create mode 100644 src/plugins/test_plugin/actions/__init__.py create mode 100644 src/plugins/test_plugin/actions/mute_action.py create mode 100644 src/plugins/test_plugin/actions/online_action.py create mode 100644 src/plugins/test_plugin/actions/test_action.py diff --git a/requirements.txt b/requirements.txt index 7abdffb486a90fb45bc11dd3bcc1cf12b15e4afe..1e374f4eb463443ac5696797afd2e0760bde01a5 100644 GIT binary patch delta 32 mcmZ3$wu5bh0<#1cLl#3ZLq0akOcsGA_l|& delta 7 OcmdnNwt#Je0y6*#z5+M^ diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index 37c50c0dc..411b08a05 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -10,7 +10,7 @@ from src.config.config import global_config from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.emoji_system.emoji_manager import emoji_manager -from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder +from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder,Prompt from src.chat.focus_chat.heartFC_sender import HeartFCSender from src.chat.utils.utils import process_llm_response from src.chat.utils.info_catcher import info_catcher_manager @@ -18,9 +18,70 @@ from src.manager.mood_manager import mood_manager from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.message_receive.chat_stream import ChatStream from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp +from src.config.config import global_config +from src.common.logger_manager import get_logger +from src.individuality.individuality import Individuality +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat +from src.chat.person_info.relationship_manager import relationship_manager +from src.chat.utils.utils import get_embedding +import time +from typing import Union, Optional +from src.common.database import db +from src.chat.utils.utils import get_recent_group_speaker +from src.manager.mood_manager import mood_manager +from src.chat.memory_system.Hippocampus import HippocampusManager +from src.chat.knowledge.knowledge_lib import qa_manager +from src.chat.focus_chat.expressors.exprssion_learner import expression_learner +import random logger = get_logger("expressor") +def init_prompt(): + Prompt( + """ +你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: +{style_habbits} + +你现在正在群里聊天,以下是群里正在进行的聊天内容: +{chat_info} + +以上是聊天内容,你需要了解聊天记录中的内容 + +{chat_target} +你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复 +你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。 +请你根据情景使用以下句法: +{grammar_habbits} +回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。 +回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。 +现在,你说: +""", + "default_expressor_prompt", + ) + + Prompt( + """ +你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: +{style_habbits} + +你现在正在群里聊天,以下是群里正在进行的聊天内容: +{chat_info} + +以上是聊天内容,你需要了解聊天记录中的内容 + +{chat_target} +你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复 +你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。 +请你根据情景使用以下句法: +{grammar_habbits} +回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。 +回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。 +现在,你说: +""", + "default_expressor_private_prompt", # New template for private FOCUSED chat + ) + class DefaultExpressor: def __init__(self, chat_id: str): @@ -106,7 +167,7 @@ class DefaultExpressor: if reply: with Timer("发送消息", cycle_timers): - sent_msg_list = await self._send_response_messages( + sent_msg_list = await self.send_response_messages( anchor_message=anchor_message, thinking_id=thinking_id, response_set=reply, @@ -162,13 +223,10 @@ class DefaultExpressor: # 3. 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 - prompt = await prompt_builder.build_prompt( - build_mode="focus", + prompt = await self.build_prompt_focus( chat_stream=self.chat_stream, # Pass the stream object in_mind_reply=in_mind_reply, reason=reason, - current_mind_info="", - structured_info="", sender_name=sender_name_for_prompt, # Pass determined name target_message=target_message, ) @@ -222,11 +280,111 @@ class DefaultExpressor: logger.error(f"{self.log_prefix}回复生成意外失败: {e}") traceback.print_exc() return None + + async def build_prompt_focus( + self, + reason, + chat_stream, + sender_name, + in_mind_reply, + target_message, + ) -> str: + individuality = Individuality.get_instance() + prompt_personality = individuality.get_prompt(x_person=0, level=2) + + # Determine if it's a group chat + is_group_chat = bool(chat_stream.group_info) + + # Use sender_name passed from caller for private chat, otherwise use a default for group + # Default sender_name for group chat isn't used in the group prompt template, but set for consistency + effective_sender_name = sender_name if not is_group_chat else "某人" + + message_list_before_now = get_raw_msg_before_timestamp_with_chat( + chat_id=chat_stream.stream_id, + timestamp=time.time(), + limit=global_config.observation_context_size, + ) + chat_talking_prompt = await build_readable_messages( + message_list_before_now, + replace_bot_name=True, + merge_messages=True, + timestamp_mode="relative", + read_mark=0.0, + truncate=True, + ) + + ( + learnt_style_expressions, + learnt_grammar_expressions, + personality_expressions, + ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id) + + style_habbits = [] + grammar_habbits = [] + # 1. learnt_expressions加权随机选3条 + if learnt_style_expressions: + weights = [expr["count"] for expr in learnt_style_expressions] + selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3) + for expr in selected_learnt: + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + # 2. learnt_grammar_expressions加权随机选3条 + if learnt_grammar_expressions: + weights = [expr["count"] for expr in learnt_grammar_expressions] + selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3) + for expr in selected_learnt: + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + # 3. personality_expressions随机选1条 + if personality_expressions: + expr = random.choice(personality_expressions) + if isinstance(expr, dict) and "situation" in expr and "style" in expr: + style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") + + style_habbits_str = "\n".join(style_habbits) + grammar_habbits_str = "\n".join(grammar_habbits) + + logger.debug("开始构建 focus prompt") + + # --- Choose template based on chat type --- + if is_group_chat: + template_name = "default_expressor_prompt" + # Group specific formatting variables (already fetched or default) + chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") + # chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") + + prompt = await global_prompt_manager.format_prompt( + template_name, + style_habbits=style_habbits_str, + grammar_habbits=grammar_habbits_str, + chat_target=chat_target_1, + chat_info=chat_talking_prompt, + bot_name=global_config.BOT_NICKNAME, + prompt_personality="", + reason=reason, + in_mind_reply=in_mind_reply, + target_message=target_message, + ) + else: # Private chat + template_name = "default_expressor_private_prompt" + prompt = await global_prompt_manager.format_prompt( + template_name, + sender_name=effective_sender_name, # Used in private template + chat_talking_prompt=chat_talking_prompt, + bot_name=global_config.BOT_NICKNAME, + prompt_personality=prompt_personality, + reason=reason, + moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), + ) + + + return prompt + # --- 发送器 (Sender) --- # - async def _send_response_messages( - self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str + async def send_response_messages( + self, anchor_message: Optional[MessageRecv], response_set: List[Tuple[str, str]], thinking_id: str = "" ) -> Optional[MessageSending]: """发送回复消息 (尝试锚定到 anchor_message),使用 HeartFCSender""" chat = self.chat_stream @@ -241,7 +399,11 @@ class DefaultExpressor: stream_name = chat_manager.get_stream_name(chat_id) or chat_id # 获取流名称用于日志 # 检查思考过程是否仍在进行,并获取开始时间 - thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id) + if thinking_id: + thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id) + else: + thinking_id = "ds"+ str(round(time.time(),2)) + thinking_start_time = time.time() if thinking_start_time is None: logger.error(f"[{stream_name}]思考过程未找到或已结束,无法发送回复。") @@ -274,6 +436,7 @@ class DefaultExpressor: reply_to=reply_to, is_emoji=is_emoji, thinking_id=thinking_id, + thinking_start_time=thinking_start_time, ) try: @@ -295,6 +458,7 @@ class DefaultExpressor: except Exception as e: logger.error(f"{self.log_prefix}发送回复片段 {i} ({part_message_id}) 时失败: {e}") + traceback.print_exc() # 这里可以选择是继续发送下一个片段还是中止 # 在尝试发送完所有片段后,完成原始的 thinking_id 状态 @@ -325,10 +489,10 @@ class DefaultExpressor: reply_to: bool, is_emoji: bool, thinking_id: str, + thinking_start_time: float, ) -> MessageSending: """构建单个发送消息""" - thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(self.chat_id, thinking_id) bot_user_info = UserInfo( user_id=global_config.BOT_QQ, user_nickname=global_config.BOT_NICKNAME, @@ -348,3 +512,40 @@ class DefaultExpressor: ) return bot_message + + + +def weighted_sample_no_replacement(items, weights, k) -> list: + """ + 加权且不放回地随机抽取k个元素。 + + 参数: + items: 待抽取的元素列表 + weights: 每个元素对应的权重(与items等长,且为正数) + k: 需要抽取的元素个数 + 返回: + selected: 按权重加权且不重复抽取的k个元素组成的列表 + + 如果 items 中的元素不足 k 个,就只会返回所有可用的元素 + + 实现思路: + 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。 + 这样保证了: + 1. count越大被选中概率越高 + 2. 不会重复选中同一个元素 + """ + selected = [] + pool = list(zip(items, weights)) + for _ in range(min(k, len(pool))): + total = sum(w for _, w in pool) + r = random.uniform(0, total) + upto = 0 + for idx, (item, weight) in enumerate(pool): + upto += weight + if upto >= r: + selected.append(item) + pool.pop(idx) + break + return selected + +init_prompt() \ No newline at end of file diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 4a28652d1..7a1671897 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -14,16 +14,17 @@ from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info_processors.chattinginfo_processor import ChattingInfoProcessor from src.chat.focus_chat.info_processors.mind_processor import MindProcessor -from src.chat.heart_flow.observation.memory_observation import MemoryObservation +from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation -from src.chat.heart_flow.observation.working_observation import WorkingObservation +from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.focus_chat.memory_activator import MemoryActivator from src.chat.focus_chat.info_processors.base_processor import BaseProcessor from src.chat.focus_chat.planners.planner import ActionPlanner -from src.chat.focus_chat.planners.action_factory import ActionManager - +from src.chat.focus_chat.planners.action_manager import ActionManager +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory install(extra_lines=3) @@ -57,7 +58,7 @@ async def _handle_cycle_delay(action_taken_this_cycle: bool, cycle_start_time: f class HeartFChatting: """ - 管理一个连续的Plan-Replier-Sender循环 + 管理一个连续的Focus Chat循环 用于在特定聊天流中生成回复。 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。 """ @@ -79,19 +80,22 @@ class HeartFChatting: # 基础属性 self.stream_id: str = chat_id # 聊天流ID self.chat_stream: Optional[ChatStream] = None # 关联的聊天流 - self.observations: List[Observation] = observations # 关联的观察列表,用于监控聊天流状态 self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback self.log_prefix: str = str(chat_id) # Initial default, will be updated - - self.memory_observation = MemoryObservation(observe_id=self.stream_id) self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id) - self.working_observation = WorkingObservation(observe_id=self.stream_id) + self.chatting_observation = observations[0] + self.memory_activator = MemoryActivator() + self.working_memory = WorkingMemory(chat_id=self.stream_id) + self.working_observation = WorkingMemoryObservation(observe_id=self.stream_id, working_memory=self.working_memory) + self.expressor = DefaultExpressor(chat_id=self.stream_id) self.action_manager = ActionManager() self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager) - - + + self.hfcloop_observation.set_action_manager(self.action_manager) + + self.all_observations = observations # --- 处理器列表 --- self.processors: List[BaseProcessor] = [] self._register_default_processors() @@ -108,9 +112,7 @@ class HeartFChatting: self._cycle_counter = 0 self._cycle_history: Deque[CycleDetail] = deque(maxlen=10) # 保留最近10个循环的信息 self._current_cycle: Optional[CycleDetail] = None - self.total_no_reply_count: int = 0 # 连续不回复计数器 self._shutting_down: bool = False # 关闭标志位 - self.total_waiting_time: float = 0.0 # 累计等待时间 async def _initialize(self) -> bool: """ @@ -151,6 +153,7 @@ class HeartFChatting: self.processors.append(ChattingInfoProcessor()) self.processors.append(MindProcessor(subheartflow_id=self.stream_id)) self.processors.append(ToolProcessor(subheartflow_id=self.stream_id)) + self.processors.append(WorkingMemoryProcessor(subheartflow_id=self.stream_id)) logger.info(f"{self.log_prefix} 已注册默认处理器: {[p.__class__.__name__ for p in self.processors]}") async def start(self): @@ -349,13 +352,12 @@ class HeartFChatting: async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> tuple[bool, str]: try: with Timer("观察", cycle_timers): - await self.observations[0].observe() - await self.memory_observation.observe() + # await self.observations[0].observe() + await self.chatting_observation.observe() await self.working_observation.observe() await self.hfcloop_observation.observe() observations: List[Observation] = [] - observations.append(self.observations[0]) - observations.append(self.memory_observation) + observations.append(self.chatting_observation) observations.append(self.working_observation) observations.append(self.hfcloop_observation) @@ -363,6 +365,8 @@ class HeartFChatting: "observations": observations, } + self.all_observations = observations + with Timer("回忆", cycle_timers): running_memorys = await self.memory_activator.activate_memory(observations) @@ -395,8 +399,7 @@ class HeartFChatting: elif action_type == "no_reply": action_str = "不回复" else: - action_type = "unknown" - action_str = "未知动作" + action_str = action_type logger.info(f"{self.log_prefix} 麦麦决定'{action_str}', 原因'{reasoning}'") @@ -452,14 +455,14 @@ class HeartFChatting: reasoning=reasoning, cycle_timers=cycle_timers, thinking_id=thinking_id, - observations=self.observations, + observations=self.all_observations, expressor=self.expressor, chat_stream=self.chat_stream, current_cycle=self._current_cycle, log_prefix=self.log_prefix, on_consecutive_no_reply_callback=self.on_consecutive_no_reply_callback, - total_no_reply_count=self.total_no_reply_count, - total_waiting_time=self.total_waiting_time, + # total_no_reply_count=self.total_no_reply_count, + # total_waiting_time=self.total_waiting_time, shutting_down=self._shutting_down, ) @@ -470,14 +473,6 @@ class HeartFChatting: # 处理动作并获取结果 success, reply_text = await action_handler.handle_action() - # 更新状态计数器 - if action == "no_reply": - self.total_no_reply_count = getattr(action_handler, "total_no_reply_count", self.total_no_reply_count) - self.total_waiting_time = getattr(action_handler, "total_waiting_time", self.total_waiting_time) - elif action == "reply": - self.total_no_reply_count = 0 - self.total_waiting_time = 0.0 - return success, reply_text except Exception as e: @@ -526,5 +521,3 @@ class HeartFChatting: if last_n is not None: history = history[-last_n:] return [cycle.to_dict() for cycle in history] - - diff --git a/src/chat/focus_chat/heartFC_sender.py b/src/chat/focus_chat/heartFC_sender.py index 057668579..81d463b02 100644 --- a/src/chat/focus_chat/heartFC_sender.py +++ b/src/chat/focus_chat/heartFC_sender.py @@ -106,6 +106,7 @@ class HeartFCSender: and not message.is_private_message() and message.reply.processed_plain_text != "[System Trigger Context]" ): + message.set_reply(message.reply) logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...") await message.process() diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 55fb79b46..830a1cfad 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -6,14 +6,13 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time -from typing import Union, Optional, Dict, Any +from typing import Union, Optional from src.common.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.focus_chat.expressors.exprssion_learner import expression_learner -import traceback import random @@ -21,27 +20,6 @@ logger = get_logger("prompt") def init_prompt(): - Prompt( - """ -你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: -{style_habbits} - -你现在正在群里聊天,以下是群里正在进行的聊天内容: -{chat_info} - -以上是聊天内容,你需要了解聊天记录中的内容 - -{chat_target} -你的名字是{bot_name},{prompt_personality},在这聊天中,"{target_message}"引起了你的注意,对这句话,你想表达:{in_mind_reply},原因是:{reason}。你现在要思考怎么回复 -你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。 -请你根据情景使用以下句法: -{grammar_habbits} -回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,你可以完全重组回复,保留最基本的表达含义就好,但注意回复要简短,但重组后保持语意通顺。 -回复不要浮夸,不要用夸张修辞,平淡一些。不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。 -现在,你说: -""", - "heart_flow_prompt", - ) Prompt( """ @@ -82,29 +60,6 @@ def init_prompt(): Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") - # --- Template for HeartFChatting (FOCUSED mode) --- - Prompt( - """ -{info_from_tools} -你正在和 {sender_name} 私聊。 -聊天记录如下: -{chat_talking_prompt} -现在你想要回复。 - -你需要扮演一位网名叫{bot_name}的人进行回复,这个人的特点是:"{prompt_personality}"。 -你正在和 {sender_name} 私聊, 现在请你读读你们之前的聊天记录,然后给出日常且口语化的回复,平淡一些。 -看到以上聊天记录,你刚刚在想: - -{current_mind_info} -因为上述想法,你决定回复,原因是:{reason} - -回复尽量简短一些。请注意把握聊天内容,{reply_style2}。{prompt_ger},不要复读自己说的话 -{reply_style1},说中文,不要刻意突出自身学科背景,注意只输出回复内容。 -{moderation_prompt}。注意:回复不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""", - "heart_flow_private_prompt", # New template for private FOCUSED chat - ) - - # --- Template for NormalChat (CHAT mode) --- Prompt( """ {memory_prompt} @@ -126,118 +81,6 @@ def init_prompt(): ) -async def _build_prompt_focus( - reason, current_mind_info, structured_info, chat_stream, sender_name, in_mind_reply, target_message -) -> str: - individuality = Individuality.get_instance() - prompt_personality = individuality.get_prompt(x_person=0, level=2) - - # Determine if it's a group chat - is_group_chat = bool(chat_stream.group_info) - - # Use sender_name passed from caller for private chat, otherwise use a default for group - # Default sender_name for group chat isn't used in the group prompt template, but set for consistency - effective_sender_name = sender_name if not is_group_chat else "某人" - - message_list_before_now = get_raw_msg_before_timestamp_with_chat( - chat_id=chat_stream.stream_id, - timestamp=time.time(), - limit=global_config.observation_context_size, - ) - chat_talking_prompt = await build_readable_messages( - message_list_before_now, - replace_bot_name=True, - merge_messages=True, - timestamp_mode="relative", - read_mark=0.0, - truncate=True, - ) - - if structured_info: - structured_info_prompt = await global_prompt_manager.format_prompt( - "info_from_tools", structured_info=structured_info - ) - else: - structured_info_prompt = "" - - # 从/data/expression/对应chat_id/expressions.json中读取表达方式 - ( - learnt_style_expressions, - learnt_grammar_expressions, - personality_expressions, - ) = await expression_learner.get_expression_by_chat_id(chat_stream.stream_id) - - style_habbits = [] - grammar_habbits = [] - # 1. learnt_expressions加权随机选3条 - if learnt_style_expressions: - weights = [expr["count"] for expr in learnt_style_expressions] - selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 3) - for expr in selected_learnt: - if isinstance(expr, dict) and "situation" in expr and "style" in expr: - style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") - # 2. learnt_grammar_expressions加权随机选3条 - if learnt_grammar_expressions: - weights = [expr["count"] for expr in learnt_grammar_expressions] - selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 3) - for expr in selected_learnt: - if isinstance(expr, dict) and "situation" in expr and "style" in expr: - grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") - # 3. personality_expressions随机选1条 - if personality_expressions: - expr = random.choice(personality_expressions) - if isinstance(expr, dict) and "situation" in expr and "style" in expr: - style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") - - style_habbits_str = "\n".join(style_habbits) - grammar_habbits_str = "\n".join(grammar_habbits) - - logger.debug("开始构建 focus prompt") - - # --- Choose template based on chat type --- - if is_group_chat: - template_name = "heart_flow_prompt" - # Group specific formatting variables (already fetched or default) - chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1") - # chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2") - - prompt = await global_prompt_manager.format_prompt( - template_name, - # info_from_tools=structured_info_prompt, - style_habbits=style_habbits_str, - grammar_habbits=grammar_habbits_str, - chat_target=chat_target_1, # Used in group template - # chat_talking_prompt=chat_talking_prompt, - chat_info=chat_talking_prompt, - bot_name=global_config.BOT_NICKNAME, - # prompt_personality=prompt_personality, - prompt_personality="", - reason=reason, - in_mind_reply=in_mind_reply, - target_message=target_message, - # moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), - # sender_name is not used in the group template - ) - else: # Private chat - template_name = "heart_flow_private_prompt" - prompt = await global_prompt_manager.format_prompt( - template_name, - info_from_tools=structured_info_prompt, - sender_name=effective_sender_name, # Used in private template - chat_talking_prompt=chat_talking_prompt, - bot_name=global_config.BOT_NICKNAME, - prompt_personality=prompt_personality, - # chat_target and chat_target_2 are not used in private template - current_mind_info=current_mind_info, - reason=reason, - moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), - ) - # --- End choosing template --- - - # logger.debug(f"focus_chat_prompt (is_group={is_group_chat}): \n{prompt}") - return prompt - - class PromptBuilder: def __init__(self): self.prompt_built = "" @@ -257,17 +100,6 @@ class PromptBuilder: ) -> Optional[str]: if build_mode == "normal": return await self._build_prompt_normal(chat_stream, message_txt or "", sender_name) - - elif build_mode == "focus": - return await _build_prompt_focus( - reason, - current_mind_info, - structured_info, - chat_stream, - sender_name, - in_mind_reply, - target_message, - ) return None async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> str: @@ -689,40 +521,5 @@ class PromptBuilder: # 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in results) - -def weighted_sample_no_replacement(items, weights, k) -> list: - """ - 加权且不放回地随机抽取k个元素。 - - 参数: - items: 待抽取的元素列表 - weights: 每个元素对应的权重(与items等长,且为正数) - k: 需要抽取的元素个数 - 返回: - selected: 按权重加权且不重复抽取的k个元素组成的列表 - - 如果 items 中的元素不足 k 个,就只会返回所有可用的元素 - - 实现思路: - 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。 - 这样保证了: - 1. count越大被选中概率越高 - 2. 不会重复选中同一个元素 - """ - selected = [] - pool = list(zip(items, weights)) - for _ in range(min(k, len(pool))): - total = sum(w for _, w in pool) - r = random.uniform(0, total) - upto = 0 - for idx, (item, weight) in enumerate(pool): - upto += weight - if upto >= r: - selected.append(item) - pool.pop(idx) - break - return selected - - init_prompt() prompt_builder = PromptBuilder() diff --git a/src/chat/focus_chat/info/info_base.py b/src/chat/focus_chat/info/info_base.py index 7779d913a..fbf060ba6 100644 --- a/src/chat/focus_chat/info/info_base.py +++ b/src/chat/focus_chat/info/info_base.py @@ -17,6 +17,7 @@ class InfoBase: type: str = "base" data: Dict[str, Any] = field(default_factory=dict) + processed_info:str = "" def get_type(self) -> str: """获取信息类型 @@ -58,3 +59,11 @@ class InfoBase: if isinstance(value, list): return value return [] + + def get_processed_info(self) -> str: + """获取处理后的信息 + + Returns: + str: 处理后的信息字符串 + """ + return self.processed_info diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index 12bc8560a..0accc2a34 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -54,6 +54,8 @@ class ChattingInfoProcessor(BaseProcessor): for obs in observations: # print(f"obs: {obs}") if isinstance(obs, ChattingObservation): + # print("1111111111111111111111读取111111111111111") + obs_info = ObsInfo() await self.chat_compress(obs) diff --git a/src/chat/focus_chat/info_processors/mind_processor.py b/src/chat/focus_chat/info_processors/mind_processor.py index 1a104e123..95233a9f7 100644 --- a/src/chat/focus_chat/info_processors/mind_processor.py +++ b/src/chat/focus_chat/info_processors/mind_processor.py @@ -16,11 +16,6 @@ from .base_processor import BaseProcessor from src.chat.focus_chat.info.mind_info import MindInfo from typing import List, Optional from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation -from src.chat.focus_chat.info_processors.processor_utils import ( - calculate_similarity, - calculate_replacement_probability, - get_spark, -) from typing import Dict from src.chat.focus_chat.info.info_base import InfoBase @@ -28,7 +23,6 @@ logger = get_logger("processor") def init_prompt(): - # --- Group Chat Prompt --- group_prompt = """ 你的名字是{bot_name} {memory_str} @@ -44,31 +38,29 @@ def init_prompt(): 现在请你继续输出观察和规划,输出要求: 1. 先关注未读新消息的内容和近期回复历史 2. 根据新信息,修改和删除之前的观察和规划 -3. 根据聊天内容继续输出观察和规划,{hf_do_next} +3. 根据聊天内容继续输出观察和规划 4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。 6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好""" Prompt(group_prompt, "sub_heartflow_prompt_before") - # --- Private Chat Prompt --- private_prompt = """ +你的名字是{bot_name} {memory_str} {extra_info} {relation_prompt} -你的名字是{bot_name},{prompt_personality},你现在{mood_info} {cycle_info_block} -现在是{time_now},你正在上网,和 {chat_target_name} 私聊,以下是你们的聊天内容: +现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容: {chat_observe_info} -以下是你之前对聊天的观察和规划: + +以下是你之前对聊天的观察和规划,你的名字是{bot_name}: {last_mind} -请仔细阅读聊天内容,想想你和 {chat_target_name} 的关系,回顾你们刚刚的交流,你刚刚发言和对方的反应,思考聊天的主题。 -请思考你要不要回复以及如何回复对方。 -思考并输出你的内心想法 -输出要求: -1. 根据聊天内容生成你的想法,{hf_do_next} -2. 不要分点、不要使用表情符号 -3. 避免多余符号(冒号、引号、括号等) -4. 语言简洁自然,不要浮夸 -5. 如果你刚发言,对方没有回复你,请谨慎回复""" + +现在请你继续输出观察和规划,输出要求: +1. 先关注未读新消息的内容和近期回复历史 +2. 根据新信息,修改和删除之前的观察和规划 +3. 根据聊天内容继续输出观察和规划 +4. 注意群聊的时间线索,话题由谁发起,进展状况如何,思考聊天的时间线。 +6. 语言简洁自然,不要分点,不要浮夸,不要修辞,仅输出思考内容就好""" Prompt(private_prompt, "sub_heartflow_prompt_private_before") @@ -210,45 +202,28 @@ class MindProcessor(BaseProcessor): for person in person_list: relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True) - # 构建个性部分 - # prompt_personality = individuality.get_prompt(x_person=2, level=2) - # 获取当前时间 - time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - - spark_prompt = get_spark() - - # ---------- 5. 构建最终提示词 ---------- template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before" logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板") prompt = (await global_prompt_manager.get_prompt_async(template_name)).format( + bot_name=individuality.name, memory_str=memory_str, extra_info=self.structured_info_str, - # prompt_personality=prompt_personality, relation_prompt=relation_prompt, - bot_name=individuality.name, - time_now=time_now, + time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), chat_observe_info=chat_observe_info, - # mood_info="mood_info", - hf_do_next=spark_prompt, last_mind=previous_mind, cycle_info_block=hfcloop_observe_info, chat_target_name=chat_target_name, ) - # 在构建完提示词后,生成最终的prompt字符串 - final_prompt = prompt - - content = "" # 初始化内容变量 + content = "(不知道该想些什么...)" try: - # 调用LLM生成响应 - response, _ = await self.llm_model.generate_response_async(prompt=final_prompt) - - # 直接使用LLM返回的文本响应作为 content - content = response if response else "" - + content, _ = await self.llm_model.generate_response_async(prompt=prompt) + if not content: + logger.warning(f"{self.log_prefix} LLM返回空结果,思考失败。") except Exception as e: # 处理总体异常 logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") @@ -256,16 +231,8 @@ class MindProcessor(BaseProcessor): content = "思考过程中出现错误" # 记录初步思考结果 - logger.debug(f"{self.log_prefix} 思考prompt: \n{final_prompt}\n") - - # 处理空响应情况 - if not content: - content = "(不知道该想些什么...)" - logger.warning(f"{self.log_prefix} LLM返回空结果,思考失败。") - - # ---------- 8. 更新思考状态并返回结果 ---------- + logger.debug(f"{self.log_prefix} 思考prompt: \n{prompt}\n") logger.info(f"{self.log_prefix} 思考结果: {content}") - # 更新当前思考内容 self.update_current_mind(content) return content @@ -275,138 +242,5 @@ class MindProcessor(BaseProcessor): self.past_mind.append(self.current_mind) self.current_mind = response - def de_similar(self, previous_mind, new_content): - try: - similarity = calculate_similarity(previous_mind, new_content) - replacement_prob = calculate_replacement_probability(similarity) - logger.debug(f"{self.log_prefix} 新旧想法相似度: {similarity:.2f}, 替换概率: {replacement_prob:.2f}") - - # 定义词语列表 (移到判断之前) - yu_qi_ci_liebiao = ["嗯", "哦", "啊", "唉", "哈", "唔"] - zhuan_zhe_liebiao = ["但是", "不过", "然而", "可是", "只是"] - cheng_jie_liebiao = ["然后", "接着", "此外", "而且", "另外"] - zhuan_jie_ci_liebiao = zhuan_zhe_liebiao + cheng_jie_liebiao - - if random.random() < replacement_prob: - # 相似度非常高时,尝试去重或特殊处理 - if similarity == 1.0: - logger.debug(f"{self.log_prefix} 想法完全重复 (相似度 1.0),执行特殊处理...") - # 随机截取大约一半内容 - if len(new_content) > 1: # 避免内容过短无法截取 - split_point = max( - 1, len(new_content) // 2 + random.randint(-len(new_content) // 4, len(new_content) // 4) - ) - truncated_content = new_content[:split_point] - else: - truncated_content = new_content # 如果只有一个字符或者为空,就不截取了 - - # 添加语气词和转折/承接词 - yu_qi_ci = random.choice(yu_qi_ci_liebiao) - zhuan_jie_ci = random.choice(zhuan_jie_ci_liebiao) - content = f"{yu_qi_ci}{zhuan_jie_ci},{truncated_content}" - logger.debug(f"{self.log_prefix} 想法重复,特殊处理后: {content}") - - else: - # 相似度较高但非100%,执行标准去重逻辑 - logger.debug(f"{self.log_prefix} 执行概率性去重 (概率: {replacement_prob:.2f})...") - logger.debug( - f"{self.log_prefix} previous_mind类型: {type(previous_mind)}, new_content类型: {type(new_content)}" - ) - - matcher = difflib.SequenceMatcher(None, previous_mind, new_content) - logger.debug(f"{self.log_prefix} matcher类型: {type(matcher)}") - - deduplicated_parts = [] - last_match_end_in_b = 0 - - # 获取并记录所有匹配块 - matching_blocks = matcher.get_matching_blocks() - logger.debug(f"{self.log_prefix} 匹配块数量: {len(matching_blocks)}") - logger.debug( - f"{self.log_prefix} 匹配块示例(前3个): {matching_blocks[:3] if len(matching_blocks) > 3 else matching_blocks}" - ) - - # get_matching_blocks()返回形如[(i, j, n), ...]的列表,其中i是a中的索引,j是b中的索引,n是匹配的长度 - for idx, match in enumerate(matching_blocks): - if not isinstance(match, tuple): - logger.error(f"{self.log_prefix} 匹配块 {idx} 不是元组类型,而是 {type(match)}: {match}") - continue - - try: - _i, j, n = match # 解包元组为三个变量 - logger.debug(f"{self.log_prefix} 匹配块 {idx}: i={_i}, j={j}, n={n}") - - if last_match_end_in_b < j: - # 确保添加的是字符串,而不是元组 - try: - non_matching_part = new_content[last_match_end_in_b:j] - logger.debug( - f"{self.log_prefix} 添加非匹配部分: '{non_matching_part}', 类型: {type(non_matching_part)}" - ) - if not isinstance(non_matching_part, str): - logger.warning( - f"{self.log_prefix} 非匹配部分不是字符串类型: {type(non_matching_part)}" - ) - non_matching_part = str(non_matching_part) - deduplicated_parts.append(non_matching_part) - except Exception as e: - logger.error(f"{self.log_prefix} 处理非匹配部分时出错: {e}") - logger.error(traceback.format_exc()) - last_match_end_in_b = j + n - except Exception as e: - logger.error(f"{self.log_prefix} 处理匹配块时出错: {e}") - logger.error(traceback.format_exc()) - - logger.debug(f"{self.log_prefix} 去重前部分列表: {deduplicated_parts}") - logger.debug(f"{self.log_prefix} 列表元素类型: {[type(part) for part in deduplicated_parts]}") - - # 确保所有元素都是字符串 - deduplicated_parts = [str(part) for part in deduplicated_parts] - - # 防止列表为空 - if not deduplicated_parts: - logger.warning(f"{self.log_prefix} 去重后列表为空,添加空字符串") - deduplicated_parts = [""] - - logger.debug(f"{self.log_prefix} 处理后的部分列表: {deduplicated_parts}") - - try: - deduplicated_content = "".join(deduplicated_parts).strip() - logger.debug(f"{self.log_prefix} 拼接后的去重内容: '{deduplicated_content}'") - except Exception as e: - logger.error(f"{self.log_prefix} 拼接去重内容时出错: {e}") - logger.error(traceback.format_exc()) - deduplicated_content = "" - - if deduplicated_content: - # 根据概率决定是否添加词语 - prefix_str = "" - if random.random() < 0.3: # 30% 概率添加语气词 - prefix_str += random.choice(yu_qi_ci_liebiao) - if random.random() < 0.7: # 70% 概率添加转折/承接词 - prefix_str += random.choice(zhuan_jie_ci_liebiao) - - # 组合最终结果 - if prefix_str: - content = f"{prefix_str},{deduplicated_content}" # 更新 content - logger.debug(f"{self.log_prefix} 去重并添加引导词后: {content}") - else: - content = deduplicated_content # 更新 content - logger.debug(f"{self.log_prefix} 去重后 (未添加引导词): {content}") - else: - logger.warning(f"{self.log_prefix} 去重后内容为空,保留原始LLM输出: {new_content}") - content = new_content # 保留原始 content - else: - logger.debug(f"{self.log_prefix} 未执行概率性去重 (概率: {replacement_prob:.2f})") - # content 保持 new_content 不变 - - except Exception as e: - logger.error(f"{self.log_prefix} 应用概率性去重或特殊处理时出错: {e}") - logger.error(traceback.format_exc()) - # 出错时保留原始 content - content = new_content - - return content - init_prompt() diff --git a/src/chat/focus_chat/info_processors/processor_utils.py b/src/chat/focus_chat/info_processors/processor_utils.py deleted file mode 100644 index 77cdc7a6b..000000000 --- a/src/chat/focus_chat/info_processors/processor_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -import difflib -import random -import time - - -def calculate_similarity(text_a: str, text_b: str) -> float: - """ - 计算两个文本字符串的相似度。 - """ - if not text_a or not text_b: - return 0.0 - matcher = difflib.SequenceMatcher(None, text_a, text_b) - return matcher.ratio() - - -def calculate_replacement_probability(similarity: float) -> float: - """ - 根据相似度计算替换的概率。 - 规则: - - 相似度 <= 0.4: 概率 = 0 - - 相似度 >= 0.9: 概率 = 1 - - 相似度 == 0.6: 概率 = 0.7 - - 0.4 < 相似度 <= 0.6: 线性插值 (0.4, 0) 到 (0.6, 0.7) - - 0.6 < 相似度 < 0.9: 线性插值 (0.6, 0.7) 到 (0.9, 1.0) - """ - if similarity <= 0.4: - return 0.0 - elif similarity >= 0.9: - return 1.0 - elif 0.4 < similarity <= 0.6: - # p = 3.5 * s - 1.4 - probability = 3.5 * similarity - 1.4 - return max(0.0, probability) - else: # 0.6 < similarity < 0.9 - # p = s + 0.1 - probability = similarity + 0.1 - return min(1.0, max(0.0, probability)) - - -def get_spark(): - local_random = random.Random() - current_minute = int(time.strftime("%M")) - local_random.seed(current_minute) - - hf_options = [ - ("可以参考之前的想法,在原来想法的基础上继续思考", 0.2), - ("可以参考之前的想法,在原来的想法上尝试新的话题", 0.4), - ("不要太深入", 0.2), - ("进行深入思考", 0.2), - ] - # 加权随机选择思考指导 - hf_do_next = local_random.choices( - [option[0] for option in hf_options], weights=[option[1] for option in hf_options], k=1 - )[0] - - return hf_do_next diff --git a/src/chat/focus_chat/info_processors/tool_processor.py b/src/chat/focus_chat/info_processors/tool_processor.py index 8840c1ae4..39e0c293c 100644 --- a/src/chat/focus_chat/info_processors/tool_processor.py +++ b/src/chat/focus_chat/info_processors/tool_processor.py @@ -155,7 +155,7 @@ class ToolProcessor(BaseProcessor): ) # 调用LLM,专注于工具使用 - logger.debug(f"开始执行工具调用{prompt}") + # logger.debug(f"开始执行工具调用{prompt}") response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools) logger.debug(f"获取到工具原始输出:\n{tool_calls}") diff --git a/src/chat/focus_chat/planners/action_factory.py b/src/chat/focus_chat/planners/action_manager.py similarity index 75% rename from src/chat/focus_chat/planners/action_factory.py rename to src/chat/focus_chat/planners/action_manager.py index 257156a25..72ff4a73e 100644 --- a/src/chat/focus_chat/planners/action_factory.py +++ b/src/chat/focus_chat/planners/action_manager.py @@ -1,18 +1,18 @@ -from typing import Dict, List, Optional, Callable, Coroutine, Type, Any, Union -import os -import importlib -from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY, _DEFAULT_ACTIONS +from typing import Dict, List, Optional, Callable, Coroutine, Type, Any +from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY from src.chat.heart_flow.observation.observation import Observation from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.common.logger_manager import get_logger +import importlib +import pkgutil +import os # 导入动作类,确保装饰器被执行 -from src.chat.focus_chat.planners.actions.reply_action import ReplyAction -from src.chat.focus_chat.planners.actions.no_reply_action import NoReplyAction +import src.chat.focus_chat.planners.actions # noqa -logger = get_logger("action_factory") +logger = get_logger("action_manager") # 定义动作信息类型 ActionInfo = Dict[str, Any] @@ -31,20 +31,18 @@ class ActionManager: self._using_actions: Dict[str, ActionInfo] = {} # 临时备份原始使用中的动作 self._original_actions_backup: Optional[Dict[str, ActionInfo]] = None - + # 默认动作集,仅作为快照,用于恢复默认 self._default_actions: Dict[str, ActionInfo] = {} - + # 加载所有已注册动作 self._load_registered_actions() + # 加载插件动作 + self._load_plugin_actions() + # 初始化时将默认动作加载到使用中的动作 self._using_actions = self._default_actions.copy() - - # logger.info(f"当前可用动作: {list(self._using_actions.keys())}") - # for action_name, action_info in self._using_actions.items(): - # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") - def _load_registered_actions(self) -> None: """ @@ -54,37 +52,78 @@ class ActionManager: # 从_ACTION_REGISTRY获取所有已注册动作 for action_name, action_class in _ACTION_REGISTRY.items(): # 获取动作相关信息 - action_description:str = getattr(action_class, "action_description", "") - action_parameters:dict[str:str] = getattr(action_class, "action_parameters", {}) - action_require:list[str] = getattr(action_class, "action_require", []) - is_default:bool = getattr(action_class, "default", False) + # 不读取插件动作和基类 + if action_name == "base_action" or action_name == "plugin_action": + continue + + action_description: str = getattr(action_class, "action_description", "") + action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {}) + action_require: list[str] = getattr(action_class, "action_require", []) + is_default: bool = getattr(action_class, "default", False) + if action_name and action_description: # 创建动作信息字典 action_info = { "description": action_description, "parameters": action_parameters, - "require": action_require + "require": action_require, } - - # 注册2 - print("注册2") - print(action_info) - + # 添加到所有已注册的动作 self._registered_actions[action_name] = action_info - + # 添加到默认动作(如果是默认动作) if is_default: self._default_actions[action_name] = action_info - + logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") logger.info(f"默认动作: {list(self._default_actions.keys())}") - # for action_name, action_info in self._default_actions.items(): - # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") - + for action_name, action_info in self._default_actions.items(): + logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") + except Exception as e: logger.error(f"加载已注册动作失败: {e}") + + def _load_plugin_actions(self) -> None: + """ + 加载所有插件目录中的动作 + """ + try: + # 检查插件目录是否存在 + plugin_path = "src.plugins" + plugin_dir = plugin_path.replace('.', os.path.sep) + if not os.path.exists(plugin_dir): + logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载") + return + + # 导入插件包 + try: + plugins_package = importlib.import_module(plugin_path) + except ImportError as e: + logger.error(f"导入插件包失败: {e}") + return + + # 遍历插件包中的所有子包 + for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + '.'): + if not is_pkg: + continue + + # 检查插件是否有actions子包 + plugin_actions_path = f"{plugin_name}.actions" + try: + # 尝试导入插件的actions包 + importlib.import_module(plugin_actions_path) + logger.info(f"成功加载插件动作模块: {plugin_actions_path}") + except ImportError as e: + logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}") + continue + + # 再次从_ACTION_REGISTRY获取所有动作(包括刚刚从插件加载的) + self._load_registered_actions() + + except Exception as e: + logger.error(f"加载插件动作失败: {e}") def create_action( self, @@ -99,8 +138,8 @@ class ActionManager: current_cycle: CycleDetail, log_prefix: str, on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], - total_no_reply_count: int = 0, - total_waiting_time: float = 0.0, + # total_no_reply_count: int = 0, + # total_waiting_time: float = 0.0, shutting_down: bool = False, ) -> Optional[BaseAction]: """ @@ -129,14 +168,14 @@ class ActionManager: if action_name not in self._using_actions: logger.warning(f"当前不可用的动作类型: {action_name}") return None - + handler_class = _ACTION_REGISTRY.get(action_name) if not handler_class: logger.warning(f"未注册的动作类型: {action_name}") return None try: - # 创建动作实例并传递所有必要参数 + # 创建动作实例 instance = handler_class( action_name=action_name, action_data=action_data, @@ -144,16 +183,16 @@ class ActionManager: cycle_timers=cycle_timers, thinking_id=thinking_id, observations=observations, - on_consecutive_no_reply_callback=on_consecutive_no_reply_callback, - current_cycle=current_cycle, - log_prefix=log_prefix, - total_no_reply_count=total_no_reply_count, - total_waiting_time=total_waiting_time, - shutting_down=shutting_down, expressor=expressor, chat_stream=chat_stream, + current_cycle=current_cycle, + log_prefix=log_prefix, + on_consecutive_no_reply_callback=on_consecutive_no_reply_callback, + # total_no_reply_count=total_no_reply_count, + # total_waiting_time=total_waiting_time, + shutting_down=shutting_down, ) - + return instance except Exception as e: @@ -167,7 +206,7 @@ class ActionManager: def get_default_actions(self) -> Dict[str, ActionInfo]: """获取默认动作集""" return self._default_actions.copy() - + def get_using_actions(self) -> Dict[str, ActionInfo]: """获取当前正在使用的动作集""" return self._using_actions.copy() @@ -175,21 +214,21 @@ class ActionManager: def add_action_to_using(self, action_name: str) -> bool: """ 添加已注册的动作到当前使用的动作集 - + Args: action_name: 动作名称 - + Returns: bool: 添加是否成功 """ if action_name not in self._registered_actions: logger.warning(f"添加失败: 动作 {action_name} 未注册") return False - + if action_name in self._using_actions: logger.info(f"动作 {action_name} 已经在使用中") return True - + self._using_actions[action_name] = self._registered_actions[action_name] logger.info(f"添加动作 {action_name} 到使用集") return True @@ -197,17 +236,17 @@ class ActionManager: def remove_action_from_using(self, action_name: str) -> bool: """ 从当前使用的动作集中移除指定动作 - + Args: action_name: 动作名称 - + Returns: bool: 移除是否成功 """ if action_name not in self._using_actions: logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中") return False - + del self._using_actions[action_name] logger.info(f"已从使用集中移除动作 {action_name}") return True @@ -215,30 +254,26 @@ class ActionManager: def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: """ 添加新的动作到注册集 - + Args: action_name: 动作名称 description: 动作描述 parameters: 动作参数定义,默认为空字典 require: 动作依赖项,默认为空列表 - + Returns: bool: 添加是否成功 """ if action_name in self._registered_actions: return False - + if parameters is None: parameters = {} if require is None: require = [] - - action_info = { - "description": description, - "parameters": parameters, - "require": require - } - + + action_info = {"description": description, "parameters": parameters, "require": require} + self._registered_actions[action_name] = action_info return True @@ -264,7 +299,7 @@ class ActionManager: if self._original_actions_backup is not None: self._using_actions = self._original_actions_backup.copy() self._original_actions_backup = None - + def restore_default_actions(self) -> None: """恢复默认动作集到使用集""" self._using_actions = self._default_actions.copy() @@ -273,15 +308,12 @@ class ActionManager: def get_action(self, action_name: str) -> Optional[Type[BaseAction]]: """ 获取指定动作的处理器类 - + Args: action_name: 动作名称 - + Returns: Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None """ return _ACTION_REGISTRY.get(action_name) - -# 创建全局实例 -ActionFactory = ActionManager() diff --git a/src/chat/focus_chat/planners/actions/__init__.py b/src/chat/focus_chat/planners/actions/__init__.py new file mode 100644 index 000000000..435d0d4b4 --- /dev/null +++ b/src/chat/focus_chat/planners/actions/__init__.py @@ -0,0 +1,5 @@ +# 导入所有动作模块以确保装饰器被执行 +from . import reply_action # noqa +from . import no_reply_action # noqa + +# 在此处添加更多动作模块导入 \ No newline at end of file diff --git a/src/chat/focus_chat/planners/actions/base_action.py b/src/chat/focus_chat/planners/actions/base_action.py index 7c77c300c..82d259677 100644 --- a/src/chat/focus_chat/planners/actions/base_action.py +++ b/src/chat/focus_chat/planners/actions/base_action.py @@ -12,7 +12,7 @@ _DEFAULT_ACTIONS: Dict[str, str] = {} def register_action(cls): """ 动作注册装饰器 - + 用法: @register_action class MyAction(BaseAction): @@ -24,22 +24,22 @@ def register_action(cls): if not hasattr(cls, "action_name") or not hasattr(cls, "action_description"): logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description") return cls - - action_name = getattr(cls, "action_name") - action_description = getattr(cls, "action_description") + + action_name = cls.action_name + action_description = cls.action_description is_default = getattr(cls, "default", False) - + if not action_name or not action_description: logger.error(f"动作类 {cls.__name__} 的 action_name 或 action_description 为空") return cls - + # 将动作类注册到全局注册表 _ACTION_REGISTRY[action_name] = cls - + # 如果是默认动作,添加到默认动作集 if is_default: _DEFAULT_ACTIONS[action_name] = action_description - + logger.info(f"已注册动作: {action_name} -> {cls.__name__},默认: {is_default}") return cls @@ -60,15 +60,14 @@ class BaseAction(ABC): cycle_timers: 计时器字典 thinking_id: 思考ID """ - #每个动作必须实现 - self.action_name:str = "base_action" - self.action_description:str = "基础动作" - self.action_parameters:dict = {} - self.action_require:list[str] = [] - - self.default:bool = False - - + # 每个动作必须实现 + self.action_name: str = "base_action" + self.action_description: str = "基础动作" + self.action_parameters: dict = {} + self.action_require: list[str] = [] + + self.default: bool = False + self.action_data = action_data self.reasoning = reasoning self.cycle_timers = cycle_timers diff --git a/src/chat/focus_chat/planners/actions/no_reply_action.py b/src/chat/focus_chat/planners/actions/no_reply_action.py index a29812c7a..406ddbdc2 100644 --- a/src/chat/focus_chat/planners/actions/no_reply_action.py +++ b/src/chat/focus_chat/planners/actions/no_reply_action.py @@ -29,7 +29,7 @@ class NoReplyAction(BaseAction): action_require = [ "话题无关/无聊/不感兴趣/不懂", "最后一条消息是你自己发的且无人回应你", - "你发送了太多消息,且无人回复" + "你发送了太多消息,且无人回复", ] default = True @@ -43,10 +43,10 @@ class NoReplyAction(BaseAction): on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], current_cycle: CycleDetail, log_prefix: str, - total_no_reply_count: int = 0, - total_waiting_time: float = 0.0, + # total_no_reply_count: int = 0, + # total_waiting_time: float = 0.0, shutting_down: bool = False, - **kwargs + **kwargs, ): """初始化不回复动作处理器 @@ -69,8 +69,8 @@ class NoReplyAction(BaseAction): self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback self._current_cycle = current_cycle self.log_prefix = log_prefix - self.total_no_reply_count = total_no_reply_count - self.total_waiting_time = total_waiting_time + # self.total_no_reply_count = total_no_reply_count + # self.total_waiting_time = total_waiting_time self._shutting_down = shutting_down async def handle_action(self) -> Tuple[bool, str]: @@ -96,34 +96,6 @@ class NoReplyAction(BaseAction): # 从计时器获取实际等待时间 current_waiting = self.cycle_timers.get("等待新消息", 0.0) - if not self._shutting_down: - self.total_no_reply_count += 1 - self.total_waiting_time += current_waiting # 累加等待时间 - logger.debug( - f"{self.log_prefix} 连续不回复计数增加: {self.total_no_reply_count}/{CONSECUTIVE_NO_REPLY_THRESHOLD}, " - f"本次等待: {current_waiting:.2f}秒, 累计等待: {self.total_waiting_time:.2f}秒" - ) - - # 检查是否同时达到次数和时间阈值 - time_threshold = 0.66 * WAITING_TIME_THRESHOLD * CONSECUTIVE_NO_REPLY_THRESHOLD - if ( - self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD - and self.total_waiting_time >= time_threshold - ): - logger.info( - f"{self.log_prefix} 连续不回复达到阈值 ({self.total_no_reply_count}次) " - f"且累计等待时间达到 {self.total_waiting_time:.2f}秒 (阈值 {time_threshold}秒)," - f"调用回调请求状态转换" - ) - # 调用回调。注意:这里不重置计数器和时间,依赖回调函数成功改变状态来隐式重置上下文。 - await self.on_consecutive_no_reply_callback() - elif self.total_no_reply_count >= CONSECUTIVE_NO_REPLY_THRESHOLD: - # 仅次数达到阈值,但时间未达到 - logger.debug( - f"{self.log_prefix} 连续不回复次数达到阈值 ({self.total_no_reply_count}次) " - f"但累计等待时间 {self.total_waiting_time:.2f}秒 未达到时间阈值 ({time_threshold}秒),暂不调用回调" - ) - # else: 次数和时间都未达到阈值,不做处理 return True, "" # 不回复动作没有回复文本 diff --git a/src/chat/focus_chat/planners/actions/plugin_action.py b/src/chat/focus_chat/planners/actions/plugin_action.py new file mode 100644 index 000000000..aec879e97 --- /dev/null +++ b/src/chat/focus_chat/planners/actions/plugin_action.py @@ -0,0 +1,215 @@ +import traceback +from typing import Tuple, Dict, List, Any, Optional +from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.focus_chat.hfc_utils import create_empty_anchor_message +from src.common.logger_manager import get_logger +from src.chat.person_info.person_info import person_info_manager +from abc import abstractmethod + +logger = get_logger("plugin_action") + +class PluginAction(BaseAction): + """插件动作基类 + + 封装了主程序内部依赖,提供简化的API接口给插件开发者 + """ + + def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs): + """初始化插件动作基类""" + super().__init__(action_data, reasoning, cycle_timers, thinking_id) + + # 存储内部服务和对象引用 + self._services = {} + + # 从kwargs提取必要的内部服务 + if "observations" in kwargs: + self._services["observations"] = kwargs["observations"] + if "expressor" in kwargs: + self._services["expressor"] = kwargs["expressor"] + if "chat_stream" in kwargs: + self._services["chat_stream"] = kwargs["chat_stream"] + if "current_cycle" in kwargs: + self._services["current_cycle"] = kwargs["current_cycle"] + + self.log_prefix = kwargs.get("log_prefix", "") + + async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]: + """根据用户名获取用户ID""" + person_id = person_info_manager.get_person_id_by_person_name(person_name) + user_id = await person_info_manager.get_value(person_id, "user_id") + platform = await person_info_manager.get_value(person_id, "platform") + return platform, user_id + + # 提供简化的API方法 + async def send_message(self, text: str, target: Optional[str] = None) -> bool: + """发送消息的简化方法 + + Args: + text: 要发送的消息文本 + target: 目标消息(可选) + + Returns: + bool: 是否发送成功 + """ + try: + expressor = self._services.get("expressor") + chat_stream = self._services.get("chat_stream") + + if not expressor or not chat_stream: + logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") + return False + + # 构造简化的动作数据 + reply_data = { + "text": text, + "target": target or "", + "emojis": [] + } + + # 获取锚定消息(如果有) + observations = self._services.get("observations", []) + + chatting_observation: ChattingObservation = next( + obs for obs in observations + if isinstance(obs, ChattingObservation) + ) + anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) + + # 如果没有找到锚点消息,创建一个占位符 + if not anchor_message: + logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message.update_chat_stream(chat_stream) + + response_set = [ + ("text", text), + ] + + # 调用内部方法发送消息 + success = await expressor.send_response_messages( + anchor_message=anchor_message, + response_set=response_set, + ) + + return success + except Exception as e: + logger.error(f"{self.log_prefix} 发送消息时出错: {e}") + traceback.print_exc() + return False + + + async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool: + """发送消息的简化方法 + + Args: + text: 要发送的消息文本 + target: 目标消息(可选) + + Returns: + bool: 是否发送成功 + """ + try: + expressor = self._services.get("expressor") + chat_stream = self._services.get("chat_stream") + + if not expressor or not chat_stream: + logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") + return False + + # 构造简化的动作数据 + reply_data = { + "text": text, + "target": target or "", + "emojis": [] + } + + # 获取锚定消息(如果有) + observations = self._services.get("observations", []) + + chatting_observation: ChattingObservation = next( + obs for obs in observations + if isinstance(obs, ChattingObservation) + ) + anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) + + # 如果没有找到锚点消息,创建一个占位符 + if not anchor_message: + logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream + ) + else: + anchor_message.update_chat_stream(chat_stream) + + # 调用内部方法发送消息 + success, _ = await expressor.deal_reply( + cycle_timers=self.cycle_timers, + action_data=reply_data, + anchor_message=anchor_message, + reasoning=self.reasoning, + thinking_id=self.thinking_id + ) + + return success + except Exception as e: + logger.error(f"{self.log_prefix} 发送消息时出错: {e}") + return False + + def get_chat_type(self) -> str: + """获取当前聊天类型 + + Returns: + str: 聊天类型 ("group" 或 "private") + """ + chat_stream = self._services.get("chat_stream") + if chat_stream and hasattr(chat_stream, "group_info"): + return "group" if chat_stream.group_info else "private" + return "unknown" + + def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]: + """获取最近的消息 + + Args: + count: 要获取的消息数量 + + Returns: + List[Dict]: 消息列表,每个消息包含发送者、内容等信息 + """ + messages = [] + observations = self._services.get("observations", []) + + if observations and len(observations) > 0: + obs = observations[0] + if hasattr(obs, "get_talking_message"): + raw_messages = obs.get_talking_message() + # 转换为简化格式 + for msg in raw_messages[-count:]: + simple_msg = { + "sender": msg.get("sender", "未知"), + "content": msg.get("content", ""), + "timestamp": msg.get("timestamp", 0) + } + messages.append(simple_msg) + + return messages + + @abstractmethod + async def process(self) -> Tuple[bool, str]: + """插件处理逻辑,子类必须实现此方法 + + Returns: + Tuple[bool, str]: (是否执行成功, 回复文本) + """ + pass + + async def handle_action(self) -> Tuple[bool, str]: + """实现BaseAction的抽象方法,调用子类的process方法 + + Returns: + Tuple[bool, str]: (是否执行成功, 回复文本) + """ + return await self.process() diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 7b2e88fa0..51e3b8eaa 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - from src.common.logger_manager import get_logger -from src.chat.utils.timer_calculator import Timer from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action -from typing import Tuple, List, Optional +from typing import Tuple, List from src.chat.heart_flow.observation.observation import Observation from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream @@ -22,23 +20,22 @@ class ReplyAction(BaseAction): 处理构建和发送消息回复的动作。 """ - action_name:str = "reply" - action_description:str = "表达想法,可以只包含文本、表情或两者都有" - action_parameters:dict[str:str] = { + action_name: str = "reply" + action_description: str = "表达想法,可以只包含文本、表情或两者都有" + action_parameters: dict[str:str] = { "text": "你想要表达的内容(可选)", "emojis": "描述当前使用表情包的场景(可选)", "target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)", } - action_require:list[str] = [ + action_require: list[str] = [ "有实质性内容需要表达", "有人提到你,但你还没有回应他", "在合适的时候添加表情(不要总是添加)", - "如果你要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本", - "除非有明确的回复目标,如果选择了target,不用特别提到某个人的人名", + "如果你有明确的,要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本", "一次只回复一个人,一次只回复一个话题,突出重点", "如果是自己发的消息想继续,需自然衔接", "避免重复或评价自己的发言,不要和自己聊天", - "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。" + "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。不要有额外的符号,尽量简单简短", ] default = True @@ -54,7 +51,7 @@ class ReplyAction(BaseAction): chat_stream: ChatStream, current_cycle: CycleDetail, log_prefix: str, - **kwargs + **kwargs, ): """初始化回复动作处理器 @@ -89,9 +86,9 @@ class ReplyAction(BaseAction): reasoning=self.reasoning, reply_data=self.action_data, cycle_timers=self.cycle_timers, - thinking_id=self.thinking_id + thinking_id=self.thinking_id, ) - + async def _handle_reply( self, reasoning: str, reply_data: dict, cycle_timers: dict, thinking_id: str ) -> tuple[bool, str]: @@ -105,13 +102,16 @@ class ReplyAction(BaseAction): "emojis": "微笑" # 表情关键词列表(可选) } """ - # 重置连续不回复计数器 - self.total_no_reply_count = 0 - self.total_waiting_time = 0.0 # 从聊天观察获取锚定消息 - observations: ChattingObservation = self.observations[0] - anchor_message = observations.serch_message_by_text(reply_data["target"]) + chatting_observation: ChattingObservation = next( + obs for obs in self.observations + if isinstance(obs, ChattingObservation) + ) + if reply_data.get("target"): + anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) + else: + anchor_message = None # 如果没有找到锚点消息,创建一个占位符 if not anchor_message: diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index bb87e1da7..79044a5a6 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -4,7 +4,6 @@ from typing import List, Dict, Any, Optional from rich.traceback import install from src.chat.models.utils_model import LLMRequest from src.config.config import global_config -from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.obs_info import ObsInfo from src.chat.focus_chat.info.cycle_info import CycleInfo @@ -13,16 +12,21 @@ from src.chat.focus_chat.info.structured_info import StructuredInfo from src.common.logger_manager import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.individuality.individuality import Individuality -from src.chat.focus_chat.planners.action_factory import ActionManager -from src.chat.focus_chat.planners.action_factory import ActionInfo +from src.chat.focus_chat.planners.action_manager import ActionManager +from src.chat.focus_chat.planners.action_manager import ActionInfo + logger = get_logger("planner") install(extra_lines=3) + def init_prompt(): Prompt( - """你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话: + """{extra_info_block} + +你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话: {chat_content_block} + {mind_info_block} {cycle_info_block} @@ -44,20 +48,20 @@ def init_prompt(): }} 请输出你的决策 JSON:""", -"planner_prompt",) - + "planner_prompt", + ) + Prompt( """ action_name: {action_name} 描述:{action_description} 参数: - {action_parameters} +{action_parameters} 动作要求: - {action_require} - """, +{action_require}""", "action_prompt", ) - + class ActionPlanner: def __init__(self, log_prefix: str, action_manager: ActionManager): @@ -68,7 +72,7 @@ class ActionPlanner: max_tokens=1000, request_type="action_planning", # 用于动作规划 ) - + self.action_manager = action_manager async def plan(self, all_plan_info: List[InfoBase], cycle_timers: dict) -> Dict[str, Any]: @@ -85,6 +89,7 @@ class ActionPlanner: try: # 获取观察信息 + extra_info: list[str] = [] for info in all_plan_info: if isinstance(info, ObsInfo): logger.debug(f"{self.log_prefix} 观察信息: {info}") @@ -104,9 +109,11 @@ class ActionPlanner: elif isinstance(info, StructuredInfo): logger.debug(f"{self.log_prefix} 结构化信息: {info}") structured_info = info.get_data() + else: + extra_info.append(info.get_processed_info()) current_available_actions = self.action_manager.get_using_actions() - + # --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- prompt = await self.build_planner_prompt( is_group_chat=is_group_chat, # <-- Pass HFC state @@ -116,6 +123,7 @@ class ActionPlanner: # structured_info=structured_info, # <-- Pass SubMind info current_available_actions=current_available_actions, # <-- Pass determined actions cycle_info=cycle_info, # <-- Pass cycle info + extra_info=extra_info, ) # --- 调用 LLM (普通文本生成) --- @@ -142,15 +150,13 @@ class ActionPlanner: extracted_action = parsed_json.get("action", "no_reply") extracted_reasoning = parsed_json.get("reasoning", "LLM未提供理由") - # 新的reply格式 - if extracted_action == "reply": - action_data = { - "text": parsed_json.get("text", []), - "emojis": parsed_json.get("emojis", []), - "target": parsed_json.get("target", ""), - } - else: - action_data = {} # 其他动作可能不需要额外数据 + # 将所有其他属性添加到action_data + action_data = {} + for key, value in parsed_json.items(): + if key not in ["action", "reasoning"]: + action_data[key] = value + + # 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data if extracted_action not in current_available_actions: logger.warning( @@ -197,7 +203,6 @@ class ActionPlanner: # 返回结果字典 return plan_result - async def build_planner_prompt( self, is_group_chat: bool, # Now passed as argument @@ -206,6 +211,7 @@ class ActionPlanner: current_mind: Optional[str], current_available_actions: Dict[str, ActionInfo], cycle_info: Optional[str], + extra_info: list[str], ) -> str: """构建 Planner LLM 的提示词 (获取模板并填充数据)""" try: @@ -218,7 +224,6 @@ class ActionPlanner: ) chat_context_description = f"你正在和 {chat_target_name} 私聊" - chat_content_block = "" if observed_messages_str: chat_content_block = f"聊天记录:\n{observed_messages_str}" @@ -234,7 +239,6 @@ class ActionPlanner: individuality = Individuality.get_instance() personality_block = individuality.get_prompt(x_person=2, level=2) - action_options_block = "" for using_actions_name, using_actions_info in current_available_actions.items(): # print(using_actions_name) @@ -242,29 +246,29 @@ class ActionPlanner: # print(using_actions_info["parameters"]) # print(using_actions_info["require"]) # print(using_actions_info["description"]) - + using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt") - + param_text = "" for param_name, param_description in using_actions_info["parameters"].items(): - param_text += f"{param_name}: {param_description}\n" - + param_text += f" {param_name}: {param_description}\n" + require_text = "" for require_item in using_actions_info["require"]: - require_text += f"- {require_item}\n" - + require_text += f" - {require_item}\n" + using_action_prompt = using_action_prompt.format( action_name=using_actions_name, action_description=using_actions_info["description"], action_parameters=param_text, action_require=require_text, ) - + action_options_block += using_action_prompt - + extra_info_block = "\n".join(extra_info) + extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策" - planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") prompt = planner_prompt_template.format( bot_name=global_config.BOT_NICKNAME, @@ -274,6 +278,7 @@ class ActionPlanner: mind_info_block=mind_info_block, cycle_info_block=cycle_info, action_options_text=action_options_block, + extra_info_block=extra_info_block, ) return prompt diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index a51eba5e2..017f24da9 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -14,6 +14,7 @@ from typing import Optional import difflib from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入 from src.chat.heart_flow.observation.observation import Observation + from src.common.logger_manager import get_logger from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.utils.prompt_builder import Prompt @@ -43,6 +44,7 @@ class ChattingObservation(Observation): def __init__(self, chat_id): super().__init__(chat_id) self.chat_id = chat_id + self.platform = "qq" # --- Initialize attributes (defaults) --- self.is_group_chat: bool = False @@ -105,7 +107,7 @@ class ChattingObservation(Observation): mid_memory_str += f"{mid_memory['theme']}\n" return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str - def serch_message_by_text(self, text: str) -> Optional[MessageRecv]: + def search_message_by_text(self, text: str) -> Optional[MessageRecv]: """ 根据回复的纯文本 1. 在talking_message中查找最新的,最匹配的消息 @@ -150,7 +152,7 @@ class ChattingObservation(Observation): } message_info = { - "platform": find_msg.get("platform"), + "platform": self.platform, "message_id": find_msg.get("message_id"), "time": find_msg.get("time"), "group_info": group_info, diff --git a/src/chat/heart_flow/observation/hfcloop_observation.py b/src/chat/heart_flow/observation/hfcloop_observation.py index 470671e28..d950e3512 100644 --- a/src/chat/heart_flow/observation/hfcloop_observation.py +++ b/src/chat/heart_flow/observation/hfcloop_observation.py @@ -3,6 +3,7 @@ from datetime import datetime from src.common.logger_manager import get_logger from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail +from src.chat.focus_chat.planners.action_manager import ActionManager from typing import List # Import the new utility function @@ -16,14 +17,16 @@ class HFCloopObservation: self.observe_id = observe_id self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 self.history_loop: List[CycleDetail] = [] + self.action_manager = ActionManager() def get_observe_info(self): return self.observe_info def add_loop_info(self, loop_info: CycleDetail): - # logger.debug(f"添加循环信息111111111111111111111111111111111111: {loop_info}") - # print(f"添加循环信息111111111111111111111111111111111111: {loop_info}") self.history_loop.append(loop_info) + + def set_action_manager(self, action_manager: ActionManager): + self.action_manager = action_manager async def observe(self): recent_active_cycles: List[CycleDetail] = [] @@ -62,7 +65,6 @@ class HFCloopObservation: if cycle_info_block: cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n" else: - # 如果最近的活动循环不是文本回复,或者没有活动循环 cycle_info_block = "\n" # 获取history_loop中最新添加的 @@ -72,8 +74,17 @@ class HFCloopObservation: end_time = last_loop.end_time if start_time is not None and end_time is not None: time_diff = int(end_time - start_time) - cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}分钟\n" + if time_diff > 60: + cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff/60}分钟\n" + else: + cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}秒\n" else: - cycle_info_block += "\n无法获取上一次阅读消息的时间\n" + cycle_info_block += "\n你还没看过消息\n" + + using_actions = self.action_manager.get_using_actions() + for action_name, action_info in using_actions.items(): + action_description = action_info["description"] + cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n" + self.observe_info = cycle_info_block diff --git a/src/chat/heart_flow/observation/observation.py b/src/chat/heart_flow/observation/observation.py index 97e254fc0..8ab9ab9a4 100644 --- a/src/chat/heart_flow/observation/observation.py +++ b/src/chat/heart_flow/observation/observation.py @@ -5,7 +5,6 @@ from src.common.logger_manager import get_logger logger = get_logger("observation") - # 所有观察的基类 class Observation: def __init__(self, observe_id): diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index 605b86b23..2460ab4ff 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -94,6 +94,15 @@ class PersonInfoManager: return True else: return False + + def get_person_id_by_person_name(self, person_name: str): + """根据用户名获取用户ID""" + document = db.person_info.find_one({"person_name": person_name}) + if document: + return document["person_id"] + else: + return "" + @staticmethod async def create_person_info(person_id: str, data: dict = None): diff --git a/src/plugins.md b/src/plugins.md new file mode 100644 index 000000000..71ca741a6 --- /dev/null +++ b/src/plugins.md @@ -0,0 +1,101 @@ +# 如何编写MaiBot插件 + +## 基本步骤 + +1. 在`src/plugins/你的插件名/actions/`目录下创建插件文件 +2. 继承`PluginAction`基类 +3. 实现`process`方法 + +## 插件结构示例 + +```python +from src.common.logger_manager import get_logger +from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action +from typing import Tuple + +logger = get_logger("your_action_name") + +@register_action +class YourAction(PluginAction): + """你的动作描述""" + + action_name = "your_action_name" # 动作名称,必须唯一 + action_description = "这个动作的详细描述,会展示给用户" + action_parameters = { + "param1": "参数1的说明(可选)", + "param2": "参数2的说明(可选)" + } + action_require = [ + "使用场景1", + "使用场景2" + ] + default = False # 是否默认启用 + + async def process(self) -> Tuple[bool, str]: + """插件核心逻辑""" + # 你的代码逻辑... + return True, "执行结果" +``` + +## 可用的API方法 + +插件可以使用`PluginAction`基类提供的以下API: + +### 1. 发送消息 + +```python +await self.send_message("要发送的文本", target="可选的回复目标") +``` + +### 2. 获取聊天类型 + +```python +chat_type = self.get_chat_type() # 返回 "group" 或 "private" 或 "unknown" +``` + +### 3. 获取最近消息 + +```python +messages = self.get_recent_messages(count=5) # 获取最近5条消息 +# 返回格式: [{"sender": "发送者", "content": "内容", "timestamp": 时间戳}, ...] +``` + +### 4. 获取动作参数 + +```python +param_value = self.action_data.get("param_name", "默认值") +``` + +### 5. 日志记录 + +```python +logger.info(f"{self.log_prefix} 你的日志信息") +logger.warning("警告信息") +logger.error("错误信息") +``` + +## 返回值说明 + +`process`方法必须返回一个元组,包含两个元素: +- 第一个元素(bool): 表示动作是否执行成功 +- 第二个元素(str): 执行结果的文本描述 + +```python +return True, "执行成功的消息" +# 或 +return False, "执行失败的原因" +``` + +## 最佳实践 + +1. 使用`action_parameters`清晰定义你的动作需要的参数 +2. 使用`action_require`描述何时应该使用你的动作 +3. 使用`action_description`准确描述你的动作功能 +4. 使用`logger`记录重要信息,方便调试 +5. 避免操作底层系统,尽量使用`PluginAction`提供的API + +## 注册与加载 + +插件会在系统启动时自动加载,只要放在正确的目录并添加了`@register_action`装饰器。 + +若设置`default = True`,插件会自动添加到默认动作集;否则需要在系统中手动启用。 diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py new file mode 100644 index 000000000..b557a4258 --- /dev/null +++ b/src/plugins/__init__.py @@ -0,0 +1 @@ +"""插件系统包""" \ No newline at end of file diff --git a/src/plugins/test_plugin/__init__.py b/src/plugins/test_plugin/__init__.py new file mode 100644 index 000000000..867ef417c --- /dev/null +++ b/src/plugins/test_plugin/__init__.py @@ -0,0 +1,4 @@ +"""测试插件包""" +""" +这是一个测试插件 +""" \ No newline at end of file diff --git a/src/plugins/test_plugin/actions/__init__.py b/src/plugins/test_plugin/actions/__init__.py new file mode 100644 index 000000000..8599d2326 --- /dev/null +++ b/src/plugins/test_plugin/actions/__init__.py @@ -0,0 +1,6 @@ +"""测试插件动作模块""" + +# 导入所有动作模块以确保装饰器被执行 +from . import test_action # noqa +# from . import online_action # noqa +from . import mute_action # noqa \ No newline at end of file diff --git a/src/plugins/test_plugin/actions/mute_action.py b/src/plugins/test_plugin/actions/mute_action.py new file mode 100644 index 000000000..723571806 --- /dev/null +++ b/src/plugins/test_plugin/actions/mute_action.py @@ -0,0 +1,48 @@ +from src.common.logger_manager import get_logger +from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action +from typing import Tuple + +logger = get_logger("mute_action") + +@register_action +class MuteAction(PluginAction): + """测试动作处理类""" + + action_name = "mute_action" + action_description = "如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人" + action_parameters = { + "target": "禁言对象,输入你要禁言的对象的名字,必填,", + "duration": "禁言时长,输入你要禁言的时长,单位为秒,必填", + } + action_require = [ + "当有人违反了公序良俗时使用", + "当有人刷屏时使用", + "当有人要求禁言自己时使用", + "当有人戳你两次以上时,防止刷屏,禁言他,必须牢记", + "当千石可乐或可乐酱要求你禁言时使用", + "当你想回避某个话题时使用", + ] + default = True # 不是默认动作,需要手动添加到使用集 + + async def process(self) -> Tuple[bool, str]: + """处理测试动作""" + logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}") + + # 发送测试消息 + target = self.action_data.get("target") + duration = self.action_data.get("duration") + reason = self.action_data.get("reason") + platform, user_id = await self.get_user_id_by_person_name(target) + + await self.send_message_by_expressor(f"我要禁言{target},{platform},时长{duration}秒,理由{reason},表达情绪") + + try: + await self.send_message(f"[command]mute,{user_id},{duration}") + + except Exception as e: + logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}") + await self.send_message_by_expressor(f"执行mute动作时出错: {e}") + + return False, "执行mute动作时出错" + + return True, "测试动作执行成功" \ No newline at end of file diff --git a/src/plugins/test_plugin/actions/online_action.py b/src/plugins/test_plugin/actions/online_action.py new file mode 100644 index 000000000..67e2d2cc9 --- /dev/null +++ b/src/plugins/test_plugin/actions/online_action.py @@ -0,0 +1,44 @@ +from src.common.logger_manager import get_logger +from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action +from typing import Tuple + +logger = get_logger("check_online_action") + +@register_action +class CheckOnlineAction(PluginAction): + """测试动作处理类""" + + action_name = "check_online_action" + action_description = "这是一个检查在线状态的动作,当有人要求你检查Maibot(麦麦 机器人)在线状态时使用" + action_parameters = { + "mode": "查看模式" + } + action_require = [ + "当有人要求你检查Maibot(麦麦 机器人)在线状态时使用", + "mode参数为version时查看在线版本状态,默认用这种", + "mode参数为type时查看在线系统类型分布", + ] + default = True # 不是默认动作,需要手动添加到使用集 + + async def process(self) -> Tuple[bool, str]: + """处理测试动作""" + logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}") + + # 发送测试消息 + mode = self.action_data.get("mode", "type") + + await self.send_message_by_expressor("我看看") + + try: + if mode == "type": + await self.send_message(f"#online detail") + elif mode == "version": + await self.send_message(f"#online") + + except Exception as e: + logger.error(f"{self.log_prefix} 执行online动作时出错: {e}") + await self.send_message_by_expressor("执行online动作时出错: {e}") + + return False, "执行online动作时出错" + + return True, "测试动作执行成功" \ No newline at end of file diff --git a/src/plugins/test_plugin/actions/test_action.py b/src/plugins/test_plugin/actions/test_action.py new file mode 100644 index 000000000..3634dbe78 --- /dev/null +++ b/src/plugins/test_plugin/actions/test_action.py @@ -0,0 +1,38 @@ +from src.common.logger_manager import get_logger +from src.chat.focus_chat.planners.actions.plugin_action import PluginAction, register_action +from typing import Tuple + +logger = get_logger("test_action") + +@register_action +class TestAction(PluginAction): + """测试动作处理类""" + + action_name = "test_action" + action_description = "这是一个测试动作,当有人要求你测试插件系统时使用" + action_parameters = { + "test_param": "测试参数(可选)" + } + action_require = [ + "测试情况下使用", + "想测试插件动作加载时使用", + ] + default = False # 不是默认动作,需要手动添加到使用集 + + async def process(self) -> Tuple[bool, str]: + """处理测试动作""" + logger.info(f"{self.log_prefix} 执行测试动作: {self.reasoning}") + + # 获取聊天类型 + chat_type = self.get_chat_type() + logger.info(f"{self.log_prefix} 当前聊天类型: {chat_type}") + + # 获取最近消息 + recent_messages = self.get_recent_messages(3) + logger.info(f"{self.log_prefix} 最近3条消息: {recent_messages}") + + # 发送测试消息 + test_param = self.action_data.get("test_param", "默认参数") + await self.send_message_by_expressor(f"测试动作执行成功,参数: {test_param}") + + return True, "测试动作执行成功" \ No newline at end of file From a60bb38158c3f05601bb7d68c12e4bf9a784c6f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 16 May 2025 10:47:17 +0800 Subject: [PATCH 19/57] Add DeepWiki Badge to use auto refresh --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index f07e7d57f..bc0140d62 100644 --- a/README.md +++ b/README.md @@ -2,14 +2,14 @@
- ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) - ![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议) - ![Status](https://img.shields.io/badge/状态-开发中-yellow) - ![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者) - ![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数) - ![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数) - ![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot) - +![Python Version](https://img.shields.io/badge/Python-3.10+-blue) +![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议) +![Status](https://img.shields.io/badge/状态-开发中-yellow) +![Contributors](https://img.shields.io/github/contributors/MaiM-with-u/MaiBot.svg?style=flat&label=贡献者) +![forks](https://img.shields.io/github/forks/MaiM-with-u/MaiBot.svg?style=flat&label=分支数) +![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数) +![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/DrSmoothl/MaiBot)

From 5d5033452dc233ca412a70ab21cafce346271164 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 16 May 2025 10:48:39 +0800 Subject: [PATCH 20/57] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bc0140d62..17a8da37b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # 麦麦!MaiCore-MaiMBot (编辑中)
-

+
![Python Version](https://img.shields.io/badge/Python-3.10+-blue) ![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议) @@ -12,7 +12,7 @@ [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/DrSmoothl/MaiBot)
-

+

Logo @@ -21,8 +21,8 @@ 画师:略nd -

MaiBot(麦麦)

-

+

MaiBot(麦麦)

+

一款专注于 群组聊天 的赛博网友
探索本项目的文档 » From 7f3178c96cc51c44ebce5d4edffd6905b6787085 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 16 May 2025 00:43:46 +0800 Subject: [PATCH 21/57] =?UTF-8?q?Feat=EF=BC=9A=E6=B7=BB=E5=8A=A0=E5=AF=B9A?= =?UTF-8?q?ction=E6=8F=92=E4=BB=B6=E7=9A=84=E6=94=AF=E6=8C=81=EF=BC=8C?= =?UTF-8?q?=E7=8E=B0=E5=9C=A8=E5=8F=AF=E4=BB=A5=E7=BC=96=E5=86=99=E6=8F=92?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit From 456def4f9c37e9644b2c9f0742627bf4aad27592 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 16 May 2025 16:13:12 +0800 Subject: [PATCH 22/57] =?UTF-8?q?feat=EF=BC=9A=E5=A2=9E=E5=8A=A0=E4=BA=86?= =?UTF-8?q?=E5=B7=A5=E4=BD=9C=E8=AE=B0=E5=BF=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../expressors/default_expressor.py | 4 +- src/chat/focus_chat/heartFC_chat.py | 3 + src/chat/focus_chat/info/self_info.py | 41 + .../focus_chat/info/workingmemory_info.py | 90 ++ .../info_processors/self_processor.py | 172 ++++ .../info_processors/tool_processor.py | 18 +- .../working_memory_processor.py | 247 ++++++ src/chat/focus_chat/memory_activator.py | 4 +- .../focus_chat/planners/action_manager.py | 8 +- .../planners/actions/reply_action.py | 4 +- src/chat/focus_chat/planners/planner.py | 12 +- .../focus_chat/working_memory/memory_item.py | 119 +++ .../working_memory/memory_manager.py | 798 ++++++++++++++++++ .../working_memory/test/memory_file_loader.py | 169 ++++ .../working_memory/test/run_memory_tests.py | 92 ++ .../test/simulate_real_usage.py | 197 +++++ .../working_memory/test/test_decay_removal.py | 323 +++++++ .../test/test_working_memory.py | 121 +++ .../working_memory/working_memory.py | 197 +++++ .../observation/chatting_observation.py | 4 +- .../observation/memory_observation.py | 55 -- .../observation/structure_observation.py | 32 + .../observation/working_observation.py | 34 +- src/common/logger.py | 8 +- 24 files changed, 2650 insertions(+), 102 deletions(-) create mode 100644 src/chat/focus_chat/info/self_info.py create mode 100644 src/chat/focus_chat/info/workingmemory_info.py create mode 100644 src/chat/focus_chat/info_processors/self_processor.py create mode 100644 src/chat/focus_chat/info_processors/working_memory_processor.py create mode 100644 src/chat/focus_chat/working_memory/memory_item.py create mode 100644 src/chat/focus_chat/working_memory/memory_manager.py create mode 100644 src/chat/focus_chat/working_memory/test/memory_file_loader.py create mode 100644 src/chat/focus_chat/working_memory/test/run_memory_tests.py create mode 100644 src/chat/focus_chat/working_memory/test/simulate_real_usage.py create mode 100644 src/chat/focus_chat/working_memory/test/test_decay_removal.py create mode 100644 src/chat/focus_chat/working_memory/test/test_working_memory.py create mode 100644 src/chat/focus_chat/working_memory/working_memory.py delete mode 100644 src/chat/heart_flow/observation/memory_observation.py create mode 100644 src/chat/heart_flow/observation/structure_observation.py diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index 411b08a05..6da4f52b8 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -127,7 +127,7 @@ class DefaultExpressor: reply=anchor_message, # 回复的是锚点消息 thinking_start_time=thinking_time_point, ) - logger.debug(f"创建思考消息thinking_message:{thinking_message}") + # logger.debug(f"创建思考消息thinking_message:{thinking_message}") await self.heart_fc_sender.register_thinking(thinking_message) @@ -244,7 +244,7 @@ class DefaultExpressor: # logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n") content, reasoning_content, model_name = await self.express_model.generate_response(prompt) - logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n") + # logger.info(f"{self.log_prefix}\nPrompt:\n{prompt}\n---------------------------\n") logger.info(f"想要表达:{in_mind_reply}") logger.info(f"理由:{reason}") diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 7a1671897..9fab88410 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -22,6 +22,7 @@ from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.focus_chat.memory_activator import MemoryActivator from src.chat.focus_chat.info_processors.base_processor import BaseProcessor +from src.chat.focus_chat.info_processors.self_processor import SelfProcessor from src.chat.focus_chat.planners.planner import ActionPlanner from src.chat.focus_chat.planners.action_manager import ActionManager from src.chat.focus_chat.working_memory.working_memory import WorkingMemory @@ -154,6 +155,7 @@ class HeartFChatting: self.processors.append(MindProcessor(subheartflow_id=self.stream_id)) self.processors.append(ToolProcessor(subheartflow_id=self.stream_id)) self.processors.append(WorkingMemoryProcessor(subheartflow_id=self.stream_id)) + self.processors.append(SelfProcessor(subheartflow_id=self.stream_id)) logger.info(f"{self.log_prefix} 已注册默认处理器: {[p.__class__.__name__ for p in self.processors]}") async def start(self): @@ -331,6 +333,7 @@ class HeartFChatting: f"{self.log_prefix} 处理器 {processor_name} 执行失败,耗时 (自并行开始): {duration_since_parallel_start:.2f}秒. 错误: {e}", exc_info=True, ) + traceback.print_exc() # 即使出错,也认为该任务结束了,已从 pending_tasks 中移除 if pending_tasks: diff --git a/src/chat/focus_chat/info/self_info.py b/src/chat/focus_chat/info/self_info.py new file mode 100644 index 000000000..82edd2655 --- /dev/null +++ b/src/chat/focus_chat/info/self_info.py @@ -0,0 +1,41 @@ +from typing import Dict, Any +from dataclasses import dataclass, field +from .info_base import InfoBase + + +@dataclass +class SelfInfo(InfoBase): + """思维信息类 + + 用于存储和管理当前思维状态的信息。 + + Attributes: + type (str): 信息类型标识符,默认为 "mind" + data (Dict[str, Any]): 包含 current_mind 的数据字典 + """ + + type: str = "self" + + def get_self_info(self) -> str: + """获取当前思维状态 + + Returns: + str: 当前思维状态 + """ + return self.get_info("self_info") or "" + + def set_self_info(self, self_info: str) -> None: + """设置当前思维状态 + + Args: + self_info: 要设置的思维状态 + """ + self.data["self_info"] = self_info + + def get_processed_info(self) -> str: + """获取处理后的信息 + + Returns: + str: 处理后的信息 + """ + return self.get_self_info() diff --git a/src/chat/focus_chat/info/workingmemory_info.py b/src/chat/focus_chat/info/workingmemory_info.py new file mode 100644 index 000000000..8c94f6fbc --- /dev/null +++ b/src/chat/focus_chat/info/workingmemory_info.py @@ -0,0 +1,90 @@ +from typing import Dict, Optional, List +from dataclasses import dataclass +from .info_base import InfoBase + + +@dataclass +class WorkingMemoryInfo(InfoBase): + + type: str = "workingmemory" + + processed_info:str = "" + + def set_talking_message(self, message: str) -> None: + """设置说话消息 + + Args: + message (str): 说话消息内容 + """ + self.data["talking_message"] = message + + def set_working_memory(self, working_memory: List[str]) -> None: + """设置工作记忆 + + Args: + working_memory (str): 工作记忆内容 + """ + self.data["working_memory"] = working_memory + + def add_working_memory(self, working_memory: str) -> None: + """添加工作记忆 + + Args: + working_memory (str): 工作记忆内容 + """ + working_memory_list = self.data.get("working_memory", []) + # print(f"working_memory_list: {working_memory_list}") + working_memory_list.append(working_memory) + # print(f"working_memory_list: {working_memory_list}") + self.data["working_memory"] = working_memory_list + + def get_working_memory(self) -> List[str]: + """获取工作记忆 + + Returns: + List[str]: 工作记忆内容 + """ + return self.data.get("working_memory", []) + + def get_type(self) -> str: + """获取信息类型 + + Returns: + str: 当前信息对象的类型标识符 + """ + return self.type + + def get_data(self) -> Dict[str, str]: + """获取所有信息数据 + + Returns: + Dict[str, str]: 包含所有信息数据的字典 + """ + return self.data + + def get_info(self, key: str) -> Optional[str]: + """获取特定属性的信息 + + Args: + key: 要获取的属性键名 + + Returns: + Optional[str]: 属性值,如果键不存在则返回 None + """ + return self.data.get(key) + + def get_processed_info(self) -> Dict[str, str]: + """获取处理后的信息 + + Returns: + Dict[str, str]: 处理后的信息数据 + """ + all_memory = self.get_working_memory() + # print(f"all_memory: {all_memory}") + memory_str = "" + for memory in all_memory: + memory_str += f"{memory}\n" + + self.processed_info = memory_str + + return self.processed_info diff --git a/src/chat/focus_chat/info_processors/self_processor.py b/src/chat/focus_chat/info_processors/self_processor.py new file mode 100644 index 000000000..923c38c35 --- /dev/null +++ b/src/chat/focus_chat/info_processors/self_processor.py @@ -0,0 +1,172 @@ +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.heart_flow.observation.observation import Observation +from src.chat.models.utils_model import LLMRequest +from src.config.config import global_config +import time +import traceback +from src.common.logger_manager import get_logger +from src.individuality.individuality import Individuality +import random +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.json_utils import safe_json_dumps +from src.chat.message_receive.chat_stream import chat_manager +import difflib +from src.chat.person_info.relationship_manager import relationship_manager +from .base_processor import BaseProcessor +from src.chat.focus_chat.info.mind_info import MindInfo +from typing import List, Optional +from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation +from typing import Dict +from src.chat.focus_chat.info.info_base import InfoBase +from src.chat.focus_chat.info.self_info import SelfInfo + +logger = get_logger("processor") + + +def init_prompt(): + indentify_prompt = """ +你的名字是{bot_name},你是一个AI,但是你伪装成了一个人类,你的人格是,{prompt_personality}。 +你对外的形象是一只橙色的鱼,头上有绿色的树叶,你用的头像也是这个。 + +{relation_prompt} +{memory_str} + +现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容: +{chat_observe_info} + +现在请你根据现有的信息,思考自我认同 +1. 你是一个什么样的人,你和群里的人关系如何 +2. 思考有没有人提到你,或者图片与你有关 +3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同 +4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景 + +""" + Prompt(indentify_prompt, "indentify_prompt") + + + +class SelfProcessor(BaseProcessor): + log_prefix = "自我认同" + + def __init__(self, subheartflow_id: str): + super().__init__() + + self.subheartflow_id = subheartflow_id + + self.llm_model = LLMRequest( + model=global_config.llm_sub_heartflow, + temperature=global_config.llm_sub_heartflow["temp"], + max_tokens=800, + request_type="self_identify", + ) + + name = chat_manager.get_stream_name(self.subheartflow_id) + self.log_prefix = f"[{name}] " + + + async def process_info( + self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos + ) -> List[InfoBase]: + """处理信息对象 + + Args: + *infos: 可变数量的InfoBase类型的信息对象 + + Returns: + List[InfoBase]: 处理后的结构化信息列表 + """ + self_info_str = await self.self_indentify(observations, running_memorys) + + if self_info_str: + self_info = SelfInfo() + self_info.set_self_info(self_info_str) + else: + self_info = None + return None + + return [self_info] + + async def self_indentify( + self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None + ): + """ + 在回复前进行思考,生成内心想法并收集工具调用结果 + + 参数: + observations: 观察信息 + + 返回: + 如果return_prompt为False: + tuple: (current_mind, past_mind) 当前想法和过去的想法列表 + 如果return_prompt为True: + tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt + """ + + + memory_str = "" + if running_memorys: + memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" + for running_memory in running_memorys: + memory_str += f"{running_memory['topic']}: {running_memory['content']}\n" + + + if observations is None: + observations = [] + for observation in observations: + if isinstance(observation, ChattingObservation): + # 获取聊天元信息 + is_group_chat = observation.is_group_chat + chat_target_info = observation.chat_target_info + chat_target_name = "对方" # 私聊默认名称 + if not is_group_chat and chat_target_info: + # 优先使用person_name,其次user_nickname,最后回退到默认值 + chat_target_name = ( + chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or chat_target_name + ) + # 获取聊天内容 + chat_observe_info = observation.get_observe_info() + person_list = observation.person_list + if isinstance(observation, HFCloopObservation): + hfcloop_observe_info = observation.get_observe_info() + + + individuality = Individuality.get_instance() + personality_block = individuality.get_prompt(x_person=2, level=2) + + relation_prompt = "" + for person in person_list: + relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True) + + + prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format( + bot_name=individuality.name, + prompt_personality=personality_block, + memory_str=memory_str, + relation_prompt=relation_prompt, + time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + chat_observe_info=chat_observe_info, + ) + + + content = "" + try: + content, _ = await self.llm_model.generate_response_async(prompt=prompt) + if not content: + logger.warning(f"{self.log_prefix} LLM返回空结果,自我识别失败。") + except Exception as e: + # 处理总体异常 + logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") + logger.error(traceback.format_exc()) + content = "自我识别过程中出现错误" + + if content == 'None': + content = "" + # 记录初步思考结果 + logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n") + logger.info(f"{self.log_prefix} 自我识别结果: {content}") + + return content + + + +init_prompt() diff --git a/src/chat/focus_chat/info_processors/tool_processor.py b/src/chat/focus_chat/info_processors/tool_processor.py index 39e0c293c..563621e03 100644 --- a/src/chat/focus_chat/info_processors/tool_processor.py +++ b/src/chat/focus_chat/info_processors/tool_processor.py @@ -4,15 +4,15 @@ from src.config.config import global_config import time from src.common.logger_manager import get_logger from src.individuality.individuality import Individuality -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.tools.tool_use import ToolUser from src.chat.utils.json_utils import process_llm_tool_calls from src.chat.person_info.relationship_manager import relationship_manager from .base_processor import BaseProcessor from typing import List, Optional, Dict from src.chat.heart_flow.observation.observation import Observation -from src.chat.heart_flow.observation.working_observation import WorkingObservation from src.chat.focus_chat.info.structured_info import StructuredInfo +from src.chat.heart_flow.observation.structure_observation import StructureObservation logger = get_logger("processor") @@ -24,9 +24,6 @@ def init_prompt(): tool_executor_prompt = """ 你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。 -你要在群聊中扮演以下角色: -{prompt_personality} - 你当前的额外信息: {memory_str} @@ -70,6 +67,8 @@ class ToolProcessor(BaseProcessor): list: 处理后的结构化信息列表 """ + working_infos = [] + if observations: for observation in observations: if isinstance(observation, ChattingObservation): @@ -77,7 +76,7 @@ class ToolProcessor(BaseProcessor): # 更新WorkingObservation中的结构化信息 for observation in observations: - if isinstance(observation, WorkingObservation): + if isinstance(observation, StructureObservation): for structured_info in result: logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}") observation.add_structured_info(structured_info) @@ -86,8 +85,9 @@ class ToolProcessor(BaseProcessor): logger.debug(f"{self.log_prefix} 获取更新后WorkingObservation中的结构化信息: {working_infos}") structured_info = StructuredInfo() - for working_info in working_infos: - structured_info.set_info(working_info.get("type"), working_info.get("content")) + if working_infos: + for working_info in working_infos: + structured_info.set_info(working_info.get("type"), working_info.get("content")) return [structured_info] @@ -148,7 +148,7 @@ class ToolProcessor(BaseProcessor): # chat_target_name=chat_target_name, is_group_chat=is_group_chat, # relation_prompt=relation_prompt, - prompt_personality=prompt_personality, + # prompt_personality=prompt_personality, # mood_info=mood_info, bot_name=individuality.name, time_now=time_now, diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py new file mode 100644 index 000000000..b3feedcf6 --- /dev/null +++ b/src/chat/focus_chat/info_processors/working_memory_processor.py @@ -0,0 +1,247 @@ +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.heart_flow.observation.observation import Observation +from src.chat.models.utils_model import LLMRequest +from src.config.config import global_config +import time +import traceback +from src.common.logger_manager import get_logger +from src.individuality.individuality import Individuality +import random +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.json_utils import safe_json_dumps +from src.chat.message_receive.chat_stream import chat_manager +import difflib +from src.chat.person_info.relationship_manager import relationship_manager +from .base_processor import BaseProcessor +from src.chat.focus_chat.info.mind_info import MindInfo +from typing import List, Optional +from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation +from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory +from typing import Dict +from src.chat.focus_chat.info.info_base import InfoBase +from json_repair import repair_json +from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo +import asyncio +import json + +logger = get_logger("processor") + + +def init_prompt(): + memory_proces_prompt = """ +你的名字是{bot_name} + +现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容: +{chat_observe_info} + +以下是你已经总结的记忆,你可以调取这些记忆来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆: +{memory_str} + +观察聊天内容和已经总结的记忆,思考是否有新内容需要总结成记忆,如果有,就输出 true,否则输出 false +如果当前聊天记录的内容已经被总结,千万不要总结新记忆,输出false +如果已经总结的记忆包含了当前聊天记录的内容,千万不要总结新记忆,输出false +如果已经总结的记忆摘要,包含了当前聊天记录的内容,千万不要总结新记忆,输出false + +如果有相近的记忆,请合并记忆,输出merge_memory,格式为[["id1", "id2"], ["id3", "id4"],...],你可以进行多组合并,但是每组合并只能有两个记忆id,不要输出其他内容 + +请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆,以JSON格式输出,格式如下: +```json +{{ + "selected_memory_ids": ["id1", "id2", ...], + "new_memory": "true" or "false", + "merge_memory": [["id1", "id2"], ["id3", "id4"],...] + +}} +``` +""" + Prompt(memory_proces_prompt, "prompt_memory_proces") + + +class WorkingMemoryProcessor(BaseProcessor): + log_prefix = "工作记忆" + + def __init__(self, subheartflow_id: str): + super().__init__() + + self.subheartflow_id = subheartflow_id + + self.llm_model = LLMRequest( + model=global_config.llm_sub_heartflow, + temperature=global_config.llm_sub_heartflow["temp"], + max_tokens=800, + request_type="working_memory", + ) + + name = chat_manager.get_stream_name(self.subheartflow_id) + self.log_prefix = f"[{name}] " + + + + async def process_info( + self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos + ) -> List[InfoBase]: + """处理信息对象 + + Args: + *infos: 可变数量的InfoBase类型的信息对象 + + Returns: + List[InfoBase]: 处理后的结构化信息列表 + """ + working_memory = None + chat_info = "" + try: + for observation in observations: + if isinstance(observation, WorkingMemoryObservation): + working_memory = observation.get_observe_info() + working_memory_obs = observation + if isinstance(observation, ChattingObservation): + chat_info = observation.get_observe_info() + # chat_info_truncate = observation.talking_message_str_truncate + + if not working_memory: + logger.warning(f"{self.log_prefix} 没有找到工作记忆对象") + mind_info = MindInfo() + return [mind_info] + except Exception as e: + logger.error(f"{self.log_prefix} 处理观察时出错: {e}") + logger.error(traceback.format_exc()) + return [] + + all_memory = working_memory.get_all_memories() + memory_prompts = [] + for memory in all_memory: + memory_content = memory.data + memory_summary = memory.summary + memory_id = memory.id + memory_brief = memory_summary.get("brief") + memory_detailed = memory_summary.get("detailed") + memory_keypoints = memory_summary.get("keypoints") + memory_events = memory_summary.get("events") + memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n" + memory_prompts.append(memory_single_prompt) + + memory_choose_str = "".join(memory_prompts) + + # 使用提示模板进行处理 + prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format( + bot_name=global_config.BOT_NICKNAME, + time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + chat_observe_info=chat_info, + memory_str=memory_choose_str + ) + + # 调用LLM处理记忆 + content = "" + try: + + logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}") + + + content, _ = await self.llm_model.generate_response_async(prompt=prompt) + if not content: + logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。") + except Exception as e: + logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") + logger.error(traceback.format_exc()) + + # 解析LLM返回的JSON + try: + result = repair_json(content) + if isinstance(result, str): + result = json.loads(result) + if not isinstance(result, dict): + logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败,结果不是字典类型: {type(result)}") + return [] + + selected_memory_ids = result.get("selected_memory_ids", []) + new_memory = result.get("new_memory", "") + merge_memory = result.get("merge_memory", []) + except Exception as e: + logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}") + logger.error(traceback.format_exc()) + return [] + + logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}") + + # 根据selected_memory_ids,调取记忆 + memory_str = "" + if selected_memory_ids: + for memory_id in selected_memory_ids: + memory = await working_memory.retrieve_memory(memory_id) + if memory: + memory_content = memory.data + memory_summary = memory.summary + memory_id = memory.id + memory_brief = memory_summary.get("brief") + memory_detailed = memory_summary.get("detailed") + memory_keypoints = memory_summary.get("keypoints") + memory_events = memory_summary.get("events") + for keypoint in memory_keypoints: + memory_str += f"记忆要点:{keypoint}\n" + for event in memory_events: + memory_str += f"记忆事件:{event}\n" + # memory_str += f"记忆摘要:{memory_detailed}\n" + # memory_str += f"记忆主题:{memory_brief}\n" + + + working_memory_info = WorkingMemoryInfo() + if memory_str: + working_memory_info.add_working_memory(memory_str) + logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}") + else: + logger.warning(f"{self.log_prefix} 没有找到工作记忆") + + # 根据聊天内容添加新记忆 + if new_memory: + # 使用异步方式添加新记忆,不阻塞主流程 + logger.debug(f"{self.log_prefix} {new_memory}新记忆: ") + asyncio.create_task(self.add_memory_async(working_memory, chat_info)) + + if merge_memory: + for merge_pairs in merge_memory: + memory1 = await working_memory.retrieve_memory(merge_pairs[0]) + memory2 = await working_memory.retrieve_memory(merge_pairs[1]) + if memory1 and memory2: + memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n" + memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n" + asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1])) + + return [working_memory_info] + + async def add_memory_async(self, working_memory: WorkingMemory, content: str): + """异步添加记忆,不阻塞主流程 + + Args: + working_memory: 工作记忆对象 + content: 记忆内容 + """ + try: + await working_memory.add_memory(content=content, from_source="chat_text") + logger.debug(f"{self.log_prefix} 异步添加新记忆成功: {content[:30]}...") + except Exception as e: + logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}") + logger.error(traceback.format_exc()) + + async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str): + """异步合并记忆,不阻塞主流程 + + Args: + working_memory: 工作记忆对象 + memory_str: 记忆内容 + """ + try: + merged_memory = await working_memory.merge_memory(memory_id1, memory_id2) + logger.debug(f"{self.log_prefix} 异步合并记忆成功: {memory_id1} 和 {memory_id2}...") + logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.summary.get('brief')}") + logger.debug(f"{self.log_prefix} 合并后的记忆详情: {merged_memory.summary.get('detailed')}") + logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('keypoints')}") + logger.debug(f"{self.log_prefix} 合并后的记忆事件: {merged_memory.summary.get('events')}") + + except Exception as e: + logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}") + logger.error(traceback.format_exc()) + + +init_prompt() diff --git a/src/chat/focus_chat/memory_activator.py b/src/chat/focus_chat/memory_activator.py index 2d7fea034..dae310c06 100644 --- a/src/chat/focus_chat/memory_activator.py +++ b/src/chat/focus_chat/memory_activator.py @@ -1,5 +1,5 @@ from src.chat.heart_flow.observation.chatting_observation import ChattingObservation -from src.chat.heart_flow.observation.working_observation import WorkingObservation +from src.chat.heart_flow.observation.structure_observation import StructureObservation from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.models.utils_model import LLMRequest from src.config.config import global_config @@ -53,7 +53,7 @@ class MemoryActivator: for observation in observations: if isinstance(observation, ChattingObservation): obs_info_text += observation.get_observe_info() - elif isinstance(observation, WorkingObservation): + elif isinstance(observation, StructureObservation): working_info = observation.get_observe_info() for working_info_item in working_info: obs_info_text += f"{working_info_item['type']}: {working_info_item['content']}\n" diff --git a/src/chat/focus_chat/planners/action_manager.py b/src/chat/focus_chat/planners/action_manager.py index 72ff4a73e..02c77c2b6 100644 --- a/src/chat/focus_chat/planners/action_manager.py +++ b/src/chat/focus_chat/planners/action_manager.py @@ -77,10 +77,10 @@ class ActionManager: if is_default: self._default_actions[action_name] = action_info - logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") - logger.info(f"默认动作: {list(self._default_actions.keys())}") - for action_name, action_info in self._default_actions.items(): - logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") + # logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") + # logger.info(f"默认动作: {list(self._default_actions.keys())}") + # for action_name, action_info in self._default_actions.items(): + # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") except Exception as e: logger.error(f"加载已注册动作失败: {e}") diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 51e3b8eaa..6452ecb0f 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -24,13 +24,13 @@ class ReplyAction(BaseAction): action_description: str = "表达想法,可以只包含文本、表情或两者都有" action_parameters: dict[str:str] = { "text": "你想要表达的内容(可选)", - "emojis": "描述当前使用表情包的场景(可选)", + "emojis": "描述当前使用表情包的场景,一段话描述(可选)", "target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)", } action_require: list[str] = [ "有实质性内容需要表达", "有人提到你,但你还没有回应他", - "在合适的时候添加表情(不要总是添加)", + "在合适的时候添加表情(不要总是添加),表情描述要详细,描述当前场景,一段话描述", "如果你有明确的,要回复特定某人的某句话,或者你想回复较早的消息,请在target中指定那句话的原始文本", "一次只回复一个人,一次只回复一个话题,突出重点", "如果是自己发的消息想继续,需自然衔接", diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index 79044a5a6..dba9d4b1a 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -24,7 +24,8 @@ def init_prompt(): Prompt( """{extra_info_block} -你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话: +你需要基于以下信息决定如何参与对话 +这些信息可能会有冲突,请你整合这些信息,并选择一个最合适的action: {chat_content_block} {mind_info_block} @@ -92,7 +93,7 @@ class ActionPlanner: extra_info: list[str] = [] for info in all_plan_info: if isinstance(info, ObsInfo): - logger.debug(f"{self.log_prefix} 观察信息: {info}") + # logger.debug(f"{self.log_prefix} 观察信息: {info}") observed_messages = info.get_talking_message() observed_messages_str = info.get_talking_message_str_truncate() chat_type = info.get_chat_type() @@ -101,15 +102,16 @@ class ActionPlanner: else: is_group_chat = False elif isinstance(info, MindInfo): - logger.debug(f"{self.log_prefix} 思维信息: {info}") + # logger.debug(f"{self.log_prefix} 思维信息: {info}") current_mind = info.get_current_mind() elif isinstance(info, CycleInfo): - logger.debug(f"{self.log_prefix} 循环信息: {info}") + # logger.debug(f"{self.log_prefix} 循环信息: {info}") cycle_info = info.get_observe_info() elif isinstance(info, StructuredInfo): - logger.debug(f"{self.log_prefix} 结构化信息: {info}") + # logger.debug(f"{self.log_prefix} 结构化信息: {info}") structured_info = info.get_data() else: + logger.debug(f"{self.log_prefix} 其他信息: {info}") extra_info.append(info.get_processed_info()) current_available_actions = self.action_manager.get_using_actions() diff --git a/src/chat/focus_chat/working_memory/memory_item.py b/src/chat/focus_chat/working_memory/memory_item.py new file mode 100644 index 000000000..f922eff8f --- /dev/null +++ b/src/chat/focus_chat/working_memory/memory_item.py @@ -0,0 +1,119 @@ +from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple +import time +import uuid +import traceback +import random +import string +from json_repair import repair_json +from rich.traceback import install +from src.common.logger_manager import get_logger +from src.chat.models.utils_model import LLMRequest +from src.config.config import global_config + + +class MemoryItem: + """记忆项类,用于存储单个记忆的所有相关信息""" + + def __init__(self, data: Any, from_source: str = "", tags: Optional[List[str]] = None): + """ + 初始化记忆项 + + Args: + data: 记忆数据 + from_source: 数据来源 + tags: 数据标签列表 + """ + # 生成可读ID:时间戳_随机字符串 + timestamp = int(time.time()) + random_str = ''.join(random.choices(string.ascii_lowercase + string.digits, k=2)) + self.id = f"{timestamp}_{random_str}" + self.data = data + self.data_type = type(data) + self.from_source = from_source + self.tags = set(tags) if tags else set() + self.timestamp = time.time() + # 修改summary的结构说明,用于存储可能的总结信息 + # summary结构:{ + # "brief": "记忆内容主题", + # "detailed": "记忆内容概括", + # "keypoints": ["关键概念1", "关键概念2"], + # "events": ["事件1", "事件2"] + # } + self.summary = None + + # 记忆精简次数 + self.compress_count = 0 + + # 记忆提取次数 + self.retrieval_count = 0 + + # 记忆强度 (初始为10) + self.memory_strength = 10.0 + + # 记忆操作历史记录 + # 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...] + self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)] + + def add_tag(self, tag: str) -> None: + """添加标签""" + self.tags.add(tag) + + def remove_tag(self, tag: str) -> None: + """移除标签""" + if tag in self.tags: + self.tags.remove(tag) + + def has_tag(self, tag: str) -> bool: + """检查是否有特定标签""" + return tag in self.tags + + def has_all_tags(self, tags: List[str]) -> bool: + """检查是否有所有指定的标签""" + return all(tag in self.tags for tag in tags) + + def matches_source(self, source: str) -> bool: + """检查来源是否匹配""" + return self.from_source == source + + def set_summary(self, summary: Dict[str, Any]) -> None: + """设置总结信息""" + self.summary = summary + + def increase_strength(self, amount: float) -> None: + """增加记忆强度""" + self.memory_strength = min(10.0, self.memory_strength + amount) + # 记录操作历史 + self.record_operation("strengthen") + + def decrease_strength(self, amount: float) -> None: + """减少记忆强度""" + self.memory_strength = max(0.1, self.memory_strength - amount) + # 记录操作历史 + self.record_operation("weaken") + + def increase_compress_count(self) -> None: + """增加精简次数并减弱记忆强度""" + self.compress_count += 1 + # 记录操作历史 + self.record_operation("compress") + + def record_retrieval(self) -> None: + """记录记忆被提取的情况""" + self.retrieval_count += 1 + # 提取后强度翻倍 + self.memory_strength = min(10.0, self.memory_strength * 2) + # 记录操作历史 + self.record_operation("retrieval") + + def record_operation(self, operation_type: str) -> None: + """记录操作历史""" + current_time = time.time() + self.history.append((operation_type, current_time, self.compress_count, self.memory_strength)) + + def to_tuple(self) -> Tuple[Any, str, Set[str], float, str]: + """转换为元组格式(为了兼容性)""" + return (self.data, self.from_source, self.tags, self.timestamp, self.id) + + def is_memory_valid(self) -> bool: + """检查记忆是否有效(强度是否大于等于1)""" + return self.memory_strength >= 1.0 \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/memory_manager.py b/src/chat/focus_chat/working_memory/memory_manager.py new file mode 100644 index 000000000..d99488378 --- /dev/null +++ b/src/chat/focus_chat/working_memory/memory_manager.py @@ -0,0 +1,798 @@ +from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple +import time +import uuid +import traceback +from json_repair import repair_json +from rich.traceback import install +from src.common.logger_manager import get_logger +from src.chat.models.utils_model import LLMRequest +from src.config.config import global_config +from src.chat.focus_chat.working_memory.memory_item import MemoryItem +import json # 添加json模块导入 + + +install(extra_lines=3) +logger = get_logger("working_memory") + +T = TypeVar('T') + + +class MemoryManager: + def __init__(self, chat_id: str): + """ + 初始化工作记忆 + + Args: + chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天 + """ + # 关联的聊天ID + self._chat_id = chat_id + + # 主存储: 数据类型 -> 记忆项列表 + self._memory: Dict[Type, List[MemoryItem]] = {} + + # ID到记忆项的映射 + self._id_map: Dict[str, MemoryItem] = {} + + self.llm_summarizer = LLMRequest( + model=global_config.llm_summary, + temperature=0.3, + max_tokens=512, + request_type="memory_summarization" + ) + + @property + def chat_id(self) -> str: + """获取关联的聊天ID""" + return self._chat_id + + @chat_id.setter + def chat_id(self, value: str): + """设置关联的聊天ID""" + self._chat_id = value + + def push_item(self, memory_item: MemoryItem) -> str: + """ + 推送一个已创建的记忆项到工作记忆中 + + Args: + memory_item: 要存储的记忆项 + + Returns: + 记忆项的ID + """ + data_type = memory_item.data_type + + # 确保存在该类型的存储列表 + if data_type not in self._memory: + self._memory[data_type] = [] + + # 添加到内存和ID映射 + self._memory[data_type].append(memory_item) + self._id_map[memory_item.id] = memory_item + + return memory_item.id + + async def push_with_summary(self, data: T, from_source: str = "", tags: Optional[List[str]] = None) -> MemoryItem: + """ + 推送一段有类型的信息到工作记忆中,并自动生成总结 + + Args: + data: 要存储的数据 + from_source: 数据来源 + tags: 数据标签列表 + + Returns: + 包含原始数据和总结信息的字典 + """ + # 如果数据是字符串类型,则先进行总结 + if isinstance(data, str): + # 先生成总结 + summary = await self.summarize_memory_item(data) + + # 准备标签 + memory_tags = list(tags) if tags else [] + + # 创建记忆项 + memory_item = MemoryItem(data, from_source, memory_tags) + + # 将总结信息保存到记忆项中 + memory_item.set_summary(summary) + + # 推送记忆项 + self.push_item(memory_item) + + return memory_item + else: + # 非字符串类型,直接创建并推送记忆项 + memory_item = MemoryItem(data, from_source, tags) + self.push_item(memory_item) + + return memory_item + + def get_by_id(self, memory_id: str) -> Optional[MemoryItem]: + """ + 通过ID获取记忆项 + + Args: + memory_id: 记忆项ID + + Returns: + 找到的记忆项,如果不存在则返回None + """ + memory_item = self._id_map.get(memory_id) + if memory_item: + + # 检查记忆强度,如果小于1则删除 + if not memory_item.is_memory_valid(): + print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除") + self.delete(memory_id) + return None + + return memory_item + + def get_all_items(self) -> List[MemoryItem]: + """获取所有记忆项""" + return list(self._id_map.values()) + + def find_items(self, + data_type: Optional[Type] = None, + source: Optional[str] = None, + tags: Optional[List[str]] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + memory_id: Optional[str] = None, + limit: Optional[int] = None, + newest_first: bool = False, + min_strength: float = 0.0) -> List[MemoryItem]: + """ + 按条件查找记忆项 + + Args: + data_type: 要查找的数据类型 + source: 数据来源 + tags: 必须包含的标签列表 + start_time: 开始时间戳 + end_time: 结束时间戳 + memory_id: 特定记忆项ID + limit: 返回结果的最大数量 + newest_first: 是否按最新优先排序 + min_strength: 最小记忆强度 + + Returns: + 符合条件的记忆项列表 + """ + # 如果提供了特定ID,直接查找 + if memory_id: + item = self.get_by_id(memory_id) + return [item] if item else [] + + results = [] + + # 确定要搜索的类型列表 + types_to_search = [data_type] if data_type else list(self._memory.keys()) + + # 对每个类型进行搜索 + for typ in types_to_search: + if typ not in self._memory: + continue + + # 获取该类型的所有项目 + items = self._memory[typ] + + # 如果需要最新优先,则反转遍历顺序 + if newest_first: + items_to_check = list(reversed(items)) + else: + items_to_check = items + + # 遍历项目 + for item in items_to_check: + # 检查来源是否匹配 + if source is not None and not item.matches_source(source): + continue + + # 检查标签是否匹配 + if tags is not None and not item.has_all_tags(tags): + continue + + # 检查时间范围 + if start_time is not None and item.timestamp < start_time: + continue + if end_time is not None and item.timestamp > end_time: + continue + + # 检查记忆强度 + if min_strength > 0 and item.memory_strength < min_strength: + continue + + # 所有条件都满足,添加到结果中 + results.append(item) + + # 如果达到限制数量,提前返回 + if limit is not None and len(results) >= limit: + return results + + return results + + async def summarize_memory_item(self, content: str) -> Dict[str, Any]: + """ + 使用LLM总结记忆项 + + Args: + content: 需要总结的内容 + + Returns: + 包含总结、概括、关键概念和事件的字典 + """ + prompt = f"""请对以下内容进行总结,总结成记忆,输出四部分: +1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么 +2. 记忆内容概括(200字以内):让用户可以了解记忆内容的大致内容 +3. 关键概念和知识(keypoints):多条,提取关键的概念、知识点和关键词,要包含对概念的解释 +4. 事件描述(events):多条,描述谁(人物)在什么时候(时间)做了什么(事件) + +内容: +{content} + +请按以下JSON格式输出: +```json +{{ + "brief": "记忆内容主题(20字以内)", + "detailed": "记忆内容概括(200字以内)", + "keypoints": [ + "概念1:解释", + "概念2:解释", + ... + ], + "events": [ + "事件1:谁在什么时候做了什么", + "事件2:谁在什么时候做了什么", + ... + ] +}} +``` +请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 +""" + default_summary = { + "brief": "主题未知的记忆", + "detailed": "大致内容未知的记忆", + "keypoints": ["未知的概念"], + "events": ["未知的事件"] + } + + try: + # 调用LLM生成总结 + response, _ = await self.llm_summarizer.generate_response_async(prompt) + + # 使用repair_json解析响应 + try: + # 使用repair_json修复JSON格式 + fixed_json_string = repair_json(response) + + # 如果repair_json返回的是字符串,需要解析为Python对象 + if isinstance(fixed_json_string, str): + try: + json_result = json.loads(fixed_json_string) + except json.JSONDecodeError as decode_error: + logger.error(f"JSON解析错误: {str(decode_error)}") + return default_summary + else: + # 如果repair_json直接返回了字典对象,直接使用 + json_result = fixed_json_string + + # 进行额外的类型检查 + if not isinstance(json_result, dict): + logger.error(f"修复后的JSON不是字典类型: {type(json_result)}") + return default_summary + + # 确保所有必要字段都存在且类型正确 + if "brief" not in json_result or not isinstance(json_result["brief"], str): + json_result["brief"] = "主题未知的记忆" + + if "detailed" not in json_result or not isinstance(json_result["detailed"], str): + json_result["detailed"] = "大致内容未知的记忆" + + # 处理关键概念 + if "keypoints" not in json_result or not isinstance(json_result["keypoints"], list): + json_result["keypoints"] = ["未知的概念"] + else: + # 确保keypoints中的每个项目都是字符串 + json_result["keypoints"] = [ + str(point) for point in json_result["keypoints"] + if point is not None + ] + if not json_result["keypoints"]: + json_result["keypoints"] = ["未知的概念"] + + # 处理事件 + if "events" not in json_result or not isinstance(json_result["events"], list): + json_result["events"] = ["未知的事件"] + else: + # 确保events中的每个项目都是字符串 + json_result["events"] = [ + str(event) for event in json_result["events"] + if event is not None + ] + if not json_result["events"]: + json_result["events"] = ["未知的事件"] + + # 兼容旧版,将keypoints和events合并到key_points中 + json_result["key_points"] = json_result["keypoints"] + json_result["events"] + + return json_result + + except Exception as json_error: + logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要") + # 返回默认结构 + return default_summary + + except Exception as e: + # 出错时返回简单的结构 + logger.error(f"生成总结时出错: {str(e)}") + return default_summary + + async def refine_memory(self, + memory_id: str, + requirements: str = "") -> Dict[str, Any]: + """ + 对记忆进行精简操作,根据要求修改要点、总结和概括 + + Args: + memory_id: 记忆ID + requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点 + + Returns: + 修改后的记忆总结字典 + """ + # 获取指定ID的记忆项 + logger.info(f"精简记忆: {memory_id}") + memory_item = self.get_by_id(memory_id) + if not memory_item: + raise ValueError(f"未找到ID为{memory_id}的记忆项") + + # 增加精简次数 + memory_item.increase_compress_count() + + summary = memory_item.summary + + # 使用LLM根据要求对总结、概括和要点进行精简修改 + prompt = f""" +请根据以下要求,对记忆内容的主题、概括、关键概念和事件进行精简,模拟记忆的遗忘过程: +要求:{requirements} +你可以随机对关键概念和事件进行压缩,模糊或者丢弃,修改后,同样修改主题和概括 + +目前主题:{summary["brief"]} + +目前概括:{summary["detailed"]} + +目前关键概念: +{chr(10).join([f"- {point}" for point in summary.get("keypoints", [])])} + +目前事件: +{chr(10).join([f"- {point}" for point in summary.get("events", [])])} + +请生成修改后的主题、概括、关键概念和事件,遵循以下格式: +```json +{{ + "brief": "修改后的主题(20字以内)", + "detailed": "修改后的概括(200字以内)", + "keypoints": [ + "修改后的概念1:解释", + "修改后的概念2:解释" + ], + "events": [ + "修改后的事件1:谁在什么时候做了什么", + "修改后的事件2:谁在什么时候做了什么" + ] +}} +``` +请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 +""" + # 检查summary中是否有旧版结构,转换为新版结构 + if "keypoints" not in summary and "events" not in summary and "key_points" in summary: + # 尝试区分key_points中的keypoints和events + # 简单地将前半部分视为keypoints,后半部分视为events + key_points = summary.get("key_points", []) + halfway = len(key_points) // 2 + summary["keypoints"] = key_points[:halfway] or ["未知的概念"] + summary["events"] = key_points[halfway:] or ["未知的事件"] + + # 定义默认的精简结果 + default_refined = { + "brief": summary["brief"], + "detailed": summary["detailed"], + "keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念 + "events": summary.get("events", ["未知的事件"])[:1] # 默认只保留第一个事件 + } + + try: + # 调用LLM修改总结、概括和要点 + response, _ = await self.llm_summarizer.generate_response_async(prompt) + logger.info(f"精简记忆响应: {response}") + # 使用repair_json处理响应 + try: + # 修复JSON格式 + fixed_json_string = repair_json(response) + + # 将修复后的字符串解析为Python对象 + if isinstance(fixed_json_string, str): + try: + refined_data = json.loads(fixed_json_string) + except json.JSONDecodeError as decode_error: + logger.error(f"JSON解析错误: {str(decode_error)}") + refined_data = default_refined + else: + # 如果repair_json直接返回了字典对象,直接使用 + refined_data = fixed_json_string + + # 确保是字典类型 + if not isinstance(refined_data, dict): + logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}") + refined_data = default_refined + + # 更新总结、概括 + summary["brief"] = refined_data.get("brief", "主题未知的记忆") + summary["detailed"] = refined_data.get("detailed", "大致内容未知的记忆") + + # 更新关键概念 + keypoints = refined_data.get("keypoints", []) + if isinstance(keypoints, list) and keypoints: + # 确保所有关键概念都是字符串 + summary["keypoints"] = [str(point) for point in keypoints if point is not None] + else: + # 如果keypoints不是列表或为空,使用默认值 + summary["keypoints"] = ["主要概念已遗忘"] + + # 更新事件 + events = refined_data.get("events", []) + if isinstance(events, list) and events: + # 确保所有事件都是字符串 + summary["events"] = [str(event) for event in events if event is not None] + else: + # 如果events不是列表或为空,使用默认值 + summary["events"] = ["事件细节已遗忘"] + + # 兼容旧版,维护key_points + summary["key_points"] = summary["keypoints"] + summary["events"] + + except Exception as e: + logger.error(f"精简记忆出错: {str(e)}") + traceback.print_exc() + + # 出错时使用简化的默认精简 + summary["brief"] = summary["brief"] + " (已简化)" + summary["keypoints"] = summary.get("keypoints", ["未知的概念"])[:1] + summary["events"] = summary.get("events", ["未知的事件"])[:1] + summary["key_points"] = summary["keypoints"] + summary["events"] + + except Exception as e: + logger.error(f"精简记忆调用LLM出错: {str(e)}") + traceback.print_exc() + + # 更新原记忆项的总结 + memory_item.set_summary(summary) + + return memory_item + + def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool: + """ + 使单个记忆衰减 + + Args: + memory_id: 记忆ID + decay_factor: 衰减因子(0-1之间) + + Returns: + 是否成功衰减 + """ + memory_item = self.get_by_id(memory_id) + if not memory_item: + return False + + # 计算衰减量(当前强度 * (1-衰减因子)) + old_strength = memory_item.memory_strength + decay_amount = old_strength * (1 - decay_factor) + + # 更新强度 + memory_item.memory_strength = decay_amount + + return True + + + def delete(self, memory_id: str) -> bool: + """ + 删除指定ID的记忆项 + + Args: + memory_id: 要删除的记忆项ID + + Returns: + 是否成功删除 + """ + if memory_id not in self._id_map: + return False + + # 获取要删除的项 + item = self._id_map[memory_id] + + # 从内存中删除 + data_type = item.data_type + if data_type in self._memory: + self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id] + + # 从ID映射中删除 + del self._id_map[memory_id] + + return True + + def clear(self, data_type: Optional[Type] = None) -> None: + """ + 清除记忆中的数据 + + Args: + data_type: 要清除的数据类型,如果为None则清除所有数据 + """ + if data_type is None: + # 清除所有数据 + self._memory.clear() + self._id_map.clear() + elif data_type in self._memory: + # 清除指定类型的数据 + for item in self._memory[data_type]: + if item.id in self._id_map: + del self._id_map[item.id] + del self._memory[data_type] + + async def merge_memories(self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True) -> MemoryItem: + """ + 合并两个记忆项 + + Args: + memory_id1: 第一个记忆项ID + memory_id2: 第二个记忆项ID + reason: 合并原因 + delete_originals: 是否删除原始记忆,默认为True + + Returns: + 包含合并后的记忆信息的字典 + """ + # 获取两个记忆项 + memory_item1 = self.get_by_id(memory_id1) + memory_item2 = self.get_by_id(memory_id2) + + if not memory_item1 or not memory_item2: + raise ValueError("无法找到指定的记忆项") + + content1 = memory_item1.data + content2 = memory_item2.data + + # 获取记忆的摘要信息(如果有) + summary1 = memory_item1.summary + summary2 = memory_item2.summary + + # 构建合并提示 + prompt = f""" +请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。 +合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。 + +合并原因:{reason} +""" + + # 如果有摘要信息,添加到提示中 + if summary1: + prompt += f"记忆1主题:{summary1['brief']}\n" + prompt += f"记忆1概括:{summary1['detailed']}\n" + + if "keypoints" in summary1: + prompt += f"记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1['keypoints']]) + "\n\n" + + if "events" in summary1: + prompt += f"记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1['events']]) + "\n\n" + elif "key_points" in summary1: + prompt += f"记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1['key_points']]) + "\n\n" + + if summary2: + prompt += f"记忆2主题:{summary2['brief']}\n" + prompt += f"记忆2概括:{summary2['detailed']}\n" + + if "keypoints" in summary2: + prompt += f"记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2['keypoints']]) + "\n\n" + + if "events" in summary2: + prompt += f"记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2['events']]) + "\n\n" + elif "key_points" in summary2: + prompt += f"记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2['key_points']]) + "\n\n" + + # 添加记忆原始内容 + prompt += f""" +记忆1原始内容: +{content1} + +记忆2原始内容: +{content2} + +请按以下JSON格式输出合并结果: +```json +{{ + "content": "合并后的记忆内容文本(尽可能保留原信息,但去除重复)", + "brief": "合并后的主题(20字以内)", + "detailed": "合并后的概括(200字以内)", + "keypoints": [ + "合并后的概念1:解释", + "合并后的概念2:解释", + "合并后的概念3:解释" + ], + "events": [ + "合并后的事件1:谁在什么时候做了什么", + "合并后的事件2:谁在什么时候做了什么" + ] +}} +``` +请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 +""" + + # 默认合并结果 + default_merged = { + "content": f"{content1}\n\n{content2}", + "brief": f"合并:{summary1['brief']} + {summary2['brief']}", + "detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}", + "keypoints": [], + "events": [] + } + + # 合并旧版key_points + if "key_points" in summary1: + default_merged["keypoints"].extend(summary1.get("keypoints", [])) + default_merged["events"].extend(summary1.get("events", [])) + # 如果没有新的结构,尝试从旧结构分离 + if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary1: + key_points = summary1["key_points"] + halfway = len(key_points) // 2 + default_merged["keypoints"].extend(key_points[:halfway]) + default_merged["events"].extend(key_points[halfway:]) + + if "key_points" in summary2: + default_merged["keypoints"].extend(summary2.get("keypoints", [])) + default_merged["events"].extend(summary2.get("events", [])) + # 如果没有新的结构,尝试从旧结构分离 + if not default_merged["keypoints"] and not default_merged["events"] and "key_points" in summary2: + key_points = summary2["key_points"] + halfway = len(key_points) // 2 + default_merged["keypoints"].extend(key_points[:halfway]) + default_merged["events"].extend(key_points[halfway:]) + + # 确保列表不为空 + if not default_merged["keypoints"]: + default_merged["keypoints"] = ["合并的关键概念"] + if not default_merged["events"]: + default_merged["events"] = ["合并的事件"] + + # 添加key_points兼容 + default_merged["key_points"] = default_merged["keypoints"] + default_merged["events"] + + try: + # 调用LLM合并记忆 + response, _ = await self.llm_summarizer.generate_response_async(prompt) + + # 处理LLM返回的合并结果 + try: + # 修复JSON格式 + fixed_json_string = repair_json(response) + + # 将修复后的字符串解析为Python对象 + if isinstance(fixed_json_string, str): + try: + merged_data = json.loads(fixed_json_string) + except json.JSONDecodeError as decode_error: + logger.error(f"JSON解析错误: {str(decode_error)}") + merged_data = default_merged + else: + # 如果repair_json直接返回了字典对象,直接使用 + merged_data = fixed_json_string + + # 确保是字典类型 + if not isinstance(merged_data, dict): + logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}") + merged_data = default_merged + + # 确保所有必要字段都存在且类型正确 + if "content" not in merged_data or not isinstance(merged_data["content"], str): + merged_data["content"] = default_merged["content"] + + if "brief" not in merged_data or not isinstance(merged_data["brief"], str): + merged_data["brief"] = default_merged["brief"] + + if "detailed" not in merged_data or not isinstance(merged_data["detailed"], str): + merged_data["detailed"] = default_merged["detailed"] + + # 处理关键概念 + if "keypoints" not in merged_data or not isinstance(merged_data["keypoints"], list): + merged_data["keypoints"] = default_merged["keypoints"] + else: + # 确保keypoints中的每个项目都是字符串 + merged_data["keypoints"] = [ + str(point) for point in merged_data["keypoints"] + if point is not None + ] + if not merged_data["keypoints"]: + merged_data["keypoints"] = ["合并的关键概念"] + + # 处理事件 + if "events" not in merged_data or not isinstance(merged_data["events"], list): + merged_data["events"] = default_merged["events"] + else: + # 确保events中的每个项目都是字符串 + merged_data["events"] = [ + str(event) for event in merged_data["events"] + if event is not None + ] + if not merged_data["events"]: + merged_data["events"] = ["合并的事件"] + + # 添加key_points兼容 + merged_data["key_points"] = merged_data["keypoints"] + merged_data["events"] + + except Exception as e: + logger.error(f"合并记忆时处理JSON出错: {str(e)}") + traceback.print_exc() + merged_data = default_merged + except Exception as e: + logger.error(f"合并记忆调用LLM出错: {str(e)}") + traceback.print_exc() + merged_data = default_merged + + # 创建新的记忆项 + # 合并记忆项的标签 + merged_tags = memory_item1.tags.union(memory_item2.tags) + + # 取两个记忆项中更强的来源 + merged_source = memory_item1.from_source if memory_item1.memory_strength >= memory_item2.memory_strength else memory_item2.from_source + + # 创建新的记忆项 + merged_memory = MemoryItem( + data=merged_data["content"], + from_source=merged_source, + tags=list(merged_tags) + ) + + # 设置合并后的摘要 + summary = { + "brief": merged_data["brief"], + "detailed": merged_data["detailed"], + "keypoints": merged_data["keypoints"], + "events": merged_data["events"], + "key_points": merged_data["key_points"] + } + merged_memory.set_summary(summary) + + # 记忆强度取两者最大值 + merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength) + + # 添加到存储中 + self.push_item(merged_memory) + + # 如果需要,删除原始记忆 + if delete_originals: + self.delete(memory_id1) + self.delete(memory_id2) + + return merged_memory + + def delete_earliest_memory(self) -> bool: + """ + 删除最早的记忆项 + + Returns: + 是否成功删除 + """ + # 获取所有记忆项 + all_memories = self.get_all_items() + + if not all_memories: + return False + + # 按时间戳排序,找到最早的记忆项 + earliest_memory = min(all_memories, key=lambda item: item.timestamp) + + # 删除最早的记忆项 + return self.delete(earliest_memory.id) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/memory_file_loader.py b/src/chat/focus_chat/working_memory/test/memory_file_loader.py new file mode 100644 index 000000000..3aa997b82 --- /dev/null +++ b/src/chat/focus_chat/working_memory/test/memory_file_loader.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import asyncio +from typing import List, Dict, Any, Optional +from pathlib import Path + +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory +from src.chat.focus_chat.working_memory.memory_item import MemoryItem +from src.common.logger_manager import get_logger + +logger = get_logger("memory_loader") + +class MemoryFileLoader: + """从文件加载记忆内容的工具类""" + + def __init__(self, working_memory: WorkingMemory): + """ + 初始化记忆文件加载器 + + Args: + working_memory: 工作记忆实例 + """ + self.working_memory = working_memory + + async def load_from_directory(self, + directory_path: str, + file_pattern: str = "*.txt", + common_tags: List[str] = None, + source_prefix: str = "文件") -> List[MemoryItem]: + """ + 从指定目录加载符合模式的文件作为记忆 + + Args: + directory_path: 目录路径 + file_pattern: 文件模式(默认为*.txt) + common_tags: 所有记忆共有的标签 + source_prefix: 来源前缀 + + Returns: + 加载的记忆项列表 + """ + directory = Path(directory_path) + if not directory.exists() or not directory.is_dir(): + logger.error(f"目录不存在或不是有效目录: {directory_path}") + return [] + + # 获取文件列表 + files = list(directory.glob(file_pattern)) + if not files: + logger.warning(f"在目录 {directory_path} 中没有找到符合 {file_pattern} 的文件") + return [] + + logger.info(f"在目录 {directory_path} 中找到 {len(files)} 个符合条件的文件") + + # 加载文件内容为记忆 + loaded_memories = [] + for file_path in files: + try: + memory_item = await self._load_single_file( + file_path=str(file_path), + common_tags=common_tags, + source_prefix=source_prefix + ) + if memory_item: + loaded_memories.append(memory_item) + logger.info(f"成功加载记忆: {file_path.name}") + + except Exception as e: + logger.error(f"加载文件 {file_path} 失败: {str(e)}") + + logger.info(f"完成加载,共加载了 {len(loaded_memories)} 个记忆") + return loaded_memories + + async def _load_single_file(self, + file_path: str, + common_tags: Optional[List[str]] = None, + source_prefix: str = "文件") -> Optional[MemoryItem]: + """ + 加载单个文件作为记忆 + + Args: + file_path: 文件路径 + common_tags: 记忆共有的标签 + source_prefix: 来源前缀 + + Returns: + 记忆项,加载失败则返回None + """ + try: + # 读取文件内容 + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + if not content.strip(): + logger.warning(f"文件 {file_path} 内容为空") + return None + + # 准备标签和来源 + file_name = os.path.basename(file_path) + tags = list(common_tags) if common_tags else [] + tags.append(file_name) # 添加文件名作为标签 + + source = f"{source_prefix}_{file_name}" + + # 添加到工作记忆 + memory = await self.working_memory.add_memory( + content=content, + from_source=source, + tags=tags + ) + + return memory + + except Exception as e: + logger.error(f"加载文件 {file_path} 失败: {str(e)}") + return None + + +async def main(): + """示例使用""" + # 初始化工作记忆 + chat_id = "demo_chat" + working_memory = WorkingMemory(chat_id=chat_id) + + try: + # 初始化加载器 + loader = MemoryFileLoader(working_memory) + + # 加载当前目录中的txt文件 + current_dir = Path(__file__).parent + memories = await loader.load_from_directory( + directory_path=str(current_dir), + file_pattern="*.txt", + common_tags=["测试数据", "自动加载"], + source_prefix="测试文件" + ) + + # 显示加载结果 + print(f"共加载了 {len(memories)} 个记忆") + + # 获取并显示所有记忆的概要 + all_memories = working_memory.memory_manager.get_all_items() + for memory in all_memories: + print("\n" + "=" * 40) + print(f"记忆ID: {memory.id}") + print(f"来源: {memory.from_source}") + print(f"标签: {', '.join(memory.tags)}") + + if memory.summary: + print(f"\n主题: {memory.summary.get('brief', '无主题')}") + print(f"概述: {memory.summary.get('detailed', '无概述')}") + print("\n要点:") + for point in memory.summary.get('key_points', []): + print(f"- {point}") + else: + print("\n无摘要信息") + + print("=" * 40) + + finally: + # 关闭工作记忆 + await working_memory.shutdown() + + +if __name__ == "__main__": + # 运行示例 + asyncio.run(main()) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/run_memory_tests.py b/src/chat/focus_chat/working_memory/test/run_memory_tests.py new file mode 100644 index 000000000..d9299cf40 --- /dev/null +++ b/src/chat/focus_chat/working_memory/test/run_memory_tests.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import asyncio +import os +import sys +from pathlib import Path + +# 添加项目根目录到系统路径 +current_dir = Path(__file__).parent +project_root = current_dir.parent.parent.parent.parent.parent +sys.path.insert(0, str(project_root)) + +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory + +async def test_load_memories_from_files(): + """测试从文件加载记忆的功能""" + print("开始测试从文件加载记忆...") + + # 初始化工作记忆 + chat_id = "test_memory_load" + working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60) + + try: + # 获取测试文件列表 + test_dir = Path(__file__).parent + test_files = [ + os.path.join(test_dir, f) + for f in os.listdir(test_dir) + if f.endswith(".txt") + ] + + print(f"找到 {len(test_files)} 个测试文件") + + # 从每个文件加载记忆 + for file_path in test_files: + file_name = os.path.basename(file_path) + print(f"从文件 {file_name} 加载记忆...") + + # 读取文件内容 + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # 添加记忆 + memory = await working_memory.add_memory( + content=content, + from_source=f"文件_{file_name}", + tags=["测试文件", file_name] + ) + + print(f"已添加记忆: ID={memory.id}") + if memory.summary: + print(f"记忆概要: {memory.summary.get('brief', '无概要')}") + print(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点']))}") + print("-" * 50) + + # 获取所有记忆 + all_memories = working_memory.memory_manager.get_all_items() + print(f"\n成功加载 {len(all_memories)} 个记忆") + + # 测试检索记忆 + if all_memories: + print("\n测试检索第一个记忆...") + first_memory = all_memories[0] + retrieved = await working_memory.retrieve_memory(first_memory.id) + + if retrieved: + print(f"成功检索记忆: ID={retrieved.id}") + print(f"检索后强度: {retrieved.memory_strength} (初始为10.0)") + print(f"检索次数: {retrieved.retrieval_count}") + else: + print("检索失败") + + # 测试记忆衰减 + print("\n测试记忆衰减...") + for memory in all_memories: + print(f"记忆 {memory.id} 衰减前强度: {memory.memory_strength}") + + await working_memory.decay_all_memories(decay_factor=0.5) + + all_memories_after = working_memory.memory_manager.get_all_items() + for memory in all_memories_after: + print(f"记忆 {memory.id} 衰减后强度: {memory.memory_strength}") + + finally: + # 关闭工作记忆 + await working_memory.shutdown() + print("\n测试完成,已关闭工作记忆") + +if __name__ == "__main__": + # 运行测试 + asyncio.run(test_load_memories_from_files()) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/simulate_real_usage.py b/src/chat/focus_chat/working_memory/test/simulate_real_usage.py new file mode 100644 index 000000000..24cf5c70a --- /dev/null +++ b/src/chat/focus_chat/working_memory/test/simulate_real_usage.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import asyncio +import os +import sys +import time +import random +from pathlib import Path +from datetime import datetime + +# 添加项目根目录到系统路径 +current_dir = Path(__file__).parent +project_root = current_dir.parent.parent.parent.parent.parent +sys.path.insert(0, str(project_root)) + +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory +from src.chat.focus_chat.working_memory.memory_item import MemoryItem +from src.common.logger_manager import get_logger + +logger = get_logger("real_usage_simulation") + +class WorkingMemorySimulator: + """模拟工作记忆的真实使用场景""" + + def __init__(self, chat_id="real_usage_test", cycle_interval=20): + """ + 初始化模拟器 + + Args: + chat_id: 聊天ID + cycle_interval: 循环间隔时间(秒) + """ + self.chat_id = chat_id + self.cycle_interval = cycle_interval + self.working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=20, auto_decay_interval=60) + self.cycle_count = 0 + self.running = False + + # 获取测试文件路径 + self.test_files = self._get_test_files() + if not self.test_files: + raise FileNotFoundError("找不到测试文件,请确保test目录中有.txt文件") + + # 存储所有添加的记忆ID + self.memory_ids = [] + + async def start(self, total_cycles=5): + """ + 开始模拟循环 + + Args: + total_cycles: 总循环次数,设为None表示无限循环 + """ + self.running = True + logger.info(f"开始模拟真实使用场景,循环间隔: {self.cycle_interval}秒") + + try: + while self.running and (total_cycles is None or self.cycle_count < total_cycles): + self.cycle_count += 1 + logger.info(f"\n===== 开始第 {self.cycle_count} 次循环 =====") + + # 执行一次循环 + await self._run_one_cycle() + + # 如果还有更多循环,则等待 + if self.running and (total_cycles is None or self.cycle_count < total_cycles): + wait_time = self.cycle_interval + logger.info(f"等待 {wait_time} 秒后开始下一循环...") + await asyncio.sleep(wait_time) + + logger.info(f"模拟完成,共执行了 {self.cycle_count} 次循环") + + except KeyboardInterrupt: + logger.info("接收到中断信号,停止模拟") + except Exception as e: + logger.error(f"模拟过程中出错: {str(e)}", exc_info=True) + finally: + # 关闭工作记忆 + await self.working_memory.shutdown() + + def stop(self): + """停止模拟循环""" + self.running = False + logger.info("正在停止模拟...") + + async def _run_one_cycle(self): + """运行一次完整循环:先检索记忆,再添加新记忆""" + start_time = time.time() + + # 1. 先检索已有记忆(如果有) + await self._retrieve_memories() + + # 2. 添加新记忆 + await self._add_new_memory() + + # 3. 显示工作记忆状态 + await self._show_memory_status() + + # 计算循环耗时 + cycle_duration = time.time() - start_time + logger.info(f"第 {self.cycle_count} 次循环完成,耗时: {cycle_duration:.2f}秒") + + async def _retrieve_memories(self): + """检索现有记忆""" + # 如果有已保存的记忆ID,随机选择1-3个进行检索 + if self.memory_ids: + num_to_retrieve = min(len(self.memory_ids), random.randint(1, 3)) + retrieval_ids = random.sample(self.memory_ids, num_to_retrieve) + + logger.info(f"正在检索 {num_to_retrieve} 条记忆...") + + for memory_id in retrieval_ids: + memory = await self.working_memory.retrieve_memory(memory_id) + if memory: + logger.info(f"成功检索记忆 ID: {memory_id}") + logger.info(f" - 强度: {memory.memory_strength:.2f},检索次数: {memory.retrieval_count}") + if memory.summary: + logger.info(f" - 主题: {memory.summary.get('brief', '无主题')}") + else: + logger.warning(f"记忆 ID: {memory_id} 不存在或已被移除") + # 从ID列表中移除 + if memory_id in self.memory_ids: + self.memory_ids.remove(memory_id) + else: + logger.info("当前没有可检索的记忆") + + async def _add_new_memory(self): + """添加新记忆""" + # 随机选择一个测试文件作为记忆内容 + file_path = random.choice(self.test_files) + file_name = os.path.basename(file_path) + + try: + # 读取文件内容 + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # 添加时间戳,模拟不同内容 + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + content_with_timestamp = f"[{timestamp}] {content}" + + # 添加记忆 + logger.info(f"正在添加新记忆,来源: {file_name}") + memory = await self.working_memory.add_memory( + content=content_with_timestamp, + from_source=f"模拟_{file_name}", + tags=["模拟测试", f"循环{self.cycle_count}", file_name] + ) + + # 保存记忆ID + self.memory_ids.append(memory.id) + + # 显示记忆信息 + logger.info(f"已添加新记忆 ID: {memory.id}") + if memory.summary: + logger.info(f"记忆主题: {memory.summary.get('brief', '无主题')}") + logger.info(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点'])[:2])}...") + + except Exception as e: + logger.error(f"添加记忆失败: {str(e)}") + + async def _show_memory_status(self): + """显示当前工作记忆状态""" + all_memories = self.working_memory.memory_manager.get_all_items() + + logger.info(f"\n当前工作记忆状态:") + logger.info(f"记忆总数: {len(all_memories)}") + + # 按强度排序 + sorted_memories = sorted(all_memories, key=lambda x: x.memory_strength, reverse=True) + + logger.info("记忆强度排名 (前5项):") + for i, memory in enumerate(sorted_memories[:5], 1): + logger.info(f"{i}. ID: {memory.id}, 强度: {memory.memory_strength:.2f}, " + f"检索次数: {memory.retrieval_count}, " + f"主题: {memory.summary.get('brief', '无主题') if memory.summary else '无摘要'}") + + def _get_test_files(self): + """获取测试文件列表""" + test_dir = Path(__file__).parent + return [ + os.path.join(test_dir, f) + for f in os.listdir(test_dir) + if f.endswith(".txt") + ] + +async def main(): + """主函数""" + # 创建模拟器 + simulator = WorkingMemorySimulator(cycle_interval=20) # 设置20秒的循环间隔 + + # 设置运行5个循环 + await simulator.start(total_cycles=5) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/test_decay_removal.py b/src/chat/focus_chat/working_memory/test/test_decay_removal.py new file mode 100644 index 000000000..c114bc495 --- /dev/null +++ b/src/chat/focus_chat/working_memory/test/test_decay_removal.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import asyncio +import os +import sys +import time +from pathlib import Path + +# 添加项目根目录到系统路径 +current_dir = Path(__file__).parent +project_root = current_dir.parent.parent.parent.parent.parent +sys.path.insert(0, str(project_root)) + +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory +from src.chat.focus_chat.working_memory.test.memory_file_loader import MemoryFileLoader +from src.common.logger_manager import get_logger + +logger = get_logger("memory_decay_test") + +async def test_manual_decay_until_removal(): + """测试手动衰减直到记忆被自动移除""" + print("\n===== 测试手动衰减直到记忆被自动移除 =====") + + # 初始化工作记忆,设置较大的衰减间隔,避免自动衰减影响测试 + chat_id = "decay_test_manual" + working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=3600) + + try: + # 创建加载器并加载测试文件 + loader = MemoryFileLoader(working_memory) + test_dir = current_dir + + # 加载第一个测试文件作为记忆 + memories = await loader.load_from_directory( + directory_path=str(test_dir), + file_pattern="test1.txt", # 只加载test1.txt + common_tags=["测试", "衰减", "自动移除"], + source_prefix="衰减测试" + ) + + if not memories: + print("未能加载记忆文件,测试结束") + return + + # 获取加载的记忆 + memory = memories[0] + memory_id = memory.id + print(f"已加载测试记忆,ID: {memory_id}") + print(f"初始强度: {memory.memory_strength}") + if memory.summary: + print(f"记忆主题: {memory.summary.get('brief', '无主题')}") + + # 执行多次衰减,直到记忆被移除 + decay_count = 0 + decay_factor = 0.5 # 每次衰减为原来的一半 + + while True: + # 获取当前记忆 + current_memory = working_memory.memory_manager.get_by_id(memory_id) + + # 如果记忆已被移除,退出循环 + if current_memory is None: + print(f"记忆已在第 {decay_count} 次衰减后被自动移除!") + break + + # 输出当前强度 + print(f"衰减 {decay_count} 次后强度: {current_memory.memory_strength}") + + # 执行衰减 + await working_memory.decay_all_memories(decay_factor=decay_factor) + decay_count += 1 + + # 输出衰减后的详细信息 + after_memory = working_memory.memory_manager.get_by_id(memory_id) + if after_memory: + print(f"第 {decay_count} 次衰减结果: 强度={after_memory.memory_strength},压缩次数={after_memory.compress_count}") + if after_memory.summary: + print(f"记忆概要: {after_memory.summary.get('brief', '无概要')}") + print(f"记忆要点数量: {len(after_memory.summary.get('key_points', []))}") + else: + print(f"第 {decay_count} 次衰减结果: 记忆已被移除") + + # 防止无限循环 + if decay_count > 20: + print("达到最大衰减次数(20),退出测试。") + break + + # 短暂等待 + await asyncio.sleep(0.5) + + # 验证记忆是否真的被移除 + all_memories = working_memory.memory_manager.get_all_items() + print(f"剩余记忆数量: {len(all_memories)}") + if len(all_memories) == 0: + print("测试通过: 记忆在强度低于阈值后被成功移除。") + else: + print("测试失败: 记忆应该被移除但仍然存在。") + + finally: + await working_memory.shutdown() + +async def test_auto_decay(): + """测试自动衰减功能""" + print("\n===== 测试自动衰减功能 =====") + + # 初始化工作记忆,设置短的衰减间隔,便于测试 + chat_id = "decay_test_auto" + decay_interval = 3 # 3秒 + working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=decay_interval) + + try: + # 创建加载器并加载测试文件 + loader = MemoryFileLoader(working_memory) + test_dir = current_dir + + # 加载第二个测试文件作为记忆 + memories = await loader.load_from_directory( + directory_path=str(test_dir), + file_pattern="test1.txt", # 只加载test2.txt + common_tags=["测试", "自动衰减"], + source_prefix="自动衰减测试" + ) + + if not memories: + print("未能加载记忆文件,测试结束") + return + + # 获取加载的记忆 + memory = memories[0] + memory_id = memory.id + print(f"已加载测试记忆,ID: {memory_id}") + print(f"初始强度: {memory.memory_strength}") + if memory.summary: + print(f"记忆主题: {memory.summary.get('brief', '无主题')}") + print(f"记忆概要: {memory.summary.get('detailed', '无概要')}") + print(f"记忆要点: {memory.summary.get('keypoints', '无要点')}") + print(f"记忆事件: {memory.summary.get('events', '无事件')}") + # 观察自动衰减 + print(f"等待自动衰减任务执行 (间隔 {decay_interval} 秒)...") + + for i in range(3): # 观察3次自动衰减 + # 等待自动衰减发生 + await asyncio.sleep(decay_interval + 1) # 多等1秒确保任务执行 + + # 获取当前记忆 + current_memory = working_memory.memory_manager.get_by_id(memory_id) + + # 如果记忆已被移除,退出循环 + if current_memory is None: + print(f"记忆已在第 {i+1} 次自动衰减后被移除!") + break + + # 输出当前强度和详细信息 + print(f"第 {i+1} 次自动衰减后强度: {current_memory.memory_strength}") + print(f"自动衰减详细结果: 压缩次数={current_memory.compress_count}, 提取次数={current_memory.retrieval_count}") + if current_memory.summary: + print(f"记忆概要: {current_memory.summary.get('brief', '无概要')}") + + print(f"\n自动衰减测试结束。") + + # 验证自动衰减是否发生 + final_memory = working_memory.memory_manager.get_by_id(memory_id) + if final_memory is None: + print("记忆已被自动衰减移除。") + elif final_memory.memory_strength < memory.memory_strength: + print(f"自动衰减有效:初始强度 {memory.memory_strength} -> 最终强度 {final_memory.memory_strength}") + print(f"衰减历史记录: {final_memory.history}") + else: + print("测试失败:记忆强度未减少,自动衰减可能未生效。") + + finally: + await working_memory.shutdown() + +async def test_decay_and_retrieval_balance(): + """测试记忆衰减和检索的平衡""" + print("\n===== 测试记忆衰减和检索的平衡 =====") + + # 初始化工作记忆 + chat_id = "decay_retrieval_balance" + working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60) + + try: + # 创建加载器并加载测试文件 + loader = MemoryFileLoader(working_memory) + test_dir = current_dir + + # 加载第三个测试文件作为记忆 + memories = await loader.load_from_directory( + directory_path=str(test_dir), + file_pattern="test3.txt", # 只加载test3.txt + common_tags=["测试", "衰减", "检索"], + source_prefix="平衡测试" + ) + + if not memories: + print("未能加载记忆文件,测试结束") + return + + # 获取加载的记忆 + memory = memories[0] + memory_id = memory.id + print(f"已加载测试记忆,ID: {memory_id}") + print(f"初始强度: {memory.memory_strength}") + if memory.summary: + print(f"记忆主题: {memory.summary.get('brief', '无主题')}") + + # 先衰减几次 + print("\n开始衰减:") + for i in range(3): + await working_memory.decay_all_memories(decay_factor=0.5) + current = working_memory.memory_manager.get_by_id(memory_id) + if current: + print(f"衰减 {i+1} 次后强度: {current.memory_strength}") + print(f"衰减详细信息: 压缩次数={current.compress_count}, 历史操作数={len(current.history)}") + if current.summary: + print(f"记忆概要: {current.summary.get('brief', '无概要')}") + else: + print(f"记忆已在第 {i+1} 次衰减后被移除。") + break + + # 如果记忆还存在,则检索几次增强它 + current = working_memory.memory_manager.get_by_id(memory_id) + if current: + print("\n开始检索增强:") + for i in range(2): + retrieved = await working_memory.retrieve_memory(memory_id) + print(f"检索 {i+1} 次后强度: {retrieved.memory_strength}") + print(f"检索后详细信息: 提取次数={retrieved.retrieval_count}, 历史记录长度={len(retrieved.history)}") + + # 再次衰减几次,测试是否会被移除 + print("\n再次衰减:") + for i in range(5): + await working_memory.decay_all_memories(decay_factor=0.5) + current = working_memory.memory_manager.get_by_id(memory_id) + if current: + print(f"最终衰减 {i+1} 次后强度: {current.memory_strength}") + print(f"衰减详细结果: 压缩次数={current.compress_count}") + else: + print(f"记忆已在最终衰减第 {i+1} 次后被移除。") + break + + print("\n测试结束。") + + finally: + await working_memory.shutdown() + +async def test_multi_memories_decay(): + """测试多条记忆同时衰减""" + print("\n===== 测试多条记忆同时衰减 =====") + + # 初始化工作记忆 + chat_id = "multi_decay_test" + working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60) + + try: + # 创建加载器并加载所有测试文件 + loader = MemoryFileLoader(working_memory) + test_dir = current_dir + + # 加载所有测试文件作为记忆 + memories = await loader.load_from_directory( + directory_path=str(test_dir), + file_pattern="*.txt", + common_tags=["测试", "多记忆衰减"], + source_prefix="多记忆测试" + ) + + if not memories or len(memories) < 2: + print("未能加载足够的记忆文件,测试结束") + return + + # 显示已加载的记忆 + print(f"已加载 {len(memories)} 条记忆:") + for idx, mem in enumerate(memories): + print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 来源: {mem.from_source}") + if mem.summary: + print(f" 主题: {mem.summary.get('brief', '无主题')}") + + # 进行多次衰减测试 + print("\n开始多记忆衰减测试:") + for decay_round in range(5): + # 执行衰减 + await working_memory.decay_all_memories(decay_factor=0.5) + + # 获取并显示所有记忆 + all_memories = working_memory.memory_manager.get_all_items() + print(f"\n第 {decay_round+1} 次衰减后,剩余记忆数量: {len(all_memories)}") + + for idx, mem in enumerate(all_memories): + print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 压缩次数: {mem.compress_count}") + if mem.summary: + print(f" 概要: {mem.summary.get('brief', '无概要')[:30]}...") + + # 如果所有记忆都被移除,退出循环 + if not all_memories: + print("所有记忆已被移除,测试结束。") + break + + # 等待一下 + await asyncio.sleep(0.5) + + print("\n多记忆衰减测试结束。") + + finally: + await working_memory.shutdown() + +async def main(): + """运行所有测试""" + # 测试手动衰减直到移除 + await test_manual_decay_until_removal() + + # 测试自动衰减 + await test_auto_decay() + + # 测试衰减和检索的平衡 + await test_decay_and_retrieval_balance() + + # 测试多条记忆同时衰减 + await test_multi_memories_decay() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/test_working_memory.py b/src/chat/focus_chat/working_memory/test/test_working_memory.py new file mode 100644 index 000000000..b9440db17 --- /dev/null +++ b/src/chat/focus_chat/working_memory/test/test_working_memory.py @@ -0,0 +1,121 @@ +import asyncio +import os +import unittest +from typing import List, Dict, Any +from pathlib import Path + +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory +from src.chat.focus_chat.working_memory.memory_item import MemoryItem + +class TestWorkingMemory(unittest.TestCase): + """工作记忆测试类""" + + def setUp(self): + """测试前准备""" + self.chat_id = "test_chat_123" + self.working_memory = WorkingMemory(chat_id=self.chat_id, max_memories_per_chat=10, auto_decay_interval=60) + self.test_dir = Path(__file__).parent + + def tearDown(self): + """测试后清理""" + loop = asyncio.get_event_loop() + loop.run_until_complete(self.working_memory.shutdown()) + + def test_init(self): + """测试初始化""" + self.assertEqual(self.working_memory.max_memories_per_chat, 10) + self.assertEqual(self.working_memory.auto_decay_interval, 60) + + def test_add_memory_from_files(self): + """从文件添加记忆""" + loop = asyncio.get_event_loop() + test_files = self._get_test_files() + + # 添加记忆 + memories = [] + for file_path in test_files: + content = self._read_file_content(file_path) + file_name = os.path.basename(file_path) + source = f"test_file_{file_name}" + tags = ["测试", f"文件_{file_name}"] + + memory = loop.run_until_complete( + self.working_memory.add_memory( + content=content, + from_source=source, + tags=tags + ) + ) + memories.append(memory) + + # 验证记忆数量 + all_items = self.working_memory.memory_manager.get_all_items() + self.assertEqual(len(all_items), len(test_files)) + + # 验证每个记忆的内容和标签 + for i, memory in enumerate(memories): + file_name = os.path.basename(test_files[i]) + retrieved_memory = loop.run_until_complete( + self.working_memory.retrieve_memory(memory.id) + ) + + self.assertIsNotNone(retrieved_memory) + self.assertTrue(retrieved_memory.has_tag("测试")) + self.assertTrue(retrieved_memory.has_tag(f"文件_{file_name}")) + self.assertEqual(retrieved_memory.from_source, f"test_file_{file_name}") + + # 验证检索后强度增加 + self.assertGreater(retrieved_memory.memory_strength, 10.0) # 原始强度为10.0,检索后增加1.5倍 + self.assertEqual(retrieved_memory.retrieval_count, 1) + + def test_decay_memories(self): + """测试记忆衰减""" + loop = asyncio.get_event_loop() + test_files = self._get_test_files()[:1] # 只使用一个文件测试衰减 + + # 添加记忆 + for file_path in test_files: + content = self._read_file_content(file_path) + loop.run_until_complete( + self.working_memory.add_memory( + content=content, + from_source="decay_test", + tags=["衰减测试"] + ) + ) + + # 获取添加后的记忆项 + all_items_before = self.working_memory.memory_manager.get_all_items() + self.assertEqual(len(all_items_before), 1) + + # 记录原始强度 + original_strength = all_items_before[0].memory_strength + + # 执行衰减 + loop.run_until_complete( + self.working_memory.decay_all_memories(decay_factor=0.5) + ) + + # 获取衰减后的记忆项 + all_items_after = self.working_memory.memory_manager.get_all_items() + + # 验证强度衰减 + self.assertEqual(len(all_items_after), 1) + self.assertLess(all_items_after[0].memory_strength, original_strength) + + def _get_test_files(self) -> List[str]: + """获取测试文件列表""" + test_dir = self.test_dir + return [ + os.path.join(test_dir, f) + for f in os.listdir(test_dir) + if f.endswith(".txt") + ] + + def _read_file_content(self, file_path: str) -> str: + """读取文件内容""" + with open(file_path, "r", encoding="utf-8") as f: + return f.read() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/working_memory.py b/src/chat/focus_chat/working_memory/working_memory.py new file mode 100644 index 000000000..9fd0e8586 --- /dev/null +++ b/src/chat/focus_chat/working_memory/working_memory.py @@ -0,0 +1,197 @@ +from typing import Dict, List, Any, Optional +import asyncio +import random +from datetime import datetime +from src.common.logger_manager import get_logger +from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem + +logger = get_logger(__name__) + +# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务 + +class WorkingMemory: + """ + 工作记忆,负责协调和运作记忆 + 从属于特定的流,用chat_id来标识 + """ + + def __init__(self,chat_id:str , max_memories_per_chat: int = 10, auto_decay_interval: int = 60): + """ + 初始化工作记忆管理器 + + Args: + max_memories_per_chat: 每个聊天的最大记忆数量 + auto_decay_interval: 自动衰减记忆的时间间隔(秒) + """ + self.memory_manager = MemoryManager(chat_id) + + # 记忆容量上限 + self.max_memories_per_chat = max_memories_per_chat + + # 自动衰减间隔 + self.auto_decay_interval = auto_decay_interval + + # 衰减任务 + self.decay_task = None + + # 启动自动衰减任务 + self._start_auto_decay() + + def _start_auto_decay(self): + """启动自动衰减任务""" + if self.decay_task is None: + self.decay_task = asyncio.create_task(self._auto_decay_loop()) + + async def _auto_decay_loop(self): + """自动衰减循环""" + while True: + await asyncio.sleep(self.auto_decay_interval) + try: + await self.decay_all_memories() + except Exception as e: + print(f"自动衰减记忆时出错: {str(e)}") + + + async def add_memory(self, + content: Any, + from_source: str = "", + tags: Optional[List[str]] = None): + """ + 添加一段记忆到指定聊天 + + Args: + content: 记忆内容 + from_source: 数据来源 + tags: 数据标签列表 + + Returns: + 包含记忆信息的字典 + """ + memory = await self.memory_manager.push_with_summary(content, from_source, tags) + if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat: + self.remove_earliest_memory() + + return memory + + def remove_earliest_memory(self): + """ + 删除最早的记忆 + """ + return self.memory_manager.delete_earliest_memory() + + async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]: + """ + 检索记忆 + + Args: + chat_id: 聊天ID + memory_id: 记忆ID + + Returns: + 检索到的记忆项,如果不存在则返回None + """ + memory_item = self.memory_manager.get_by_id(memory_id) + if memory_item: + memory_item.retrieval_count += 1 + memory_item.increase_strength(5) + return memory_item + return None + + + async def decay_all_memories(self, decay_factor: float = 0.5): + """ + 对所有聊天的所有记忆进行衰减 + 衰减:对记忆进行refine压缩,强度会变为原先的0.5 + + Args: + decay_factor: 衰减因子(0-1之间) + """ + logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}") + + all_memories = self.memory_manager.get_all_items() + + for memory_item in all_memories: + # 如果压缩完小于1会被删除 + memory_id = memory_item.id + self.memory_manager.decay_memory(memory_id, decay_factor) + if memory_item.memory_strength < 1: + self.memory_manager.delete(memory_id) + continue + # 计算衰减量 + if memory_item.memory_strength < 5: + await self.memory_manager.refine_memory(memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩") + + async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem: + """合并记忆 + + Args: + memory_str: 记忆内容 + """ + return await self.memory_manager.merge_memories(memory_id1 = memory_id1, memory_id2 = memory_id2,reason = "两端记忆有重复的内容") + + + + # 暂时没用,先留着 + async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2): + """ + 模拟记忆模糊过程,随机选择一部分记忆进行精简 + + Args: + chat_id: 聊天ID + blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简 + """ + memory = self.get_memory(chat_id) + + # 获取所有字符串类型且有总结的记忆 + all_summarized_memories = [] + for type_items in memory._memory.values(): + for item in type_items: + if isinstance(item.data, str) and hasattr(item, 'summary') and item.summary: + all_summarized_memories.append(item) + + if not all_summarized_memories: + return + + # 计算要模糊的记忆数量 + blur_count = max(1, int(len(all_summarized_memories) * blur_rate)) + + # 随机选择要模糊的记忆 + memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories))) + + # 对选中的记忆进行精简 + for memory_item in memories_to_blur: + try: + # 根据记忆强度决定模糊程度 + if memory_item.memory_strength > 7: + requirement = "保留所有重要信息,仅略微精简" + elif memory_item.memory_strength > 4: + requirement = "保留核心要点,适度精简细节" + else: + requirement = "只保留最关键的1-2个要点,大幅精简内容" + + # 进行精简 + await memory.refine_memory(memory_item.id, requirement) + print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}") + + except Exception as e: + print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}") + + + + async def shutdown(self) -> None: + """关闭管理器,停止所有任务""" + if self.decay_task and not self.decay_task.done(): + self.decay_task.cancel() + try: + await self.decay_task + except asyncio.CancelledError: + pass + + def get_all_memories(self) -> List[MemoryItem]: + """ + 获取所有记忆项目 + + Returns: + List[MemoryItem]: 当前工作记忆中的所有记忆项目列表 + """ + return self.memory_manager.get_all_items() \ No newline at end of file diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index 017f24da9..6bb72bca0 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -120,12 +120,12 @@ class ChattingObservation(Observation): for message in reverse_talking_message: if message["processed_plain_text"] == text: find_msg = message - logger.debug(f"找到的锚定消息:find_msg: {find_msg}") + # logger.debug(f"找到的锚定消息:find_msg: {find_msg}") break else: similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio() msg_list.append({"message": message, "similarity": similarity}) - logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}") + # logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}") if not find_msg: if msg_list: msg_list.sort(key=lambda x: x["similarity"], reverse=True) diff --git a/src/chat/heart_flow/observation/memory_observation.py b/src/chat/heart_flow/observation/memory_observation.py deleted file mode 100644 index 1938a47d3..000000000 --- a/src/chat/heart_flow/observation/memory_observation.py +++ /dev/null @@ -1,55 +0,0 @@ -from src.chat.heart_flow.observation.observation import Observation -from datetime import datetime -from src.common.logger_manager import get_logger -import traceback - -# Import the new utility function -from src.chat.memory_system.Hippocampus import HippocampusManager -import jieba -from typing import List - -logger = get_logger("memory") - - -class MemoryObservation(Observation): - def __init__(self, observe_id): - super().__init__(observe_id) - self.observe_info: str = "" - self.context: str = "" - self.running_memory: List[dict] = [] - - def get_observe_info(self): - for memory in self.running_memory: - self.observe_info += f"{memory['topic']}:{memory['content']}\n" - return self.observe_info - - async def observe(self): - # ---------- 2. 获取记忆 ---------- - try: - # 从聊天内容中提取关键词 - chat_words = set(jieba.cut(self.context)) - # 过滤掉停用词和单字词 - keywords = [word for word in chat_words if len(word) > 1] - # 去重并限制数量 - keywords = list(set(keywords))[:5] - - logger.debug(f"取的关键词: {keywords}") - - # 调用记忆系统获取相关记忆 - related_memory = await HippocampusManager.get_instance().get_memory_from_topic( - valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3 - ) - - logger.debug(f"获取到的记忆: {related_memory}") - - if related_memory: - for topic, memory in related_memory: - # 将记忆添加到 running_memory - self.running_memory.append( - {"topic": topic, "content": memory, "timestamp": datetime.now().isoformat()} - ) - logger.debug(f"添加新记忆: {topic} - {memory}") - - except Exception as e: - logger.error(f"观察 记忆时出错: {e}") - logger.error(traceback.format_exc()) diff --git a/src/chat/heart_flow/observation/structure_observation.py b/src/chat/heart_flow/observation/structure_observation.py new file mode 100644 index 000000000..5c5c0a362 --- /dev/null +++ b/src/chat/heart_flow/observation/structure_observation.py @@ -0,0 +1,32 @@ +from datetime import datetime +from src.common.logger_manager import get_logger + +# Import the new utility function + +logger = get_logger("observation") + + +# 所有观察的基类 +class StructureObservation: + def __init__(self, observe_id): + self.observe_info = "" + self.observe_id = observe_id + self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 + self.history_loop = [] + self.structured_info = [] + + def get_observe_info(self): + return self.structured_info + + def add_structured_info(self, structured_info: dict): + self.structured_info.append(structured_info) + + async def observe(self): + observed_structured_infos = [] + for structured_info in self.structured_info: + if structured_info.get("ttl") > 0: + structured_info["ttl"] -= 1 + observed_structured_infos.append(structured_info) + logger.debug(f"观察到结构化信息仍旧在: {structured_info}") + + self.structured_info = observed_structured_infos \ No newline at end of file diff --git a/src/chat/heart_flow/observation/working_observation.py b/src/chat/heart_flow/observation/working_observation.py index 27b6ab92d..2e32f84d5 100644 --- a/src/chat/heart_flow/observation/working_observation.py +++ b/src/chat/heart_flow/observation/working_observation.py @@ -2,33 +2,33 @@ # 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体 from datetime import datetime from src.common.logger_manager import get_logger - +from src.chat.focus_chat.working_memory.working_memory import WorkingMemory +from src.chat.focus_chat.working_memory.memory_item import MemoryItem +from typing import List # Import the new utility function logger = get_logger("observation") # 所有观察的基类 -class WorkingObservation: - def __init__(self, observe_id): +class WorkingMemoryObservation: + def __init__(self, observe_id, working_memory: WorkingMemory): self.observe_info = "" self.observe_id = observe_id - self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 - self.history_loop = [] - self.structured_info = [] + self.last_observe_time = datetime.now().timestamp() + + self.working_memory = working_memory + + self.retrieved_working_memory = [] def get_observe_info(self): - return self.structured_info + return self.working_memory - def add_structured_info(self, structured_info: dict): - self.structured_info.append(structured_info) + def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]): + self.retrieved_working_memory.append(retrieved_working_memory) + + def get_retrieved_working_memory(self): + return self.retrieved_working_memory async def observe(self): - observed_structured_infos = [] - for structured_info in self.structured_info: - if structured_info.get("ttl") > 0: - structured_info["ttl"] -= 1 - observed_structured_infos.append(structured_info) - logger.debug(f"观察到结构化信息仍旧在: {structured_info}") - - self.structured_info = observed_structured_infos + pass diff --git a/src/common/logger.py b/src/common/logger.py index 9f2dee455..adc15fe71 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -629,22 +629,22 @@ PROCESSOR_STYLE_CONFIG = { PLANNER_STYLE_CONFIG = { "advanced": { - "console_format": "{time:HH:mm:ss} | 规划器 | {message}", + "console_format": "{time:HH:mm:ss} | 规划器 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}", }, "simple": { - "console_format": "{time:HH:mm:ss} | 规划器 | {message}", + "console_format": "{time:HH:mm:ss} | 规划器 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 规划器 | {message}", }, } ACTION_TAKEN_STYLE_CONFIG = { "advanced": { - "console_format": "{time:HH:mm:ss} | 动作 | {message}", + "console_format": "{time:HH:mm:ss} | 动作 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}", }, "simple": { - "console_format": "{time:HH:mm:ss} | 动作 | {message}", + "console_format": "{time:HH:mm:ss} | 动作 | {message}", "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 动作 | {message}", }, } From f126b6b1577c2cd1781c9851cc1b21154d2731f8 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 16 May 2025 16:13:25 +0800 Subject: [PATCH 23/57] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index df3ab670f..ac400b137 100644 --- a/.gitignore +++ b/.gitignore @@ -301,3 +301,5 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk +src/chat/focus_chat/working_memory/test/test1.txt +src/chat/focus_chat/working_memory/test/test4.txt From 021e7f1a971d677a5c2498efef64cb94350dccd8 Mon Sep 17 00:00:00 2001 From: Oct-autumn Date: Fri, 16 May 2025 16:50:53 +0800 Subject: [PATCH 24/57] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/reload_config.py | 6 +- src/chat/emoji_system/emoji_manager.py | 23 +- .../expressors/default_expressor.py | 16 +- .../expressors/exprssion_learner.py | 5 +- src/chat/focus_chat/heartflow_processor.py | 4 +- .../focus_chat/heartflow_prompt_builder.py | 40 +- .../info_processors/chattinginfo_processor.py | 11 +- .../info_processors/mind_processor.py | 4 +- .../info_processors/tool_processor.py | 2 +- src/chat/focus_chat/memory_activator.py | 3 +- src/chat/heart_flow/heartflow.py | 3 +- src/chat/heart_flow/interest_chatting.py | 4 +- src/chat/heart_flow/mai_state_manager.py | 17 +- .../observation/chatting_observation.py | 19 +- src/chat/heart_flow/subheartflow_manager.py | 7 +- src/chat/memory_system/Hippocampus.py | 56 +- src/chat/memory_system/debug_memory.py | 3 +- src/chat/memory_system/memory_config.py | 48 -- src/chat/message_receive/bot.py | 10 +- src/chat/message_receive/message_buffer.py | 4 +- src/chat/message_receive/message_sender.py | 2 +- src/chat/models/utils_model.py | 14 +- src/chat/normal_chat/normal_chat.py | 20 +- src/chat/normal_chat/normal_chat_generator.py | 15 +- .../normal_chat/willing/mode_classical.py | 28 +- src/chat/normal_chat/willing/mode_mxp.py | 9 +- .../normal_chat/willing/willing_manager.py | 5 +- src/chat/person_info/person_info.py | 3 +- src/chat/utils/chat_message_builder.py | 8 +- src/chat/utils/info_catcher.py | 3 +- src/chat/utils/utils.py | 37 +- src/chat/utils/utils_image.py | 6 +- src/common/remote.py | 11 +- src/config/config.py | 745 +++--------------- src/config/config_base.py | 116 +++ src/config/official_configs.py | 399 ++++++++++ src/experimental/PFC/action_planner.py | 4 +- src/experimental/PFC/chat_observer.py | 2 +- src/experimental/PFC/message_sender.py | 4 +- src/experimental/PFC/pfc.py | 7 +- src/experimental/PFC/pfc_KnowledgeFetcher.py | 5 +- src/experimental/PFC/reply_checker.py | 4 +- src/experimental/PFC/reply_generator.py | 2 +- src/experimental/PFC/waiter.py | 2 +- src/experimental/only_message_process.py | 4 +- src/main.py | 32 +- src/manager/mood_manager.py | 8 +- src/tools/not_used/change_mood.py | 2 +- src/tools/tool_use.py | 4 +- template/bot_config_meta.toml | 104 --- template/bot_config_template.toml | 107 ++- tests/test_config.py | 7 + 52 files changed, 902 insertions(+), 1102 deletions(-) delete mode 100644 src/chat/memory_system/memory_config.py create mode 100644 src/config/config_base.py create mode 100644 src/config/official_configs.py delete mode 100644 template/bot_config_meta.toml create mode 100644 tests/test_config.py diff --git a/src/api/reload_config.py b/src/api/reload_config.py index a5f36e3db..1772800b6 100644 --- a/src/api/reload_config.py +++ b/src/api/reload_config.py @@ -1,6 +1,6 @@ from fastapi import HTTPException from rich.traceback import install -from src.config.config import BotConfig +from src.config.config import Config from src.common.logger_manager import get_logger import os @@ -14,8 +14,8 @@ async def reload_config(): from src.config import config as config_module logger.debug("正在重载配置文件...") - bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") - config_module.global_config = BotConfig.load_config(config_path=bot_config_path) + bot_config_path = os.path.join(Config.get_config_dir(), "bot_config.toml") + config_module.global_config = Config.load_config(config_path=bot_config_path) logger.debug("配置文件重载成功") return {"status": "reloaded"} except FileNotFoundError as e: diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 5d800866f..52a7288ec 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -369,14 +369,15 @@ class EmojiManager: def __init__(self): self._initialized = None self._scan_task = None - self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") + + self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") self.llm_emotion_judge = LLMRequest( - model=global_config.llm_normal, max_tokens=600, request_type="emoji" + model=global_config.model.normal, max_tokens=600, request_type="emoji" ) # 更高的温度,更少的token(后续可以根据情绪来调整温度) self.emoji_num = 0 - self.emoji_num_max = global_config.max_emoji_num - self.emoji_num_max_reach_deletion = global_config.max_reach_deletion + self.emoji_num_max = global_config.emoji.max_reg_num + self.emoji_num_max_reach_deletion = global_config.emoji.do_replace self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型 logger.info("启动表情包管理器") @@ -613,18 +614,18 @@ class EmojiManager: logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}") os.makedirs(EMOJI_DIR, exist_ok=True) logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}") - await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) + await asyncio.sleep(global_config.emoji.check_interval * 60) continue # 检查目录是否为空 files = os.listdir(EMOJI_DIR) if not files: logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}") - await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) + await asyncio.sleep(global_config.emoji.check_interval * 60) continue # 检查是否需要处理表情包(数量超过最大值或不足) - if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or ( + if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or ( self.emoji_num < self.emoji_num_max ): try: @@ -651,7 +652,7 @@ class EmojiManager: except Exception as e: logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") - await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) + await asyncio.sleep(global_config.emoji.check_interval * 60) async def get_all_emoji_from_db(self): """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" @@ -788,7 +789,7 @@ class EmojiManager: # 构建提示词 prompt = ( - f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})," + f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})," f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n" f"新表情包信息:\n" f"描述: {new_emoji.description}\n\n" @@ -871,10 +872,10 @@ class EmojiManager: description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) # 审核表情包 - if global_config.EMOJI_CHECK: + if global_config.emoji.content_filtration: prompt = f''' 这是一个表情包,请对这个表情包进行审核,标准如下: - 1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求 + 1. 必须符合"{global_config.emoji.filtration_prompt}"的要求 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗 3. 不能是任何形式的截图,聊天记录或视频截图 4. 不要出现5个以上文字 diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index 37c50c0dc..c5aa5f9a4 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -25,9 +25,10 @@ logger = get_logger("expressor") class DefaultExpressor: def __init__(self, chat_id: str): self.log_prefix = "expressor" + # TODO: API-Adapter修改标记 self.express_model = LLMRequest( - model=global_config.llm_normal, - temperature=global_config.llm_normal["temp"], + model=global_config.model.normal, + temperature=global_config.model.normal["temp"], max_tokens=256, request_type="response_heartflow", ) @@ -51,8 +52,8 @@ class DefaultExpressor: messageinfo = anchor_message.message_info thinking_time_point = parse_thinking_id_to_timestamp(thinking_id) bot_user_info = UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=messageinfo.platform, ) # logger.debug(f"创建思考消息:{anchor_message}") @@ -141,7 +142,7 @@ class DefaultExpressor: try: # 1. 获取情绪影响因子并调整模型温度 arousal_multiplier = mood_manager.get_arousal_multiplier() - current_temp = float(global_config.llm_normal["temp"]) * arousal_multiplier + current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier self.express_model.params["temperature"] = current_temp # 动态调整温度 # 2. 获取信息捕捉器 @@ -183,6 +184,7 @@ class DefaultExpressor: try: with Timer("LLM生成", {}): # 内部计时器,可选保留 + # TODO: API-Adapter修改标记 # logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n") content, reasoning_content, model_name = await self.express_model.generate_response(prompt) @@ -330,8 +332,8 @@ class DefaultExpressor: thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(self.chat_id, thinking_id) bot_user_info = UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=self.chat_stream.platform, ) diff --git a/src/chat/focus_chat/expressors/exprssion_learner.py b/src/chat/focus_chat/expressors/exprssion_learner.py index 942162bc8..7766fde56 100644 --- a/src/chat/focus_chat/expressors/exprssion_learner.py +++ b/src/chat/focus_chat/expressors/exprssion_learner.py @@ -77,8 +77,9 @@ def init_prompt() -> None: class ExpressionLearner: def __init__(self) -> None: + # TODO: API-Adapter修改标记 self.express_learn_model: LLMRequest = LLMRequest( - model=global_config.llm_normal, + model=global_config.model.normal, temperature=0.1, max_tokens=256, request_type="response_heartflow", @@ -289,7 +290,7 @@ class ExpressionLearner: # 构建prompt prompt = await global_prompt_manager.format_prompt( "personality_expression_prompt", - personality=global_config.expression_style, + personality=global_config.personality.expression_style, ) # logger.info(f"个性表达方式提取prompt: {prompt}") diff --git a/src/chat/focus_chat/heartflow_processor.py b/src/chat/focus_chat/heartflow_processor.py index bbfa4ce46..a4cf360a5 100644 --- a/src/chat/focus_chat/heartflow_processor.py +++ b/src/chat/focus_chat/heartflow_processor.py @@ -112,7 +112,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool: Returns: bool: 是否包含过滤词 """ - for word in global_config.ban_words: + for word in global_config.chat.ban_words: if word in text: chat_name = chat.group_info.group_name if chat.group_info else "私聊" logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") @@ -132,7 +132,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool: Returns: bool: 是否匹配过滤正则 """ - for pattern in global_config.ban_msgs_regex: + for pattern in global_config.chat.ban_msgs_regex: if pattern.search(text): chat_name = chat.group_info.group_name if chat.group_info else "私聊" logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}") diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 55fb79b46..fae00a9db 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -6,14 +6,13 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time -from typing import Union, Optional, Dict, Any +from typing import Union, Optional from src.common.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.focus_chat.expressors.exprssion_learner import expression_learner -import traceback import random @@ -142,7 +141,7 @@ async def _build_prompt_focus( message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), - limit=global_config.observation_context_size, + limit=global_config.chat.observation_context_size, ) chat_talking_prompt = await build_readable_messages( message_list_before_now, @@ -209,7 +208,7 @@ async def _build_prompt_focus( chat_target=chat_target_1, # Used in group template # chat_talking_prompt=chat_talking_prompt, chat_info=chat_talking_prompt, - bot_name=global_config.BOT_NICKNAME, + bot_name=global_config.bot.nickname, # prompt_personality=prompt_personality, prompt_personality="", reason=reason, @@ -225,7 +224,7 @@ async def _build_prompt_focus( info_from_tools=structured_info_prompt, sender_name=effective_sender_name, # Used in private template chat_talking_prompt=chat_talking_prompt, - bot_name=global_config.BOT_NICKNAME, + bot_name=global_config.bot.nickname, prompt_personality=prompt_personality, # chat_target and chat_target_2 are not used in private template current_mind_info=current_mind_info, @@ -280,7 +279,7 @@ class PromptBuilder: who_chat_in_group = get_recent_group_speaker( chat_stream.stream_id, (chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None, - limit=global_config.observation_context_size, + limit=global_config.chat.observation_context_size, ) elif chat_stream.user_info: who_chat_in_group.append( @@ -328,7 +327,7 @@ class PromptBuilder: message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), - limit=global_config.observation_context_size, + limit=global_config.chat.observation_context_size, ) chat_talking_prompt = await build_readable_messages( message_list_before_now, @@ -340,18 +339,15 @@ class PromptBuilder: # 关键词检测与反应 keywords_reaction_prompt = "" - for rule in global_config.keywords_reaction_rules: - if rule.get("enable", False): - if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): - logger.info( - f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" - ) - keywords_reaction_prompt += rule.get("reaction", "") + "," + for rule in global_config.keyword_reaction.rules: + if rule.enable: + if any(keyword in message_txt for keyword in rule.keywords): + logger.info(f"检测到以下关键词之一:{rule.keywords},触发反应:{rule.reaction}") + keywords_reaction_prompt += f"{rule.reaction}," else: - for pattern in rule.get("regex", []): - result = pattern.search(message_txt) - if result: - reaction = rule.get("reaction", "") + for pattern in rule.regex: + if result := pattern.search(message_txt): + reaction = rule.reaction for name, content in result.groupdict().items(): reaction = reaction.replace(f"[{name}]", content) logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}") @@ -397,8 +393,8 @@ class PromptBuilder: chat_target_2=chat_target_2, chat_talking_prompt=chat_talking_prompt, message_txt=message_txt, - bot_name=global_config.BOT_NICKNAME, - bot_other_names="/".join(global_config.BOT_ALIAS_NAMES), + bot_name=global_config.bot.nickname, + bot_other_names="/".join(global_config.bot.alias_names), prompt_personality=prompt_personality, mood_prompt=mood_prompt, reply_style1=reply_style1_chosen, @@ -419,8 +415,8 @@ class PromptBuilder: prompt_info=prompt_info, chat_talking_prompt=chat_talking_prompt, message_txt=message_txt, - bot_name=global_config.BOT_NICKNAME, - bot_other_names="/".join(global_config.BOT_ALIAS_NAMES), + bot_name=global_config.bot.nickname, + bot_other_names="/".join(global_config.bot.alias_names), prompt_personality=prompt_personality, mood_prompt=mood_prompt, reply_style1=reply_style1_chosen, diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index 12bc8560a..bb70c043a 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -26,8 +26,9 @@ class ChattingInfoProcessor(BaseProcessor): def __init__(self): """初始化观察处理器""" super().__init__() + # TODO: API-Adapter修改标记 self.llm_summary = LLMRequest( - model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation" + model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" ) async def process_info( @@ -108,12 +109,12 @@ class ChattingInfoProcessor(BaseProcessor): "created_at": datetime.now().timestamp(), } - obs.mid_memorys.append(mid_memory) - if len(obs.mid_memorys) > obs.max_mid_memory_len: - obs.mid_memorys.pop(0) # 移除最旧的 + obs.mid_memories.append(mid_memory) + if len(obs.mid_memories) > obs.max_mid_memory_len: + obs.mid_memories.pop(0) # 移除最旧的 mid_memory_str = "之前聊天的内容概述是:\n" - for mid_memory_item in obs.mid_memorys: # 重命名循环变量以示区分 + for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分 time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60) mid_memory_str += ( f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}):{mid_memory_item['theme']}\n" diff --git a/src/chat/focus_chat/info_processors/mind_processor.py b/src/chat/focus_chat/info_processors/mind_processor.py index 1a104e123..221935e3d 100644 --- a/src/chat/focus_chat/info_processors/mind_processor.py +++ b/src/chat/focus_chat/info_processors/mind_processor.py @@ -81,8 +81,8 @@ class MindProcessor(BaseProcessor): self.subheartflow_id = subheartflow_id self.llm_model = LLMRequest( - model=global_config.llm_sub_heartflow, - temperature=global_config.llm_sub_heartflow["temp"], + model=global_config.model.sub_heartflow, + temperature=global_config.model.sub_heartflow["temp"], max_tokens=800, request_type="sub_heart_flow", ) diff --git a/src/chat/focus_chat/info_processors/tool_processor.py b/src/chat/focus_chat/info_processors/tool_processor.py index 8840c1ae4..57bac5f79 100644 --- a/src/chat/focus_chat/info_processors/tool_processor.py +++ b/src/chat/focus_chat/info_processors/tool_processor.py @@ -52,7 +52,7 @@ class ToolProcessor(BaseProcessor): self.subheartflow_id = subheartflow_id self.log_prefix = f"[{subheartflow_id}:ToolExecutor] " self.llm_model = LLMRequest( - model=global_config.llm_tool_use, + model=global_config.model.tool_use, max_tokens=500, request_type="tool_execution", ) diff --git a/src/chat/focus_chat/memory_activator.py b/src/chat/focus_chat/memory_activator.py index 2d7fea034..4faf43747 100644 --- a/src/chat/focus_chat/memory_activator.py +++ b/src/chat/focus_chat/memory_activator.py @@ -34,8 +34,9 @@ def init_prompt(): class MemoryActivator: def __init__(self): + # TODO: API-Adapter修改标记 self.summary_model = LLMRequest( - model=global_config.llm_summary, temperature=0.7, max_tokens=50, request_type="chat_observation" + model=global_config.model.summary, temperature=0.7, max_tokens=50, request_type="chat_observation" ) self.running_memory = [] diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index ad876bcf0..748c8331e 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -35,8 +35,9 @@ class Heartflow: self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state) # LLM模型配置 + # TODO: API-Adapter修改标记 self.llm_model = LLMRequest( - model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" + model=global_config.model.heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" ) # 外部依赖模块 diff --git a/src/chat/heart_flow/interest_chatting.py b/src/chat/heart_flow/interest_chatting.py index 45f7fe952..bce372b5c 100644 --- a/src/chat/heart_flow/interest_chatting.py +++ b/src/chat/heart_flow/interest_chatting.py @@ -20,9 +20,9 @@ MAX_REPLY_PROBABILITY = 1 class InterestChatting: def __init__( self, - decay_rate=global_config.default_decay_rate_per_second, + decay_rate=global_config.focus_chat.default_decay_rate_per_second, max_interest=MAX_INTEREST, - trigger_threshold=global_config.reply_trigger_threshold, + trigger_threshold=global_config.focus_chat.reply_trigger_threshold, max_probability=MAX_REPLY_PROBABILITY, ): # 基础属性初始化 diff --git a/src/chat/heart_flow/mai_state_manager.py b/src/chat/heart_flow/mai_state_manager.py index 7dea910e9..017656ad2 100644 --- a/src/chat/heart_flow/mai_state_manager.py +++ b/src/chat/heart_flow/mai_state_manager.py @@ -18,19 +18,14 @@ enable_unlimited_hfc_chat = True # 调试用:无限专注聊天 prevent_offline_state = True # 目前默认不启用OFFLINE状态 -# 不同状态下普通聊天的最大消息数 -base_normal_chat_num = global_config.base_normal_chat_num -base_focused_chat_num = global_config.base_focused_chat_num - - -MAX_NORMAL_CHAT_NUM_PEEKING = int(base_normal_chat_num / 2) -MAX_NORMAL_CHAT_NUM_NORMAL = base_normal_chat_num -MAX_NORMAL_CHAT_NUM_FOCUSED = base_normal_chat_num + 1 +MAX_NORMAL_CHAT_NUM_PEEKING = int(global_config.chat.base_normal_chat_num / 2) +MAX_NORMAL_CHAT_NUM_NORMAL = global_config.chat.base_normal_chat_num +MAX_NORMAL_CHAT_NUM_FOCUSED = global_config.chat.base_normal_chat_num + 1 # 不同状态下专注聊天的最大消息数 -MAX_FOCUSED_CHAT_NUM_PEEKING = int(base_focused_chat_num / 2) -MAX_FOCUSED_CHAT_NUM_NORMAL = base_focused_chat_num -MAX_FOCUSED_CHAT_NUM_FOCUSED = base_focused_chat_num + 2 +MAX_FOCUSED_CHAT_NUM_PEEKING = int(global_config.chat.base_focused_chat_num / 2) +MAX_FOCUSED_CHAT_NUM_NORMAL = global_config.chat.base_focused_chat_num +MAX_FOCUSED_CHAT_NUM_FOCUSED = global_config.chat.base_focused_chat_num + 2 # -- 状态定义 -- diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index a51eba5e2..c30bc8e43 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -53,19 +53,20 @@ class ChattingObservation(Observation): self.talking_message = [] self.talking_message_str = "" self.talking_message_str_truncate = "" - self.name = global_config.BOT_NICKNAME - self.nick_name = global_config.BOT_ALIAS_NAMES - self.max_now_obs_len = global_config.observation_context_size - self.overlap_len = global_config.compressed_length - self.mid_memorys = [] - self.max_mid_memory_len = global_config.compress_length_limit + self.name = global_config.bot.nickname + self.nick_name = global_config.bot.alias_names + self.max_now_obs_len = global_config.chat.observation_context_size + self.overlap_len = global_config.focus_chat.compressed_length + self.mid_memories = [] + self.max_mid_memory_len = global_config.focus_chat.compress_length_limit self.mid_memory_info = "" self.person_list = [] self.oldest_messages = [] self.oldest_messages_str = "" self.compressor_prompt = "" + # TODO: API-Adapter修改标记 self.llm_summary = LLMRequest( - model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation" + model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" ) async def initialize(self): @@ -83,7 +84,7 @@ class ChattingObservation(Observation): for id in ids: print(f"id:{id}") try: - for mid_memory in self.mid_memorys: + for mid_memory in self.mid_memories: if mid_memory["id"] == id: mid_memory_by_id = mid_memory msg_str = "" @@ -101,7 +102,7 @@ class ChattingObservation(Observation): else: mid_memory_str = "之前的聊天内容:\n" - for mid_memory in self.mid_memorys: + for mid_memory in self.mid_memories: mid_memory_str += f"{mid_memory['theme']}\n" return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str diff --git a/src/chat/heart_flow/subheartflow_manager.py b/src/chat/heart_flow/subheartflow_manager.py index a4bff8338..bf4ddf7e1 100644 --- a/src/chat/heart_flow/subheartflow_manager.py +++ b/src/chat/heart_flow/subheartflow_manager.py @@ -76,8 +76,9 @@ class SubHeartflowManager: # 为 LLM 状态评估创建一个 LLMRequest 实例 # 使用与 Heartflow 相同的模型和参数 + # TODO: API-Adapter修改标记 self.llm_state_evaluator = LLMRequest( - model=global_config.llm_heartflow, # 与 Heartflow 一致 + model=global_config.model.heartflow, # 与 Heartflow 一致 temperature=0.6, # 与 Heartflow 一致 max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多) request_type="subheartflow_state_eval", # 保留特定的请求类型 @@ -278,7 +279,7 @@ class SubHeartflowManager: focused_limit = current_state.get_focused_chat_max_num() # --- 新增:检查是否允许进入 FOCUS 模式 --- # - if not global_config.allow_focus_mode: + if not global_config.chat.allow_focus_mode: if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏 logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)") return # 如果不允许,直接返回 @@ -766,7 +767,7 @@ class SubHeartflowManager: focused_limit = current_mai_state.get_focused_chat_max_num() # --- 检查是否允许 FOCUS 模式 --- # - if not global_config.allow_focus_mode: + if not global_config.chat.allow_focus_mode: # Log less frequently to avoid spam # if int(time.time()) % 60 == 0: # logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态") diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 70eb679c9..d8c7c50e6 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -19,9 +19,10 @@ from ..utils.chat_message_builder import ( build_readable_messages, ) # 导入 build_readable_messages from ..utils.utils import translate_timestamp_to_human_readable -from .memory_config import MemoryConfig from rich.traceback import install +from ...config.config import global_config + install(extra_lines=3) @@ -195,18 +196,16 @@ class Hippocampus: self.llm_summary = None self.entorhinal_cortex = None self.parahippocampal_gyrus = None - self.config = None - def initialize(self, global_config): - # 使用导入的 MemoryConfig dataclass 和其 from_global_config 方法 - self.config = MemoryConfig.from_global_config(global_config) + def initialize(self): # 初始化子组件 self.entorhinal_cortex = EntorhinalCortex(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self) # 从数据库加载记忆图 self.entorhinal_cortex.sync_memory_from_db() - self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory") - self.llm_summary = LLMRequest(self.config.llm_summary, request_type="memory") + # TODO: API-Adapter修改标记 + self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory") + self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory") def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表""" @@ -792,7 +791,6 @@ class EntorhinalCortex: def __init__(self, hippocampus: Hippocampus): self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - self.config = hippocampus.config def get_memory_sample(self): """从数据库获取记忆样本""" @@ -801,13 +799,13 @@ class EntorhinalCortex: # 创建双峰分布的记忆调度器 sample_scheduler = MemoryBuildScheduler( - n_hours1=self.config.memory_build_distribution[0], - std_hours1=self.config.memory_build_distribution[1], - weight1=self.config.memory_build_distribution[2], - n_hours2=self.config.memory_build_distribution[3], - std_hours2=self.config.memory_build_distribution[4], - weight2=self.config.memory_build_distribution[5], - total_samples=self.config.build_memory_sample_num, + n_hours1=global_config.memory.memory_build_distribution[0], + std_hours1=global_config.memory.memory_build_distribution[1], + weight1=global_config.memory.memory_build_distribution[2], + n_hours2=global_config.memory.memory_build_distribution[3], + std_hours2=global_config.memory.memory_build_distribution[4], + weight2=global_config.memory.memory_build_distribution[5], + total_samples=global_config.memory.memory_build_sample_num, ) timestamps = sample_scheduler.get_timestamp_array() @@ -818,7 +816,7 @@ class EntorhinalCortex: for timestamp in timestamps: # 调用修改后的 random_get_msg_snippet messages = self.random_get_msg_snippet( - timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg + timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg ) if messages: time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 @@ -1099,7 +1097,6 @@ class ParahippocampalGyrus: def __init__(self, hippocampus: Hippocampus): self.hippocampus = hippocampus self.memory_graph = hippocampus.memory_graph - self.config = hippocampus.config async def memory_compress(self, messages: list, compress_rate=0.1): """压缩和总结消息内容,生成记忆主题和摘要。 @@ -1159,7 +1156,7 @@ class ParahippocampalGyrus: # 3. 过滤掉包含禁用关键词的topic filtered_topics = [ - topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words) + topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words) ] logger.debug(f"过滤后话题: {filtered_topics}") @@ -1222,7 +1219,7 @@ class ParahippocampalGyrus: bar = "█" * filled_length + "-" * (bar_length - filled_length) logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") - compress_rate = self.config.memory_compress_rate + compress_rate = global_config.memory.memory_compress_rate try: compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) except Exception as e: @@ -1322,7 +1319,7 @@ class ParahippocampalGyrus: edge_data = self.memory_graph.G[source][target] last_modified = edge_data.get("last_modified") - if current_time - last_modified > 3600 * self.config.memory_forget_time: + if current_time - last_modified > 3600 * global_config.memory.memory_forget_time: current_strength = edge_data.get("strength", 1) new_strength = current_strength - 1 @@ -1430,8 +1427,8 @@ class ParahippocampalGyrus: async def operation_consolidate_memory(self): """整合记忆:合并节点内相似的记忆项""" start_time = time.time() - percentage = self.config.consolidate_memory_percentage - similarity_threshold = self.config.consolidation_similarity_threshold + percentage = global_config.memory.consolidate_memory_percentage + similarity_threshold = global_config.memory.consolidation_similarity_threshold logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}") # 获取所有至少有2条记忆项的节点 @@ -1544,7 +1541,6 @@ class ParahippocampalGyrus: class HippocampusManager: _instance = None _hippocampus = None - _global_config = None _initialized = False @classmethod @@ -1559,19 +1555,15 @@ class HippocampusManager: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") return cls._hippocampus - def initialize(self, global_config): + def initialize(self): """初始化海马体实例""" if self._initialized: return self._hippocampus - self._global_config = global_config self._hippocampus = Hippocampus() - self._hippocampus.initialize(global_config) + self._hippocampus.initialize() self._initialized = True - # 输出记忆系统参数信息 - config = self._hippocampus.config - # 输出记忆图统计信息 memory_graph = self._hippocampus.memory_graph.G node_count = len(memory_graph.nodes()) @@ -1579,9 +1571,9 @@ class HippocampusManager: logger.success(f"""-------------------------------- 记忆系统参数配置: - 构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate} - 记忆构建分布: {config.memory_build_distribution} - 遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后 + 构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate} + 记忆构建分布: {global_config.memory.memory_build_distribution} + 遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后 记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count} --------------------------------""") # noqa: E501 diff --git a/src/chat/memory_system/debug_memory.py b/src/chat/memory_system/debug_memory.py index baf745409..b09e703a1 100644 --- a/src/chat/memory_system/debug_memory.py +++ b/src/chat/memory_system/debug_memory.py @@ -7,7 +7,6 @@ import os # 添加项目根目录到系统路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) from src.chat.memory_system.Hippocampus import HippocampusManager -from src.config.config import global_config from rich.traceback import install install(extra_lines=3) @@ -19,7 +18,7 @@ async def test_memory_system(): # 初始化记忆系统 print("开始初始化记忆系统...") hippocampus_manager = HippocampusManager.get_instance() - hippocampus_manager.initialize(global_config=global_config) + hippocampus_manager.initialize() print("记忆系统初始化完成") # 测试记忆构建 diff --git a/src/chat/memory_system/memory_config.py b/src/chat/memory_system/memory_config.py deleted file mode 100644 index b82e54ec1..000000000 --- a/src/chat/memory_system/memory_config.py +++ /dev/null @@ -1,48 +0,0 @@ -from dataclasses import dataclass -from typing import List - - -@dataclass -class MemoryConfig: - """记忆系统配置类""" - - # 记忆构建相关配置 - memory_build_distribution: List[float] # 记忆构建的时间分布参数 - build_memory_sample_num: int # 每次构建记忆的样本数量 - build_memory_sample_length: int # 每个样本的消息长度 - memory_compress_rate: float # 记忆压缩率 - - # 记忆遗忘相关配置 - memory_forget_time: int # 记忆遗忘时间(小时) - - # 记忆过滤相关配置 - memory_ban_words: List[str] # 记忆过滤词列表 - - # 新增:记忆整合相关配置 - consolidation_similarity_threshold: float # 相似度阈值 - consolidate_memory_percentage: float # 检查节点比例 - consolidate_memory_interval: int # 记忆整合间隔 - - llm_topic_judge: str # 话题判断模型 - llm_summary: str # 话题总结模型 - - @classmethod - def from_global_config(cls, global_config): - """从全局配置创建记忆系统配置""" - # 使用 getattr 提供默认值,防止全局配置缺少这些项 - return cls( - memory_build_distribution=getattr( - global_config, "memory_build_distribution", (24, 12, 0.5, 168, 72, 0.5) - ), # 添加默认值 - build_memory_sample_num=getattr(global_config, "build_memory_sample_num", 5), - build_memory_sample_length=getattr(global_config, "build_memory_sample_length", 30), - memory_compress_rate=getattr(global_config, "memory_compress_rate", 0.1), - memory_forget_time=getattr(global_config, "memory_forget_time", 24 * 7), - memory_ban_words=getattr(global_config, "memory_ban_words", []), - # 新增加载整合配置,并提供默认值 - consolidation_similarity_threshold=getattr(global_config, "consolidation_similarity_threshold", 0.7), - consolidate_memory_percentage=getattr(global_config, "consolidate_memory_percentage", 0.01), - consolidate_memory_interval=getattr(global_config, "consolidate_memory_interval", 1000), - llm_topic_judge=getattr(global_config, "llm_topic_judge", "default_judge_model"), # 添加默认模型名 - llm_summary=getattr(global_config, "llm_summary", "default_summary_model"), # 添加默认模型名 - ) diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 3c9e4420c..0e35f6f6e 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -41,7 +41,7 @@ class ChatBot: chat_id = str(message.chat_stream.stream_id) private_name = str(message.message_info.user_info.user_nickname) - if global_config.enable_pfc_chatting: + if global_config.experimental.enable_pfc_chatting: await self.pfc_manager.get_or_create_conversation(chat_id, private_name) except Exception as e: @@ -78,19 +78,19 @@ class ChatBot: userinfo = message.message_info.user_info # 用户黑名单拦截 - if userinfo.user_id in global_config.ban_user_id: + if userinfo.user_id in global_config.chat_target.ban_user_id: logger.debug(f"用户{userinfo.user_id}被禁止回复") return if groupinfo is None: logger.trace("检测到私聊消息,检查") # 好友黑名单拦截 - if userinfo.user_id not in global_config.talk_allowed_private: + if userinfo.user_id not in global_config.experimental.talk_allowed_private: logger.debug(f"用户{userinfo.user_id}没有私聊权限") return # 群聊黑名单拦截 - if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups: + if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups: logger.trace(f"群{groupinfo.group_id}被禁止回复") return @@ -112,7 +112,7 @@ class ChatBot: if groupinfo is None: logger.trace("检测到私聊消息") # 是否在配置信息中开启私聊模式 - if global_config.enable_friend_chat: + if global_config.experimental.enable_friend_chat: logger.trace("私聊模式已启用") # 是否进入PFC if global_config.enable_pfc_chatting: diff --git a/src/chat/message_receive/message_buffer.py b/src/chat/message_receive/message_buffer.py index f3cf63d0a..2df256ce5 100644 --- a/src/chat/message_receive/message_buffer.py +++ b/src/chat/message_receive/message_buffer.py @@ -38,7 +38,7 @@ class MessageBuffer: async def start_caching_messages(self, message: MessageRecv): """添加消息,启动缓冲""" - if not global_config.message_buffer: + if not global_config.chat.message_buffer: person_id = person_info_manager.get_person_id( message.message_info.user_info.platform, message.message_info.user_info.user_id ) @@ -107,7 +107,7 @@ class MessageBuffer: async def query_buffer_result(self, message: MessageRecv) -> bool: """查询缓冲结果,并清理""" - if not global_config.message_buffer: + if not global_config.chat.message_buffer: return True person_id_ = self.get_person_id_( message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info diff --git a/src/chat/message_receive/message_sender.py b/src/chat/message_receive/message_sender.py index 5db34fdea..cf5877989 100644 --- a/src/chat/message_receive/message_sender.py +++ b/src/chat/message_receive/message_sender.py @@ -279,7 +279,7 @@ class MessageManager: ) # 检查是否超时 - if thinking_time > global_config.thinking_timeout: + if thinking_time > global_config.normal_chat.thinking_timeout: logger.warning( f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}" ) diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py index e662a8e33..a161ae4d9 100644 --- a/src/chat/models/utils_model.py +++ b/src/chat/models/utils_model.py @@ -111,8 +111,8 @@ class LLMRequest: def __init__(self, model: dict, **kwargs): # 将大写的配置键转换为小写并从config中获取实际值 try: - self.api_key = os.environ[model["key"]] - self.base_url = os.environ[model["base_url"]] + self.api_key = os.environ[f"{model['provider']}_KEY"] + self.base_url = os.environ[f"{model['provider']}_BASE_URL"] except AttributeError as e: logger.error(f"原始 model dict 信息:{model}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") @@ -500,11 +500,11 @@ class LLMRequest: logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}") # 对全局配置进行更新 - if global_config.llm_normal.get("name") == old_model_name: - global_config.llm_normal["name"] = self.model_name + if global_config.model.normal.get("name") == old_model_name: + global_config.model.normal["name"] = self.model_name logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}") - if global_config.llm_reasoning.get("name") == old_model_name: - global_config.llm_reasoning["name"] = self.model_name + if global_config.model.reasoning.get("name") == old_model_name: + global_config.model.reasoning["name"] = self.model_name logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") if payload and "model" in payload: @@ -636,7 +636,7 @@ class LLMRequest: **params_copy, } if "max_tokens" not in payload and "max_completion_tokens" not in payload: - payload["max_tokens"] = global_config.model_max_output_length + payload["max_tokens"] = global_config.model.model_max_output_length # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: payload["max_completion_tokens"] = payload.pop("max_tokens") diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 9dc2454ff..96cc2b8cb 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -73,8 +73,8 @@ class NormalChat: messageinfo = message.message_info bot_user_info = UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=messageinfo.platform, ) @@ -121,8 +121,8 @@ class NormalChat: message_id=thinking_id, chat_stream=self.chat_stream, # 使用 self.chat_stream bot_user_info=UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=message.message_info.platform, ), sender_info=message.message_info.user_info, @@ -147,7 +147,7 @@ class NormalChat: # 改为实例方法 async def _handle_emoji(self, message: MessageRecv, response: str): """处理表情包""" - if random() < global_config.emoji_chance: + if random() < global_config.normal_chat.emoji_chance: emoji_raw = await emoji_manager.get_emoji_for_text(response) if emoji_raw: emoji_path, description = emoji_raw @@ -160,8 +160,8 @@ class NormalChat: message_id="mt" + str(thinking_time_point), chat_stream=self.chat_stream, # 使用 self.chat_stream bot_user_info=UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=message.message_info.platform, ), sender_info=message.message_info.user_info, @@ -186,7 +186,7 @@ class NormalChat: label=emotion, stance=stance, # 使用 self.chat_stream ) - self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor) + self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor) async def _reply_interested_message(self) -> None: """ @@ -430,7 +430,7 @@ class NormalChat: def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: """检查消息中是否包含过滤词""" stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id - for word in global_config.ban_words: + for word in global_config.chat.ban_words: if word in text: logger.info( f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]" @@ -445,7 +445,7 @@ class NormalChat: def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool: """检查消息是否匹配过滤正则表达式""" stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id - for pattern in global_config.ban_msgs_regex: + for pattern in global_config.chat.ban_msgs_regex: if pattern.search(text): logger.info( f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]" diff --git a/src/chat/normal_chat/normal_chat_generator.py b/src/chat/normal_chat/normal_chat_generator.py index aec65ed1d..631f7baa5 100644 --- a/src/chat/normal_chat/normal_chat_generator.py +++ b/src/chat/normal_chat/normal_chat_generator.py @@ -15,21 +15,22 @@ logger = get_logger("llm") class NormalChatGenerator: def __init__(self): + # TODO: API-Adapter修改标记 self.model_reasoning = LLMRequest( - model=global_config.llm_reasoning, + model=global_config.model.reasoning, temperature=0.7, max_tokens=3000, request_type="response_reasoning", ) self.model_normal = LLMRequest( - model=global_config.llm_normal, - temperature=global_config.llm_normal["temp"], + model=global_config.model.normal, + temperature=global_config.model.normal["temp"], max_tokens=256, request_type="response_reasoning", ) self.model_sum = LLMRequest( - model=global_config.llm_summary, temperature=0.7, max_tokens=3000, request_type="relation" + model=global_config.model.summary, temperature=0.7, max_tokens=3000, request_type="relation" ) self.current_model_type = "r1" # 默认使用 R1 self.current_model_name = "unknown model" @@ -37,7 +38,7 @@ class NormalChatGenerator: async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]: """根据当前模型类型选择对应的生成函数""" # 从global_config中获取模型概率值并选择模型 - if random.random() < global_config.model_reasoning_probability: + if random.random() < global_config.normal_chat.reasoning_model_probability: self.current_model_type = "深深地" current_model = self.model_reasoning else: @@ -51,7 +52,7 @@ class NormalChatGenerator: model_response = await self._generate_response_with_model(message, current_model, thinking_id) if model_response: - logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}") + logger.info(f"{global_config.bot.nickname}的回复是:{model_response}") model_response = await self._process_response(model_response) return model_response @@ -113,7 +114,7 @@ class NormalChatGenerator: - "中立":不表达明确立场或无关回应 2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签 3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒" - 4. 考虑回复者的人格设定为{global_config.personality_core} + 4. 考虑回复者的人格设定为{global_config.personality.personality_core} 对话示例: 被回复:「A就是笨」 diff --git a/src/chat/normal_chat/willing/mode_classical.py b/src/chat/normal_chat/willing/mode_classical.py index e96aa77a7..a9f04273a 100644 --- a/src/chat/normal_chat/willing/mode_classical.py +++ b/src/chat/normal_chat/willing/mode_classical.py @@ -1,18 +1,20 @@ import asyncio + +from src.config.config import global_config from .willing_manager import BaseWillingManager class ClassicalWillingManager(BaseWillingManager): def __init__(self): super().__init__() - self._decay_task: asyncio.Task = None + self._decay_task: asyncio.Task | None = None async def _decay_reply_willing(self): """定期衰减回复意愿""" while True: await asyncio.sleep(1) for chat_id in self.chat_reply_willing: - self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.9) + self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9) async def async_task_starter(self): if self._decay_task is None: @@ -23,35 +25,33 @@ class ClassicalWillingManager(BaseWillingManager): chat_id = willing_info.chat_id current_willing = self.chat_reply_willing.get(chat_id, 0) - interested_rate = willing_info.interested_rate * self.global_config.response_interested_rate_amplifier + interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier if interested_rate > 0.4: current_willing += interested_rate - 0.3 - if willing_info.is_mentioned_bot and current_willing < 1.0: - current_willing += 1 - elif willing_info.is_mentioned_bot: - current_willing += 0.05 + if willing_info.is_mentioned_bot: + current_willing += 1 if current_willing < 1.0 else 0.05 is_emoji_not_reply = False if willing_info.is_emoji: - if self.global_config.emoji_response_penalty != 0: - current_willing *= self.global_config.emoji_response_penalty + if global_config.normal_chat.emoji_response_penalty != 0: + current_willing *= global_config.normal_chat.emoji_response_penalty else: is_emoji_not_reply = True self.chat_reply_willing[chat_id] = min(current_willing, 3.0) reply_probability = min( - max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1 + max((current_willing - 0.5), 0.01) * global_config.normal_chat.response_willing_amplifier * 2, 1 ) # 检查群组权限(如果是群聊) if ( willing_info.group_info - and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups + and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups ): - reply_probability = reply_probability / self.global_config.down_frequency_rate + reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate if is_emoji_not_reply: reply_probability = 0 @@ -61,7 +61,7 @@ class ClassicalWillingManager(BaseWillingManager): async def before_generate_reply_handle(self, message_id): chat_id = self.ongoing_messages[message_id].chat_id current_willing = self.chat_reply_willing.get(chat_id, 0) - self.chat_reply_willing[chat_id] = max(0, current_willing - 1.8) + self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8) async def after_generate_reply_handle(self, message_id): if message_id not in self.ongoing_messages: @@ -70,7 +70,7 @@ class ClassicalWillingManager(BaseWillingManager): chat_id = self.ongoing_messages[message_id].chat_id current_willing = self.chat_reply_willing.get(chat_id, 0) if current_willing < 1: - self.chat_reply_willing[chat_id] = min(1, current_willing + 0.4) + self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4) async def bombing_buffer_message_handle(self, message_id): return await super().bombing_buffer_message_handle(message_id) diff --git a/src/chat/normal_chat/willing/mode_mxp.py b/src/chat/normal_chat/willing/mode_mxp.py index 78120ac53..1e7d5856d 100644 --- a/src/chat/normal_chat/willing/mode_mxp.py +++ b/src/chat/normal_chat/willing/mode_mxp.py @@ -19,6 +19,7 @@ Mxp 模式:梦溪畔独家赞助 下下策是询问一个菜鸟(@梦溪畔) """ +from src.config.config import global_config from .willing_manager import BaseWillingManager from typing import Dict import asyncio @@ -50,8 +51,6 @@ class MxpWillingManager(BaseWillingManager): self.mention_willing_gain = 0.6 # 提及意愿增益 self.interest_willing_gain = 0.3 # 兴趣意愿增益 - self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚 - self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数 self.single_chat_gain = 0.12 # 单聊增益 self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int) @@ -179,10 +178,10 @@ class MxpWillingManager(BaseWillingManager): probability = self._willing_to_probability(current_willing) if w_info.is_emoji: - probability *= self.emoji_response_penalty + probability *= global_config.normal_chat.emoji_response_penalty - if w_info.group_info and w_info.group_info.group_id in self.global_config.talk_frequency_down_groups: - probability /= self.down_frequency_rate + if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups: + probability /= global_config.normal_chat.down_frequency_rate self.temporary_willing = current_willing diff --git a/src/chat/normal_chat/willing/willing_manager.py b/src/chat/normal_chat/willing/willing_manager.py index 37e623d11..bbc5dcc0a 100644 --- a/src/chat/normal_chat/willing/willing_manager.py +++ b/src/chat/normal_chat/willing/willing_manager.py @@ -1,6 +1,6 @@ from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger from dataclasses import dataclass -from src.config.config import global_config, BotConfig +from src.config.config import global_config from src.chat.message_receive.chat_stream import ChatStream, GroupInfo from src.chat.message_receive.message import MessageRecv from src.chat.person_info.person_info import person_info_manager, PersonInfoManager @@ -93,7 +93,6 @@ class BaseWillingManager(ABC): self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id) self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id) self.lock = asyncio.Lock() - self.global_config: BotConfig = global_config self.logger: LoguruLogger = logger def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float): @@ -173,7 +172,7 @@ def init_willing_manager() -> BaseWillingManager: Returns: 对应mode的WillingManager实例 """ - mode = global_config.willing_mode.lower() + mode = global_config.normal_chat.willing_mode.lower() return BaseWillingManager.create(mode) diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index 605b86b23..aadbb1d2e 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -59,8 +59,9 @@ person_info_default = { class PersonInfoManager: def __init__(self): self.person_name_list = {} + # TODO: API-Adapter修改标记 self.qv_name_llm = LLMRequest( - model=global_config.llm_normal, + model=global_config.model.normal, max_tokens=256, request_type="qv_name", ) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 15b1e4fc6..de018bdb8 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -190,8 +190,8 @@ async def _build_readable_messages_internal( person_id = person_info_manager.get_person_id(platform, user_id) # 根据 replace_bot_name 参数决定是否替换机器人名称 - if replace_bot_name and user_id == global_config.BOT_QQ: - person_name = f"{global_config.BOT_NICKNAME}(你)" + if replace_bot_name and user_id == global_config.bot.qq_account: + person_name = f"{global_config.bot.nickname}(你)" else: person_name = await person_info_manager.get_value(person_id, "person_name") @@ -427,7 +427,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: output_lines = [] def get_anon_name(platform, user_id): - if user_id == global_config.BOT_QQ: + if user_id == global_config.bot.qq_account: return "SELF" person_id = person_info_manager.get_person_id(platform, user_id) if person_id not in person_map: @@ -501,7 +501,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: user_id = user_info.get("user_id") # 检查必要信息是否存在 且 不是机器人自己 - if not all([platform, user_id]) or user_id == global_config.BOT_QQ: + if not all([platform, user_id]) or user_id == global_config.bot.qq_account: continue person_id = person_info_manager.get_person_id(platform, user_id) diff --git a/src/chat/utils/info_catcher.py b/src/chat/utils/info_catcher.py index 174bb5b49..a5b04d704 100644 --- a/src/chat/utils/info_catcher.py +++ b/src/chat/utils/info_catcher.py @@ -9,7 +9,6 @@ from typing import List class InfoCatcher: def __init__(self): self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~ - self.context_length = global_config.observation_context_size self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~ self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~ @@ -143,7 +142,7 @@ class InfoCatcher: messages_before = ( db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}}) .sort("time", -1) - .limit(self.context_length * 3) + .limit(global_config.chat.observation_context_size * 3) ) # 获取更多历史信息 return list(messages_before) diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 8fe8334b8..58eb49de8 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -43,8 +43,8 @@ def db_message_to_str(message_dict: dict) -> str: def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: """检查消息是否提到了机器人""" - keywords = [global_config.BOT_NICKNAME] - nicknames = global_config.BOT_ALIAS_NAMES + keywords = [global_config.bot.nickname] + nicknames = global_config.bot.alias_names reply_probability = 0.0 is_at = False is_mentioned = False @@ -64,18 +64,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: ) # 判断是否被@ - if re.search(f"@[\s\S]*?(id:{global_config.BOT_QQ})", message.processed_plain_text): + if re.search(f"@[\s\S]*?(id:{global_config.bot.qq_account})", message.processed_plain_text): is_at = True is_mentioned = True - if is_at and global_config.at_bot_inevitable_reply: + if is_at and global_config.normal_chat.at_bot_inevitable_reply: reply_probability = 1.0 logger.info("被@,回复概率设置为100%") else: if not is_mentioned: # 判断是否被回复 if re.match( - f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\):[\s\S]*?],说:", message.processed_plain_text + f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\):[\s\S]*?],说:", message.processed_plain_text ): is_mentioned = True else: @@ -88,7 +88,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: for nickname in nicknames: if nickname in message_content: is_mentioned = True - if is_mentioned and global_config.mentioned_bot_inevitable_reply: + if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply: reply_probability = 1.0 logger.info("被提及,回复概率设置为100%") return is_mentioned, reply_probability @@ -96,7 +96,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: async def get_embedding(text, request_type="embedding"): """获取文本的embedding向量""" - llm = LLMRequest(model=global_config.embedding, request_type=request_type) + # TODO: API-Adapter修改标记 + llm = LLMRequest(model=global_config.model.embedding, request_type=request_type) # return llm.get_embedding_sync(text) try: embedding = await llm.get_embedding(text) @@ -163,7 +164,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li user_info = UserInfo.from_dict(msg_db_data["user_info"]) if ( (user_info.platform, user_info.user_id) != sender - and user_info.user_id != global_config.BOT_QQ + and user_info.user_id != global_config.bot.qq_account and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group and len(who_chat_in_group) < 5 ): # 排除重复,排除消息发送者,排除bot,限制加载的关系数目 @@ -321,7 +322,7 @@ def random_remove_punctuation(text: str) -> str: def process_llm_response(text: str) -> list[str]: # 先保护颜文字 - if global_config.enable_kaomoji_protection: + if global_config.response_splitter.enable_kaomoji_protection: protected_text, kaomoji_mapping = protect_kaomoji(text) logger.trace(f"保护颜文字后的文本: {protected_text}") else: @@ -340,8 +341,8 @@ def process_llm_response(text: str) -> list[str]: logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}") # 对清理后的文本进行进一步处理 - max_length = global_config.response_max_length * 2 - max_sentence_num = global_config.response_max_sentence_num + max_length = global_config.response_splitter.max_length * 2 + max_sentence_num = global_config.response_splitter.max_sentence_num # 如果基本上是中文,则进行长度过滤 if get_western_ratio(cleaned_text) < 0.1: if len(cleaned_text) > max_length: @@ -349,20 +350,20 @@ def process_llm_response(text: str) -> list[str]: return ["懒得说"] typo_generator = ChineseTypoGenerator( - error_rate=global_config.chinese_typo_error_rate, - min_freq=global_config.chinese_typo_min_freq, - tone_error_rate=global_config.chinese_typo_tone_error_rate, - word_replace_rate=global_config.chinese_typo_word_replace_rate, + error_rate=global_config.chinese_typo.error_rate, + min_freq=global_config.chinese_typo.min_freq, + tone_error_rate=global_config.chinese_typo.tone_error_rate, + word_replace_rate=global_config.chinese_typo.word_replace_rate, ) - if global_config.enable_response_splitter: + if global_config.response_splitter.enable: split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) else: split_sentences = [cleaned_text] sentences = [] for sentence in split_sentences: - if global_config.chinese_typo_enable: + if global_config.chinese_typo.enable: typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence) sentences.append(typoed_text) if typo_corrections: @@ -372,7 +373,7 @@ def process_llm_response(text: str) -> list[str]: if len(sentences) > max_sentence_num: logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") - return [f"{global_config.BOT_NICKNAME}不知道哦"] + return [f"{global_config.bot.nickname}不知道哦"] # if extracted_contents: # for content in extracted_contents: diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 455038246..6958bc26b 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -36,7 +36,7 @@ class ImageManager: self._ensure_description_collection() self._ensure_image_dir() self._initialized = True - self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image") + self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image") def _ensure_image_dir(self): """确保图像存储目录存在""" @@ -134,7 +134,7 @@ class ImageManager: return f"[表情包,含义看起来是:{cached_description}]" # 根据配置决定是否保存图片 - if global_config.save_emoji: + if global_config.emoji.save_emoji: # 生成文件名和路径 timestamp = int(time.time()) filename = f"{timestamp}_{image_hash[:8]}.{image_format}" @@ -200,7 +200,7 @@ class ImageManager: return "[图片]" # 根据配置决定是否保存图片 - if global_config.save_pic: + if global_config.emoji.save_pic: # 生成文件名和路径 timestamp = int(time.time()) filename = f"{timestamp}_{image_hash[:8]}.{image_format}" diff --git a/src/common/remote.py b/src/common/remote.py index 1d26df01b..b1108be9c 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask): info_dict = { "os_type": "Unknown", "py_version": platform.python_version(), - "mmc_version": global_config.MAI_VERSION, + "mmc_version": global_config.MMC_VERSION, } match platform.system(): @@ -133,10 +133,9 @@ class TelemetryHeartBeatTask(AsyncTask): async def run(self): # 发送心跳 - if global_config.remote_enable: - if self.client_uuid is None: - if not await self._req_uuid(): - logger.error("获取UUID失败,跳过此次心跳") - return + if global_config.telemetry.enable: + if self.client_uuid is None and not await self._req_uuid(): + logger.error("获取UUID失败,跳过此次心跳") + return await self._send_heartbeat() diff --git a/src/config/config.py b/src/config/config.py index b186f3b83..e6b7c5326 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,64 +1,68 @@ import os -import re -from dataclasses import dataclass, field -from typing import Dict, List, Optional +from dataclasses import field, dataclass -import tomli import tomlkit import shutil from datetime import datetime -from pathlib import Path -from packaging import version -from packaging.version import Version, InvalidVersion -from packaging.specifiers import SpecifierSet, InvalidSpecifier + +from tomlkit import TOMLDocument +from tomlkit.items import Table from src.common.logger_manager import get_logger from rich.traceback import install +from src.config.config_base import ConfigBase +from src.config.official_configs import ( + BotConfig, + ChatTargetConfig, + PersonalityConfig, + IdentityConfig, + PlatformsConfig, + ChatConfig, + NormalChatConfig, + FocusChatConfig, + EmojiConfig, + MemoryConfig, + MoodConfig, + KeywordReactionConfig, + ChineseTypoConfig, + ResponseSplitterConfig, + TelemetryConfig, + ExperimentalConfig, + ModelConfig, +) + install(extra_lines=3) # 配置主程序日志格式 logger = get_logger("config") -# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 -is_test = True -mai_version_main = "0.6.4" -mai_version_fix = "snapshot-1" +CONFIG_DIR = "config" +TEMPLATE_DIR = "template" -if mai_version_fix: - if is_test: - mai_version = f"test-{mai_version_main}-{mai_version_fix}" - else: - mai_version = f"{mai_version_main}-{mai_version_fix}" -else: - if is_test: - mai_version = f"test-{mai_version_main}" - else: - mai_version = mai_version_main +# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 +# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ +MMC_VERSION = "0.7.0-snapshot.1" def update_config(): # 获取根目录路径 - root_dir = Path(__file__).parent.parent.parent - template_dir = root_dir / "template" - config_dir = root_dir / "config" - old_config_dir = config_dir / "old" + old_config_dir = f"{CONFIG_DIR}/old" # 定义文件路径 - template_path = template_dir / "bot_config_template.toml" - old_config_path = config_dir / "bot_config.toml" - new_config_path = config_dir / "bot_config.toml" + template_path = f"{TEMPLATE_DIR}/bot_config_template.toml" + old_config_path = f"{CONFIG_DIR}/bot_config.toml" + new_config_path = f"{CONFIG_DIR}/bot_config.toml" # 检查配置文件是否存在 - if not old_config_path.exists(): + if not os.path.exists(old_config_path): logger.info("配置文件不存在,从模板创建新配置") - # 创建文件夹 - old_config_dir.mkdir(parents=True, exist_ok=True) - shutil.copy2(template_path, old_config_path) + os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 + shutil.copy2(template_path, old_config_path) # 复制模板文件 logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") # 如果是新创建的配置文件,直接返回 - return quit() + quit() # 读取旧配置文件和模板文件 with open(old_config_path, "r", encoding="utf-8") as f: @@ -75,13 +79,15 @@ def update_config(): return else: logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") + else: + logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") # 创建old目录(如果不存在) - old_config_dir.mkdir(exist_ok=True) + os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml" + old_backup_path = f"{old_config_dir}/bot_config_{timestamp}.toml" # 移动旧配置文件到old目录 shutil.move(old_config_path, old_backup_path) @@ -91,24 +97,23 @@ def update_config(): shutil.copy2(template_path, new_config_path) logger.info(f"已创建新配置文件: {new_config_path}") - # 递归更新配置 - def update_dict(target, source): + def update_dict(target: TOMLDocument | dict, source: TOMLDocument | dict): + """ + 将source字典的值更新到target字典中(如果target中存在相同的键) + """ for key, value in source.items(): # 跳过version字段的更新 if key == "version": continue if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)): + if isinstance(value, dict) and isinstance(target[key], (dict, Table)): update_dict(target[key], value) else: try: # 对数组类型进行特殊处理 if isinstance(value, list): # 如果是空数组,确保它保持为空数组 - if not value: - target[key] = tomlkit.array() - else: - target[key] = tomlkit.array(value) + target[key] = tomlkit.array(str(value)) if value else tomlkit.array() else: # 其他类型使用item方法创建新值 target[key] = tomlkit.item(value) @@ -123,619 +128,57 @@ def update_config(): # 保存更新后的配置(保留注释和格式) with open(new_config_path, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(new_config)) - logger.info("配置文件更新完成") + logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + quit() @dataclass -class BotConfig: - """机器人配置类""" - - INNER_VERSION: Version = None - MAI_VERSION: str = mai_version # 硬编码的版本信息 - - # bot - BOT_QQ: Optional[str] = "114514" - BOT_NICKNAME: Optional[str] = None - BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它 - - # group - talk_allowed_groups = set() - talk_frequency_down_groups = set() - ban_user_id = set() - - # personality - personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋 - personality_sides: List[str] = field( - default_factory=lambda: [ - "用一句话或几句话描述人格的一些侧面", - "用一句话或几句话描述人格的一些侧面", - "用一句话或几句话描述人格的一些侧面", - ] - ) - expression_style = "描述麦麦说话的表达风格,表达习惯" - # identity - identity_detail: List[str] = field( - default_factory=lambda: [ - "身份特点", - "身份特点", - ] - ) - height: int = 170 # 身高 单位厘米 - weight: int = 50 # 体重 单位千克 - age: int = 20 # 年龄 单位岁 - gender: str = "男" # 性别 - appearance: str = "用几句话描述外貌特征" # 外貌特征 - - # chat - allow_focus_mode: bool = True # 是否允许专注聊天状态 - - base_normal_chat_num: int = 3 # 最多允许多少个群进行普通聊天 - base_focused_chat_num: int = 2 # 最多允许多少个群进行专注聊天 - - observation_context_size: int = 12 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩 - - message_buffer: bool = True # 消息缓冲器 - - ban_words = set() - ban_msgs_regex = set() - - # focus_chat - reply_trigger_threshold: float = 3.0 # 心流聊天触发阈值,越低越容易触发 - default_decay_rate_per_second: float = 0.98 # 默认衰减率,越大衰减越慢 - consecutive_no_reply_threshold = 3 - - compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 - compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除 - - # normal_chat - model_reasoning_probability: float = 0.7 # 麦麦回答时选择推理模型(主要)模型概率 - model_normal_probability: float = 0.3 # 麦麦回答时选择一般模型(次要)模型概率 - - emoji_chance: float = 0.2 # 发送表情包的基础概率 - thinking_timeout: int = 120 # 思考时间 - - willing_mode: str = "classical" # 意愿模式 - response_willing_amplifier: float = 1.0 # 回复意愿放大系数 - response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数 - down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数 - emoji_response_penalty: float = 0.0 # 表情包回复惩罚 - mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复 - at_bot_inevitable_reply: bool = False # @bot 必然回复 - - # emoji - max_emoji_num: int = 200 # 表情包最大数量 - max_reach_deletion: bool = True # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包 - EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) - - save_pic: bool = False # 是否保存图片 - save_emoji: bool = False # 是否保存表情包 - steal_emoji: bool = True # 是否偷取表情包,让麦麦可以发送她保存的这些表情包 - - EMOJI_CHECK: bool = False # 是否开启过滤 - EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求 - - # memory - build_memory_interval: int = 600 # 记忆构建间隔(秒) - memory_build_distribution: list = field( - default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4] - ) # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重 - build_memory_sample_num: int = 10 # 记忆构建采样数量 - build_memory_sample_length: int = 20 # 记忆构建采样长度 - memory_compress_rate: float = 0.1 # 记忆压缩率 - - forget_memory_interval: int = 600 # 记忆遗忘间隔(秒) - memory_forget_time: int = 24 # 记忆遗忘时间(小时) - memory_forget_percentage: float = 0.01 # 记忆遗忘比例 - - consolidate_memory_interval: int = 1000 # 记忆整合间隔(秒) - consolidation_similarity_threshold: float = 0.7 # 相似度阈值 - consolidate_memory_percentage: float = 0.01 # 检查节点比例 - - memory_ban_words: list = field( - default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"] - ) # 添加新的配置项默认值 - - # mood - mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒 - mood_decay_rate: float = 0.95 # 情绪衰减率 - mood_intensity_factor: float = 0.7 # 情绪强度因子 - - # keywords - keywords_reaction_rules = [] # 关键词回复规则 - - # chinese_typo - chinese_typo_enable = True # 是否启用中文错别字生成器 - chinese_typo_error_rate = 0.03 # 单字替换概率 - chinese_typo_min_freq = 7 # 最小字频阈值 - chinese_typo_tone_error_rate = 0.2 # 声调错误概率 - chinese_typo_word_replace_rate = 0.02 # 整词替换概率 - - # response_splitter - enable_kaomoji_protection = False # 是否启用颜文字保护 - enable_response_splitter = True # 是否启用回复分割器 - response_max_length = 100 # 回复允许的最大长度 - response_max_sentence_num = 3 # 回复允许的最大句子数 - - model_max_output_length: int = 800 # 最大回复长度 - - # remote - remote_enable: bool = True # 是否启用远程控制 - - # experimental - enable_friend_chat: bool = False # 是否启用好友聊天 - # enable_think_flow: bool = False # 是否启用思考流程 - talk_allowed_private = set() - enable_pfc_chatting: bool = False # 是否启用PFC聊天 - - # 模型配置 - llm_reasoning: dict[str, str] = field(default_factory=lambda: {}) - # llm_reasoning_minor: dict[str, str] = field(default_factory=lambda: {}) - llm_normal: Dict[str, str] = field(default_factory=lambda: {}) - llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {}) - llm_summary: Dict[str, str] = field(default_factory=lambda: {}) - embedding: Dict[str, str] = field(default_factory=lambda: {}) - vlm: Dict[str, str] = field(default_factory=lambda: {}) - moderation: Dict[str, str] = field(default_factory=lambda: {}) - - llm_observation: Dict[str, str] = field(default_factory=lambda: {}) - llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {}) - llm_heartflow: Dict[str, str] = field(default_factory=lambda: {}) - llm_tool_use: Dict[str, str] = field(default_factory=lambda: {}) - llm_plan: Dict[str, str] = field(default_factory=lambda: {}) - - api_urls: Dict[str, str] = field(default_factory=lambda: {}) - - @staticmethod - def get_config_dir() -> str: - """获取配置文件目录""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.abspath(os.path.join(current_dir, "..", "..")) - config_dir = os.path.join(root_dir, "config") - if not os.path.exists(config_dir): - os.makedirs(config_dir) - return config_dir - - @classmethod - def convert_to_specifierset(cls, value: str) -> SpecifierSet: - """将 字符串 版本表达式转换成 SpecifierSet - Args: - value[str]: 版本表达式(字符串) - Returns: - SpecifierSet - """ - - try: - converted = SpecifierSet(value) - except InvalidSpecifier: - logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码") - exit(1) - - return converted - - @classmethod - def get_config_version(cls, toml: dict) -> Version: - """提取配置文件的 SpecifierSet 版本数据 - Args: - toml[dict]: 输入的配置文件字典 - Returns: - Version - """ - - if "inner" in toml: - try: - config_version: str = toml["inner"]["version"] - except KeyError as e: - logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件") - raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e - else: - toml["inner"] = {"version": "0.0.0"} - config_version = toml["inner"]["version"] - - try: - ver = version.parse(config_version) - except InvalidVersion as e: - logger.error( - "配置文件中 inner段 的 version 键是错误的版本描述\n" - "请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n" - "本项目在不同的版本下有不同的模板,请注意识别" - ) - raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e - - return ver - - @classmethod - def load_config(cls, config_path: str = None) -> "BotConfig": - """从TOML配置文件加载配置""" - config = cls() - - def personality(parent: dict): - personality_config = parent["personality"] - if config.INNER_VERSION in SpecifierSet(">=1.2.4"): - config.personality_core = personality_config.get("personality_core", config.personality_core) - config.personality_sides = personality_config.get("personality_sides", config.personality_sides) - if config.INNER_VERSION in SpecifierSet(">=1.7.0"): - config.expression_style = personality_config.get("expression_style", config.expression_style) - - def identity(parent: dict): - identity_config = parent["identity"] - if config.INNER_VERSION in SpecifierSet(">=1.2.4"): - config.identity_detail = identity_config.get("identity_detail", config.identity_detail) - config.height = identity_config.get("height", config.height) - config.weight = identity_config.get("weight", config.weight) - config.age = identity_config.get("age", config.age) - config.gender = identity_config.get("gender", config.gender) - config.appearance = identity_config.get("appearance", config.appearance) - - def emoji(parent: dict): - emoji_config = parent["emoji"] - config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) - config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT) - config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK) - if config.INNER_VERSION in SpecifierSet(">=1.1.1"): - config.max_emoji_num = emoji_config.get("max_emoji_num", config.max_emoji_num) - config.max_reach_deletion = emoji_config.get("max_reach_deletion", config.max_reach_deletion) - if config.INNER_VERSION in SpecifierSet(">=1.4.2"): - config.save_pic = emoji_config.get("save_pic", config.save_pic) - config.save_emoji = emoji_config.get("save_emoji", config.save_emoji) - config.steal_emoji = emoji_config.get("steal_emoji", config.steal_emoji) - - def bot(parent: dict): - # 机器人基础配置 - bot_config = parent["bot"] - bot_qq = bot_config.get("qq") - config.BOT_QQ = str(bot_qq) - config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME) - config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES) - - def chat(parent: dict): - chat_config = parent["chat"] - config.allow_focus_mode = chat_config.get("allow_focus_mode", config.allow_focus_mode) - config.base_normal_chat_num = chat_config.get("base_normal_chat_num", config.base_normal_chat_num) - config.base_focused_chat_num = chat_config.get("base_focused_chat_num", config.base_focused_chat_num) - config.observation_context_size = chat_config.get( - "observation_context_size", config.observation_context_size - ) - config.message_buffer = chat_config.get("message_buffer", config.message_buffer) - config.ban_words = chat_config.get("ban_words", config.ban_words) - for r in chat_config.get("ban_msgs_regex", config.ban_msgs_regex): - config.ban_msgs_regex.add(re.compile(r)) - - def normal_chat(parent: dict): - normal_chat_config = parent["normal_chat"] - config.model_reasoning_probability = normal_chat_config.get( - "model_reasoning_probability", config.model_reasoning_probability - ) - config.model_normal_probability = normal_chat_config.get( - "model_normal_probability", config.model_normal_probability - ) - config.emoji_chance = normal_chat_config.get("emoji_chance", config.emoji_chance) - config.thinking_timeout = normal_chat_config.get("thinking_timeout", config.thinking_timeout) - - config.willing_mode = normal_chat_config.get("willing_mode", config.willing_mode) - config.response_willing_amplifier = normal_chat_config.get( - "response_willing_amplifier", config.response_willing_amplifier - ) - config.response_interested_rate_amplifier = normal_chat_config.get( - "response_interested_rate_amplifier", config.response_interested_rate_amplifier - ) - config.down_frequency_rate = normal_chat_config.get("down_frequency_rate", config.down_frequency_rate) - config.emoji_response_penalty = normal_chat_config.get( - "emoji_response_penalty", config.emoji_response_penalty - ) - - config.mentioned_bot_inevitable_reply = normal_chat_config.get( - "mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply - ) - config.at_bot_inevitable_reply = normal_chat_config.get( - "at_bot_inevitable_reply", config.at_bot_inevitable_reply - ) - - def focus_chat(parent: dict): - focus_chat_config = parent["focus_chat"] - config.compressed_length = focus_chat_config.get("compressed_length", config.compressed_length) - config.compress_length_limit = focus_chat_config.get("compress_length_limit", config.compress_length_limit) - config.reply_trigger_threshold = focus_chat_config.get( - "reply_trigger_threshold", config.reply_trigger_threshold - ) - config.default_decay_rate_per_second = focus_chat_config.get( - "default_decay_rate_per_second", config.default_decay_rate_per_second - ) - config.consecutive_no_reply_threshold = focus_chat_config.get( - "consecutive_no_reply_threshold", config.consecutive_no_reply_threshold - ) - - def model(parent: dict): - # 加载模型配置 - model_config: dict = parent["model"] - - config_list = [ - "llm_reasoning", - # "llm_reasoning_minor", - "llm_normal", - "llm_topic_judge", - "llm_summary", - "vlm", - "embedding", - "llm_tool_use", - "llm_observation", - "llm_sub_heartflow", - "llm_plan", - "llm_heartflow", - "llm_PFC_action_planner", - "llm_PFC_chat", - "llm_PFC_reply_checker", - ] - - for item in config_list: - if item in model_config: - cfg_item: dict = model_config[item] - - # base_url 的例子: SILICONFLOW_BASE_URL - # key 的例子: SILICONFLOW_KEY - cfg_target = { - "name": "", - "base_url": "", - "key": "", - "stream": False, - "pri_in": 0, - "pri_out": 0, - "temp": 0.7, - } - - if config.INNER_VERSION in SpecifierSet("<=0.0.0"): - cfg_target = cfg_item - - elif config.INNER_VERSION in SpecifierSet(">=0.0.1"): - stable_item = ["name", "pri_in", "pri_out"] - - stream_item = ["stream"] - if config.INNER_VERSION in SpecifierSet(">=1.0.1"): - stable_item.append("stream") - - pricing_item = ["pri_in", "pri_out"] - - # 从配置中原始拷贝稳定字段 - for i in stable_item: - # 如果 字段 属于计费项 且获取不到,那默认值是 0 - if i in pricing_item and i not in cfg_item: - cfg_target[i] = 0 - - if i in stream_item and i not in cfg_item: - cfg_target[i] = False - - else: - # 没有特殊情况则原样复制 - try: - cfg_target[i] = cfg_item[i] - except KeyError as e: - logger.error(f"{item} 中的必要字段不存在,请检查") - raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e - - # 如果配置中有temp参数,就使用配置中的值 - if "temp" in cfg_item: - cfg_target["temp"] = cfg_item["temp"] - else: - # 如果没有temp参数,就删除默认值 - cfg_target.pop("temp", None) - - provider = cfg_item.get("provider") - if provider is None: - logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查") - raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查") - - cfg_target["base_url"] = f"{provider}_BASE_URL" - cfg_target["key"] = f"{provider}_KEY" - - # 如果 列表中的项目在 model_config 中,利用反射来设置对应项目 - setattr(config, item, cfg_target) - else: - logger.error(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件") - raise KeyError(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件") - - def memory(parent: dict): - memory_config = parent["memory"] - config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval) - config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval) - config.memory_ban_words = set(memory_config.get("memory_ban_words", [])) - config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time) - config.memory_forget_percentage = memory_config.get( - "memory_forget_percentage", config.memory_forget_percentage - ) - config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate) - if config.INNER_VERSION in SpecifierSet(">=0.0.11"): - config.memory_build_distribution = memory_config.get( - "memory_build_distribution", config.memory_build_distribution - ) - config.build_memory_sample_num = memory_config.get( - "build_memory_sample_num", config.build_memory_sample_num - ) - config.build_memory_sample_length = memory_config.get( - "build_memory_sample_length", config.build_memory_sample_length - ) - if config.INNER_VERSION in SpecifierSet(">=1.5.1"): - config.consolidate_memory_interval = memory_config.get( - "consolidate_memory_interval", config.consolidate_memory_interval - ) - config.consolidation_similarity_threshold = memory_config.get( - "consolidation_similarity_threshold", config.consolidation_similarity_threshold - ) - config.consolidate_memory_percentage = memory_config.get( - "consolidate_memory_percentage", config.consolidate_memory_percentage - ) - - def remote(parent: dict): - remote_config = parent["remote"] - config.remote_enable = remote_config.get("enable", config.remote_enable) - - def mood(parent: dict): - mood_config = parent["mood"] - config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval) - config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate) - config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor) - - def keywords_reaction(parent: dict): - keywords_reaction_config = parent["keywords_reaction"] - if keywords_reaction_config.get("enable", False): - config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules) - for rule in config.keywords_reaction_rules: - if rule.get("enable", False) and "regex" in rule: - rule["regex"] = [re.compile(r) for r in rule.get("regex", [])] - - def chinese_typo(parent: dict): - chinese_typo_config = parent["chinese_typo"] - config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable) - config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate) - config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq) - config.chinese_typo_tone_error_rate = chinese_typo_config.get( - "tone_error_rate", config.chinese_typo_tone_error_rate - ) - config.chinese_typo_word_replace_rate = chinese_typo_config.get( - "word_replace_rate", config.chinese_typo_word_replace_rate - ) - - def response_splitter(parent: dict): - response_splitter_config = parent["response_splitter"] - config.enable_response_splitter = response_splitter_config.get( - "enable_response_splitter", config.enable_response_splitter - ) - config.response_max_length = response_splitter_config.get("response_max_length", config.response_max_length) - config.response_max_sentence_num = response_splitter_config.get( - "response_max_sentence_num", config.response_max_sentence_num - ) - if config.INNER_VERSION in SpecifierSet(">=1.4.2"): - config.enable_kaomoji_protection = response_splitter_config.get( - "enable_kaomoji_protection", config.enable_kaomoji_protection - ) - if config.INNER_VERSION in SpecifierSet(">=1.6.0"): - config.model_max_output_length = response_splitter_config.get( - "model_max_output_length", config.model_max_output_length - ) - - def groups(parent: dict): - groups_config = parent["groups"] - # config.talk_allowed_groups = set(groups_config.get("talk_allowed", [])) - config.talk_allowed_groups = set(str(group) for group in groups_config.get("talk_allowed", [])) - # config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", [])) - config.talk_frequency_down_groups = set( - str(group) for group in groups_config.get("talk_frequency_down", []) - ) - # config.ban_user_id = set(groups_config.get("ban_user_id", [])) - config.ban_user_id = set(str(user) for user in groups_config.get("ban_user_id", [])) - - def experimental(parent: dict): - experimental_config = parent["experimental"] - config.enable_friend_chat = experimental_config.get("enable_friend_chat", config.enable_friend_chat) - # config.enable_think_flow = experimental_config.get("enable_think_flow", config.enable_think_flow) - config.talk_allowed_private = set(str(user) for user in experimental_config.get("talk_allowed_private", [])) - if config.INNER_VERSION in SpecifierSet(">=1.1.0"): - config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting) - - # 版本表达式:>=1.0.0,<2.0.0 - # 允许字段:func: method, support: str, notice: str, necessary: bool - # 如果使用 notice 字段,在该组配置加载时,会展示该字段对用户的警示 - # 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以 - # 正常执行程序,但是会看到这条自定义提示 - - # 版本格式:主版本号.次版本号.修订号,版本号递增规则如下: - # 主版本号:当你做了不兼容的 API 修改, - # 次版本号:当你做了向下兼容的功能性新增, - # 修订号:当你做了向下兼容的问题修正。 - # 先行版本号及版本编译信息可以加到"主版本号.次版本号.修订号"的后面,作为延伸。 - - # 如果你做了break的修改,就应该改动主版本号 - # 如果做了一个兼容修改,就不应该要求这个选项是必须的! - include_configs = { - "bot": {"func": bot, "support": ">=0.0.0"}, - "groups": {"func": groups, "support": ">=0.0.0"}, - "personality": {"func": personality, "support": ">=0.0.0"}, - "identity": {"func": identity, "support": ">=1.2.4"}, - "emoji": {"func": emoji, "support": ">=0.0.0"}, - "model": {"func": model, "support": ">=0.0.0"}, - "memory": {"func": memory, "support": ">=0.0.0", "necessary": False}, - "mood": {"func": mood, "support": ">=0.0.0"}, - "remote": {"func": remote, "support": ">=0.0.10", "necessary": False}, - "keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False}, - "chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False}, - "response_splitter": {"func": response_splitter, "support": ">=0.0.11", "necessary": False}, - "experimental": {"func": experimental, "support": ">=0.0.11", "necessary": False}, - "chat": {"func": chat, "support": ">=1.6.0", "necessary": False}, - "normal_chat": {"func": normal_chat, "support": ">=1.6.0", "necessary": False}, - "focus_chat": {"func": focus_chat, "support": ">=1.6.0", "necessary": False}, - } - - # 原地修改,将 字符串版本表达式 转换成 版本对象 - for key in include_configs: - item_support = include_configs[key]["support"] - include_configs[key]["support"] = cls.convert_to_specifierset(item_support) - - if os.path.exists(config_path): - with open(config_path, "rb") as f: - try: - toml_dict = tomli.load(f) - except tomli.TOMLDecodeError as e: - logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}") - exit(1) - - # 获取配置文件版本 - config.INNER_VERSION = cls.get_config_version(toml_dict) - - # 如果在配置中找到了需要的项,调用对应项的闭包函数处理 - for key in include_configs: - if key in toml_dict: - group_specifierset: SpecifierSet = include_configs[key]["support"] - - # 检查配置文件版本是否在支持范围内 - if config.INNER_VERSION in group_specifierset: - # 如果版本在支持范围内,检查是否存在通知 - if "notice" in include_configs[key]: - logger.warning(include_configs[key]["notice"]) - - include_configs[key]["func"](toml_dict) - - else: - # 如果版本不在支持范围内,崩溃并提示用户 - logger.error( - f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n" - f"当前程序仅支持以下版本范围: {group_specifierset}" - ) - raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}") - - # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理 - elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False: - # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 - if key == "keywords_reaction": - pass - - else: - # 如果用户根本没有需要的配置项,提示缺少配置 - logger.error(f"配置文件中缺少必需的字段: '{key}'") - raise KeyError(f"配置文件中缺少必需的字段: '{key}'") - - # identity_detail字段非空检查 - if not config.identity_detail: - logger.error("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串") - raise ValueError("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串") - - logger.success(f"成功加载配置文件: {config_path}") - - return config +class Config(ConfigBase): + """总配置类""" + + MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息 + + bot: BotConfig + chat_target: ChatTargetConfig + personality: PersonalityConfig + identity: IdentityConfig + platforms: PlatformsConfig + chat: ChatConfig + normal_chat: NormalChatConfig + focus_chat: FocusChatConfig + emoji: EmojiConfig + memory: MemoryConfig + mood: MoodConfig + keyword_reaction: KeywordReactionConfig + chinese_typo: ChineseTypoConfig + response_splitter: ResponseSplitterConfig + telemetry: TelemetryConfig + experimental: ExperimentalConfig + model: ModelConfig + + +def load_config(config_path: str) -> Config: + """ + 加载配置文件 + :param config_path: 配置文件路径 + :return: Config对象 + """ + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + config_data = tomlkit.load(f) + + # 创建Config对象 + try: + return Config.from_dict(config_data) + except Exception as e: + logger.critical("配置文件解析失败") + raise e # 获取配置文件路径 -logger.info(f"MaiCore当前版本: {mai_version}") +logger.info(f"MaiCore当前版本: {MMC_VERSION}") update_config() -bot_config_floder_path = BotConfig.get_config_dir() -logger.info(f"正在品鉴配置文件目录: {bot_config_floder_path}") - -bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml") - -if os.path.exists(bot_config_path): - # 如果开发环境配置文件不存在,则使用默认配置文件 - logger.info(f"异常的新鲜,异常的美味: {bot_config_path}") -else: - # 配置文件不存在 - logger.error("配置文件不存在,请检查路径: {bot_config_path}") - raise FileNotFoundError(f"配置文件不存在: {bot_config_path}") - -global_config = BotConfig.load_config(config_path=bot_config_path) +logger.info("正在品鉴配置文件...") +global_config = load_config(config_path=f"{CONFIG_DIR}/bot_config.toml") +logger.info("非常的新鲜,非常的美味!") diff --git a/src/config/config_base.py b/src/config/config_base.py new file mode 100644 index 000000000..92f6cf9d4 --- /dev/null +++ b/src/config/config_base.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass, fields, MISSING +from typing import TypeVar, Type, Any, get_origin, get_args + +T = TypeVar("T", bound="ConfigBase") + +TOML_DICT_TYPE = { + int, + float, + str, + bool, + list, + dict, +} + + +@dataclass +class ConfigBase: + """配置类的基类""" + + @classmethod + def from_dict(cls: Type[T], data: dict[str, Any]) -> T: + """从字典加载配置字段""" + if not isinstance(data, dict): + raise TypeError(f"Expected a dictionary, got {type(data).__name__}") + + init_args: dict[str, Any] = {} + + for f in fields(cls): + field_name = f.name + + if field_name.startswith("_"): + # 跳过以 _ 开头的字段 + continue + + if field_name not in data: + if f.default is not MISSING or f.default_factory is not MISSING: + # 跳过未提供且有默认值/默认构造方法的字段 + continue + else: + raise ValueError(f"Missing required field: '{field_name}'") + + value = data[field_name] + field_type = f.type + + try: + init_args[field_name] = cls._convert_field(value, field_type) + except TypeError as e: + raise TypeError(f"Field '{field_name}' has a type error: {e}") from e + except Exception as e: + raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e + + return cls(**init_args) + + @classmethod + def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any: + """ + 转换字段值为指定类型 + + 1. 对于嵌套的 dataclass,递归调用相应的 from_dict 方法 + 2. 对于泛型集合类型(list, set, tuple),递归转换每个元素 + 3. 对于基础类型(int, str, float, bool),直接转换 + 4. 对于其他类型,尝试直接转换,如果失败则抛出异常 + """ + + # 如果是嵌套的 dataclass,递归调用 from_dict 方法 + if isinstance(field_type, type) and issubclass(field_type, ConfigBase): + if not isinstance(value, dict): + raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") + return field_type.from_dict(value) + + # 处理泛型集合类型(list, set, tuple) + field_origin_type = get_origin(field_type) + field_type_args = get_args(field_type) + + if field_origin_type in {list, set, tuple}: + # 检查提供的value是否为list + if not isinstance(value, list): + raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}") + + if field_origin_type is list: + return [cls._convert_field(item, field_type_args[0]) for item in value] + elif field_origin_type is set: + return {cls._convert_field(item, field_type_args[0]) for item in value} + elif field_origin_type is tuple: + # 检查提供的value长度是否与类型参数一致 + if len(value) != len(field_type_args): + raise TypeError( + f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}" + ) + return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args)) + + if field_origin_type is dict: + # 检查提供的value是否为dict + if not isinstance(value, dict): + raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}") + + # 检查字典的键值类型 + if len(field_type_args) != 2: + raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}") + key_type, value_type = field_type_args + + return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()} + + # 处理基础类型,例如 int, str 等 + if field_type is Any or isinstance(value, field_type): + return value + + # 其他类型,尝试直接转换 + try: + return field_type(value) + except (ValueError, TypeError) as e: + raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e + + def __str__(self): + """返回配置类的字符串表示""" + return f"{self.__class__.__name__}({', '.join(f'{f.name}={getattr(self, f.name)}' for f in fields(self))})" diff --git a/src/config/official_configs.py b/src/config/official_configs.py new file mode 100644 index 000000000..d92d925d6 --- /dev/null +++ b/src/config/official_configs.py @@ -0,0 +1,399 @@ +from dataclasses import dataclass, field +from typing import Any + +from src.config.config_base import ConfigBase + +""" +须知: +1. 本文件中记录了所有的配置项 +2. 所有新增的class都需要继承自ConfigBase +3. 所有新增的class都应在config.py中的Config类中添加字段 +4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default +""" + + +@dataclass +class BotConfig(ConfigBase): + """QQ机器人配置类""" + + qq_account: str + """QQ账号""" + + nickname: str + """昵称""" + + alias_names: list[str] = field(default_factory=lambda: []) + """别名列表""" + + +@dataclass +class ChatTargetConfig(ConfigBase): + """ + 聊天目标配置类 + 此类中有聊天的群组和用户配置 + """ + + talk_allowed_groups: set[str] = field(default_factory=lambda: set()) + """允许聊天的群组列表""" + + talk_frequency_down_groups: set[str] = field(default_factory=lambda: set()) + """降低聊天频率的群组列表""" + + ban_user_id: set[str] = field(default_factory=lambda: set()) + """禁止聊天的用户列表""" + + +@dataclass +class PersonalityConfig(ConfigBase): + """人格配置类""" + + personality_core: str + """核心人格""" + + expression_style: str + """表达风格""" + + personality_sides: list[str] = field(default_factory=lambda: []) + """人格侧写""" + + +@dataclass +class IdentityConfig(ConfigBase): + """个体特征配置类""" + + height: int = 170 + """身高(单位:厘米)""" + + weight: float = 50 + """体重(单位:千克)""" + + age: int = 18 + """年龄(单位:岁)""" + + gender: str = "女" + """性别(男/女)""" + + appearance: str = "可爱" + """外貌描述""" + + identity_detail: list[str] = field(default_factory=lambda: []) + """身份特征""" + + +@dataclass +class PlatformsConfig(ConfigBase): + """平台配置类""" + + qq: str + """QQ适配器连接URL配置""" + + +@dataclass +class ChatConfig(ConfigBase): + """聊天配置类""" + + allow_focus_mode: bool = True + """是否允许专注聊天状态""" + + base_normal_chat_num: int = 3 + """最多允许多少个群进行普通聊天""" + + base_focused_chat_num: int = 2 + """最多允许多少个群进行专注聊天""" + + observation_context_size: int = 12 + """可观察到的最长上下文大小,超过这个值的上下文会被压缩""" + + message_buffer: bool = True + """消息缓冲器""" + + ban_words: set[str] = field(default_factory=lambda: set()) + """过滤词列表""" + + ban_msgs_regex: set[str] = field(default_factory=lambda: set()) + """过滤正则表达式列表""" + + +@dataclass +class NormalChatConfig(ConfigBase): + """普通聊天配置类""" + + reasoning_model_probability: float = 0.3 + """ + 发言时选择推理模型的概率(0-1之间) + 选择普通模型的概率为 1 - reasoning_normal_model_probability + """ + + emoji_chance: float = 0.2 + """发送表情包的基础概率""" + + thinking_timeout: int = 120 + """最长思考时间""" + + willing_mode: str = "classical" + """意愿模式""" + + response_willing_amplifier: float = 1.0 + """回复意愿放大系数""" + + response_interested_rate_amplifier: float = 1.0 + """回复兴趣度放大系数""" + + down_frequency_rate: float = 3.0 + """降低回复频率的群组回复意愿降低系数""" + + emoji_response_penalty: float = 0.0 + """表情包回复惩罚系数""" + + mentioned_bot_inevitable_reply: bool = False + """提及 bot 必然回复""" + + at_bot_inevitable_reply: bool = False + """@bot 必然回复""" + + +@dataclass +class FocusChatConfig(ConfigBase): + """专注聊天配置类""" + + reply_trigger_threshold: float = 3.0 + """心流聊天触发阈值,越低越容易触发""" + + default_decay_rate_per_second: float = 0.98 + """默认衰减率,越大衰减越快""" + + consecutive_no_reply_threshold: int = 3 + """连续不回复的次数阈值""" + + compressed_length: int = 5 + """心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5""" + + compress_length_limit: int = 5 + """最多压缩份数,超过该数值的压缩上下文会被删除""" + + +@dataclass +class EmojiConfig(ConfigBase): + """表情包配置类""" + + max_reg_num: int = 200 + """表情包最大注册数量""" + + do_replace: bool = True + """达到最大注册数量时替换旧表情包""" + + check_interval: int = 120 + """表情包检查间隔(分钟)""" + + save_pic: bool = False + """是否保存图片""" + + cache_emoji: bool = True + """是否缓存表情包""" + + steal_emoji: bool = True + """是否偷取表情包,让麦麦可以发送她保存的这些表情包""" + + content_filtration: bool = False + """是否开启表情包过滤""" + + filtration_prompt: str = "符合公序良俗" + """表情包过滤要求""" + + +@dataclass +class MemoryConfig(ConfigBase): + """记忆配置类""" + + memory_build_interval: int = 600 + """记忆构建间隔(秒)""" + + memory_build_distribution: tuple[ + float, + float, + float, + float, + float, + float, + ] = field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4)) + """记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重""" + + memory_build_sample_num: int = 8 + """记忆构建采样数量""" + + memory_build_sample_length: int = 40 + """记忆构建采样长度""" + + memory_compress_rate: float = 0.1 + """记忆压缩率""" + + forget_memory_interval: int = 1000 + """记忆遗忘间隔(秒)""" + + memory_forget_time: int = 24 + """记忆遗忘时间(小时)""" + + memory_forget_percentage: float = 0.01 + """记忆遗忘比例""" + + consolidate_memory_interval: int = 1000 + """记忆整合间隔(秒)""" + + consolidation_similarity_threshold: float = 0.7 + """整合相似度阈值""" + + consolidate_memory_percentage: float = 0.01 + """整合检查节点比例""" + + memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]) + """不允许记忆的词列表""" + + +@dataclass +class MoodConfig(ConfigBase): + """情绪配置类""" + + mood_update_interval: int = 1 + """情绪更新间隔(秒)""" + + mood_decay_rate: float = 0.95 + """情绪衰减率""" + + mood_intensity_factor: float = 0.7 + """情绪强度因子""" + + +@dataclass +class KeywordRuleConfig(ConfigBase): + """关键词规则配置类""" + + enable: bool = True + """是否启用关键词规则""" + + keywords: list[str] = field(default_factory=lambda: []) + """关键词列表""" + + regex: list[str] = field(default_factory=lambda: []) + """正则表达式列表""" + + reaction: str = "" + """关键词触发的反应""" + + +@dataclass +class KeywordReactionConfig(ConfigBase): + """关键词配置类""" + + enable: bool = True + """是否启用关键词反应""" + + rules: list[KeywordRuleConfig] = field(default_factory=lambda: []) + """关键词反应规则列表""" + + +@dataclass +class ChineseTypoConfig(ConfigBase): + """中文错别字配置类""" + + enable: bool = True + """是否启用中文错别字生成器""" + + error_rate: float = 0.01 + """单字替换概率""" + + min_freq: int = 9 + """最小字频阈值""" + + tone_error_rate: float = 0.1 + """声调错误概率""" + + word_replace_rate: float = 0.006 + """整词替换概率""" + + +@dataclass +class ResponseSplitterConfig(ConfigBase): + """回复分割器配置类""" + + enable: bool = True + """是否启用回复分割器""" + + max_length: int = 256 + """回复允许的最大长度""" + + max_sentence_num: int = 3 + """回复允许的最大句子数""" + + enable_kaomoji_protection: bool = False + """是否启用颜文字保护""" + + +@dataclass +class TelemetryConfig(ConfigBase): + """遥测配置类""" + + enable: bool = True + """是否启用遥测""" + + +@dataclass +class ExperimentalConfig(ConfigBase): + """实验功能配置类""" + + enable_friend_chat: bool = False + """是否启用好友聊天""" + + talk_allowed_private: set[str] = field(default_factory=lambda: set()) + """允许聊天的私聊列表""" + + pfc_chatting: bool = False + """是否启用PFC""" + + +@dataclass +class ModelConfig(ConfigBase): + """模型配置类""" + + model_max_output_length: int = 800 # 最大回复长度 + + reasoning: dict[str, Any] = field(default_factory=lambda: {}) + """推理模型配置""" + + normal: dict[str, Any] = field(default_factory=lambda: {}) + """普通模型配置""" + + topic_judge: dict[str, Any] = field(default_factory=lambda: {}) + """主题判断模型配置""" + + summary: dict[str, Any] = field(default_factory=lambda: {}) + """摘要模型配置""" + + vlm: dict[str, Any] = field(default_factory=lambda: {}) + """视觉语言模型配置""" + + heartflow: dict[str, Any] = field(default_factory=lambda: {}) + """心流模型配置""" + + observation: dict[str, Any] = field(default_factory=lambda: {}) + """观察模型配置""" + + sub_heartflow: dict[str, Any] = field(default_factory=lambda: {}) + """子心流模型配置""" + + plan: dict[str, Any] = field(default_factory=lambda: {}) + """计划模型配置""" + + embedding: dict[str, Any] = field(default_factory=lambda: {}) + """嵌入模型配置""" + + pfc_action_planner: dict[str, Any] = field(default_factory=lambda: {}) + """PFC动作规划模型配置""" + + pfc_chat: dict[str, Any] = field(default_factory=lambda: {}) + """PFC聊天模型配置""" + + pfc_reply_checker: dict[str, Any] = field(default_factory=lambda: {}) + """PFC回复检查模型配置""" + + tool_use: dict[str, Any] = field(default_factory=lambda: {}) + """工具使用模型配置""" diff --git a/src/experimental/PFC/action_planner.py b/src/experimental/PFC/action_planner.py index b4182c9aa..c0bff5887 100644 --- a/src/experimental/PFC/action_planner.py +++ b/src/experimental/PFC/action_planner.py @@ -114,7 +114,7 @@ class ActionPlanner: request_type="action_planning", ) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) - self.name = global_config.BOT_NICKNAME + self.name = global_config.bot.nickname self.private_name = private_name self.chat_observer = ChatObserver.get_instance(stream_id, private_name) # self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量 @@ -140,7 +140,7 @@ class ActionPlanner: # (这部分逻辑不变) time_since_last_bot_message_info = "" try: - bot_id = str(global_config.BOT_QQ) + bot_id = str(global_config.bot.qq_account) if hasattr(observation_info, "chat_history") and observation_info.chat_history: for i in range(len(observation_info.chat_history) - 1, -1, -1): msg = observation_info.chat_history[i] diff --git a/src/experimental/PFC/chat_observer.py b/src/experimental/PFC/chat_observer.py index 704eeb330..6135bd0f7 100644 --- a/src/experimental/PFC/chat_observer.py +++ b/src/experimental/PFC/chat_observer.py @@ -323,7 +323,7 @@ class ChatObserver: for msg in messages: try: user_info = UserInfo.from_dict(msg.get("user_info", {})) - if user_info.user_id == global_config.BOT_QQ: + if user_info.user_id == global_config.bot.qq_account: self.update_bot_speak_time(msg["time"]) else: self.update_user_speak_time(msg["time"]) diff --git a/src/experimental/PFC/message_sender.py b/src/experimental/PFC/message_sender.py index 181bf171b..4b193a41d 100644 --- a/src/experimental/PFC/message_sender.py +++ b/src/experimental/PFC/message_sender.py @@ -42,8 +42,8 @@ class DirectMessageSender: # 获取麦麦的信息 bot_user_info = UserInfo( - user_id=global_config.BOT_QQ, - user_nickname=global_config.BOT_NICKNAME, + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, platform=chat_stream.platform, ) diff --git a/src/experimental/PFC/pfc.py b/src/experimental/PFC/pfc.py index 84fb9f8dc..686d4af49 100644 --- a/src/experimental/PFC/pfc.py +++ b/src/experimental/PFC/pfc.py @@ -42,13 +42,14 @@ class GoalAnalyzer: """对话目标分析器""" def __init__(self, stream_id: str, private_name: str): + # TODO: API-Adapter修改标记 self.llm = LLMRequest( - model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal" + model=global_config.model.normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal" ) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) - self.name = global_config.BOT_NICKNAME - self.nick_name = global_config.BOT_ALIAS_NAMES + self.name = global_config.bot.nickname + self.nick_name = global_config.bot.alias_names self.private_name = private_name self.chat_observer = ChatObserver.get_instance(stream_id, private_name) diff --git a/src/experimental/PFC/pfc_KnowledgeFetcher.py b/src/experimental/PFC/pfc_KnowledgeFetcher.py index 8ebc307e2..4c1d8c759 100644 --- a/src/experimental/PFC/pfc_KnowledgeFetcher.py +++ b/src/experimental/PFC/pfc_KnowledgeFetcher.py @@ -14,9 +14,10 @@ class KnowledgeFetcher: """知识调取器""" def __init__(self, private_name: str): + # TODO: API-Adapter修改标记 self.llm = LLMRequest( - model=global_config.llm_normal, - temperature=global_config.llm_normal["temp"], + model=global_config.model.normal, + temperature=global_config.model.normal["temp"], max_tokens=1000, request_type="knowledge_fetch", ) diff --git a/src/experimental/PFC/reply_checker.py b/src/experimental/PFC/reply_checker.py index a76e8a0da..5bca9d601 100644 --- a/src/experimental/PFC/reply_checker.py +++ b/src/experimental/PFC/reply_checker.py @@ -16,7 +16,7 @@ class ReplyChecker: self.llm = LLMRequest( model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check" ) - self.name = global_config.BOT_NICKNAME + self.name = global_config.bot.nickname self.private_name = private_name self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.max_retries = 3 # 最大重试次数 @@ -43,7 +43,7 @@ class ReplyChecker: bot_messages = [] for msg in reversed(chat_history): user_info = UserInfo.from_dict(msg.get("user_info", {})) - if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串 + if str(user_info.user_id) == str(global_config.bot.qq_account): # 确保比较的是字符串 bot_messages.append(msg.get("processed_plain_text", "")) if len(bot_messages) >= 2: # 只和最近的两条比较 break diff --git a/src/experimental/PFC/reply_generator.py b/src/experimental/PFC/reply_generator.py index 6dcda69af..bac8a769f 100644 --- a/src/experimental/PFC/reply_generator.py +++ b/src/experimental/PFC/reply_generator.py @@ -93,7 +93,7 @@ class ReplyGenerator: request_type="reply_generation", ) self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3) - self.name = global_config.BOT_NICKNAME + self.name = global_config.bot.nickname self.private_name = private_name self.chat_observer = ChatObserver.get_instance(stream_id, private_name) self.reply_checker = ReplyChecker(stream_id, private_name) diff --git a/src/experimental/PFC/waiter.py b/src/experimental/PFC/waiter.py index af5cf7ad0..452446589 100644 --- a/src/experimental/PFC/waiter.py +++ b/src/experimental/PFC/waiter.py @@ -19,7 +19,7 @@ class Waiter: def __init__(self, stream_id: str, private_name: str): self.chat_observer = ChatObserver.get_instance(stream_id, private_name) - self.name = global_config.BOT_NICKNAME + self.name = global_config.bot.nickname self.private_name = private_name # self.wait_accumulated_time = 0 # 不再需要累加计时 diff --git a/src/experimental/only_message_process.py b/src/experimental/only_message_process.py index 3d1432703..62f73c700 100644 --- a/src/experimental/only_message_process.py +++ b/src/experimental/only_message_process.py @@ -16,7 +16,7 @@ class MessageProcessor: @staticmethod def _check_ban_words(text: str, chat, userinfo) -> bool: """检查消息中是否包含过滤词""" - for word in global_config.ban_words: + for word in global_config.chat.ban_words: if word in text: logger.info( f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" @@ -28,7 +28,7 @@ class MessageProcessor: @staticmethod def _check_ban_regex(text: str, chat, userinfo) -> bool: """检查消息是否匹配过滤正则表达式""" - for pattern in global_config.ban_msgs_regex: + for pattern in global_config.chat.ban_msgs_regex: if pattern.search(text): logger.info( f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}" diff --git a/src/main.py b/src/main.py index 34b7eda3d..4f8af28ef 100644 --- a/src/main.py +++ b/src/main.py @@ -40,7 +40,7 @@ class MainSystem: async def initialize(self): """初始化系统组件""" - logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......") + logger.debug(f"正在唤醒{global_config.bot.nickname}......") # 其他初始化任务 await asyncio.gather(self._init_components()) @@ -84,7 +84,7 @@ class MainSystem: asyncio.create_task(chat_manager._auto_save_task()) # 使用HippocampusManager初始化海马体 - self.hippocampus_manager.initialize(global_config=global_config) + self.hippocampus_manager.initialize() # await asyncio.sleep(0.5) #防止logger输出飞了 # 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中 @@ -92,15 +92,15 @@ class MainSystem: # 初始化个体特征 self.individuality.initialize( - bot_nickname=global_config.BOT_NICKNAME, - personality_core=global_config.personality_core, - personality_sides=global_config.personality_sides, - identity_detail=global_config.identity_detail, - height=global_config.height, - weight=global_config.weight, - age=global_config.age, - gender=global_config.gender, - appearance=global_config.appearance, + bot_nickname=global_config.bot.nickname, + personality_core=global_config.personality.personality_core, + personality_sides=global_config.personality.personality_sides, + identity_detail=global_config.identity.identity_detail, + height=global_config.identity.height, + weight=global_config.identity.weight, + age=global_config.identity.age, + gender=global_config.identity.gender, + appearance=global_config.identity.appearance, ) logger.success("个体特征初始化成功") @@ -141,7 +141,7 @@ class MainSystem: async def build_memory_task(): """记忆构建任务""" while True: - await asyncio.sleep(global_config.build_memory_interval) + await asyncio.sleep(global_config.memory.memory_build_interval) logger.info("正在进行记忆构建") await HippocampusManager.get_instance().build_memory() @@ -149,16 +149,18 @@ class MainSystem: async def forget_memory_task(): """记忆遗忘任务""" while True: - await asyncio.sleep(global_config.forget_memory_interval) + await asyncio.sleep(global_config.memory.forget_memory_interval) print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...") - await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage) + await HippocampusManager.get_instance().forget_memory( + percentage=global_config.memory.memory_forget_percentage + ) print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") @staticmethod async def consolidate_memory_task(): """记忆整合任务""" while True: - await asyncio.sleep(global_config.consolidate_memory_interval) + await asyncio.sleep(global_config.memory.consolidate_memory_interval) print("\033[1;32m[记忆整合]\033[0m 开始整合记忆...") await HippocampusManager.get_instance().consolidate_memory() print("\033[1;32m[记忆整合]\033[0m 记忆整合完成") diff --git a/src/manager/mood_manager.py b/src/manager/mood_manager.py index 42677d4e1..c83fbeb7c 100644 --- a/src/manager/mood_manager.py +++ b/src/manager/mood_manager.py @@ -34,14 +34,14 @@ class MoodUpdateTask(AsyncTask): def __init__(self): super().__init__( task_name="Mood Update Task", - wait_before_start=global_config.mood_update_interval, - run_interval=global_config.mood_update_interval, + wait_before_start=global_config.mood.mood_update_interval, + run_interval=global_config.mood.mood_update_interval, ) # 从配置文件获取衰减率 - self.decay_rate_valence: float = 1 - global_config.mood_decay_rate + self.decay_rate_valence: float = 1 - global_config.mood.mood_decay_rate """愉悦度衰减率""" - self.decay_rate_arousal: float = 1 - global_config.mood_decay_rate + self.decay_rate_arousal: float = 1 - global_config.mood.mood_decay_rate """唤醒度衰减率""" self.last_update = time.time() diff --git a/src/tools/not_used/change_mood.py b/src/tools/not_used/change_mood.py index c34bebb93..69fc3bb78 100644 --- a/src/tools/not_used/change_mood.py +++ b/src/tools/not_used/change_mood.py @@ -44,7 +44,7 @@ class ChangeMoodTool(BaseTool): _ori_response = ",".join(response_set) # _stance, emotion = await gpt._get_emotion_tags(ori_response, message_processed_plain_text) emotion = "平静" - mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor) + mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor) return {"name": "change_mood", "content": f"你的心情刚刚变化了,现在的心情是: {emotion}"} except Exception as e: logger.error(f"心情改变工具执行失败: {str(e)}") diff --git a/src/tools/tool_use.py b/src/tools/tool_use.py index c55170b88..ff36085d6 100644 --- a/src/tools/tool_use.py +++ b/src/tools/tool_use.py @@ -15,7 +15,7 @@ logger = get_logger("tool_use") class ToolUser: def __init__(self): self.llm_model_tool = LLMRequest( - model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use" + model=global_config.model.tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use" ) @staticmethod @@ -37,7 +37,7 @@ class ToolUser: # print(f"intol111111111111111111111111111111111222222222222mid_memory_info:{mid_memory_info}") # 这些信息应该从调用者传入,而不是从self获取 - bot_name = global_config.BOT_NICKNAME + bot_name = global_config.bot.nickname prompt = "" prompt += mid_memory_info prompt += "你正在思考如何回复群里的消息。\n" diff --git a/template/bot_config_meta.toml b/template/bot_config_meta.toml deleted file mode 100644 index c3541baad..000000000 --- a/template/bot_config_meta.toml +++ /dev/null @@ -1,104 +0,0 @@ -[inner.version] -describe = "版本号" -important = true -can_edit = false - -[bot.qq] -describe = "机器人的QQ号" -important = true -can_edit = true - -[bot.nickname] -describe = "机器人的昵称" -important = true -can_edit = true - -[bot.alias_names] -describe = "机器人的别名列表,该选项还在调试中,暂时未生效" -important = false -can_edit = true - -[groups.talk_allowed] -describe = "可以回复消息的群号码列表" -important = true -can_edit = true - -[groups.talk_frequency_down] -describe = "降低回复频率的群号码列表" -important = false -can_edit = true - -[groups.ban_user_id] -describe = "禁止回复和读取消息的QQ号列表" -important = false -can_edit = true - -[personality.personality_core] -describe = "用一句话或几句话描述人格的核心特点,建议20字以内" -important = true -can_edit = true - -[personality.personality_sides] -describe = "用一句话或几句话描述人格的一些细节,条数任意,不能为0,该选项还在调试中" -important = false -can_edit = true - -[identity.identity_detail] -describe = "身份特点列表,条数任意,不能为0,该选项还在调试中" -important = false -can_edit = true - -[identity.age] -describe = "年龄,单位岁" -important = false -can_edit = true - -[identity.gender] -describe = "性别" -important = false -can_edit = true - -[identity.appearance] -describe = "外貌特征描述,该选项还在调试中,暂时未生效" -important = false -can_edit = true - -[platforms.nonebot-qq] -describe = "nonebot-qq适配器提供的链接" -important = true -can_edit = true - -[chat.allow_focus_mode] -describe = "是否允许专注聊天状态" -important = false -can_edit = true - -[chat.base_normal_chat_num] -describe = "最多允许多少个群进行普通聊天" -important = false -can_edit = true - -[chat.base_focused_chat_num] -describe = "最多允许多少个群进行专注聊天" -important = false -can_edit = true - -[chat.observation_context_size] -describe = "观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖" -important = false -can_edit = true - -[chat.message_buffer] -describe = "启用消息缓冲器,启用此项以解决消息的拆分问题,但会使麦麦的回复延迟" -important = false -can_edit = true - -[chat.ban_words] -describe = "需要过滤的消息列表" -important = false -can_edit = true - -[chat.ban_msgs_regex] -describe = "需要过滤的消息(原始消息)匹配的正则表达式,匹配到的消息将被过滤(支持CQ码)" -important = false -can_edit = true \ No newline at end of file diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 931afe2ed..64e51da77 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,18 +1,10 @@ [inner] -version = "1.7.0" +version = "2.0.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 -#如果新增项目,请在BotConfig类下新增相应的变量 -#1.如果你修改的是[]层级项目,例如你新增了 [memory],那么请在config.py的 load_config函数中的include_configs字典中新增"内容":{ -#"func":memory, -#"support":">=0.0.0", #新的版本号 -#"necessary":False #是否必须 -#} -#2.如果你修改的是[]下的项目,例如你新增了[memory]下的 memory_ban_words ,那么请在config.py的 load_config函数中的 memory函数下新增版本判断: - # if config.INNER_VERSION in SpecifierSet(">=0.0.2"): - # config.memory_ban_words = set(memory_config.get("memory_ban_words", [])) - +#如果新增项目,请阅读src/config/official_configs.py中的说明 +# # 版本格式:主版本号.次版本号.修订号,版本号递增规则如下: # 主版本号:当你做了不兼容的 API 修改, # 次版本号:当你做了向下兼容的功能性新增, @@ -21,11 +13,11 @@ version = "1.7.0" #----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- [bot] -qq = 1145141919810 +qq_account = 1145141919810 nickname = "麦麦" alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效 -[groups] +[chat_target] talk_allowed = [ 123, 123, @@ -53,10 +45,13 @@ identity_detail = [ "身份特点", "身份特点", ]# 条数任意,不能为0, 该选项还在调试中 + #外貌特征 -age = 20 # 年龄 单位岁 -gender = "男" # 性别 -appearance = "用几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效 +age = 18 # 年龄 单位岁 +gender = "女" # 性别 +height = "170" # 身高(单位cm) +weight = "50" # 体重(单位kg) +appearance = "用一句或几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效 [platforms] # 必填项目,填写每个平台适配器提供的链接 qq="http://127.0.0.1:18002/api/message" @@ -85,11 +80,10 @@ ban_msgs_regex = [ [normal_chat] #普通聊天 #一般回复参数 -model_reasoning_probability = 0.7 # 麦麦回答时选择推理模型 模型的概率 -model_normal_probability = 0.3 # 麦麦回答时选择一般模型 模型的概率 +reasoning_model_probability = 0.3 # 麦麦回答时选择推理模型的概率(与之相对的,普通模型的概率为1 - reasoning_model_probability) emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率,设置为1让麦麦自己决定发不发 -thinking_timeout = 100 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢) +thinking_timeout = 120 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢) willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,mxp模式:mxp,自定义模式:custom(需要你自己实现) response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1 @@ -100,8 +94,8 @@ mentioned_bot_inevitable_reply = false # 提及 bot 必然回复 at_bot_inevitable_reply = false # @bot 必然回复 [focus_chat] #专注聊天 -reply_trigger_threshold = 3.6 # 专注聊天触发阈值,越低越容易进入专注聊天 -default_decay_rate_per_second = 0.95 # 默认衰减率,越大衰减越快,越高越难进入专注聊天 +reply_trigger_threshold = 3.0 # 专注聊天触发阈值,越低越容易进入专注聊天 +default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快,越高越难进入专注聊天 consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天 # 以下选项暂时无效 @@ -110,20 +104,20 @@ compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下 [emoji] -max_emoji_num = 40 # 表情包最大数量 -max_reach_deletion = true # 开启则在达到最大数量时删除表情包,关闭则达到最大数量时不删除,只是不会继续收集表情包 -check_interval = 10 # 检查表情包(注册,破损,删除)的时间间隔(分钟) +max_reg_num = 40 # 表情包最大注册数量 +do_replace = true # 开启则在达到最大数量时删除(替换)表情包,关闭则达到最大数量时不会继续收集表情包 +check_interval = 120 # 检查表情包(注册,破损,删除)的时间间隔(分钟) save_pic = false # 是否保存图片 -save_emoji = false # 是否保存表情包 +cache_emoji = true # 是否缓存表情包 steal_emoji = true # 是否偷取表情包,让麦麦可以发送她保存的这些表情包 -enable_check = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存 -check_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存 +content_filtration = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存 +filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存 [memory] -build_memory_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多 -build_memory_distribution = [6.0,3.0,0.6,32.0,12.0,0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重 -build_memory_sample_num = 8 # 采样数量,数值越高记忆采样次数越多 -build_memory_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富 +memory_build_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多 +memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重 +memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多 +memory_build_sample_length = 40 # 采样长度,数值越高一段记忆内容越丰富 memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多 forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习 @@ -135,49 +129,45 @@ consolidation_similarity_threshold = 0.7 # 相似度阈值 consolidation_check_percentage = 0.01 # 检查节点比例 #不希望记忆的词,已经记忆的不会受到影响 -memory_ban_words = [ - # "403","张三" -] +memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ] [mood] mood_update_interval = 1.0 # 情绪更新间隔 单位秒 mood_decay_rate = 0.95 # 情绪衰减率 mood_intensity_factor = 1.0 # 情绪强度因子 -[keywords_reaction] # 针对某个关键词作出反应 +[keyword_reaction] # 针对某个关键词作出反应 enable = true # 关键词反应功能的总开关 -[[keywords_reaction.rules]] # 如果想要新增多个关键词,直接复制本条,修改keywords和reaction即可 +[[keyword_reaction.rules]] # 如果想要新增多个关键词,直接复制本条,修改keywords和reaction即可 enable = true # 是否启用此条(为了人类在未来AI战争能更好地识别AI(bushi),默认开启) keywords = ["人机", "bot", "机器", "入机", "robot", "机器人","ai","AI"] # 会触发反应的关键词 reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" # 触发之后添加的提示词 -[[keywords_reaction.rules]] # 就像这样复制 +[[keyword_reaction.rules]] # 就像这样复制 enable = false # 仅作示例,不会触发 keywords = ["测试关键词回复","test",""] reaction = "回答“测试成功”" # 修复错误的引号 -[[keywords_reaction.rules]] # 使用正则表达式匹配句式 +[[keyword_reaction.rules]] # 使用正则表达式匹配句式 enable = false # 仅作示例,不会触发 regex = ["^(?P\\S{1,20})是这样的$"] # 将匹配到的词汇命名为n,反应中对应的[n]会被替换为匹配到的内容,若不了解正则表达式请勿编写 reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)" [chinese_typo] enable = true # 是否启用中文错别字生成器 -error_rate=0.001 # 单字替换概率 +error_rate=0.01 # 单字替换概率 min_freq=9 # 最小字频阈值 tone_error_rate=0.1 # 声调错误概率 word_replace_rate=0.006 # 整词替换概率 [response_splitter] -enable_response_splitter = true # 是否启用回复分割器 -response_max_length = 256 # 回复允许的最大长度 -response_max_sentence_num = 4 # 回复允许的最大句子数 +enable = true # 是否启用回复分割器 +max_length = 256 # 回复允许的最大长度 +max_sentence_num = 4 # 回复允许的最大句子数 enable_kaomoji_protection = false # 是否启用颜文字保护 -model_max_output_length = 256 # 模型单次返回的最大token数 - -[remote] #发送统计信息,主要是看全球有多少只麦麦 +[telemetry] #发送统计信息,主要是看全球有多少只麦麦 enable = true [experimental] #实验性功能 @@ -194,14 +184,17 @@ pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与 # stream = : 用于指定模型是否是使用流式输出 # 如果不指定,则该项是 False +[model] +model_max_output_length = 800 # 模型单次返回的最大token数 + #这个模型必须是推理模型 -[model.llm_reasoning] # 一般聊天模式的推理回复模型 +[model.reasoning] # 一般聊天模式的推理回复模型 name = "Pro/deepseek-ai/DeepSeek-R1" provider = "SILICONFLOW" pri_in = 1.0 #模型的输入价格(非必填,可以记录消耗) pri_out = 4.0 #模型的输出价格(非必填,可以记录消耗) -[model.llm_normal] #V3 回复模型 专注和一般聊天模式共用的回复模型 +[model.normal] #V3 回复模型 专注和一般聊天模式共用的回复模型 name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 #模型的输入价格(非必填,可以记录消耗) @@ -209,13 +202,13 @@ pri_out = 8 #模型的输出价格(非必填,可以记录消耗) #默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数 temp = 0.2 #模型的温度,新V3建议0.1-0.3 -[model.llm_topic_judge] #主题判断模型:建议使用qwen2.5 7b +[model.topic_judge] #主题判断模型:建议使用qwen2.5 7b name = "Pro/Qwen/Qwen2.5-7B-Instruct" provider = "SILICONFLOW" pri_in = 0.35 pri_out = 0.35 -[model.llm_summary] #概括模型,建议使用qwen2.5 32b 及以上 +[model.summary] #概括模型,建议使用qwen2.5 32b 及以上 name = "Qwen/Qwen2.5-32B-Instruct" provider = "SILICONFLOW" pri_in = 1.26 @@ -227,27 +220,27 @@ provider = "SILICONFLOW" pri_in = 0.35 pri_out = 0.35 -[model.llm_heartflow] # 用于控制麦麦是否参与聊天的模型 +[model.heartflow] # 用于控制麦麦是否参与聊天的模型 name = "Qwen/Qwen2.5-32B-Instruct" provider = "SILICONFLOW" pri_in = 1.26 pri_out = 1.26 -[model.llm_observation] #观察模型,压缩聊天内容,建议用免费的 +[model.observation] #观察模型,压缩聊天内容,建议用免费的 # name = "Pro/Qwen/Qwen2.5-7B-Instruct" name = "Qwen/Qwen2.5-7B-Instruct" provider = "SILICONFLOW" pri_in = 0 pri_out = 0 -[model.llm_sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型 +[model.sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型 name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 pri_out = 8 temp = 0.3 #模型的温度,新V3建议0.1-0.3 -[model.llm_plan] #决策:认真水群时,负责决定麦麦该做什么 +[model.plan] #决策:认真水群时,负责决定麦麦该做什么 name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 @@ -265,7 +258,7 @@ pri_out = 0 #私聊PFC:需要开启PFC功能,默认三个模型均为硅基流动v3,如果需要支持多人同时私聊或频繁调用,建议把其中的一个或两个换成官方v3或其它模型,以免撞到429 #PFC决策模型 -[model.llm_PFC_action_planner] +[model.pfc_action_planner] name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" temp = 0.3 @@ -273,7 +266,7 @@ pri_in = 2 pri_out = 8 #PFC聊天模型 -[model.llm_PFC_chat] +[model.pfc_chat] name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" temp = 0.3 @@ -281,7 +274,7 @@ pri_in = 2 pri_out = 8 #PFC检查模型 -[model.llm_PFC_reply_checker] +[model.pfc_reply_checker] name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 @@ -294,7 +287,7 @@ pri_out = 8 #以下模型暂时没有使用!! #以下模型暂时没有使用!! -[model.llm_tool_use] #工具调用模型,需要使用支持工具调用的模型,建议使用qwen2.5 32b +[model.tool_use] #工具调用模型,需要使用支持工具调用的模型,建议使用qwen2.5 32b name = "Qwen/Qwen2.5-32B-Instruct" provider = "SILICONFLOW" pri_in = 1.26 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..1a1239601 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,7 @@ +from src.config.config import global_config + + +class TestConfig: + def test_load(self): + config = global_config + print(config) From 97134648e3aa92662f40a90378c6055545314750 Mon Sep 17 00:00:00 2001 From: Oct-autumn Date: Fri, 16 May 2025 17:00:12 +0800 Subject: [PATCH 25/57] fix: ruff format & check --- src/chat/focus_chat/heartFC_chat.py | 3 - .../focus_chat/planners/action_factory.py | 85 +++++++++---------- .../planners/actions/base_action.py | 33 ++++--- .../planners/actions/no_reply_action.py | 4 +- .../planners/actions/reply_action.py | 19 ++--- src/chat/focus_chat/planners/planner.py | 32 +++---- 6 files changed, 79 insertions(+), 97 deletions(-) diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 4a28652d1..ff4f7fdb0 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -91,7 +91,6 @@ class HeartFChatting: self.action_manager = ActionManager() self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager) - # --- 处理器列表 --- self.processors: List[BaseProcessor] = [] self._register_default_processors() @@ -526,5 +525,3 @@ class HeartFChatting: if last_n is not None: history = history[-last_n:] return [cycle.to_dict() for cycle in history] - - diff --git a/src/chat/focus_chat/planners/action_factory.py b/src/chat/focus_chat/planners/action_factory.py index 257156a25..bca49c496 100644 --- a/src/chat/focus_chat/planners/action_factory.py +++ b/src/chat/focus_chat/planners/action_factory.py @@ -1,7 +1,5 @@ -from typing import Dict, List, Optional, Callable, Coroutine, Type, Any, Union -import os -import importlib -from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY, _DEFAULT_ACTIONS +from typing import Dict, List, Optional, Callable, Coroutine, Type, Any +from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY from src.chat.heart_flow.observation.observation import Observation from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream @@ -9,8 +7,6 @@ from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.common.logger_manager import get_logger # 导入动作类,确保装饰器被执行 -from src.chat.focus_chat.planners.actions.reply_action import ReplyAction -from src.chat.focus_chat.planners.actions.no_reply_action import NoReplyAction logger = get_logger("action_factory") @@ -31,20 +27,19 @@ class ActionManager: self._using_actions: Dict[str, ActionInfo] = {} # 临时备份原始使用中的动作 self._original_actions_backup: Optional[Dict[str, ActionInfo]] = None - + # 默认动作集,仅作为快照,用于恢复默认 self._default_actions: Dict[str, ActionInfo] = {} - + # 加载所有已注册动作 self._load_registered_actions() - + # 初始化时将默认动作加载到使用中的动作 self._using_actions = self._default_actions.copy() - + # logger.info(f"当前可用动作: {list(self._using_actions.keys())}") # for action_name, action_info in self._using_actions.items(): - # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") - + # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") def _load_registered_actions(self) -> None: """ @@ -54,35 +49,35 @@ class ActionManager: # 从_ACTION_REGISTRY获取所有已注册动作 for action_name, action_class in _ACTION_REGISTRY.items(): # 获取动作相关信息 - action_description:str = getattr(action_class, "action_description", "") - action_parameters:dict[str:str] = getattr(action_class, "action_parameters", {}) - action_require:list[str] = getattr(action_class, "action_require", []) - is_default:bool = getattr(action_class, "default", False) - + action_description: str = getattr(action_class, "action_description", "") + action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {}) + action_require: list[str] = getattr(action_class, "action_require", []) + is_default: bool = getattr(action_class, "default", False) + if action_name and action_description: # 创建动作信息字典 action_info = { "description": action_description, "parameters": action_parameters, - "require": action_require + "require": action_require, } - + # 注册2 print("注册2") print(action_info) - + # 添加到所有已注册的动作 self._registered_actions[action_name] = action_info - + # 添加到默认动作(如果是默认动作) if is_default: self._default_actions[action_name] = action_info - + logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") logger.info(f"默认动作: {list(self._default_actions.keys())}") # for action_name, action_info in self._default_actions.items(): - # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") - + # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") + except Exception as e: logger.error(f"加载已注册动作失败: {e}") @@ -129,7 +124,7 @@ class ActionManager: if action_name not in self._using_actions: logger.warning(f"当前不可用的动作类型: {action_name}") return None - + handler_class = _ACTION_REGISTRY.get(action_name) if not handler_class: logger.warning(f"未注册的动作类型: {action_name}") @@ -153,7 +148,7 @@ class ActionManager: expressor=expressor, chat_stream=chat_stream, ) - + return instance except Exception as e: @@ -167,7 +162,7 @@ class ActionManager: def get_default_actions(self) -> Dict[str, ActionInfo]: """获取默认动作集""" return self._default_actions.copy() - + def get_using_actions(self) -> Dict[str, ActionInfo]: """获取当前正在使用的动作集""" return self._using_actions.copy() @@ -175,21 +170,21 @@ class ActionManager: def add_action_to_using(self, action_name: str) -> bool: """ 添加已注册的动作到当前使用的动作集 - + Args: action_name: 动作名称 - + Returns: bool: 添加是否成功 """ if action_name not in self._registered_actions: logger.warning(f"添加失败: 动作 {action_name} 未注册") return False - + if action_name in self._using_actions: logger.info(f"动作 {action_name} 已经在使用中") return True - + self._using_actions[action_name] = self._registered_actions[action_name] logger.info(f"添加动作 {action_name} 到使用集") return True @@ -197,17 +192,17 @@ class ActionManager: def remove_action_from_using(self, action_name: str) -> bool: """ 从当前使用的动作集中移除指定动作 - + Args: action_name: 动作名称 - + Returns: bool: 移除是否成功 """ if action_name not in self._using_actions: logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中") return False - + del self._using_actions[action_name] logger.info(f"已从使用集中移除动作 {action_name}") return True @@ -215,30 +210,26 @@ class ActionManager: def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool: """ 添加新的动作到注册集 - + Args: action_name: 动作名称 description: 动作描述 parameters: 动作参数定义,默认为空字典 require: 动作依赖项,默认为空列表 - + Returns: bool: 添加是否成功 """ if action_name in self._registered_actions: return False - + if parameters is None: parameters = {} if require is None: require = [] - - action_info = { - "description": description, - "parameters": parameters, - "require": require - } - + + action_info = {"description": description, "parameters": parameters, "require": require} + self._registered_actions[action_name] = action_info return True @@ -264,7 +255,7 @@ class ActionManager: if self._original_actions_backup is not None: self._using_actions = self._original_actions_backup.copy() self._original_actions_backup = None - + def restore_default_actions(self) -> None: """恢复默认动作集到使用集""" self._using_actions = self._default_actions.copy() @@ -273,10 +264,10 @@ class ActionManager: def get_action(self, action_name: str) -> Optional[Type[BaseAction]]: """ 获取指定动作的处理器类 - + Args: action_name: 动作名称 - + Returns: Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None """ diff --git a/src/chat/focus_chat/planners/actions/base_action.py b/src/chat/focus_chat/planners/actions/base_action.py index 7c77c300c..82d259677 100644 --- a/src/chat/focus_chat/planners/actions/base_action.py +++ b/src/chat/focus_chat/planners/actions/base_action.py @@ -12,7 +12,7 @@ _DEFAULT_ACTIONS: Dict[str, str] = {} def register_action(cls): """ 动作注册装饰器 - + 用法: @register_action class MyAction(BaseAction): @@ -24,22 +24,22 @@ def register_action(cls): if not hasattr(cls, "action_name") or not hasattr(cls, "action_description"): logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description") return cls - - action_name = getattr(cls, "action_name") - action_description = getattr(cls, "action_description") + + action_name = cls.action_name + action_description = cls.action_description is_default = getattr(cls, "default", False) - + if not action_name or not action_description: logger.error(f"动作类 {cls.__name__} 的 action_name 或 action_description 为空") return cls - + # 将动作类注册到全局注册表 _ACTION_REGISTRY[action_name] = cls - + # 如果是默认动作,添加到默认动作集 if is_default: _DEFAULT_ACTIONS[action_name] = action_description - + logger.info(f"已注册动作: {action_name} -> {cls.__name__},默认: {is_default}") return cls @@ -60,15 +60,14 @@ class BaseAction(ABC): cycle_timers: 计时器字典 thinking_id: 思考ID """ - #每个动作必须实现 - self.action_name:str = "base_action" - self.action_description:str = "基础动作" - self.action_parameters:dict = {} - self.action_require:list[str] = [] - - self.default:bool = False - - + # 每个动作必须实现 + self.action_name: str = "base_action" + self.action_description: str = "基础动作" + self.action_parameters: dict = {} + self.action_require: list[str] = [] + + self.default: bool = False + self.action_data = action_data self.reasoning = reasoning self.cycle_timers = cycle_timers diff --git a/src/chat/focus_chat/planners/actions/no_reply_action.py b/src/chat/focus_chat/planners/actions/no_reply_action.py index a29812c7a..71f1cb3f3 100644 --- a/src/chat/focus_chat/planners/actions/no_reply_action.py +++ b/src/chat/focus_chat/planners/actions/no_reply_action.py @@ -29,7 +29,7 @@ class NoReplyAction(BaseAction): action_require = [ "话题无关/无聊/不感兴趣/不懂", "最后一条消息是你自己发的且无人回应你", - "你发送了太多消息,且无人回复" + "你发送了太多消息,且无人回复", ] default = True @@ -46,7 +46,7 @@ class NoReplyAction(BaseAction): total_no_reply_count: int = 0, total_waiting_time: float = 0.0, shutting_down: bool = False, - **kwargs + **kwargs, ): """初始化不回复动作处理器 diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 7b2e88fa0..6e4f41d4d 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -2,9 +2,8 @@ # -*- coding: utf-8 -*- from src.common.logger_manager import get_logger -from src.chat.utils.timer_calculator import Timer from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action -from typing import Tuple, List, Optional +from typing import Tuple, List from src.chat.heart_flow.observation.observation import Observation from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream @@ -22,14 +21,14 @@ class ReplyAction(BaseAction): 处理构建和发送消息回复的动作。 """ - action_name:str = "reply" - action_description:str = "表达想法,可以只包含文本、表情或两者都有" - action_parameters:dict[str:str] = { + action_name: str = "reply" + action_description: str = "表达想法,可以只包含文本、表情或两者都有" + action_parameters: dict[str:str] = { "text": "你想要表达的内容(可选)", "emojis": "描述当前使用表情包的场景(可选)", "target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)", } - action_require:list[str] = [ + action_require: list[str] = [ "有实质性内容需要表达", "有人提到你,但你还没有回应他", "在合适的时候添加表情(不要总是添加)", @@ -38,7 +37,7 @@ class ReplyAction(BaseAction): "一次只回复一个人,一次只回复一个话题,突出重点", "如果是自己发的消息想继续,需自然衔接", "避免重复或评价自己的发言,不要和自己聊天", - "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。" + "注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。", ] default = True @@ -54,7 +53,7 @@ class ReplyAction(BaseAction): chat_stream: ChatStream, current_cycle: CycleDetail, log_prefix: str, - **kwargs + **kwargs, ): """初始化回复动作处理器 @@ -89,9 +88,9 @@ class ReplyAction(BaseAction): reasoning=self.reasoning, reply_data=self.action_data, cycle_timers=self.cycle_timers, - thinking_id=self.thinking_id + thinking_id=self.thinking_id, ) - + async def _handle_reply( self, reasoning: str, reply_data: dict, cycle_timers: dict, thinking_id: str ) -> tuple[bool, str]: diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index bb87e1da7..83c8b6791 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -4,7 +4,6 @@ from typing import List, Dict, Any, Optional from rich.traceback import install from src.chat.models.utils_model import LLMRequest from src.config.config import global_config -from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.obs_info import ObsInfo from src.chat.focus_chat.info.cycle_info import CycleInfo @@ -15,10 +14,12 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.individuality.individuality import Individuality from src.chat.focus_chat.planners.action_factory import ActionManager from src.chat.focus_chat.planners.action_factory import ActionInfo + logger = get_logger("planner") install(extra_lines=3) + def init_prompt(): Prompt( """你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话: @@ -44,8 +45,9 @@ def init_prompt(): }} 请输出你的决策 JSON:""", -"planner_prompt",) - + "planner_prompt", + ) + Prompt( """ action_name: {action_name} @@ -57,7 +59,7 @@ action_name: {action_name} """, "action_prompt", ) - + class ActionPlanner: def __init__(self, log_prefix: str, action_manager: ActionManager): @@ -68,7 +70,7 @@ class ActionPlanner: max_tokens=1000, request_type="action_planning", # 用于动作规划 ) - + self.action_manager = action_manager async def plan(self, all_plan_info: List[InfoBase], cycle_timers: dict) -> Dict[str, Any]: @@ -103,10 +105,10 @@ class ActionPlanner: cycle_info = info.get_observe_info() elif isinstance(info, StructuredInfo): logger.debug(f"{self.log_prefix} 结构化信息: {info}") - structured_info = info.get_data() + _structured_info = info.get_data() current_available_actions = self.action_manager.get_using_actions() - + # --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- prompt = await self.build_planner_prompt( is_group_chat=is_group_chat, # <-- Pass HFC state @@ -197,7 +199,6 @@ class ActionPlanner: # 返回结果字典 return plan_result - async def build_planner_prompt( self, is_group_chat: bool, # Now passed as argument @@ -218,7 +219,6 @@ class ActionPlanner: ) chat_context_description = f"你正在和 {chat_target_name} 私聊" - chat_content_block = "" if observed_messages_str: chat_content_block = f"聊天记录:\n{observed_messages_str}" @@ -234,7 +234,6 @@ class ActionPlanner: individuality = Individuality.get_instance() personality_block = individuality.get_prompt(x_person=2, level=2) - action_options_block = "" for using_actions_name, using_actions_info in current_available_actions.items(): # print(using_actions_name) @@ -242,29 +241,26 @@ class ActionPlanner: # print(using_actions_info["parameters"]) # print(using_actions_info["require"]) # print(using_actions_info["description"]) - + using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt") - + param_text = "" for param_name, param_description in using_actions_info["parameters"].items(): param_text += f"{param_name}: {param_description}\n" - + require_text = "" for require_item in using_actions_info["require"]: require_text += f"- {require_item}\n" - + using_action_prompt = using_action_prompt.format( action_name=using_actions_name, action_description=using_actions_info["description"], action_parameters=param_text, action_require=require_text, ) - + action_options_block += using_action_prompt - - - planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") prompt = planner_prompt_template.format( bot_name=global_config.BOT_NICKNAME, From b698d17a76280f80abf3b1aa316cb35037645e25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 16 May 2025 17:08:30 +0800 Subject: [PATCH 26/57] =?UTF-8?q?=E9=87=8D=E5=91=BD=E5=90=8D=E8=A1=A8?= =?UTF-8?q?=E6=83=85=E5=8C=85=E5=92=8C=E5=9B=BE=E5=83=8F=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E5=93=88=E5=B8=8C=E5=AD=97=E6=AE=B5=EF=BC=8C?= =?UTF-8?q?=E7=BB=9F=E4=B8=80=E4=B8=BA=20emoji=5Fhash=20=E5=92=8C=20image?= =?UTF-8?q?=5Fdescription=5Fhash=EF=BC=8C=E4=BB=A5=E6=8F=90=E9=AB=98?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=80=E8=87=B4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 12 ++++++------ src/chat/utils/utils_image.py | 6 +++--- src/common/database/database_model.py | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 77835d1fb..7b5574691 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -197,7 +197,7 @@ class MaiEmoji: # 2. 删除数据库记录 try: - will_delete_emoji = Emoji.get(Emoji.hash == self.hash) + will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash) result = will_delete_emoji.delete_instance() # Returns the number of rows deleted. except Emoji.DoesNotExist: logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。") @@ -260,7 +260,7 @@ def _to_emoji_objects(data): try: emoji = MaiEmoji(full_path=full_path) - emoji.hash = emoji_data.hash + emoji.hash = emoji_data.emoji_hash if not emoji.hash: logger.warning(f"[加载错误] 数据库记录缺少 'hash' 字段: {full_path}") load_errors += 1 @@ -405,7 +405,7 @@ class EmojiManager: def record_usage(self, emoji_hash: str): """记录表情使用次数""" try: - emoji_update = Emoji.get(Emoji.hash == emoji_hash) + emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash) emoji_update.usage_count += 1 emoji_update.last_used_time = time.time() # Update last used time emoji_update.save() # Persist changes to DB @@ -475,7 +475,7 @@ class EmojiManager: selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~ # 更新使用次数 - self.record_usage(selected_emoji.hash) + self.record_usage(selected_emoji.emoji_hash) _time_end = time.time() @@ -671,7 +671,7 @@ class EmojiManager: self._ensure_db() if emoji_hash: - query = Emoji.select().where(Emoji.hash == emoji_hash) + query = Emoji.select().where(Emoji.emoji_hash == emoji_hash) else: logger.warning( "[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。" @@ -804,7 +804,7 @@ class EmojiManager: # 删除选定的表情包 logger.info(f"[决策] 删除表情包: {emoji_to_delete.description}") - delete_success = await self.delete_emoji(emoji_to_delete.hash) + delete_success = await self.delete_emoji(emoji_to_delete.emoji_hash) if delete_success: # 修复:等待异步注册完成 diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index ee5846031..14cb24922 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -61,7 +61,7 @@ class ImageManager: """ try: record = ImageDescriptions.get_or_none( - (ImageDescriptions.hash == image_hash) & (ImageDescriptions.type == description_type) + (ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type) ) return record.description if record else None except Exception as e: @@ -141,7 +141,7 @@ class ImageManager: # 保存到数据库 (Images表) try: - img_obj = Images.get((Images.hash == image_hash) & (Images.type == "emoji")) + img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji")) img_obj.path = file_path img_obj.description = description img_obj.timestamp = current_timestamp @@ -214,7 +214,7 @@ class ImageManager: # 保存到数据库 (Images表) try: - img_obj = Images.get((Images.hash == image_hash) & (Images.type == "image")) + img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "image")) img_obj.path = file_path img_obj.description = description img_obj.timestamp = current_timestamp diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 35f464b5f..d885312b0 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -100,7 +100,7 @@ class Emoji(BaseModel): full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名) format = TextField() # 图片格式 - hash = TextField(index=True) # 表情包的哈希值 + emoji_hash = TextField(index=True) # 表情包的哈希值 description = TextField() # 表情包的描述 query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数) is_registered = BooleanField(default=False) # 是否已注册 @@ -160,7 +160,7 @@ class Images(BaseModel): 用于存储图像信息的模型。 """ - hash = TextField(index=True) # 图像的哈希值 + emoji_hash = TextField(index=True) # 图像的哈希值 description = TextField(null=True) # 图像的描述 path = TextField(unique=True) # 图像文件的路径 timestamp = FloatField() # 时间戳 @@ -177,7 +177,7 @@ class ImageDescriptions(BaseModel): """ type = TextField() # 类型,例如 "emoji" - hash = TextField(index=True) # 图像的哈希值 + image_description_hash = TextField(index=True) # 图像的哈希值 description = TextField() # 图像的描述 timestamp = FloatField() # 时间戳 From fdd4ac8b4f383f392e82e8b88dd71a46768136cb Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 16 May 2025 17:13:18 +0800 Subject: [PATCH 27/57] Merge branch '063fix3' of https://github.com/MaiM-with-u/MaiBot into 063fix3 --- src/chat/focus_chat/info_processors/working_memory_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py index b3feedcf6..25041af8a 100644 --- a/src/chat/focus_chat/info_processors/working_memory_processor.py +++ b/src/chat/focus_chat/info_processors/working_memory_processor.py @@ -35,7 +35,7 @@ def init_prompt(): 现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容: {chat_observe_info} -以下是你已经总结的记忆,你可以调取这些记忆来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆: +以下是你已经总结的记忆摘要,你可以调取这些记忆查看内容来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆: {memory_str} 观察聊天内容和已经总结的记忆,思考是否有新内容需要总结成记忆,如果有,就输出 true,否则输出 false From d19d5fe885b455ff010ffc05a3b1f1b26baaf305 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 16 May 2025 17:15:55 +0800 Subject: [PATCH 28/57] fix:ruff --- .../expressors/default_expressor.py | 28 +- src/chat/focus_chat/heartFC_chat.py | 16 +- .../focus_chat/heartflow_prompt_builder.py | 3 +- src/chat/focus_chat/info/info_base.py | 2 +- src/chat/focus_chat/info/self_info.py | 5 +- .../focus_chat/info/workingmemory_info.py | 15 +- .../info_processors/chattinginfo_processor.py | 2 +- .../info_processors/mind_processor.py | 4 - .../info_processors/self_processor.py | 19 +- .../info_processors/tool_processor.py | 6 +- .../working_memory_processor.py | 59 ++- .../focus_chat/planners/action_manager.py | 27 +- .../focus_chat/planners/actions/__init__.py | 2 +- .../planners/actions/no_reply_action.py | 3 +- .../planners/actions/plugin_action.py | 94 ++--- .../planners/actions/reply_action.py | 3 +- src/chat/focus_chat/planners/planner.py | 4 +- .../focus_chat/working_memory/memory_item.py | 51 ++- .../working_memory/memory_manager.py | 363 +++++++++--------- .../working_memory/test/memory_file_loader.py | 169 -------- .../working_memory/test/run_memory_tests.py | 92 ----- .../test/simulate_real_usage.py | 197 ---------- .../working_memory/test/test_decay_removal.py | 323 ---------------- .../test/test_working_memory.py | 121 ------ .../working_memory/working_memory.py | 95 +++-- .../observation/hfcloop_observation.py | 9 +- .../heart_flow/observation/observation.py | 1 + .../observation/structure_observation.py | 2 +- .../observation/working_observation.py | 6 +- src/chat/person_info/person_info.py | 3 +- src/chat/utils/chat_message_builder.py | 4 +- src/plugins/__init__.py | 2 +- src/plugins/test_plugin/__init__.py | 3 +- src/plugins/test_plugin/actions/__init__.py | 7 +- .../test_plugin/actions/mute_action.py | 19 +- .../test_plugin/actions/online_action.py | 23 +- .../test_plugin/actions/test_action.py | 15 +- 37 files changed, 409 insertions(+), 1388 deletions(-) delete mode 100644 src/chat/focus_chat/working_memory/test/memory_file_loader.py delete mode 100644 src/chat/focus_chat/working_memory/test/run_memory_tests.py delete mode 100644 src/chat/focus_chat/working_memory/test/simulate_real_usage.py delete mode 100644 src/chat/focus_chat/working_memory/test/test_decay_removal.py delete mode 100644 src/chat/focus_chat/working_memory/test/test_working_memory.py diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index 6da4f52b8..37b634b37 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -10,7 +10,6 @@ from src.config.config import global_config from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.emoji_system.emoji_manager import emoji_manager -from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder,Prompt from src.chat.focus_chat.heartFC_sender import HeartFCSender from src.chat.utils.utils import process_llm_response from src.chat.utils.info_catcher import info_catcher_manager @@ -18,25 +17,16 @@ from src.manager.mood_manager import mood_manager from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info from src.chat.message_receive.chat_stream import ChatStream from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp -from src.config.config import global_config -from src.common.logger_manager import get_logger from src.individuality.individuality import Individuality from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat -from src.chat.person_info.relationship_manager import relationship_manager -from src.chat.utils.utils import get_embedding import time -from typing import Union, Optional -from src.common.database import db -from src.chat.utils.utils import get_recent_group_speaker -from src.manager.mood_manager import mood_manager -from src.chat.memory_system.Hippocampus import HippocampusManager -from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.focus_chat.expressors.exprssion_learner import expression_learner import random logger = get_logger("expressor") + def init_prompt(): Prompt( """ @@ -59,7 +49,7 @@ def init_prompt(): """, "default_expressor_prompt", ) - + Prompt( """ 你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: @@ -280,7 +270,7 @@ class DefaultExpressor: logger.error(f"{self.log_prefix}回复生成意外失败: {e}") traceback.print_exc() return None - + async def build_prompt_focus( self, reason, @@ -357,7 +347,7 @@ class DefaultExpressor: template_name, style_habbits=style_habbits_str, grammar_habbits=grammar_habbits_str, - chat_target=chat_target_1, + chat_target=chat_target_1, chat_info=chat_talking_prompt, bot_name=global_config.BOT_NICKNAME, prompt_personality="", @@ -377,9 +367,7 @@ class DefaultExpressor: moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), ) - - return prompt - + return prompt # --- 发送器 (Sender) --- # @@ -402,7 +390,7 @@ class DefaultExpressor: if thinking_id: thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id) else: - thinking_id = "ds"+ str(round(time.time(),2)) + thinking_id = "ds" + str(round(time.time(), 2)) thinking_start_time = time.time() if thinking_start_time is None: @@ -514,7 +502,6 @@ class DefaultExpressor: return bot_message - def weighted_sample_no_replacement(items, weights, k) -> list: """ 加权且不放回地随机抽取k个元素。 @@ -548,4 +535,5 @@ def weighted_sample_no_replacement(items, weights, k) -> list: break return selected -init_prompt() \ No newline at end of file + +init_prompt() diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 9fab88410..0f5371a36 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -17,7 +17,6 @@ from src.chat.focus_chat.info_processors.mind_processor import MindProcessor from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation -from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.focus_chat.memory_activator import MemoryActivator @@ -26,6 +25,7 @@ from src.chat.focus_chat.info_processors.self_processor import SelfProcessor from src.chat.focus_chat.planners.planner import ActionPlanner from src.chat.focus_chat.planners.action_manager import ActionManager from src.chat.focus_chat.working_memory.working_memory import WorkingMemory + install(extra_lines=3) @@ -85,17 +85,19 @@ class HeartFChatting: self.log_prefix: str = str(chat_id) # Initial default, will be updated self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id) self.chatting_observation = observations[0] - + self.memory_activator = MemoryActivator() self.working_memory = WorkingMemory(chat_id=self.stream_id) - self.working_observation = WorkingMemoryObservation(observe_id=self.stream_id, working_memory=self.working_memory) - + self.working_observation = WorkingMemoryObservation( + observe_id=self.stream_id, working_memory=self.working_memory + ) + self.expressor = DefaultExpressor(chat_id=self.stream_id) self.action_manager = ActionManager() self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager) - + self.hfcloop_observation.set_action_manager(self.action_manager) - + self.all_observations = observations # --- 处理器列表 --- self.processors: List[BaseProcessor] = [] @@ -369,7 +371,7 @@ class HeartFChatting: } self.all_observations = observations - + with Timer("回忆", cycle_timers): running_memorys = await self.memory_activator.activate_memory(observations) diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 830a1cfad..74bac0a1f 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -12,7 +12,6 @@ from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager -from src.chat.focus_chat.expressors.exprssion_learner import expression_learner import random @@ -20,7 +19,6 @@ logger = get_logger("prompt") def init_prompt(): - Prompt( """ 你有以下信息可供参考: @@ -521,5 +519,6 @@ class PromptBuilder: # 返回所有找到的内容,用换行分隔 return "\n".join(str(result["content"]) for result in results) + init_prompt() prompt_builder = PromptBuilder() diff --git a/src/chat/focus_chat/info/info_base.py b/src/chat/focus_chat/info/info_base.py index fbf060ba6..53ad30230 100644 --- a/src/chat/focus_chat/info/info_base.py +++ b/src/chat/focus_chat/info/info_base.py @@ -17,7 +17,7 @@ class InfoBase: type: str = "base" data: Dict[str, Any] = field(default_factory=dict) - processed_info:str = "" + processed_info: str = "" def get_type(self) -> str: """获取信息类型 diff --git a/src/chat/focus_chat/info/self_info.py b/src/chat/focus_chat/info/self_info.py index 82edd2655..866457956 100644 --- a/src/chat/focus_chat/info/self_info.py +++ b/src/chat/focus_chat/info/self_info.py @@ -1,5 +1,4 @@ -from typing import Dict, Any -from dataclasses import dataclass, field +from dataclasses import dataclass from .info_base import InfoBase @@ -31,7 +30,7 @@ class SelfInfo(InfoBase): self_info: 要设置的思维状态 """ self.data["self_info"] = self_info - + def get_processed_info(self) -> str: """获取处理后的信息 diff --git a/src/chat/focus_chat/info/workingmemory_info.py b/src/chat/focus_chat/info/workingmemory_info.py index 8c94f6fbc..0edce8944 100644 --- a/src/chat/focus_chat/info/workingmemory_info.py +++ b/src/chat/focus_chat/info/workingmemory_info.py @@ -5,10 +5,9 @@ from .info_base import InfoBase @dataclass class WorkingMemoryInfo(InfoBase): - type: str = "workingmemory" - - processed_info:str = "" + + processed_info: str = "" def set_talking_message(self, message: str) -> None: """设置说话消息 @@ -25,7 +24,7 @@ class WorkingMemoryInfo(InfoBase): working_memory (str): 工作记忆内容 """ self.data["working_memory"] = working_memory - + def add_working_memory(self, working_memory: str) -> None: """添加工作记忆 @@ -37,7 +36,7 @@ class WorkingMemoryInfo(InfoBase): working_memory_list.append(working_memory) # print(f"working_memory_list: {working_memory_list}") self.data["working_memory"] = working_memory_list - + def get_working_memory(self) -> List[str]: """获取工作记忆 @@ -72,7 +71,7 @@ class WorkingMemoryInfo(InfoBase): Optional[str]: 属性值,如果键不存在则返回 None """ return self.data.get(key) - + def get_processed_info(self) -> Dict[str, str]: """获取处理后的信息 @@ -84,7 +83,7 @@ class WorkingMemoryInfo(InfoBase): memory_str = "" for memory in all_memory: memory_str += f"{memory}\n" - + self.processed_info = memory_str - + return self.processed_info diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index 0accc2a34..bb565ee7e 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -55,7 +55,7 @@ class ChattingInfoProcessor(BaseProcessor): # print(f"obs: {obs}") if isinstance(obs, ChattingObservation): # print("1111111111111111111111读取111111111111111") - + obs_info = ObsInfo() await self.chat_compress(obs) diff --git a/src/chat/focus_chat/info_processors/mind_processor.py b/src/chat/focus_chat/info_processors/mind_processor.py index 95233a9f7..09228174c 100644 --- a/src/chat/focus_chat/info_processors/mind_processor.py +++ b/src/chat/focus_chat/info_processors/mind_processor.py @@ -6,11 +6,9 @@ import time import traceback from src.common.logger_manager import get_logger from src.individuality.individuality import Individuality -import random from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.json_utils import safe_json_dumps from src.chat.message_receive.chat_stream import chat_manager -import difflib from src.chat.person_info.relationship_manager import relationship_manager from .base_processor import BaseProcessor from src.chat.focus_chat.info.mind_info import MindInfo @@ -202,7 +200,6 @@ class MindProcessor(BaseProcessor): for person in person_list: relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True) - template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before" logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板") @@ -218,7 +215,6 @@ class MindProcessor(BaseProcessor): chat_target_name=chat_target_name, ) - content = "(不知道该想些什么...)" try: content, _ = await self.llm_model.generate_response_async(prompt=prompt) diff --git a/src/chat/focus_chat/info_processors/self_processor.py b/src/chat/focus_chat/info_processors/self_processor.py index 923c38c35..19876c93c 100644 --- a/src/chat/focus_chat/info_processors/self_processor.py +++ b/src/chat/focus_chat/info_processors/self_processor.py @@ -6,14 +6,10 @@ import time import traceback from src.common.logger_manager import get_logger from src.individuality.individuality import Individuality -import random from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.chat.utils.json_utils import safe_json_dumps from src.chat.message_receive.chat_stream import chat_manager -import difflib from src.chat.person_info.relationship_manager import relationship_manager from .base_processor import BaseProcessor -from src.chat.focus_chat.info.mind_info import MindInfo from typing import List, Optional from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from typing import Dict @@ -44,7 +40,6 @@ def init_prompt(): Prompt(indentify_prompt, "indentify_prompt") - class SelfProcessor(BaseProcessor): log_prefix = "自我认同" @@ -63,7 +58,6 @@ class SelfProcessor(BaseProcessor): name = chat_manager.get_stream_name(self.subheartflow_id) self.log_prefix = f"[{name}] " - async def process_info( self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos ) -> List[InfoBase]: @@ -76,7 +70,7 @@ class SelfProcessor(BaseProcessor): List[InfoBase]: 处理后的结构化信息列表 """ self_info_str = await self.self_indentify(observations, running_memorys) - + if self_info_str: self_info = SelfInfo() self_info.set_self_info(self_info_str) @@ -102,14 +96,12 @@ class SelfProcessor(BaseProcessor): tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt """ - memory_str = "" if running_memorys: memory_str = "以下是当前在聊天中,你回忆起的记忆:\n" for running_memory in running_memorys: memory_str += f"{running_memory['topic']}: {running_memory['content']}\n" - if observations is None: observations = [] for observation in observations: @@ -127,8 +119,8 @@ class SelfProcessor(BaseProcessor): chat_observe_info = observation.get_observe_info() person_list = observation.person_list if isinstance(observation, HFCloopObservation): - hfcloop_observe_info = observation.get_observe_info() - + # hfcloop_observe_info = observation.get_observe_info() + pass individuality = Individuality.get_instance() personality_block = individuality.get_prompt(x_person=2, level=2) @@ -137,7 +129,6 @@ class SelfProcessor(BaseProcessor): for person in person_list: relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True) - prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format( bot_name=individuality.name, prompt_personality=personality_block, @@ -147,7 +138,6 @@ class SelfProcessor(BaseProcessor): chat_observe_info=chat_observe_info, ) - content = "" try: content, _ = await self.llm_model.generate_response_async(prompt=prompt) @@ -159,7 +149,7 @@ class SelfProcessor(BaseProcessor): logger.error(traceback.format_exc()) content = "自我识别过程中出现错误" - if content == 'None': + if content == "None": content = "" # 记录初步思考结果 logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n") @@ -168,5 +158,4 @@ class SelfProcessor(BaseProcessor): return content - init_prompt() diff --git a/src/chat/focus_chat/info_processors/tool_processor.py b/src/chat/focus_chat/info_processors/tool_processor.py index 563621e03..92c1b607a 100644 --- a/src/chat/focus_chat/info_processors/tool_processor.py +++ b/src/chat/focus_chat/info_processors/tool_processor.py @@ -4,7 +4,7 @@ from src.config.config import global_config import time from src.common.logger_manager import get_logger from src.individuality.individuality import Individuality -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.tools.tool_use import ToolUser from src.chat.utils.json_utils import process_llm_tool_calls from src.chat.person_info.relationship_manager import relationship_manager @@ -68,7 +68,7 @@ class ToolProcessor(BaseProcessor): """ working_infos = [] - + if observations: for observation in observations: if isinstance(observation, ChattingObservation): @@ -134,7 +134,7 @@ class ToolProcessor(BaseProcessor): # 获取个性信息 individuality = Individuality.get_instance() - prompt_personality = individuality.get_prompt(x_person=2, level=2) + # prompt_personality = individuality.get_prompt(x_person=2, level=2) # 获取时间信息 time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py index 25041af8a..c682da699 100644 --- a/src/chat/focus_chat/info_processors/working_memory_processor.py +++ b/src/chat/focus_chat/info_processors/working_memory_processor.py @@ -5,17 +5,11 @@ from src.config.config import global_config import time import traceback from src.common.logger_manager import get_logger -from src.individuality.individuality import Individuality -import random from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.chat.utils.json_utils import safe_json_dumps from src.chat.message_receive.chat_stream import chat_manager -import difflib -from src.chat.person_info.relationship_manager import relationship_manager from .base_processor import BaseProcessor from src.chat.focus_chat.info.mind_info import MindInfo from typing import List, Optional -from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation from src.chat.focus_chat.working_memory.working_memory import WorkingMemory from typing import Dict @@ -76,8 +70,6 @@ class WorkingMemoryProcessor(BaseProcessor): name = chat_manager.get_stream_name(self.subheartflow_id) self.log_prefix = f"[{name}] " - - async def process_info( self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos ) -> List[InfoBase]: @@ -95,11 +87,11 @@ class WorkingMemoryProcessor(BaseProcessor): for observation in observations: if isinstance(observation, WorkingMemoryObservation): working_memory = observation.get_observe_info() - working_memory_obs = observation + # working_memory_obs = observation if isinstance(observation, ChattingObservation): chat_info = observation.get_observe_info() # chat_info_truncate = observation.talking_message_str_truncate - + if not working_memory: logger.warning(f"{self.log_prefix} 没有找到工作记忆对象") mind_info = MindInfo() @@ -108,44 +100,42 @@ class WorkingMemoryProcessor(BaseProcessor): logger.error(f"{self.log_prefix} 处理观察时出错: {e}") logger.error(traceback.format_exc()) return [] - + all_memory = working_memory.get_all_memories() memory_prompts = [] for memory in all_memory: - memory_content = memory.data + # memory_content = memory.data memory_summary = memory.summary memory_id = memory.id memory_brief = memory_summary.get("brief") - memory_detailed = memory_summary.get("detailed") + # memory_detailed = memory_summary.get("detailed") memory_keypoints = memory_summary.get("keypoints") memory_events = memory_summary.get("events") memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n" memory_prompts.append(memory_single_prompt) - + memory_choose_str = "".join(memory_prompts) - + # 使用提示模板进行处理 prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format( bot_name=global_config.BOT_NICKNAME, time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), chat_observe_info=chat_info, - memory_str=memory_choose_str + memory_str=memory_choose_str, ) - + # 调用LLM处理记忆 content = "" try: - logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}") - - + content, _ = await self.llm_model.generate_response_async(prompt=prompt) if not content: logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。") except Exception as e: logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") logger.error(traceback.format_exc()) - + # 解析LLM返回的JSON try: result = repair_json(content) @@ -154,7 +144,7 @@ class WorkingMemoryProcessor(BaseProcessor): if not isinstance(result, dict): logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败,结果不是字典类型: {type(result)}") return [] - + selected_memory_ids = result.get("selected_memory_ids", []) new_memory = result.get("new_memory", "") merge_memory = result.get("merge_memory", []) @@ -162,20 +152,20 @@ class WorkingMemoryProcessor(BaseProcessor): logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}") logger.error(traceback.format_exc()) return [] - + logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}") - + # 根据selected_memory_ids,调取记忆 memory_str = "" if selected_memory_ids: for memory_id in selected_memory_ids: memory = await working_memory.retrieve_memory(memory_id) if memory: - memory_content = memory.data + # memory_content = memory.data memory_summary = memory.summary memory_id = memory.id memory_brief = memory_summary.get("brief") - memory_detailed = memory_summary.get("detailed") + # memory_detailed = memory_summary.get("detailed") memory_keypoints = memory_summary.get("keypoints") memory_events = memory_summary.get("events") for keypoint in memory_keypoints: @@ -184,21 +174,20 @@ class WorkingMemoryProcessor(BaseProcessor): memory_str += f"记忆事件:{event}\n" # memory_str += f"记忆摘要:{memory_detailed}\n" # memory_str += f"记忆主题:{memory_brief}\n" - - + working_memory_info = WorkingMemoryInfo() if memory_str: working_memory_info.add_working_memory(memory_str) logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}") else: logger.warning(f"{self.log_prefix} 没有找到工作记忆") - + # 根据聊天内容添加新记忆 if new_memory: # 使用异步方式添加新记忆,不阻塞主流程 logger.debug(f"{self.log_prefix} {new_memory}新记忆: ") asyncio.create_task(self.add_memory_async(working_memory, chat_info)) - + if merge_memory: for merge_pairs in merge_memory: memory1 = await working_memory.retrieve_memory(merge_pairs[0]) @@ -207,12 +196,12 @@ class WorkingMemoryProcessor(BaseProcessor): memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n" memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n" asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1])) - + return [working_memory_info] async def add_memory_async(self, working_memory: WorkingMemory, content: str): """异步添加记忆,不阻塞主流程 - + Args: working_memory: 工作记忆对象 content: 记忆内容 @@ -223,10 +212,10 @@ class WorkingMemoryProcessor(BaseProcessor): except Exception as e: logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}") logger.error(traceback.format_exc()) - + async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str): """异步合并记忆,不阻塞主流程 - + Args: working_memory: 工作记忆对象 memory_str: 记忆内容 @@ -238,7 +227,7 @@ class WorkingMemoryProcessor(BaseProcessor): logger.debug(f"{self.log_prefix} 合并后的记忆详情: {merged_memory.summary.get('detailed')}") logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('keypoints')}") logger.debug(f"{self.log_prefix} 合并后的记忆事件: {merged_memory.summary.get('events')}") - + except Exception as e: logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}") logger.error(traceback.format_exc()) diff --git a/src/chat/focus_chat/planners/action_manager.py b/src/chat/focus_chat/planners/action_manager.py index 02c77c2b6..2ee7f349d 100644 --- a/src/chat/focus_chat/planners/action_manager.py +++ b/src/chat/focus_chat/planners/action_manager.py @@ -37,7 +37,7 @@ class ActionManager: # 加载所有已注册动作 self._load_registered_actions() - + # 加载插件动作 self._load_plugin_actions() @@ -52,11 +52,11 @@ class ActionManager: # 从_ACTION_REGISTRY获取所有已注册动作 for action_name, action_class in _ACTION_REGISTRY.items(): # 获取动作相关信息 - + # 不读取插件动作和基类 if action_name == "base_action" or action_name == "plugin_action": continue - + action_description: str = getattr(action_class, "action_description", "") action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {}) action_require: list[str] = getattr(action_class, "action_require", []) @@ -80,11 +80,11 @@ class ActionManager: # logger.info(f"所有注册动作: {list(self._registered_actions.keys())}") # logger.info(f"默认动作: {list(self._default_actions.keys())}") # for action_name, action_info in self._default_actions.items(): - # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") + # logger.info(f"动作名称: {action_name}, 动作信息: {action_info}") except Exception as e: logger.error(f"加载已注册动作失败: {e}") - + def _load_plugin_actions(self) -> None: """ 加载所有插件目录中的动作 @@ -92,23 +92,25 @@ class ActionManager: try: # 检查插件目录是否存在 plugin_path = "src.plugins" - plugin_dir = plugin_path.replace('.', os.path.sep) + plugin_dir = plugin_path.replace(".", os.path.sep) if not os.path.exists(plugin_dir): logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载") return - + # 导入插件包 try: plugins_package = importlib.import_module(plugin_path) except ImportError as e: logger.error(f"导入插件包失败: {e}") return - + # 遍历插件包中的所有子包 - for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + '.'): + for _, plugin_name, is_pkg in pkgutil.iter_modules( + plugins_package.__path__, plugins_package.__name__ + "." + ): if not is_pkg: continue - + # 检查插件是否有actions子包 plugin_actions_path = f"{plugin_name}.actions" try: @@ -118,10 +120,10 @@ class ActionManager: except ImportError as e: logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}") continue - + # 再次从_ACTION_REGISTRY获取所有动作(包括刚刚从插件加载的) self._load_registered_actions() - + except Exception as e: logger.error(f"加载插件动作失败: {e}") @@ -316,4 +318,3 @@ class ActionManager: Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None """ return _ACTION_REGISTRY.get(action_name) - diff --git a/src/chat/focus_chat/planners/actions/__init__.py b/src/chat/focus_chat/planners/actions/__init__.py index 435d0d4b4..3f2baf665 100644 --- a/src/chat/focus_chat/planners/actions/__init__.py +++ b/src/chat/focus_chat/planners/actions/__init__.py @@ -2,4 +2,4 @@ from . import reply_action # noqa from . import no_reply_action # noqa -# 在此处添加更多动作模块导入 \ No newline at end of file +# 在此处添加更多动作模块导入 diff --git a/src/chat/focus_chat/planners/actions/no_reply_action.py b/src/chat/focus_chat/planners/actions/no_reply_action.py index 406ddbdc2..c6852fbe1 100644 --- a/src/chat/focus_chat/planners/actions/no_reply_action.py +++ b/src/chat/focus_chat/planners/actions/no_reply_action.py @@ -94,8 +94,7 @@ class NoReplyAction(BaseAction): # 等待新消息、超时或关闭信号,并获取结果 await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix) # 从计时器获取实际等待时间 - current_waiting = self.cycle_timers.get("等待新消息", 0.0) - + _current_waiting = self.cycle_timers.get("等待新消息", 0.0) return True, "" # 不回复动作没有回复文本 diff --git a/src/chat/focus_chat/planners/actions/plugin_action.py b/src/chat/focus_chat/planners/actions/plugin_action.py index aec879e97..5e8ddd998 100644 --- a/src/chat/focus_chat/planners/actions/plugin_action.py +++ b/src/chat/focus_chat/planners/actions/plugin_action.py @@ -1,6 +1,6 @@ import traceback from typing import Tuple, Dict, List, Any, Optional -from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action +from src.chat.focus_chat.planners.actions.base_action import BaseAction from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.focus_chat.hfc_utils import create_empty_anchor_message from src.common.logger_manager import get_logger @@ -9,19 +9,20 @@ from abc import abstractmethod logger = get_logger("plugin_action") + class PluginAction(BaseAction): """插件动作基类 - + 封装了主程序内部依赖,提供简化的API接口给插件开发者 """ - + def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs): """初始化插件动作基类""" super().__init__(action_data, reasoning, cycle_timers, thinking_id) - + # 存储内部服务和对象引用 self._services = {} - + # 从kwargs提取必要的内部服务 if "observations" in kwargs: self._services["observations"] = kwargs["observations"] @@ -31,48 +32,43 @@ class PluginAction(BaseAction): self._services["chat_stream"] = kwargs["chat_stream"] if "current_cycle" in kwargs: self._services["current_cycle"] = kwargs["current_cycle"] - + self.log_prefix = kwargs.get("log_prefix", "") - + async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]: """根据用户名获取用户ID""" person_id = person_info_manager.get_person_id_by_person_name(person_name) user_id = await person_info_manager.get_value(person_id, "user_id") platform = await person_info_manager.get_value(person_id, "platform") return platform, user_id - + # 提供简化的API方法 async def send_message(self, text: str, target: Optional[str] = None) -> bool: """发送消息的简化方法 - + Args: text: 要发送的消息文本 target: 目标消息(可选) - + Returns: bool: 是否发送成功 """ try: expressor = self._services.get("expressor") chat_stream = self._services.get("chat_stream") - + if not expressor or not chat_stream: logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") return False - + # 构造简化的动作数据 - reply_data = { - "text": text, - "target": target or "", - "emojis": [] - } - + reply_data = {"text": text, "target": target or "", "emojis": []} + # 获取锚定消息(如果有) observations = self._services.get("observations", []) chatting_observation: ChattingObservation = next( - obs for obs in observations - if isinstance(obs, ChattingObservation) + obs for obs in observations if isinstance(obs, ChattingObservation) ) anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) @@ -84,55 +80,49 @@ class PluginAction(BaseAction): ) else: anchor_message.update_chat_stream(chat_stream) - + response_set = [ ("text", text), ] - + # 调用内部方法发送消息 success = await expressor.send_response_messages( anchor_message=anchor_message, response_set=response_set, ) - + return success except Exception as e: logger.error(f"{self.log_prefix} 发送消息时出错: {e}") traceback.print_exc() return False - - + async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool: """发送消息的简化方法 - + Args: text: 要发送的消息文本 target: 目标消息(可选) - + Returns: bool: 是否发送成功 """ try: expressor = self._services.get("expressor") chat_stream = self._services.get("chat_stream") - + if not expressor or not chat_stream: logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") return False - + # 构造简化的动作数据 - reply_data = { - "text": text, - "target": target or "", - "emojis": [] - } - + reply_data = {"text": text, "target": target or "", "emojis": []} + # 获取锚定消息(如果有) observations = self._services.get("observations", []) chatting_observation: ChattingObservation = next( - obs for obs in observations - if isinstance(obs, ChattingObservation) + obs for obs in observations if isinstance(obs, ChattingObservation) ) anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) @@ -144,24 +134,24 @@ class PluginAction(BaseAction): ) else: anchor_message.update_chat_stream(chat_stream) - + # 调用内部方法发送消息 success, _ = await expressor.deal_reply( cycle_timers=self.cycle_timers, action_data=reply_data, anchor_message=anchor_message, reasoning=self.reasoning, - thinking_id=self.thinking_id + thinking_id=self.thinking_id, ) - + return success except Exception as e: logger.error(f"{self.log_prefix} 发送消息时出错: {e}") return False - + def get_chat_type(self) -> str: """获取当前聊天类型 - + Returns: str: 聊天类型 ("group" 或 "private") """ @@ -169,19 +159,19 @@ class PluginAction(BaseAction): if chat_stream and hasattr(chat_stream, "group_info"): return "group" if chat_stream.group_info else "private" return "unknown" - + def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]: """获取最近的消息 - + Args: count: 要获取的消息数量 - + Returns: List[Dict]: 消息列表,每个消息包含发送者、内容等信息 """ messages = [] observations = self._services.get("observations", []) - + if observations and len(observations) > 0: obs = observations[0] if hasattr(obs, "get_talking_message"): @@ -191,24 +181,24 @@ class PluginAction(BaseAction): simple_msg = { "sender": msg.get("sender", "未知"), "content": msg.get("content", ""), - "timestamp": msg.get("timestamp", 0) + "timestamp": msg.get("timestamp", 0), } messages.append(simple_msg) - + return messages - + @abstractmethod async def process(self) -> Tuple[bool, str]: """插件处理逻辑,子类必须实现此方法 - + Returns: Tuple[bool, str]: (是否执行成功, 回复文本) """ pass - + async def handle_action(self) -> Tuple[bool, str]: """实现BaseAction的抽象方法,调用子类的process方法 - + Returns: Tuple[bool, str]: (是否执行成功, 回复文本) """ diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 6452ecb0f..07e35b458 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -105,8 +105,7 @@ class ReplyAction(BaseAction): # 从聊天观察获取锚定消息 chatting_observation: ChattingObservation = next( - obs for obs in self.observations - if isinstance(obs, ChattingObservation) + obs for obs in self.observations if isinstance(obs, ChattingObservation) ) if reply_data.get("target"): anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index dba9d4b1a..21ca157f9 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -109,7 +109,7 @@ class ActionPlanner: cycle_info = info.get_observe_info() elif isinstance(info, StructuredInfo): # logger.debug(f"{self.log_prefix} 结构化信息: {info}") - structured_info = info.get_data() + _structured_info = info.get_data() else: logger.debug(f"{self.log_prefix} 其他信息: {info}") extra_info.append(info.get_processed_info()) @@ -157,7 +157,7 @@ class ActionPlanner: for key, value in parsed_json.items(): if key not in ["action", "reasoning"]: action_data[key] = value - + # 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data if extracted_action not in current_available_actions: diff --git a/src/chat/focus_chat/working_memory/memory_item.py b/src/chat/focus_chat/working_memory/memory_item.py index f922eff8f..15724a387 100644 --- a/src/chat/focus_chat/working_memory/memory_item.py +++ b/src/chat/focus_chat/working_memory/memory_item.py @@ -1,23 +1,16 @@ -from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple +from typing import Dict, Any, List, Optional, Set, Tuple import time -import uuid -import traceback import random import string -from json_repair import repair_json -from rich.traceback import install -from src.common.logger_manager import get_logger -from src.chat.models.utils_model import LLMRequest -from src.config.config import global_config class MemoryItem: """记忆项类,用于存储单个记忆的所有相关信息""" - + def __init__(self, data: Any, from_source: str = "", tags: Optional[List[str]] = None): """ 初始化记忆项 - + Args: data: 记忆数据 from_source: 数据来源 @@ -25,7 +18,7 @@ class MemoryItem: """ # 生成可读ID:时间戳_随机字符串 timestamp = int(time.time()) - random_str = ''.join(random.choices(string.ascii_lowercase + string.digits, k=2)) + random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2)) self.id = f"{timestamp}_{random_str}" self.data = data self.data_type = type(data) @@ -40,63 +33,63 @@ class MemoryItem: # "events": ["事件1", "事件2"] # } self.summary = None - + # 记忆精简次数 self.compress_count = 0 - + # 记忆提取次数 self.retrieval_count = 0 - + # 记忆强度 (初始为10) self.memory_strength = 10.0 - + # 记忆操作历史记录 # 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...] self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)] - + def add_tag(self, tag: str) -> None: """添加标签""" self.tags.add(tag) - + def remove_tag(self, tag: str) -> None: """移除标签""" if tag in self.tags: self.tags.remove(tag) - + def has_tag(self, tag: str) -> bool: """检查是否有特定标签""" return tag in self.tags - + def has_all_tags(self, tags: List[str]) -> bool: """检查是否有所有指定的标签""" return all(tag in self.tags for tag in tags) - + def matches_source(self, source: str) -> bool: """检查来源是否匹配""" return self.from_source == source - + def set_summary(self, summary: Dict[str, Any]) -> None: """设置总结信息""" self.summary = summary - + def increase_strength(self, amount: float) -> None: """增加记忆强度""" self.memory_strength = min(10.0, self.memory_strength + amount) # 记录操作历史 self.record_operation("strengthen") - + def decrease_strength(self, amount: float) -> None: """减少记忆强度""" self.memory_strength = max(0.1, self.memory_strength - amount) # 记录操作历史 self.record_operation("weaken") - + def increase_compress_count(self) -> None: """增加精简次数并减弱记忆强度""" self.compress_count += 1 # 记录操作历史 self.record_operation("compress") - + def record_retrieval(self) -> None: """记录记忆被提取的情况""" self.retrieval_count += 1 @@ -104,16 +97,16 @@ class MemoryItem: self.memory_strength = min(10.0, self.memory_strength * 2) # 记录操作历史 self.record_operation("retrieval") - + def record_operation(self, operation_type: str) -> None: """记录操作历史""" current_time = time.time() self.history.append((operation_type, current_time, self.compress_count, self.memory_strength)) - + def to_tuple(self) -> Tuple[Any, str, Set[str], float, str]: """转换为元组格式(为了兼容性)""" return (self.data, self.from_source, self.tags, self.timestamp, self.id) - + def is_memory_valid(self) -> bool: """检查记忆是否有效(强度是否大于等于1)""" - return self.memory_strength >= 1.0 \ No newline at end of file + return self.memory_strength >= 1.0 diff --git a/src/chat/focus_chat/working_memory/memory_manager.py b/src/chat/focus_chat/working_memory/memory_manager.py index d99488378..7154fe48c 100644 --- a/src/chat/focus_chat/working_memory/memory_manager.py +++ b/src/chat/focus_chat/working_memory/memory_manager.py @@ -1,6 +1,4 @@ -from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple -import time -import uuid +from typing import Dict, Any, Type, TypeVar, List, Optional import traceback from json_repair import repair_json from rich.traceback import install @@ -14,74 +12,71 @@ import json # 添加json模块导入 install(extra_lines=3) logger = get_logger("working_memory") -T = TypeVar('T') +T = TypeVar("T") class MemoryManager: def __init__(self, chat_id: str): """ 初始化工作记忆 - + Args: chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天 """ # 关联的聊天ID self._chat_id = chat_id - + # 主存储: 数据类型 -> 记忆项列表 self._memory: Dict[Type, List[MemoryItem]] = {} - + # ID到记忆项的映射 self._id_map: Dict[str, MemoryItem] = {} - + self.llm_summarizer = LLMRequest( - model=global_config.llm_summary, - temperature=0.3, - max_tokens=512, - request_type="memory_summarization" + model=global_config.llm_summary, temperature=0.3, max_tokens=512, request_type="memory_summarization" ) - + @property def chat_id(self) -> str: """获取关联的聊天ID""" return self._chat_id - + @chat_id.setter def chat_id(self, value: str): """设置关联的聊天ID""" self._chat_id = value - + def push_item(self, memory_item: MemoryItem) -> str: """ 推送一个已创建的记忆项到工作记忆中 - + Args: memory_item: 要存储的记忆项 - + Returns: 记忆项的ID """ data_type = memory_item.data_type - + # 确保存在该类型的存储列表 if data_type not in self._memory: self._memory[data_type] = [] - + # 添加到内存和ID映射 self._memory[data_type].append(memory_item) self._id_map[memory_item.id] = memory_item - + return memory_item.id - + async def push_with_summary(self, data: T, from_source: str = "", tags: Optional[List[str]] = None) -> MemoryItem: """ 推送一段有类型的信息到工作记忆中,并自动生成总结 - + Args: data: 要存储的数据 from_source: 数据来源 tags: 数据标签列表 - + Returns: 包含原始数据和总结信息的字典 """ @@ -89,65 +84,66 @@ class MemoryManager: if isinstance(data, str): # 先生成总结 summary = await self.summarize_memory_item(data) - + # 准备标签 memory_tags = list(tags) if tags else [] - + # 创建记忆项 memory_item = MemoryItem(data, from_source, memory_tags) - + # 将总结信息保存到记忆项中 memory_item.set_summary(summary) - + # 推送记忆项 self.push_item(memory_item) - + return memory_item else: # 非字符串类型,直接创建并推送记忆项 memory_item = MemoryItem(data, from_source, tags) self.push_item(memory_item) - + return memory_item - + def get_by_id(self, memory_id: str) -> Optional[MemoryItem]: """ 通过ID获取记忆项 - + Args: memory_id: 记忆项ID - + Returns: 找到的记忆项,如果不存在则返回None """ memory_item = self._id_map.get(memory_id) if memory_item: - # 检查记忆强度,如果小于1则删除 if not memory_item.is_memory_valid(): print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除") self.delete(memory_id) return None - + return memory_item - + def get_all_items(self) -> List[MemoryItem]: """获取所有记忆项""" return list(self._id_map.values()) - - def find_items(self, - data_type: Optional[Type] = None, - source: Optional[str] = None, - tags: Optional[List[str]] = None, - start_time: Optional[float] = None, - end_time: Optional[float] = None, - memory_id: Optional[str] = None, - limit: Optional[int] = None, - newest_first: bool = False, - min_strength: float = 0.0) -> List[MemoryItem]: + + def find_items( + self, + data_type: Optional[Type] = None, + source: Optional[str] = None, + tags: Optional[List[str]] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + memory_id: Optional[str] = None, + limit: Optional[int] = None, + newest_first: bool = False, + min_strength: float = 0.0, + ) -> List[MemoryItem]: """ 按条件查找记忆项 - + Args: data_type: 要查找的数据类型 source: 数据来源 @@ -158,7 +154,7 @@ class MemoryManager: limit: 返回结果的最大数量 newest_first: 是否按最新优先排序 min_strength: 最小记忆强度 - + Returns: 符合条件的记忆项列表 """ @@ -166,62 +162,62 @@ class MemoryManager: if memory_id: item = self.get_by_id(memory_id) return [item] if item else [] - + results = [] - + # 确定要搜索的类型列表 types_to_search = [data_type] if data_type else list(self._memory.keys()) - + # 对每个类型进行搜索 for typ in types_to_search: if typ not in self._memory: continue - + # 获取该类型的所有项目 items = self._memory[typ] - + # 如果需要最新优先,则反转遍历顺序 if newest_first: items_to_check = list(reversed(items)) else: items_to_check = items - + # 遍历项目 for item in items_to_check: # 检查来源是否匹配 if source is not None and not item.matches_source(source): continue - + # 检查标签是否匹配 if tags is not None and not item.has_all_tags(tags): continue - + # 检查时间范围 if start_time is not None and item.timestamp < start_time: continue if end_time is not None and item.timestamp > end_time: continue - + # 检查记忆强度 if min_strength > 0 and item.memory_strength < min_strength: continue - + # 所有条件都满足,添加到结果中 results.append(item) - + # 如果达到限制数量,提前返回 if limit is not None and len(results) >= limit: return results - + return results - + async def summarize_memory_item(self, content: str) -> Dict[str, Any]: """ 使用LLM总结记忆项 - + Args: content: 需要总结的内容 - + Returns: 包含总结、概括、关键概念和事件的字典 """ @@ -257,18 +253,18 @@ class MemoryManager: "brief": "主题未知的记忆", "detailed": "大致内容未知的记忆", "keypoints": ["未知的概念"], - "events": ["未知的事件"] + "events": ["未知的事件"], } - + try: # 调用LLM生成总结 response, _ = await self.llm_summarizer.generate_response_async(prompt) - + # 使用repair_json解析响应 try: # 使用repair_json修复JSON格式 fixed_json_string = repair_json(response) - + # 如果repair_json返回的是字符串,需要解析为Python对象 if isinstance(fixed_json_string, str): try: @@ -279,68 +275,60 @@ class MemoryManager: else: # 如果repair_json直接返回了字典对象,直接使用 json_result = fixed_json_string - + # 进行额外的类型检查 if not isinstance(json_result, dict): logger.error(f"修复后的JSON不是字典类型: {type(json_result)}") return default_summary - + # 确保所有必要字段都存在且类型正确 if "brief" not in json_result or not isinstance(json_result["brief"], str): json_result["brief"] = "主题未知的记忆" - + if "detailed" not in json_result or not isinstance(json_result["detailed"], str): json_result["detailed"] = "大致内容未知的记忆" - + # 处理关键概念 if "keypoints" not in json_result or not isinstance(json_result["keypoints"], list): json_result["keypoints"] = ["未知的概念"] else: # 确保keypoints中的每个项目都是字符串 - json_result["keypoints"] = [ - str(point) for point in json_result["keypoints"] - if point is not None - ] + json_result["keypoints"] = [str(point) for point in json_result["keypoints"] if point is not None] if not json_result["keypoints"]: json_result["keypoints"] = ["未知的概念"] - + # 处理事件 if "events" not in json_result or not isinstance(json_result["events"], list): json_result["events"] = ["未知的事件"] else: # 确保events中的每个项目都是字符串 - json_result["events"] = [ - str(event) for event in json_result["events"] - if event is not None - ] + json_result["events"] = [str(event) for event in json_result["events"] if event is not None] if not json_result["events"]: json_result["events"] = ["未知的事件"] - + # 兼容旧版,将keypoints和events合并到key_points中 json_result["key_points"] = json_result["keypoints"] + json_result["events"] - + return json_result - + except Exception as json_error: logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要") # 返回默认结构 return default_summary - + except Exception as e: # 出错时返回简单的结构 logger.error(f"生成总结时出错: {str(e)}") return default_summary - - async def refine_memory(self, - memory_id: str, - requirements: str = "") -> Dict[str, Any]: + + async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]: """ 对记忆进行精简操作,根据要求修改要点、总结和概括 - + Args: memory_id: 记忆ID requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点 - + Returns: 修改后的记忆总结字典 """ @@ -349,12 +337,12 @@ class MemoryManager: memory_item = self.get_by_id(memory_id) if not memory_item: raise ValueError(f"未找到ID为{memory_id}的记忆项") - + # 增加精简次数 memory_item.increase_compress_count() - + summary = memory_item.summary - + # 使用LLM根据要求对总结、概括和要点进行精简修改 prompt = f""" 请根据以下要求,对记忆内容的主题、概括、关键概念和事件进行精简,模拟记忆的遗忘过程: @@ -396,15 +384,15 @@ class MemoryManager: halfway = len(key_points) // 2 summary["keypoints"] = key_points[:halfway] or ["未知的概念"] summary["events"] = key_points[halfway:] or ["未知的事件"] - + # 定义默认的精简结果 default_refined = { "brief": summary["brief"], "detailed": summary["detailed"], "keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念 - "events": summary.get("events", ["未知的事件"])[:1] # 默认只保留第一个事件 + "events": summary.get("events", ["未知的事件"])[:1], # 默认只保留第一个事件 } - + try: # 调用LLM修改总结、概括和要点 response, _ = await self.llm_summarizer.generate_response_async(prompt) @@ -413,7 +401,7 @@ class MemoryManager: try: # 修复JSON格式 fixed_json_string = repair_json(response) - + # 将修复后的字符串解析为Python对象 if isinstance(fixed_json_string, str): try: @@ -424,16 +412,16 @@ class MemoryManager: else: # 如果repair_json直接返回了字典对象,直接使用 refined_data = fixed_json_string - + # 确保是字典类型 if not isinstance(refined_data, dict): logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}") refined_data = default_refined - + # 更新总结、概括 summary["brief"] = refined_data.get("brief", "主题未知的记忆") summary["detailed"] = refined_data.get("detailed", "大致内容未知的记忆") - + # 更新关键概念 keypoints = refined_data.get("keypoints", []) if isinstance(keypoints, list) and keypoints: @@ -442,7 +430,7 @@ class MemoryManager: else: # 如果keypoints不是列表或为空,使用默认值 summary["keypoints"] = ["主要概念已遗忘"] - + # 更新事件 events = refined_data.get("events", []) if isinstance(events, list) and events: @@ -451,84 +439,83 @@ class MemoryManager: else: # 如果events不是列表或为空,使用默认值 summary["events"] = ["事件细节已遗忘"] - + # 兼容旧版,维护key_points summary["key_points"] = summary["keypoints"] + summary["events"] - + except Exception as e: logger.error(f"精简记忆出错: {str(e)}") traceback.print_exc() - + # 出错时使用简化的默认精简 summary["brief"] = summary["brief"] + " (已简化)" summary["keypoints"] = summary.get("keypoints", ["未知的概念"])[:1] summary["events"] = summary.get("events", ["未知的事件"])[:1] summary["key_points"] = summary["keypoints"] + summary["events"] - + except Exception as e: logger.error(f"精简记忆调用LLM出错: {str(e)}") traceback.print_exc() - + # 更新原记忆项的总结 memory_item.set_summary(summary) - + return memory_item - + def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool: """ 使单个记忆衰减 - + Args: memory_id: 记忆ID decay_factor: 衰减因子(0-1之间) - + Returns: 是否成功衰减 """ memory_item = self.get_by_id(memory_id) if not memory_item: return False - + # 计算衰减量(当前强度 * (1-衰减因子)) old_strength = memory_item.memory_strength decay_amount = old_strength * (1 - decay_factor) - + # 更新强度 memory_item.memory_strength = decay_amount - + return True - - + def delete(self, memory_id: str) -> bool: """ 删除指定ID的记忆项 - + Args: memory_id: 要删除的记忆项ID - + Returns: 是否成功删除 """ if memory_id not in self._id_map: return False - + # 获取要删除的项 item = self._id_map[memory_id] - + # 从内存中删除 data_type = item.data_type if data_type in self._memory: self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id] - + # 从ID映射中删除 del self._id_map[memory_id] - + return True - + def clear(self, data_type: Optional[Type] = None) -> None: """ 清除记忆中的数据 - + Args: data_type: 要清除的数据类型,如果为None则清除所有数据 """ @@ -542,34 +529,36 @@ class MemoryManager: if item.id in self._id_map: del self._id_map[item.id] del self._memory[data_type] - - async def merge_memories(self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True) -> MemoryItem: + + async def merge_memories( + self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True + ) -> MemoryItem: """ 合并两个记忆项 - + Args: memory_id1: 第一个记忆项ID memory_id2: 第二个记忆项ID reason: 合并原因 delete_originals: 是否删除原始记忆,默认为True - + Returns: 包含合并后的记忆信息的字典 """ # 获取两个记忆项 memory_item1 = self.get_by_id(memory_id1) memory_item2 = self.get_by_id(memory_id2) - + if not memory_item1 or not memory_item2: raise ValueError("无法找到指定的记忆项") - + content1 = memory_item1.data content2 = memory_item2.data - + # 获取记忆的摘要信息(如果有) summary1 = memory_item1.summary summary2 = memory_item2.summary - + # 构建合并提示 prompt = f""" 请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。 @@ -577,32 +566,32 @@ class MemoryManager: 合并原因:{reason} """ - + # 如果有摘要信息,添加到提示中 if summary1: prompt += f"记忆1主题:{summary1['brief']}\n" prompt += f"记忆1概括:{summary1['detailed']}\n" - + if "keypoints" in summary1: - prompt += f"记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1['keypoints']]) + "\n\n" - + prompt += "记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1["keypoints"]]) + "\n\n" + if "events" in summary1: - prompt += f"记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1['events']]) + "\n\n" + prompt += "记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1["events"]]) + "\n\n" elif "key_points" in summary1: - prompt += f"记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1['key_points']]) + "\n\n" - + prompt += "记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1["key_points"]]) + "\n\n" + if summary2: prompt += f"记忆2主题:{summary2['brief']}\n" prompt += f"记忆2概括:{summary2['detailed']}\n" - + if "keypoints" in summary2: - prompt += f"记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2['keypoints']]) + "\n\n" - + prompt += "记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2["keypoints"]]) + "\n\n" + if "events" in summary2: - prompt += f"记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2['events']]) + "\n\n" + prompt += "记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2["events"]]) + "\n\n" elif "key_points" in summary2: - prompt += f"记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2['key_points']]) + "\n\n" - + prompt += "记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2["key_points"]]) + "\n\n" + # 添加记忆原始内容 prompt += f""" 记忆1原始内容: @@ -630,16 +619,16 @@ class MemoryManager: ``` 请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 """ - + # 默认合并结果 default_merged = { "content": f"{content1}\n\n{content2}", "brief": f"合并:{summary1['brief']} + {summary2['brief']}", "detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}", "keypoints": [], - "events": [] + "events": [], } - + # 合并旧版key_points if "key_points" in summary1: default_merged["keypoints"].extend(summary1.get("keypoints", [])) @@ -650,7 +639,7 @@ class MemoryManager: halfway = len(key_points) // 2 default_merged["keypoints"].extend(key_points[:halfway]) default_merged["events"].extend(key_points[halfway:]) - + if "key_points" in summary2: default_merged["keypoints"].extend(summary2.get("keypoints", [])) default_merged["events"].extend(summary2.get("events", [])) @@ -660,25 +649,25 @@ class MemoryManager: halfway = len(key_points) // 2 default_merged["keypoints"].extend(key_points[:halfway]) default_merged["events"].extend(key_points[halfway:]) - + # 确保列表不为空 if not default_merged["keypoints"]: default_merged["keypoints"] = ["合并的关键概念"] if not default_merged["events"]: default_merged["events"] = ["合并的事件"] - + # 添加key_points兼容 default_merged["key_points"] = default_merged["keypoints"] + default_merged["events"] - + try: # 调用LLM合并记忆 response, _ = await self.llm_summarizer.generate_response_async(prompt) - + # 处理LLM返回的合并结果 try: # 修复JSON格式 fixed_json_string = repair_json(response) - + # 将修复后的字符串解析为Python对象 if isinstance(fixed_json_string, str): try: @@ -689,49 +678,43 @@ class MemoryManager: else: # 如果repair_json直接返回了字典对象,直接使用 merged_data = fixed_json_string - + # 确保是字典类型 if not isinstance(merged_data, dict): logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}") merged_data = default_merged - + # 确保所有必要字段都存在且类型正确 if "content" not in merged_data or not isinstance(merged_data["content"], str): merged_data["content"] = default_merged["content"] - + if "brief" not in merged_data or not isinstance(merged_data["brief"], str): merged_data["brief"] = default_merged["brief"] - + if "detailed" not in merged_data or not isinstance(merged_data["detailed"], str): merged_data["detailed"] = default_merged["detailed"] - + # 处理关键概念 if "keypoints" not in merged_data or not isinstance(merged_data["keypoints"], list): merged_data["keypoints"] = default_merged["keypoints"] else: # 确保keypoints中的每个项目都是字符串 - merged_data["keypoints"] = [ - str(point) for point in merged_data["keypoints"] - if point is not None - ] + merged_data["keypoints"] = [str(point) for point in merged_data["keypoints"] if point is not None] if not merged_data["keypoints"]: merged_data["keypoints"] = ["合并的关键概念"] - + # 处理事件 if "events" not in merged_data or not isinstance(merged_data["events"], list): merged_data["events"] = default_merged["events"] else: # 确保events中的每个项目都是字符串 - merged_data["events"] = [ - str(event) for event in merged_data["events"] - if event is not None - ] + merged_data["events"] = [str(event) for event in merged_data["events"] if event is not None] if not merged_data["events"]: merged_data["events"] = ["合并的事件"] - + # 添加key_points兼容 merged_data["key_points"] = merged_data["keypoints"] + merged_data["events"] - + except Exception as e: logger.error(f"合并记忆时处理JSON出错: {str(e)}") traceback.print_exc() @@ -740,59 +723,59 @@ class MemoryManager: logger.error(f"合并记忆调用LLM出错: {str(e)}") traceback.print_exc() merged_data = default_merged - + # 创建新的记忆项 # 合并记忆项的标签 merged_tags = memory_item1.tags.union(memory_item2.tags) - + # 取两个记忆项中更强的来源 - merged_source = memory_item1.from_source if memory_item1.memory_strength >= memory_item2.memory_strength else memory_item2.from_source - - # 创建新的记忆项 - merged_memory = MemoryItem( - data=merged_data["content"], - from_source=merged_source, - tags=list(merged_tags) + merged_source = ( + memory_item1.from_source + if memory_item1.memory_strength >= memory_item2.memory_strength + else memory_item2.from_source ) - + + # 创建新的记忆项 + merged_memory = MemoryItem(data=merged_data["content"], from_source=merged_source, tags=list(merged_tags)) + # 设置合并后的摘要 summary = { "brief": merged_data["brief"], "detailed": merged_data["detailed"], "keypoints": merged_data["keypoints"], "events": merged_data["events"], - "key_points": merged_data["key_points"] + "key_points": merged_data["key_points"], } merged_memory.set_summary(summary) - + # 记忆强度取两者最大值 merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength) - + # 添加到存储中 self.push_item(merged_memory) - + # 如果需要,删除原始记忆 if delete_originals: self.delete(memory_id1) self.delete(memory_id2) - + return merged_memory - + def delete_earliest_memory(self) -> bool: """ 删除最早的记忆项 - + Returns: 是否成功删除 """ # 获取所有记忆项 all_memories = self.get_all_items() - + if not all_memories: return False - + # 按时间戳排序,找到最早的记忆项 earliest_memory = min(all_memories, key=lambda item: item.timestamp) - + # 删除最早的记忆项 - return self.delete(earliest_memory.id) \ No newline at end of file + return self.delete(earliest_memory.id) diff --git a/src/chat/focus_chat/working_memory/test/memory_file_loader.py b/src/chat/focus_chat/working_memory/test/memory_file_loader.py deleted file mode 100644 index 3aa997b82..000000000 --- a/src/chat/focus_chat/working_memory/test/memory_file_loader.py +++ /dev/null @@ -1,169 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import asyncio -from typing import List, Dict, Any, Optional -from pathlib import Path - -from src.chat.focus_chat.working_memory.working_memory import WorkingMemory -from src.chat.focus_chat.working_memory.memory_item import MemoryItem -from src.common.logger_manager import get_logger - -logger = get_logger("memory_loader") - -class MemoryFileLoader: - """从文件加载记忆内容的工具类""" - - def __init__(self, working_memory: WorkingMemory): - """ - 初始化记忆文件加载器 - - Args: - working_memory: 工作记忆实例 - """ - self.working_memory = working_memory - - async def load_from_directory(self, - directory_path: str, - file_pattern: str = "*.txt", - common_tags: List[str] = None, - source_prefix: str = "文件") -> List[MemoryItem]: - """ - 从指定目录加载符合模式的文件作为记忆 - - Args: - directory_path: 目录路径 - file_pattern: 文件模式(默认为*.txt) - common_tags: 所有记忆共有的标签 - source_prefix: 来源前缀 - - Returns: - 加载的记忆项列表 - """ - directory = Path(directory_path) - if not directory.exists() or not directory.is_dir(): - logger.error(f"目录不存在或不是有效目录: {directory_path}") - return [] - - # 获取文件列表 - files = list(directory.glob(file_pattern)) - if not files: - logger.warning(f"在目录 {directory_path} 中没有找到符合 {file_pattern} 的文件") - return [] - - logger.info(f"在目录 {directory_path} 中找到 {len(files)} 个符合条件的文件") - - # 加载文件内容为记忆 - loaded_memories = [] - for file_path in files: - try: - memory_item = await self._load_single_file( - file_path=str(file_path), - common_tags=common_tags, - source_prefix=source_prefix - ) - if memory_item: - loaded_memories.append(memory_item) - logger.info(f"成功加载记忆: {file_path.name}") - - except Exception as e: - logger.error(f"加载文件 {file_path} 失败: {str(e)}") - - logger.info(f"完成加载,共加载了 {len(loaded_memories)} 个记忆") - return loaded_memories - - async def _load_single_file(self, - file_path: str, - common_tags: Optional[List[str]] = None, - source_prefix: str = "文件") -> Optional[MemoryItem]: - """ - 加载单个文件作为记忆 - - Args: - file_path: 文件路径 - common_tags: 记忆共有的标签 - source_prefix: 来源前缀 - - Returns: - 记忆项,加载失败则返回None - """ - try: - # 读取文件内容 - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - - if not content.strip(): - logger.warning(f"文件 {file_path} 内容为空") - return None - - # 准备标签和来源 - file_name = os.path.basename(file_path) - tags = list(common_tags) if common_tags else [] - tags.append(file_name) # 添加文件名作为标签 - - source = f"{source_prefix}_{file_name}" - - # 添加到工作记忆 - memory = await self.working_memory.add_memory( - content=content, - from_source=source, - tags=tags - ) - - return memory - - except Exception as e: - logger.error(f"加载文件 {file_path} 失败: {str(e)}") - return None - - -async def main(): - """示例使用""" - # 初始化工作记忆 - chat_id = "demo_chat" - working_memory = WorkingMemory(chat_id=chat_id) - - try: - # 初始化加载器 - loader = MemoryFileLoader(working_memory) - - # 加载当前目录中的txt文件 - current_dir = Path(__file__).parent - memories = await loader.load_from_directory( - directory_path=str(current_dir), - file_pattern="*.txt", - common_tags=["测试数据", "自动加载"], - source_prefix="测试文件" - ) - - # 显示加载结果 - print(f"共加载了 {len(memories)} 个记忆") - - # 获取并显示所有记忆的概要 - all_memories = working_memory.memory_manager.get_all_items() - for memory in all_memories: - print("\n" + "=" * 40) - print(f"记忆ID: {memory.id}") - print(f"来源: {memory.from_source}") - print(f"标签: {', '.join(memory.tags)}") - - if memory.summary: - print(f"\n主题: {memory.summary.get('brief', '无主题')}") - print(f"概述: {memory.summary.get('detailed', '无概述')}") - print("\n要点:") - for point in memory.summary.get('key_points', []): - print(f"- {point}") - else: - print("\n无摘要信息") - - print("=" * 40) - - finally: - # 关闭工作记忆 - await working_memory.shutdown() - - -if __name__ == "__main__": - # 运行示例 - asyncio.run(main()) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/run_memory_tests.py b/src/chat/focus_chat/working_memory/test/run_memory_tests.py deleted file mode 100644 index d9299cf40..000000000 --- a/src/chat/focus_chat/working_memory/test/run_memory_tests.py +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import asyncio -import os -import sys -from pathlib import Path - -# 添加项目根目录到系统路径 -current_dir = Path(__file__).parent -project_root = current_dir.parent.parent.parent.parent.parent -sys.path.insert(0, str(project_root)) - -from src.chat.focus_chat.working_memory.working_memory import WorkingMemory - -async def test_load_memories_from_files(): - """测试从文件加载记忆的功能""" - print("开始测试从文件加载记忆...") - - # 初始化工作记忆 - chat_id = "test_memory_load" - working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60) - - try: - # 获取测试文件列表 - test_dir = Path(__file__).parent - test_files = [ - os.path.join(test_dir, f) - for f in os.listdir(test_dir) - if f.endswith(".txt") - ] - - print(f"找到 {len(test_files)} 个测试文件") - - # 从每个文件加载记忆 - for file_path in test_files: - file_name = os.path.basename(file_path) - print(f"从文件 {file_name} 加载记忆...") - - # 读取文件内容 - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - - # 添加记忆 - memory = await working_memory.add_memory( - content=content, - from_source=f"文件_{file_name}", - tags=["测试文件", file_name] - ) - - print(f"已添加记忆: ID={memory.id}") - if memory.summary: - print(f"记忆概要: {memory.summary.get('brief', '无概要')}") - print(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点']))}") - print("-" * 50) - - # 获取所有记忆 - all_memories = working_memory.memory_manager.get_all_items() - print(f"\n成功加载 {len(all_memories)} 个记忆") - - # 测试检索记忆 - if all_memories: - print("\n测试检索第一个记忆...") - first_memory = all_memories[0] - retrieved = await working_memory.retrieve_memory(first_memory.id) - - if retrieved: - print(f"成功检索记忆: ID={retrieved.id}") - print(f"检索后强度: {retrieved.memory_strength} (初始为10.0)") - print(f"检索次数: {retrieved.retrieval_count}") - else: - print("检索失败") - - # 测试记忆衰减 - print("\n测试记忆衰减...") - for memory in all_memories: - print(f"记忆 {memory.id} 衰减前强度: {memory.memory_strength}") - - await working_memory.decay_all_memories(decay_factor=0.5) - - all_memories_after = working_memory.memory_manager.get_all_items() - for memory in all_memories_after: - print(f"记忆 {memory.id} 衰减后强度: {memory.memory_strength}") - - finally: - # 关闭工作记忆 - await working_memory.shutdown() - print("\n测试完成,已关闭工作记忆") - -if __name__ == "__main__": - # 运行测试 - asyncio.run(test_load_memories_from_files()) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/simulate_real_usage.py b/src/chat/focus_chat/working_memory/test/simulate_real_usage.py deleted file mode 100644 index 24cf5c70a..000000000 --- a/src/chat/focus_chat/working_memory/test/simulate_real_usage.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import asyncio -import os -import sys -import time -import random -from pathlib import Path -from datetime import datetime - -# 添加项目根目录到系统路径 -current_dir = Path(__file__).parent -project_root = current_dir.parent.parent.parent.parent.parent -sys.path.insert(0, str(project_root)) - -from src.chat.focus_chat.working_memory.working_memory import WorkingMemory -from src.chat.focus_chat.working_memory.memory_item import MemoryItem -from src.common.logger_manager import get_logger - -logger = get_logger("real_usage_simulation") - -class WorkingMemorySimulator: - """模拟工作记忆的真实使用场景""" - - def __init__(self, chat_id="real_usage_test", cycle_interval=20): - """ - 初始化模拟器 - - Args: - chat_id: 聊天ID - cycle_interval: 循环间隔时间(秒) - """ - self.chat_id = chat_id - self.cycle_interval = cycle_interval - self.working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=20, auto_decay_interval=60) - self.cycle_count = 0 - self.running = False - - # 获取测试文件路径 - self.test_files = self._get_test_files() - if not self.test_files: - raise FileNotFoundError("找不到测试文件,请确保test目录中有.txt文件") - - # 存储所有添加的记忆ID - self.memory_ids = [] - - async def start(self, total_cycles=5): - """ - 开始模拟循环 - - Args: - total_cycles: 总循环次数,设为None表示无限循环 - """ - self.running = True - logger.info(f"开始模拟真实使用场景,循环间隔: {self.cycle_interval}秒") - - try: - while self.running and (total_cycles is None or self.cycle_count < total_cycles): - self.cycle_count += 1 - logger.info(f"\n===== 开始第 {self.cycle_count} 次循环 =====") - - # 执行一次循环 - await self._run_one_cycle() - - # 如果还有更多循环,则等待 - if self.running and (total_cycles is None or self.cycle_count < total_cycles): - wait_time = self.cycle_interval - logger.info(f"等待 {wait_time} 秒后开始下一循环...") - await asyncio.sleep(wait_time) - - logger.info(f"模拟完成,共执行了 {self.cycle_count} 次循环") - - except KeyboardInterrupt: - logger.info("接收到中断信号,停止模拟") - except Exception as e: - logger.error(f"模拟过程中出错: {str(e)}", exc_info=True) - finally: - # 关闭工作记忆 - await self.working_memory.shutdown() - - def stop(self): - """停止模拟循环""" - self.running = False - logger.info("正在停止模拟...") - - async def _run_one_cycle(self): - """运行一次完整循环:先检索记忆,再添加新记忆""" - start_time = time.time() - - # 1. 先检索已有记忆(如果有) - await self._retrieve_memories() - - # 2. 添加新记忆 - await self._add_new_memory() - - # 3. 显示工作记忆状态 - await self._show_memory_status() - - # 计算循环耗时 - cycle_duration = time.time() - start_time - logger.info(f"第 {self.cycle_count} 次循环完成,耗时: {cycle_duration:.2f}秒") - - async def _retrieve_memories(self): - """检索现有记忆""" - # 如果有已保存的记忆ID,随机选择1-3个进行检索 - if self.memory_ids: - num_to_retrieve = min(len(self.memory_ids), random.randint(1, 3)) - retrieval_ids = random.sample(self.memory_ids, num_to_retrieve) - - logger.info(f"正在检索 {num_to_retrieve} 条记忆...") - - for memory_id in retrieval_ids: - memory = await self.working_memory.retrieve_memory(memory_id) - if memory: - logger.info(f"成功检索记忆 ID: {memory_id}") - logger.info(f" - 强度: {memory.memory_strength:.2f},检索次数: {memory.retrieval_count}") - if memory.summary: - logger.info(f" - 主题: {memory.summary.get('brief', '无主题')}") - else: - logger.warning(f"记忆 ID: {memory_id} 不存在或已被移除") - # 从ID列表中移除 - if memory_id in self.memory_ids: - self.memory_ids.remove(memory_id) - else: - logger.info("当前没有可检索的记忆") - - async def _add_new_memory(self): - """添加新记忆""" - # 随机选择一个测试文件作为记忆内容 - file_path = random.choice(self.test_files) - file_name = os.path.basename(file_path) - - try: - # 读取文件内容 - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - - # 添加时间戳,模拟不同内容 - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - content_with_timestamp = f"[{timestamp}] {content}" - - # 添加记忆 - logger.info(f"正在添加新记忆,来源: {file_name}") - memory = await self.working_memory.add_memory( - content=content_with_timestamp, - from_source=f"模拟_{file_name}", - tags=["模拟测试", f"循环{self.cycle_count}", file_name] - ) - - # 保存记忆ID - self.memory_ids.append(memory.id) - - # 显示记忆信息 - logger.info(f"已添加新记忆 ID: {memory.id}") - if memory.summary: - logger.info(f"记忆主题: {memory.summary.get('brief', '无主题')}") - logger.info(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点'])[:2])}...") - - except Exception as e: - logger.error(f"添加记忆失败: {str(e)}") - - async def _show_memory_status(self): - """显示当前工作记忆状态""" - all_memories = self.working_memory.memory_manager.get_all_items() - - logger.info(f"\n当前工作记忆状态:") - logger.info(f"记忆总数: {len(all_memories)}") - - # 按强度排序 - sorted_memories = sorted(all_memories, key=lambda x: x.memory_strength, reverse=True) - - logger.info("记忆强度排名 (前5项):") - for i, memory in enumerate(sorted_memories[:5], 1): - logger.info(f"{i}. ID: {memory.id}, 强度: {memory.memory_strength:.2f}, " - f"检索次数: {memory.retrieval_count}, " - f"主题: {memory.summary.get('brief', '无主题') if memory.summary else '无摘要'}") - - def _get_test_files(self): - """获取测试文件列表""" - test_dir = Path(__file__).parent - return [ - os.path.join(test_dir, f) - for f in os.listdir(test_dir) - if f.endswith(".txt") - ] - -async def main(): - """主函数""" - # 创建模拟器 - simulator = WorkingMemorySimulator(cycle_interval=20) # 设置20秒的循环间隔 - - # 设置运行5个循环 - await simulator.start(total_cycles=5) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/test_decay_removal.py b/src/chat/focus_chat/working_memory/test/test_decay_removal.py deleted file mode 100644 index c114bc495..000000000 --- a/src/chat/focus_chat/working_memory/test/test_decay_removal.py +++ /dev/null @@ -1,323 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import asyncio -import os -import sys -import time -from pathlib import Path - -# 添加项目根目录到系统路径 -current_dir = Path(__file__).parent -project_root = current_dir.parent.parent.parent.parent.parent -sys.path.insert(0, str(project_root)) - -from src.chat.focus_chat.working_memory.working_memory import WorkingMemory -from src.chat.focus_chat.working_memory.test.memory_file_loader import MemoryFileLoader -from src.common.logger_manager import get_logger - -logger = get_logger("memory_decay_test") - -async def test_manual_decay_until_removal(): - """测试手动衰减直到记忆被自动移除""" - print("\n===== 测试手动衰减直到记忆被自动移除 =====") - - # 初始化工作记忆,设置较大的衰减间隔,避免自动衰减影响测试 - chat_id = "decay_test_manual" - working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=3600) - - try: - # 创建加载器并加载测试文件 - loader = MemoryFileLoader(working_memory) - test_dir = current_dir - - # 加载第一个测试文件作为记忆 - memories = await loader.load_from_directory( - directory_path=str(test_dir), - file_pattern="test1.txt", # 只加载test1.txt - common_tags=["测试", "衰减", "自动移除"], - source_prefix="衰减测试" - ) - - if not memories: - print("未能加载记忆文件,测试结束") - return - - # 获取加载的记忆 - memory = memories[0] - memory_id = memory.id - print(f"已加载测试记忆,ID: {memory_id}") - print(f"初始强度: {memory.memory_strength}") - if memory.summary: - print(f"记忆主题: {memory.summary.get('brief', '无主题')}") - - # 执行多次衰减,直到记忆被移除 - decay_count = 0 - decay_factor = 0.5 # 每次衰减为原来的一半 - - while True: - # 获取当前记忆 - current_memory = working_memory.memory_manager.get_by_id(memory_id) - - # 如果记忆已被移除,退出循环 - if current_memory is None: - print(f"记忆已在第 {decay_count} 次衰减后被自动移除!") - break - - # 输出当前强度 - print(f"衰减 {decay_count} 次后强度: {current_memory.memory_strength}") - - # 执行衰减 - await working_memory.decay_all_memories(decay_factor=decay_factor) - decay_count += 1 - - # 输出衰减后的详细信息 - after_memory = working_memory.memory_manager.get_by_id(memory_id) - if after_memory: - print(f"第 {decay_count} 次衰减结果: 强度={after_memory.memory_strength},压缩次数={after_memory.compress_count}") - if after_memory.summary: - print(f"记忆概要: {after_memory.summary.get('brief', '无概要')}") - print(f"记忆要点数量: {len(after_memory.summary.get('key_points', []))}") - else: - print(f"第 {decay_count} 次衰减结果: 记忆已被移除") - - # 防止无限循环 - if decay_count > 20: - print("达到最大衰减次数(20),退出测试。") - break - - # 短暂等待 - await asyncio.sleep(0.5) - - # 验证记忆是否真的被移除 - all_memories = working_memory.memory_manager.get_all_items() - print(f"剩余记忆数量: {len(all_memories)}") - if len(all_memories) == 0: - print("测试通过: 记忆在强度低于阈值后被成功移除。") - else: - print("测试失败: 记忆应该被移除但仍然存在。") - - finally: - await working_memory.shutdown() - -async def test_auto_decay(): - """测试自动衰减功能""" - print("\n===== 测试自动衰减功能 =====") - - # 初始化工作记忆,设置短的衰减间隔,便于测试 - chat_id = "decay_test_auto" - decay_interval = 3 # 3秒 - working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=decay_interval) - - try: - # 创建加载器并加载测试文件 - loader = MemoryFileLoader(working_memory) - test_dir = current_dir - - # 加载第二个测试文件作为记忆 - memories = await loader.load_from_directory( - directory_path=str(test_dir), - file_pattern="test1.txt", # 只加载test2.txt - common_tags=["测试", "自动衰减"], - source_prefix="自动衰减测试" - ) - - if not memories: - print("未能加载记忆文件,测试结束") - return - - # 获取加载的记忆 - memory = memories[0] - memory_id = memory.id - print(f"已加载测试记忆,ID: {memory_id}") - print(f"初始强度: {memory.memory_strength}") - if memory.summary: - print(f"记忆主题: {memory.summary.get('brief', '无主题')}") - print(f"记忆概要: {memory.summary.get('detailed', '无概要')}") - print(f"记忆要点: {memory.summary.get('keypoints', '无要点')}") - print(f"记忆事件: {memory.summary.get('events', '无事件')}") - # 观察自动衰减 - print(f"等待自动衰减任务执行 (间隔 {decay_interval} 秒)...") - - for i in range(3): # 观察3次自动衰减 - # 等待自动衰减发生 - await asyncio.sleep(decay_interval + 1) # 多等1秒确保任务执行 - - # 获取当前记忆 - current_memory = working_memory.memory_manager.get_by_id(memory_id) - - # 如果记忆已被移除,退出循环 - if current_memory is None: - print(f"记忆已在第 {i+1} 次自动衰减后被移除!") - break - - # 输出当前强度和详细信息 - print(f"第 {i+1} 次自动衰减后强度: {current_memory.memory_strength}") - print(f"自动衰减详细结果: 压缩次数={current_memory.compress_count}, 提取次数={current_memory.retrieval_count}") - if current_memory.summary: - print(f"记忆概要: {current_memory.summary.get('brief', '无概要')}") - - print(f"\n自动衰减测试结束。") - - # 验证自动衰减是否发生 - final_memory = working_memory.memory_manager.get_by_id(memory_id) - if final_memory is None: - print("记忆已被自动衰减移除。") - elif final_memory.memory_strength < memory.memory_strength: - print(f"自动衰减有效:初始强度 {memory.memory_strength} -> 最终强度 {final_memory.memory_strength}") - print(f"衰减历史记录: {final_memory.history}") - else: - print("测试失败:记忆强度未减少,自动衰减可能未生效。") - - finally: - await working_memory.shutdown() - -async def test_decay_and_retrieval_balance(): - """测试记忆衰减和检索的平衡""" - print("\n===== 测试记忆衰减和检索的平衡 =====") - - # 初始化工作记忆 - chat_id = "decay_retrieval_balance" - working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60) - - try: - # 创建加载器并加载测试文件 - loader = MemoryFileLoader(working_memory) - test_dir = current_dir - - # 加载第三个测试文件作为记忆 - memories = await loader.load_from_directory( - directory_path=str(test_dir), - file_pattern="test3.txt", # 只加载test3.txt - common_tags=["测试", "衰减", "检索"], - source_prefix="平衡测试" - ) - - if not memories: - print("未能加载记忆文件,测试结束") - return - - # 获取加载的记忆 - memory = memories[0] - memory_id = memory.id - print(f"已加载测试记忆,ID: {memory_id}") - print(f"初始强度: {memory.memory_strength}") - if memory.summary: - print(f"记忆主题: {memory.summary.get('brief', '无主题')}") - - # 先衰减几次 - print("\n开始衰减:") - for i in range(3): - await working_memory.decay_all_memories(decay_factor=0.5) - current = working_memory.memory_manager.get_by_id(memory_id) - if current: - print(f"衰减 {i+1} 次后强度: {current.memory_strength}") - print(f"衰减详细信息: 压缩次数={current.compress_count}, 历史操作数={len(current.history)}") - if current.summary: - print(f"记忆概要: {current.summary.get('brief', '无概要')}") - else: - print(f"记忆已在第 {i+1} 次衰减后被移除。") - break - - # 如果记忆还存在,则检索几次增强它 - current = working_memory.memory_manager.get_by_id(memory_id) - if current: - print("\n开始检索增强:") - for i in range(2): - retrieved = await working_memory.retrieve_memory(memory_id) - print(f"检索 {i+1} 次后强度: {retrieved.memory_strength}") - print(f"检索后详细信息: 提取次数={retrieved.retrieval_count}, 历史记录长度={len(retrieved.history)}") - - # 再次衰减几次,测试是否会被移除 - print("\n再次衰减:") - for i in range(5): - await working_memory.decay_all_memories(decay_factor=0.5) - current = working_memory.memory_manager.get_by_id(memory_id) - if current: - print(f"最终衰减 {i+1} 次后强度: {current.memory_strength}") - print(f"衰减详细结果: 压缩次数={current.compress_count}") - else: - print(f"记忆已在最终衰减第 {i+1} 次后被移除。") - break - - print("\n测试结束。") - - finally: - await working_memory.shutdown() - -async def test_multi_memories_decay(): - """测试多条记忆同时衰减""" - print("\n===== 测试多条记忆同时衰减 =====") - - # 初始化工作记忆 - chat_id = "multi_decay_test" - working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60) - - try: - # 创建加载器并加载所有测试文件 - loader = MemoryFileLoader(working_memory) - test_dir = current_dir - - # 加载所有测试文件作为记忆 - memories = await loader.load_from_directory( - directory_path=str(test_dir), - file_pattern="*.txt", - common_tags=["测试", "多记忆衰减"], - source_prefix="多记忆测试" - ) - - if not memories or len(memories) < 2: - print("未能加载足够的记忆文件,测试结束") - return - - # 显示已加载的记忆 - print(f"已加载 {len(memories)} 条记忆:") - for idx, mem in enumerate(memories): - print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 来源: {mem.from_source}") - if mem.summary: - print(f" 主题: {mem.summary.get('brief', '无主题')}") - - # 进行多次衰减测试 - print("\n开始多记忆衰减测试:") - for decay_round in range(5): - # 执行衰减 - await working_memory.decay_all_memories(decay_factor=0.5) - - # 获取并显示所有记忆 - all_memories = working_memory.memory_manager.get_all_items() - print(f"\n第 {decay_round+1} 次衰减后,剩余记忆数量: {len(all_memories)}") - - for idx, mem in enumerate(all_memories): - print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 压缩次数: {mem.compress_count}") - if mem.summary: - print(f" 概要: {mem.summary.get('brief', '无概要')[:30]}...") - - # 如果所有记忆都被移除,退出循环 - if not all_memories: - print("所有记忆已被移除,测试结束。") - break - - # 等待一下 - await asyncio.sleep(0.5) - - print("\n多记忆衰减测试结束。") - - finally: - await working_memory.shutdown() - -async def main(): - """运行所有测试""" - # 测试手动衰减直到移除 - await test_manual_decay_until_removal() - - # 测试自动衰减 - await test_auto_decay() - - # 测试衰减和检索的平衡 - await test_decay_and_retrieval_balance() - - # 测试多条记忆同时衰减 - await test_multi_memories_decay() - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/test/test_working_memory.py b/src/chat/focus_chat/working_memory/test/test_working_memory.py deleted file mode 100644 index b9440db17..000000000 --- a/src/chat/focus_chat/working_memory/test/test_working_memory.py +++ /dev/null @@ -1,121 +0,0 @@ -import asyncio -import os -import unittest -from typing import List, Dict, Any -from pathlib import Path - -from src.chat.focus_chat.working_memory.working_memory import WorkingMemory -from src.chat.focus_chat.working_memory.memory_item import MemoryItem - -class TestWorkingMemory(unittest.TestCase): - """工作记忆测试类""" - - def setUp(self): - """测试前准备""" - self.chat_id = "test_chat_123" - self.working_memory = WorkingMemory(chat_id=self.chat_id, max_memories_per_chat=10, auto_decay_interval=60) - self.test_dir = Path(__file__).parent - - def tearDown(self): - """测试后清理""" - loop = asyncio.get_event_loop() - loop.run_until_complete(self.working_memory.shutdown()) - - def test_init(self): - """测试初始化""" - self.assertEqual(self.working_memory.max_memories_per_chat, 10) - self.assertEqual(self.working_memory.auto_decay_interval, 60) - - def test_add_memory_from_files(self): - """从文件添加记忆""" - loop = asyncio.get_event_loop() - test_files = self._get_test_files() - - # 添加记忆 - memories = [] - for file_path in test_files: - content = self._read_file_content(file_path) - file_name = os.path.basename(file_path) - source = f"test_file_{file_name}" - tags = ["测试", f"文件_{file_name}"] - - memory = loop.run_until_complete( - self.working_memory.add_memory( - content=content, - from_source=source, - tags=tags - ) - ) - memories.append(memory) - - # 验证记忆数量 - all_items = self.working_memory.memory_manager.get_all_items() - self.assertEqual(len(all_items), len(test_files)) - - # 验证每个记忆的内容和标签 - for i, memory in enumerate(memories): - file_name = os.path.basename(test_files[i]) - retrieved_memory = loop.run_until_complete( - self.working_memory.retrieve_memory(memory.id) - ) - - self.assertIsNotNone(retrieved_memory) - self.assertTrue(retrieved_memory.has_tag("测试")) - self.assertTrue(retrieved_memory.has_tag(f"文件_{file_name}")) - self.assertEqual(retrieved_memory.from_source, f"test_file_{file_name}") - - # 验证检索后强度增加 - self.assertGreater(retrieved_memory.memory_strength, 10.0) # 原始强度为10.0,检索后增加1.5倍 - self.assertEqual(retrieved_memory.retrieval_count, 1) - - def test_decay_memories(self): - """测试记忆衰减""" - loop = asyncio.get_event_loop() - test_files = self._get_test_files()[:1] # 只使用一个文件测试衰减 - - # 添加记忆 - for file_path in test_files: - content = self._read_file_content(file_path) - loop.run_until_complete( - self.working_memory.add_memory( - content=content, - from_source="decay_test", - tags=["衰减测试"] - ) - ) - - # 获取添加后的记忆项 - all_items_before = self.working_memory.memory_manager.get_all_items() - self.assertEqual(len(all_items_before), 1) - - # 记录原始强度 - original_strength = all_items_before[0].memory_strength - - # 执行衰减 - loop.run_until_complete( - self.working_memory.decay_all_memories(decay_factor=0.5) - ) - - # 获取衰减后的记忆项 - all_items_after = self.working_memory.memory_manager.get_all_items() - - # 验证强度衰减 - self.assertEqual(len(all_items_after), 1) - self.assertLess(all_items_after[0].memory_strength, original_strength) - - def _get_test_files(self) -> List[str]: - """获取测试文件列表""" - test_dir = self.test_dir - return [ - os.path.join(test_dir, f) - for f in os.listdir(test_dir) - if f.endswith(".txt") - ] - - def _read_file_content(self, file_path: str) -> str: - """读取文件内容""" - with open(file_path, "r", encoding="utf-8") as f: - return f.read() - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/src/chat/focus_chat/working_memory/working_memory.py b/src/chat/focus_chat/working_memory/working_memory.py index 9fd0e8586..db9824150 100644 --- a/src/chat/focus_chat/working_memory/working_memory.py +++ b/src/chat/focus_chat/working_memory/working_memory.py @@ -1,7 +1,6 @@ -from typing import Dict, List, Any, Optional +from typing import List, Any, Optional import asyncio import random -from datetime import datetime from src.common.logger_manager import get_logger from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem @@ -9,39 +8,40 @@ logger = get_logger(__name__) # 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务 + class WorkingMemory: """ 工作记忆,负责协调和运作记忆 从属于特定的流,用chat_id来标识 """ - - def __init__(self,chat_id:str , max_memories_per_chat: int = 10, auto_decay_interval: int = 60): + + def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60): """ 初始化工作记忆管理器 - + Args: max_memories_per_chat: 每个聊天的最大记忆数量 auto_decay_interval: 自动衰减记忆的时间间隔(秒) """ self.memory_manager = MemoryManager(chat_id) - + # 记忆容量上限 self.max_memories_per_chat = max_memories_per_chat - + # 自动衰减间隔 self.auto_decay_interval = auto_decay_interval - + # 衰减任务 self.decay_task = None - + # 启动自动衰减任务 self._start_auto_decay() - + def _start_auto_decay(self): """启动自动衰减任务""" if self.decay_task is None: self.decay_task = asyncio.create_task(self._auto_decay_loop()) - + async def _auto_decay_loop(self): """自动衰减循环""" while True: @@ -50,43 +50,39 @@ class WorkingMemory: await self.decay_all_memories() except Exception as e: print(f"自动衰减记忆时出错: {str(e)}") - - - async def add_memory(self, - content: Any, - from_source: str = "", - tags: Optional[List[str]] = None): + + async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None): """ 添加一段记忆到指定聊天 - + Args: content: 记忆内容 from_source: 数据来源 tags: 数据标签列表 - + Returns: 包含记忆信息的字典 """ memory = await self.memory_manager.push_with_summary(content, from_source, tags) if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat: self.remove_earliest_memory() - + return memory - + def remove_earliest_memory(self): """ 删除最早的记忆 """ return self.memory_manager.delete_earliest_memory() - + async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]: """ 检索记忆 - + Args: chat_id: 聊天ID memory_id: 记忆ID - + Returns: 检索到的记忆项,如果不存在则返回None """ @@ -97,19 +93,18 @@ class WorkingMemory: return memory_item return None - async def decay_all_memories(self, decay_factor: float = 0.5): """ 对所有聊天的所有记忆进行衰减 衰减:对记忆进行refine压缩,强度会变为原先的0.5 - + Args: decay_factor: 衰减因子(0-1之间) """ logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}") - + all_memories = self.memory_manager.get_all_items() - + for memory_item in all_memories: # 如果压缩完小于1会被删除 memory_id = memory_item.id @@ -119,45 +114,47 @@ class WorkingMemory: continue # 计算衰减量 if memory_item.memory_strength < 5: - await self.memory_manager.refine_memory(memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩") - + await self.memory_manager.refine_memory( + memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩" + ) + async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem: """合并记忆 - + Args: memory_str: 记忆内容 """ - return await self.memory_manager.merge_memories(memory_id1 = memory_id1, memory_id2 = memory_id2,reason = "两端记忆有重复的内容") - - - + return await self.memory_manager.merge_memories( + memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容" + ) + # 暂时没用,先留着 async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2): """ 模拟记忆模糊过程,随机选择一部分记忆进行精简 - + Args: chat_id: 聊天ID blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简 """ memory = self.get_memory(chat_id) - + # 获取所有字符串类型且有总结的记忆 all_summarized_memories = [] for type_items in memory._memory.values(): for item in type_items: - if isinstance(item.data, str) and hasattr(item, 'summary') and item.summary: + if isinstance(item.data, str) and hasattr(item, "summary") and item.summary: all_summarized_memories.append(item) - + if not all_summarized_memories: return - + # 计算要模糊的记忆数量 blur_count = max(1, int(len(all_summarized_memories) * blur_rate)) - + # 随机选择要模糊的记忆 memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories))) - + # 对选中的记忆进行精简 for memory_item in memories_to_blur: try: @@ -168,16 +165,14 @@ class WorkingMemory: requirement = "保留核心要点,适度精简细节" else: requirement = "只保留最关键的1-2个要点,大幅精简内容" - + # 进行精简 await memory.refine_memory(memory_item.id, requirement) print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}") - + except Exception as e: print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}") - - async def shutdown(self) -> None: """关闭管理器,停止所有任务""" if self.decay_task and not self.decay_task.done(): @@ -185,13 +180,13 @@ class WorkingMemory: try: await self.decay_task except asyncio.CancelledError: - pass - + pass + def get_all_memories(self) -> List[MemoryItem]: """ 获取所有记忆项目 - + Returns: List[MemoryItem]: 当前工作记忆中的所有记忆项目列表 """ - return self.memory_manager.get_all_items() \ No newline at end of file + return self.memory_manager.get_all_items() diff --git a/src/chat/heart_flow/observation/hfcloop_observation.py b/src/chat/heart_flow/observation/hfcloop_observation.py index d950e3512..82c9c879a 100644 --- a/src/chat/heart_flow/observation/hfcloop_observation.py +++ b/src/chat/heart_flow/observation/hfcloop_observation.py @@ -17,14 +17,14 @@ class HFCloopObservation: self.observe_id = observe_id self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 self.history_loop: List[CycleDetail] = [] - self.action_manager = ActionManager() + self.action_manager = ActionManager() def get_observe_info(self): return self.observe_info def add_loop_info(self, loop_info: CycleDetail): self.history_loop.append(loop_info) - + def set_action_manager(self, action_manager: ActionManager): self.action_manager = action_manager @@ -75,16 +75,15 @@ class HFCloopObservation: if start_time is not None and end_time is not None: time_diff = int(end_time - start_time) if time_diff > 60: - cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff/60}分钟\n" + cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff / 60}分钟\n" else: cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}秒\n" else: cycle_info_block += "\n你还没看过消息\n" - + using_actions = self.action_manager.get_using_actions() for action_name, action_info in using_actions.items(): action_description = action_info["description"] cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n" - self.observe_info = cycle_info_block diff --git a/src/chat/heart_flow/observation/observation.py b/src/chat/heart_flow/observation/observation.py index 8ab9ab9a4..97e254fc0 100644 --- a/src/chat/heart_flow/observation/observation.py +++ b/src/chat/heart_flow/observation/observation.py @@ -5,6 +5,7 @@ from src.common.logger_manager import get_logger logger = get_logger("observation") + # 所有观察的基类 class Observation: def __init__(self, observe_id): diff --git a/src/chat/heart_flow/observation/structure_observation.py b/src/chat/heart_flow/observation/structure_observation.py index 5c5c0a362..2732ef0b1 100644 --- a/src/chat/heart_flow/observation/structure_observation.py +++ b/src/chat/heart_flow/observation/structure_observation.py @@ -29,4 +29,4 @@ class StructureObservation: observed_structured_infos.append(structured_info) logger.debug(f"观察到结构化信息仍旧在: {structured_info}") - self.structured_info = observed_structured_infos \ No newline at end of file + self.structured_info = observed_structured_infos diff --git a/src/chat/heart_flow/observation/working_observation.py b/src/chat/heart_flow/observation/working_observation.py index 2e32f84d5..7013c3a2b 100644 --- a/src/chat/heart_flow/observation/working_observation.py +++ b/src/chat/heart_flow/observation/working_observation.py @@ -16,9 +16,9 @@ class WorkingMemoryObservation: self.observe_info = "" self.observe_id = observe_id self.last_observe_time = datetime.now().timestamp() - + self.working_memory = working_memory - + self.retrieved_working_memory = [] def get_observe_info(self): @@ -26,7 +26,7 @@ class WorkingMemoryObservation: def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]): self.retrieved_working_memory.append(retrieved_working_memory) - + def get_retrieved_working_memory(self): return self.retrieved_working_memory diff --git a/src/chat/person_info/person_info.py b/src/chat/person_info/person_info.py index 2460ab4ff..c8394a195 100644 --- a/src/chat/person_info/person_info.py +++ b/src/chat/person_info/person_info.py @@ -94,7 +94,7 @@ class PersonInfoManager: return True else: return False - + def get_person_id_by_person_name(self, person_name: str): """根据用户名获取用户ID""" document = db.person_info.find_one({"person_name": person_name}) @@ -102,7 +102,6 @@ class PersonInfoManager: return document["person_id"] else: return "" - @staticmethod async def create_person_info(person_id: str, data: dict = None): diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 15b1e4fc6..e5ccd82a7 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -451,7 +451,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # 处理 回复 reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" - def reply_replacer(match): + def reply_replacer(match, platform=platform): # aaa = match.group(1) bbb = match.group(2) anon_reply = get_anon_name(platform, bbb) @@ -462,7 +462,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # 处理 @ at_pattern = r"@<([^:<>]+):([^:<>]+)>" - def at_replacer(match): + def at_replacer(match, platform=platform): # aaa = match.group(1) bbb = match.group(2) anon_at = get_anon_name(platform, bbb) diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py index b557a4258..0b0692d42 100644 --- a/src/plugins/__init__.py +++ b/src/plugins/__init__.py @@ -1 +1 @@ -"""插件系统包""" \ No newline at end of file +"""插件系统包""" diff --git a/src/plugins/test_plugin/__init__.py b/src/plugins/test_plugin/__init__.py index 867ef417c..b5fefb97e 100644 --- a/src/plugins/test_plugin/__init__.py +++ b/src/plugins/test_plugin/__init__.py @@ -1,4 +1,5 @@ """测试插件包""" + """ 这是一个测试插件 -""" \ No newline at end of file +""" diff --git a/src/plugins/test_plugin/actions/__init__.py b/src/plugins/test_plugin/actions/__init__.py index 8599d2326..7d96ea8a4 100644 --- a/src/plugins/test_plugin/actions/__init__.py +++ b/src/plugins/test_plugin/actions/__init__.py @@ -1,6 +1,7 @@ """测试插件动作模块""" # 导入所有动作模块以确保装饰器被执行 -from . import test_action # noqa -# from . import online_action # noqa -from . import mute_action # noqa \ No newline at end of file +from . import test_action # noqa + +# from . import online_action # noqa +from . import mute_action # noqa diff --git a/src/plugins/test_plugin/actions/mute_action.py b/src/plugins/test_plugin/actions/mute_action.py index 723571806..c96204172 100644 --- a/src/plugins/test_plugin/actions/mute_action.py +++ b/src/plugins/test_plugin/actions/mute_action.py @@ -4,12 +4,15 @@ from typing import Tuple logger = get_logger("mute_action") + @register_action class MuteAction(PluginAction): """测试动作处理类""" action_name = "mute_action" - action_description = "如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人" + action_description = ( + "如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人" + ) action_parameters = { "target": "禁言对象,输入你要禁言的对象的名字,必填,", "duration": "禁言时长,输入你要禁言的时长,单位为秒,必填", @@ -27,22 +30,22 @@ class MuteAction(PluginAction): async def process(self) -> Tuple[bool, str]: """处理测试动作""" logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}") - + # 发送测试消息 target = self.action_data.get("target") duration = self.action_data.get("duration") reason = self.action_data.get("reason") platform, user_id = await self.get_user_id_by_person_name(target) - + await self.send_message_by_expressor(f"我要禁言{target},{platform},时长{duration}秒,理由{reason},表达情绪") - + try: await self.send_message(f"[command]mute,{user_id},{duration}") - + except Exception as e: logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}") await self.send_message_by_expressor(f"执行mute动作时出错: {e}") - + return False, "执行mute动作时出错" - - return True, "测试动作执行成功" \ No newline at end of file + + return True, "测试动作执行成功" diff --git a/src/plugins/test_plugin/actions/online_action.py b/src/plugins/test_plugin/actions/online_action.py index 67e2d2cc9..4f49045f2 100644 --- a/src/plugins/test_plugin/actions/online_action.py +++ b/src/plugins/test_plugin/actions/online_action.py @@ -4,15 +4,14 @@ from typing import Tuple logger = get_logger("check_online_action") + @register_action class CheckOnlineAction(PluginAction): """测试动作处理类""" action_name = "check_online_action" action_description = "这是一个检查在线状态的动作,当有人要求你检查Maibot(麦麦 机器人)在线状态时使用" - action_parameters = { - "mode": "查看模式" - } + action_parameters = {"mode": "查看模式"} action_require = [ "当有人要求你检查Maibot(麦麦 机器人)在线状态时使用", "mode参数为version时查看在线版本状态,默认用这种", @@ -23,22 +22,22 @@ class CheckOnlineAction(PluginAction): async def process(self) -> Tuple[bool, str]: """处理测试动作""" logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}") - + # 发送测试消息 mode = self.action_data.get("mode", "type") - + await self.send_message_by_expressor("我看看") - + try: if mode == "type": - await self.send_message(f"#online detail") + await self.send_message("#online detail") elif mode == "version": - await self.send_message(f"#online") - + await self.send_message("#online") + except Exception as e: logger.error(f"{self.log_prefix} 执行online动作时出错: {e}") await self.send_message_by_expressor("执行online动作时出错: {e}") - + return False, "执行online动作时出错" - - return True, "测试动作执行成功" \ No newline at end of file + + return True, "测试动作执行成功" diff --git a/src/plugins/test_plugin/actions/test_action.py b/src/plugins/test_plugin/actions/test_action.py index 3634dbe78..995dd918a 100644 --- a/src/plugins/test_plugin/actions/test_action.py +++ b/src/plugins/test_plugin/actions/test_action.py @@ -4,15 +4,14 @@ from typing import Tuple logger = get_logger("test_action") + @register_action class TestAction(PluginAction): """测试动作处理类""" action_name = "test_action" action_description = "这是一个测试动作,当有人要求你测试插件系统时使用" - action_parameters = { - "test_param": "测试参数(可选)" - } + action_parameters = {"test_param": "测试参数(可选)"} action_require = [ "测试情况下使用", "想测试插件动作加载时使用", @@ -22,17 +21,17 @@ class TestAction(PluginAction): async def process(self) -> Tuple[bool, str]: """处理测试动作""" logger.info(f"{self.log_prefix} 执行测试动作: {self.reasoning}") - + # 获取聊天类型 chat_type = self.get_chat_type() logger.info(f"{self.log_prefix} 当前聊天类型: {chat_type}") - + # 获取最近消息 recent_messages = self.get_recent_messages(3) logger.info(f"{self.log_prefix} 最近3条消息: {recent_messages}") - + # 发送测试消息 test_param = self.action_data.get("test_param", "默认参数") await self.send_message_by_expressor(f"测试动作执行成功,参数: {test_param}") - - return True, "测试动作执行成功" \ No newline at end of file + + return True, "测试动作执行成功" From e4f7c1fe6274d5f5ecb55d82ad453bcd017932db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 16 May 2025 17:34:43 +0800 Subject: [PATCH 29/57] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E5=88=9D=E5=A7=8B=E5=8C=96=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E7=A1=AE=E4=BF=9D=E6=95=B0=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=9C=89=E6=95=88=E5=B9=B6=E6=A3=80=E6=9F=A5=E8=A1=A8=E5=8F=8A?= =?UTF-8?q?=E5=AD=97=E6=AE=B5=E7=9A=84=E5=AD=98=E5=9C=A8=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 25 +++++++++++++++---------- src/common/database/database_model.py | 23 ++++++++++++++++++++--- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 7b5574691..cb70dba05 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -383,17 +383,22 @@ class EmojiManager: def initialize(self): """初始化数据库连接和表情目录""" - if not self._initialized: - try: - # Ensure Peewee database connection is up and tables are created - if not peewee_db.is_closed(): - peewee_db.connect(reuse_if_open=True) - Emoji.create_table(safe=True) # Ensures table exists + peewee_db.connect(reuse_if_open=True) + if peewee_db.is_closed(): + raise RuntimeError("数据库连接失败") + _ensure_emoji_dir() + Emoji.create_table(safe=True) # Ensures table exists + # if not self._initialized: + # try: + # # Ensure Peewee database connection is up and tables are created + + + - _ensure_emoji_dir() - self._initialized = True - except Exception as e: - logger.exception(f"初始化表情管理器失败: {e}") + + # self._initialized = True + # except Exception as e: + # logger.exception(f"初始化表情管理器失败: {e}") def _ensure_db(self): """确保数据库已初始化""" diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index d885312b0..68b73a7b4 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -300,7 +300,10 @@ def create_tables(): def initialize_database(): """ 检查所有定义的表是否存在,如果不存在则创建它们。 + 检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。 """ + import sys + models = [ ChatStreams, LLMUsage, @@ -319,12 +322,26 @@ def initialize_database(): try: with db: # 管理 table_exists 检查的连接 for model in models: + table_name = model._meta.table_name if not db.table_exists(model): - logger.warning(f"表 '{model._meta.table_name}' 未找到。") + logger.warning(f"表 '{table_name}' 未找到。") needs_creation = True break # 一个表丢失,无需进一步检查。 + if not needs_creation: + # 检查字段 + for model in models: + table_name = model._meta.table_name + cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") + existing_columns = {row[1] for row in cursor.fetchall()} + model_fields = model._meta.fields + for field_name in model_fields: + if field_name not in existing_columns: + logger.error( + f"表 '{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。" + ) + sys.exit(1) except Exception as e: - logger.exception(f"检查表是否存在时出错: {e}") + logger.exception(f"检查表或字段是否存在时出错: {e}") # 如果检查失败(例如数据库不可用),则退出 return @@ -336,7 +353,7 @@ def initialize_database(): except Exception as e: logger.exception(f"创建表期间出错: {e}") else: - logger.info("所有数据库表均已存在。") + logger.info("所有数据库表及字段均已存在。") # 模块加载时调用初始化函数 From 335c62c50f42b46df55205de984b5f1114d70bc5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 16 May 2025 09:35:00 +0000 Subject: [PATCH 30/57] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 4 ---- src/common/database/database_model.py | 4 +--- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index cb70dba05..ea5a0c2f4 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -391,11 +391,7 @@ class EmojiManager: # if not self._initialized: # try: # # Ensure Peewee database connection is up and tables are created - - - - # self._initialized = True # except Exception as e: # logger.exception(f"初始化表情管理器失败: {e}") diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 68b73a7b4..bd7a2d319 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -336,9 +336,7 @@ def initialize_database(): model_fields = model._meta.fields for field_name in model_fields: if field_name not in existing_columns: - logger.error( - f"表 '{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。" - ) + logger.error(f"表 '{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。") sys.exit(1) except Exception as e: logger.exception(f"检查表或字段是否存在时出错: {e}") From 13ae323a1cccb048054928b152fc70fc50fb4e66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Fri, 16 May 2025 17:45:50 +0800 Subject: [PATCH 31/57] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=A1=A8=E6=83=85?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=99=A8=E4=B8=AD=E7=9A=84=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E6=B3=A8=E8=A7=A3=EF=BC=8C=E7=A1=AE=E4=BF=9D=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E7=B1=BB=E5=9E=8B=E6=98=8E=E7=A1=AE=EF=BC=8C?= =?UTF-8?q?=E5=B9=B6=E8=B0=83=E6=95=B4LLMUsage=E8=A1=A8=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=E6=97=A5=E5=BF=97=E7=BA=A7=E5=88=AB=E4=B8=BA=E8=B0=83?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 106 +++++++------------------ src/chat/models/utils_model.py | 2 +- 2 files changed, 31 insertions(+), 77 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index ea5a0c2f4..f8e36da15 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -5,7 +5,7 @@ import os import random import time import traceback -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Any from PIL import Image import io import re @@ -54,7 +54,7 @@ class MaiEmoji: self.is_deleted = False # 标记是否已被删除 self.format = "" - async def initialize_hash_format(self): + async def initialize_hash_format(self) -> Optional[bool]: """从文件创建表情包实例, 计算哈希值和格式""" try: # 使用 full_path 检查文件是否存在 @@ -107,7 +107,7 @@ class MaiEmoji: self.is_deleted = True return None - async def register_to_db(self): + async def register_to_db(self) -> bool: """ 注册表情包 将表情包对应的文件,从当前路径移动到EMOJI_REGISTED_DIR目录下 @@ -176,7 +176,7 @@ class MaiEmoji: logger.error(traceback.format_exc()) return False - async def delete(self): + async def delete(self) -> bool: """删除表情包 删除表情包的文件和数据库记录 @@ -223,7 +223,7 @@ class MaiEmoji: return False -def _emoji_objects_to_readable_list(emoji_objects): +def _emoji_objects_to_readable_list(emoji_objects: List['MaiEmoji']) -> List[str]: """将表情包对象列表转换为可读的字符串列表 参数: @@ -242,7 +242,7 @@ def _emoji_objects_to_readable_list(emoji_objects): return emoji_info_list -def _to_emoji_objects(data): +def _to_emoji_objects(data: Any) -> Tuple[List['MaiEmoji'], int]: emoji_objects = [] load_errors = 0 # data is now an iterable of Peewee Emoji model instances @@ -292,13 +292,13 @@ def _to_emoji_objects(data): return emoji_objects, load_errors -def _ensure_emoji_dir(): +def _ensure_emoji_dir() -> None: """确保表情存储目录存在""" os.makedirs(EMOJI_DIR, exist_ok=True) os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True) -async def clear_temp_emoji(): +async def clear_temp_emoji() -> None: """清理临时表情包 清理/data/emoji和/data/image目录下的所有文件 当目录中文件数超过100时,会全部删除 @@ -320,7 +320,7 @@ async def clear_temp_emoji(): logger.success("[清理] 完成") -async def clean_unused_emojis(emoji_dir, emoji_objects): +async def clean_unused_emojis(emoji_dir: str, emoji_objects: List['MaiEmoji']) -> None: """清理指定目录中未被 emoji_objects 追踪的表情包文件""" if not os.path.exists(emoji_dir): logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") @@ -360,13 +360,13 @@ async def clean_unused_emojis(emoji_dir, emoji_objects): class EmojiManager: _instance = None - def __new__(cls): + def __new__(cls) -> 'EmojiManager': if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - def __init__(self): + def __init__(self) -> None: self._initialized = None self._scan_task = None self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") @@ -377,33 +377,26 @@ class EmojiManager: self.emoji_num = 0 self.emoji_num_max = global_config.max_emoji_num self.emoji_num_max_reach_deletion = global_config.max_reach_deletion - self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型 + self.emoji_objects: List[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型 logger.info("启动表情包管理器") - def initialize(self): + def initialize(self) -> None: """初始化数据库连接和表情目录""" peewee_db.connect(reuse_if_open=True) if peewee_db.is_closed(): raise RuntimeError("数据库连接失败") _ensure_emoji_dir() Emoji.create_table(safe=True) # Ensures table exists - # if not self._initialized: - # try: - # # Ensure Peewee database connection is up and tables are created - # self._initialized = True - # except Exception as e: - # logger.exception(f"初始化表情管理器失败: {e}") - - def _ensure_db(self): + def _ensure_db(self) -> None: """确保数据库已初始化""" if not self._initialized: self.initialize() if not self._initialized: raise RuntimeError("EmojiManager not initialized") - def record_usage(self, emoji_hash: str): + def record_usage(self, emoji_hash: str) -> None: """记录表情使用次数""" try: emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash) @@ -431,7 +424,6 @@ class EmojiManager: if not all_emojis: logger.warning("内存中没有任何表情包对象") - # 可以考虑再查一次数据库?或者依赖定期任务更新 return None # 计算每个表情包与输入文本的最大情感相似度 @@ -447,18 +439,18 @@ class EmojiManager: # 计算与每个emotion标签的相似度,取最大值 max_similarity = 0 - best_matching_emotion = "" # 记录最匹配的 emotion 喵~ + best_matching_emotion = "" for emotion in emotions: # 使用编辑距离计算相似度 distance = self._levenshtein_distance(text_emotion, emotion) max_len = max(len(text_emotion), len(emotion)) similarity = 1 - (distance / max_len if max_len > 0 else 0) - if similarity > max_similarity: # 如果找到更相似的喵~ + if similarity > max_similarity: max_similarity = similarity - best_matching_emotion = emotion # 就记下这个 emotion 喵~ + best_matching_emotion = emotion - if best_matching_emotion: # 确保有匹配的情感才添加喵~ - emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 把 emotion 也存起来喵~ + if best_matching_emotion: + emoji_similarities.append((emoji, max_similarity, best_matching_emotion)) # 按相似度降序排序 emoji_similarities.sort(key=lambda x: x[1], reverse=True) @@ -466,21 +458,21 @@ class EmojiManager: # 获取前10个最相似的表情包 top_emojis = ( emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities - ) # 改个名字,更清晰喵~ + ) if not top_emojis: logger.warning("未找到匹配的表情包") return None # 从前几个中随机选择一个 - selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 把匹配的 emotion 也拿出来喵~ + selected_emoji, similarity, matched_emotion = random.choice(top_emojis) # 更新使用次数 self.record_usage(selected_emoji.emoji_hash) _time_end = time.time() - logger.info( # 使用匹配到的 emotion 记录日志喵~ + logger.info( f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}" ) # 返回完整文件路径和描述 @@ -518,7 +510,7 @@ class EmojiManager: return previous_row[-1] - async def check_emoji_file_integrity(self): + async def check_emoji_file_integrity(self) -> None: """检查表情包文件完整性 遍历self.emoji_objects中的所有对象,检查文件是否存在 如果文件已被删除,则执行对象的删除方法并从列表中移除 @@ -583,7 +575,7 @@ class EmojiManager: logger.error(f"[错误] 检查表情包完整性失败: {str(e)}") logger.error(traceback.format_exc()) - async def start_periodic_check_register(self): + async def start_periodic_check_register(self) -> None: """定期检查表情包完整性和数量""" await self.get_all_emoji_from_db() while True: @@ -637,7 +629,7 @@ class EmojiManager: await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) - async def get_all_emoji_from_db(self): + async def get_all_emoji_from_db(self) -> None: """获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects""" try: self._ensure_db() @@ -659,7 +651,7 @@ class EmojiManager: self.emoji_objects = [] # 加载失败则清空列表 self.emoji_num = 0 - async def get_emoji_from_db(self, emoji_hash=None): + async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List['MaiEmoji']: """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) 参数: @@ -691,7 +683,7 @@ class EmojiManager: logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") return [] - async def get_emoji_from_manager(self, emoji_hash) -> Optional[MaiEmoji]: + async def get_emoji_from_manager(self, emoji_hash: str) -> Optional['MaiEmoji']: """从内存中的 emoji_objects 列表获取表情包 参数: @@ -744,7 +736,7 @@ class EmojiManager: logger.error(traceback.format_exc()) return False - async def replace_a_emoji(self, new_emoji: MaiEmoji): + async def replace_a_emoji(self, new_emoji: 'MaiEmoji') -> bool: """替换一个表情包 Args: @@ -833,7 +825,7 @@ class EmojiManager: logger.error(traceback.format_exc()) return False - async def build_emoji_description(self, image_base64: str) -> Tuple[str, list]: + async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]: """获取表情包描述和情感列表 Args: @@ -894,44 +886,6 @@ class EmojiManager: logger.error(f"获取表情包描述失败: {str(e)}") return "", [] - # async def register_emoji_by_filename(self, filename: str) -> bool: - # if global_config.EMOJI_CHECK: - # prompt = f''' - # 这是一个表情包,请对这个表情包进行审核,标准如下: - # 1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求 - # 2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗 - # 3. 不能是任何形式的截图,聊天记录或视频截图 - # 4. 不要出现5个以上文字 - # 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容 - # ''' - # content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) - # if content == "否": - # return "", [] - - # # 分析情感含义 - # emotion_prompt = f""" - # 请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字 - # 这是一个基于这个表情包的描述:'{description}' - # 你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析 - # 请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔 - # """ - # emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7) - - # # 处理情感列表 - # emotions = [e.strip() for e in emotions_text.split(",") if e.strip()] - - # # 根据情感标签数量随机选择喵~超过5个选3个,超过2个选2个 - # if len(emotions) > 5: - # emotions = random.sample(emotions, 3) - # elif len(emotions) > 2: - # emotions = random.sample(emotions, 2) - - # return f"[表情包:{description}]", emotions - - # except Exception as e: - # logger.error(f"获取表情包描述失败: {str(e)}") - # return "", [] - async def register_emoji_by_filename(self, filename: str) -> bool: """读取指定文件名的表情包图片,分析并注册到数据库 diff --git a/src/chat/models/utils_model.py b/src/chat/models/utils_model.py index 986036e86..eae0ae01a 100644 --- a/src/chat/models/utils_model.py +++ b/src/chat/models/utils_model.py @@ -135,7 +135,7 @@ class LLMRequest: try: # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 db.create_tables([LLMUsage], safe=True) - logger.info("LLMUsage 表已初始化/确保存在。") + logger.debug("LLMUsage 表已初始化/确保存在。") except Exception as e: logger.error(f"创建 LLMUsage 表失败: {str(e)}") From f5132db6f1b35c95f7f66b45ad6e8b580e32fd20 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 16 May 2025 09:46:02 +0000 Subject: [PATCH 32/57] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/emoji_system/emoji_manager.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index f8e36da15..2cdead064 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -223,7 +223,7 @@ class MaiEmoji: return False -def _emoji_objects_to_readable_list(emoji_objects: List['MaiEmoji']) -> List[str]: +def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]: """将表情包对象列表转换为可读的字符串列表 参数: @@ -242,7 +242,7 @@ def _emoji_objects_to_readable_list(emoji_objects: List['MaiEmoji']) -> List[str return emoji_info_list -def _to_emoji_objects(data: Any) -> Tuple[List['MaiEmoji'], int]: +def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]: emoji_objects = [] load_errors = 0 # data is now an iterable of Peewee Emoji model instances @@ -320,7 +320,7 @@ async def clear_temp_emoji() -> None: logger.success("[清理] 完成") -async def clean_unused_emojis(emoji_dir: str, emoji_objects: List['MaiEmoji']) -> None: +async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"]) -> None: """清理指定目录中未被 emoji_objects 追踪的表情包文件""" if not os.path.exists(emoji_dir): logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}") @@ -360,7 +360,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List['MaiEmoji']) - class EmojiManager: _instance = None - def __new__(cls) -> 'EmojiManager': + def __new__(cls) -> "EmojiManager": if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False @@ -456,9 +456,7 @@ class EmojiManager: emoji_similarities.sort(key=lambda x: x[1], reverse=True) # 获取前10个最相似的表情包 - top_emojis = ( - emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities - ) + top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities if not top_emojis: logger.warning("未找到匹配的表情包") @@ -651,7 +649,7 @@ class EmojiManager: self.emoji_objects = [] # 加载失败则清空列表 self.emoji_num = 0 - async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List['MaiEmoji']: + async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]: """获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找) 参数: @@ -683,7 +681,7 @@ class EmojiManager: logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}") return [] - async def get_emoji_from_manager(self, emoji_hash: str) -> Optional['MaiEmoji']: + async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: """从内存中的 emoji_objects 列表获取表情包 参数: @@ -736,7 +734,7 @@ class EmojiManager: logger.error(traceback.format_exc()) return False - async def replace_a_emoji(self, new_emoji: 'MaiEmoji') -> bool: + async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool: """替换一个表情包 Args: From e067384985dc64dda2d5fcabf4ed9fbd4f510892 Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 17:52:33 +0800 Subject: [PATCH 33/57] =?UTF-8?q?=E5=B0=86=E6=96=87=E4=BB=B6=E5=85=A8?= =?UTF-8?q?=E9=83=A8=E5=BD=92=E8=BF=9Bdocs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 146 ++++-------------- changelogs/changelog_dev.md | 27 ---- {src => docs}/0.6Bing.md | 0 {depends-data => docs}/CONTRIBUTE.md | 0 .../HeartFC_chatting_logic.md | 0 .../HeartFC_readme.md | 0 src/README.md => docs/HeartFC_system.md | 0 .../README.md => docs/use_tool.md | 0 8 files changed, 34 insertions(+), 139 deletions(-) delete mode 100644 changelogs/changelog_dev.md rename {src => docs}/0.6Bing.md (100%) rename {depends-data => docs}/CONTRIBUTE.md (100%) rename src/heartFC_chatting_logic.md => docs/HeartFC_chatting_logic.md (100%) rename src/heartFC_readme.md => docs/HeartFC_readme.md (100%) rename src/README.md => docs/HeartFC_system.md (100%) rename src/tools/tool_can_use/README.md => docs/use_tool.md (100%) diff --git a/README.md b/README.md index 17a8da37b..98a8076ac 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ -# 麦麦!MaiCore-MaiMBot (编辑中) -
-

+# 麦麦!MaiCore-MaiBot (编辑中) + + + MaiBot + ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) ![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议) @@ -10,40 +12,14 @@ ![stars](https://img.shields.io/github/stars/MaiM-with-u/MaiBot?style=flat&label=星标数) ![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/DrSmoothl/MaiBot) + + 🌟 案例展示 | 🚀 快速入门 | 📃 教程 | 💬 讨论 | 🙋 常见问题
-

- - Logo - -
- - 画师:略nd - - -

MaiBot(麦麦)

-

- 一款专注于 群组聊天 的赛博网友 -
- 探索本项目的文档 » -
-
- - 报告Bug - · - 提出新特性 -

-

- -## 新版0.6.x部署前先阅读:https://docs.mai-mai.org/faq/maibot/backup_update.html - - ## 📝 项目简介 **🍔MaiCore是一个基于大语言模型的可交互智能体** - - 💭 **智能对话系统**:基于LLM的自然语言交互 - 🤔 **实时思维系统**:模拟人类思考过程 - 💝 **情感表达系统**:丰富的表情包和情绪表达 @@ -58,49 +34,29 @@
- ### 📢 版本信息 -**最新版本: v0.6.3** ([查看更新日志](changelogs/changelog.md)) -> [!WARNING] -> 请阅读教程后更新!!!!!!! -> 请阅读教程后更新!!!!!!! -> 请阅读教程后更新!!!!!!! -> 次版本MaiBot将基于MaiCore运行,不再依赖于nonebot相关组件运行。 -> MaiBot将通过nonebot的插件与nonebot建立联系,然后nonebot与QQ建立联系,实现MaiBot与QQ的交互 +**最新版本: v0.6.3** ([更新日志](changelogs/changelog.md)) -**分支说明:** +**GitHub分支说明:** - `main`: 稳定发布版本 -- `dev`: 开发测试版本(不知道什么意思就别下) -- `classical`: 0.6.0之前的版本 - - - -> [!WARNING] -> - 项目处于活跃开发阶段,代码可能随时更改 -> - 文档未完善,有问题可以提交 Issue 或者 Discussion -> - QQ机器人存在被限制风险,请自行了解,谨慎使用 -> - 由于持续迭代,可能存在一些已知或未知的bug -> - 由于开发中,可能消耗较多token +- `dev`: 开发测试版本(不稳定) +- `classical`: 0.6.0之前的版本(停止维护) ### ⚠️ 重要提示 -- 升级到v0.6.x版本前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) -- 本版本基于MaiCore重构,通过nonebot插件与QQ平台交互 -- 项目处于活跃开发阶段,功能和API可能随时调整 - -### 💬交流群(开发和建议相关讨论)不一定有空回复,会优先写文档和代码 -- [一群](https://qm.qq.com/q/VQ3XZrWgMs) 766798517 -- [二群](https://qm.qq.com/q/RzmCiRtHEW) 571780722 -- [五群](https://qm.qq.com/q/JxvHZnxyec) 1022489779 -- [三群](https://qm.qq.com/q/wlH5eT8OmQ) 1035228475【已满】 -- [四群](https://qm.qq.com/q/wGePTl1UyY) 729957033【已满】 - +- 从0.5.x旧版本升级前请请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) +- 项目处于活跃开发阶段,功能和API可能随时调整。 +- 文档未完善,有问题可以提交 Issue 或者 Discussion。 +- QQ机器人存在被限制风险,请自行了解,谨慎使用。 +- 由于持续迭代,可能存在一些已知或未知的bug。 +- 由于程序处于开发中,可能消耗较多token。 +### 💬交流群 +- [一群](https://qm.qq.com/q/VQ3XZrWgMs) | [二群](https://qm.qq.com/q/RzmCiRtHEW) | [五群](https://qm.qq.com/q/JxvHZnxyec) | [三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)| [四群](https://qm.qq.com/q/wGePTl1UyY)(已满) ## 📚 文档 - ### (部分内容可能过时,请注意版本对应) ### 核心文档 @@ -109,45 +65,9 @@ ### 最新版本部署教程(MaiCore版本) - [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于MaiCore的新版本部署方式(与旧版本不兼容) - -## 🎯 0.6.3 功能介绍 - -| 模块 | 主要功能 | 特点 | -|----------|------------------------------------------------------------------|-------| -| 💬 聊天系统 | • **统一调控不同回复逻辑**
• 智能交互模式 (普通聊天/专注聊天)
• 关键词主动发言
• 多模型支持
• 动态prompt构建
• 私聊功能(PFC)增强 | 拟人化交互 | -| 🧠 心流系统 | • 实时思考生成
• **智能状态管理**
• **概率回复机制**
• 自动启停机制
• 日程系统联动
• **上下文感知工具调用** | 智能化决策 | -| 🧠 记忆系统 | • **记忆整合与提取**
• 海马体记忆机制
• 聊天记录概括 | 持久化记忆 | -| 😊 表情系统 | • **全新表情包系统**
• **优化选择逻辑**
• 情绪匹配发送
• GIF支持
• 自动收集与审查 | 丰富表达 | -| 📅 日程系统 | • 动态日程生成
• 自定义想象力
• 思维流联动 | 智能规划 | -| 👥 关系系统 | • **工具调用动态更新**
• 关系管理优化
• 丰富接口支持
• 个性化交互 | 深度社交 | -| 📊 统计系统 | • 使用数据统计
• LLM调用记录
• 实时控制台显示 | 数据可视 | -| 🛠️ 工具系统 | • **LPMM知识库集成**
• **上下文感知调用**
• 知识获取工具
• 自动注册机制
• 多工具支持 | 扩展功能 | -| 📚 **知识库(LPMM)** | • **全新LPMM系统**
• **强大的信息检索能力** | 知识增强 | -| ✨ **昵称系统** | • **自动为群友取昵称**
• **降低认错人概率** (早期阶段) | 身份识别 | - -## 📐 项目架构 - -```mermaid -graph TD - A[MaiCore] --> B[对话系统] - A --> C[心流系统] - A --> D[记忆系统] - A --> E[情感系统] - B --> F[多模型支持] - B --> G[动态Prompt] - C --> H[实时思考] - C --> I[日程联动] - D --> J[记忆存储] - D --> K[记忆检索] - E --> L[表情管理] - E --> M[情绪识别] -``` - ## ✍️如何给本项目报告BUG/提交建议/做贡献 -MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交bug报告、功能需求还是代码pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](depends-data/CONTRIBUTE.md)(待补完) - - +MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交bug报告、功能需求还是代码pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) ## 设计理念(原始时代的火花) @@ -162,27 +82,29 @@ MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献, ## 📌 注意事项 > [!WARNING] -> 使用本项目前必须阅读和同意用户协议和隐私协议 +> 使用本项目前必须阅读和同意[用户协议](https://docs.mai-mai.org/manual/other/EULA.html)和隐私协议。 > 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI生成内容不代表本人观点和立场。 -## 致谢 - -- [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现 - -## 麦麦仓库状态 - -![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "Repobeats analytics image") - -### 贡献者 +## 贡献者 感谢各位大佬! contributors + + 画师:略nd + +### 致谢 + +- [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现 **也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们** -## Stargazers over time +## 麦麦仓库状态 -[![Stargazers over time](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) +![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "Repobeats analytics image") + +## Star 趋势 + +[![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) diff --git a/changelogs/changelog_dev.md b/changelogs/changelog_dev.md deleted file mode 100644 index 663ad9629..000000000 --- a/changelogs/changelog_dev.md +++ /dev/null @@ -1,27 +0,0 @@ -这里放置了测试版本的细节更新 - -## [test-0.6.1-snapshot-1] - 2025-4-5 -- 修复pfc回复出错bug -- 修复表情包打字时间,不会卡表情包 -- 改进了知识库的提取 -- 提供了新的数据库连接方式 -- 修复了ban_user无效的问题 - -## [test-0.6.0-snapshot-9] - 2025-4-4 -- 可以识别gif表情包 - -## [test-0.6.0-snapshot-8] - 2025-4-3 -- 修复了表情包的注册,获取和发送逻辑 -- 表情包增加存储上限 -- 更改了回复引用的逻辑,从基于时间改为基于新消息 -- 增加了调试信息 -- 自动清理缓存图片 -- 修复并重启了关系系统 - -## [test-0.6.0-snapshot-7] - 2025-4-2 -- 修改版本号命名:test-前缀为测试版,无前缀为正式版 -- 提供私聊的PFC模式,可以进行有目的,自由多轮对话 - -## [0.6.0-mmc-4] - 2025-4-1 -- 提供两种聊天逻辑,思维流聊天(ThinkFlowChat 和 推理聊天(ReasoningChat) -- 从结构上可支持多种回复消息逻辑 \ No newline at end of file diff --git a/src/0.6Bing.md b/docs/0.6Bing.md similarity index 100% rename from src/0.6Bing.md rename to docs/0.6Bing.md diff --git a/depends-data/CONTRIBUTE.md b/docs/CONTRIBUTE.md similarity index 100% rename from depends-data/CONTRIBUTE.md rename to docs/CONTRIBUTE.md diff --git a/src/heartFC_chatting_logic.md b/docs/HeartFC_chatting_logic.md similarity index 100% rename from src/heartFC_chatting_logic.md rename to docs/HeartFC_chatting_logic.md diff --git a/src/heartFC_readme.md b/docs/HeartFC_readme.md similarity index 100% rename from src/heartFC_readme.md rename to docs/HeartFC_readme.md diff --git a/src/README.md b/docs/HeartFC_system.md similarity index 100% rename from src/README.md rename to docs/HeartFC_system.md diff --git a/src/tools/tool_can_use/README.md b/docs/use_tool.md similarity index 100% rename from src/tools/tool_can_use/README.md rename to docs/use_tool.md From 174a47e1df4c69716bc892b3831f31e1f9366e76 Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 18:27:34 +0800 Subject: [PATCH 34/57] =?UTF-8?q?=E8=B0=83=E6=95=B4=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 42 +++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 98a8076ac..a0113e435 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ 🌟 案例展示 | 🚀 快速入门 | 📃 教程 | 💬 讨论 | 🙋 常见问题 -## 📝 项目简介 +## 🎉 介绍 **🍔MaiCore是一个基于大语言模型的可交互智能体** @@ -34,15 +34,16 @@ -### 📢 版本信息 - +## 🔥 更新和安装 **最新版本: v0.6.3** ([更新日志](changelogs/changelog.md)) - **GitHub分支说明:** - `main`: 稳定发布版本 - `dev`: 开发测试版本(不稳定) - `classical`: 0.6.0之前的版本(停止维护) +### 最新版本部署教程(MaiCore版本) +- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于MaiCore的新版本部署方式(与旧版本不兼容) + ### ⚠️ 重要提示 - 从0.5.x旧版本升级前请请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) @@ -52,24 +53,15 @@ - 由于持续迭代,可能存在一些已知或未知的bug。 - 由于程序处于开发中,可能消耗较多token。 -### 💬交流群 +## 💬 讨论 - [一群](https://qm.qq.com/q/VQ3XZrWgMs) | [二群](https://qm.qq.com/q/RzmCiRtHEW) | [五群](https://qm.qq.com/q/JxvHZnxyec) | [三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)| [四群](https://qm.qq.com/q/wGePTl1UyY)(已满) ## 📚 文档 +**部分内容可能更新不够及时,请注意版本对应** -### (部分内容可能过时,请注意版本对应) - -### 核心文档 - [📚 核心Wiki文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切 -### 最新版本部署教程(MaiCore版本) -- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于MaiCore的新版本部署方式(与旧版本不兼容) - -## ✍️如何给本项目报告BUG/提交建议/做贡献 - -MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交bug报告、功能需求还是代码pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) - -## 设计理念(原始时代的火花) +### 设计理念(原始时代的火花) > **千石可乐说:** > - 这个项目最初只是为了给牛牛bot添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在QQ群聊的"生命体"。可以目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 @@ -82,21 +74,21 @@ MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献, ## 📌 注意事项 > [!WARNING] -> 使用本项目前必须阅读和同意[用户协议](https://docs.mai-mai.org/manual/other/EULA.html)和隐私协议。 -> 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI生成内容不代表本人观点和立场。 +> 使用本项目前必须阅读和同意[用户协议](https://docs.mai-mai.org/manual/other/EULA.html)和[隐私协议](PRIVACY.md)。 +> 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI生成内容不代表本项目团队的观点和立场。 -## 贡献者 +## 贡献和致谢 +MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交bug报告、功能需求还是代码pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) -感谢各位大佬! +### 贡献 +感谢各位大佬! contributors - - 画师:略nd - -### 致谢 +### 致谢 +- [略nd](https://space.bilibili.com/1344099355): 麦麦人设 - [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现 **也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们** @@ -105,6 +97,6 @@ MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献, ![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "Repobeats analytics image") -## Star 趋势 +### Star 趋势 [![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) From c3ffe5b6ad02518e64d17fa3095354bf61e25ec6 Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 21:25:39 +0800 Subject: [PATCH 35/57] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 格式化,改图片为固定高度 --- README.md | 67 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index a0113e435..a6dec73ed 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # 麦麦!MaiCore-MaiBot (编辑中) - MaiBot + MaiBot ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) @@ -13,74 +13,88 @@ ![issues](https://img.shields.io/github/issues/MaiM-with-u/MaiBot) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/DrSmoothl/MaiBot) - 🌟 案例展示 | 🚀 快速入门 | 📃 教程 | 💬 讨论 | 🙋 常见问题 - + +🌟 案例展示 | +🚀 快速入门 | +📃 教程 | +💬 讨论 | +🙋 常见问题 + ## 🎉 介绍 -**🍔MaiCore是一个基于大语言模型的可交互智能体** +**🍔MaiCore 是一个基于大语言模型的可交互智能体** -- 💭 **智能对话系统**:基于LLM的自然语言交互 +- 💭 **智能对话系统**:基于 LLM 的自然语言交互 - 🤔 **实时思维系统**:模拟人类思考过程 - 💝 **情感表达系统**:丰富的表情包和情绪表达 -- 🧠 **持久记忆系统**:基于MongoDB的长期记忆存储 +- 🧠 **持久记忆系统**:基于 MongoDB 的长期记忆存储 - 🔄 **动态人格系统**:自适应的性格特征 ## 🔥 更新和安装 + **最新版本: v0.6.3** ([更新日志](changelogs/changelog.md)) -**GitHub分支说明:** + +**GitHub 分支说明:** - `main`: 稳定发布版本 - `dev`: 开发测试版本(不稳定) -- `classical`: 0.6.0之前的版本(停止维护) +- `classical`: 0.6.0 之前的版本 (停止维护) -### 最新版本部署教程(MaiCore版本) -- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于MaiCore的新版本部署方式(与旧版本不兼容) +### 最新版本部署教程 (MaiCore 版本) +- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容) ### ⚠️ 重要提示 -- 从0.5.x旧版本升级前请请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) -- 项目处于活跃开发阶段,功能和API可能随时调整。 +- 从 0.5.x 旧版本升级前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) +- 项目处于活跃开发阶段,功能和 API 可能随时调整。 - 文档未完善,有问题可以提交 Issue 或者 Discussion。 -- QQ机器人存在被限制风险,请自行了解,谨慎使用。 -- 由于持续迭代,可能存在一些已知或未知的bug。 -- 由于程序处于开发中,可能消耗较多token。 +- QQ 机器人存在被限制风险,请自行了解,谨慎使用。 +- 由于持续迭代,可能存在一些已知或未知的 bug。 +- 由于程序处于开发中,可能消耗较多 token。 ## 💬 讨论 -- [一群](https://qm.qq.com/q/VQ3XZrWgMs) | [二群](https://qm.qq.com/q/RzmCiRtHEW) | [五群](https://qm.qq.com/q/JxvHZnxyec) | [三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)| [四群](https://qm.qq.com/q/wGePTl1UyY)(已满) + +- [一群](https://qm.qq.com/q/VQ3XZrWgMs) | + [二群](https://qm.qq.com/q/RzmCiRtHEW) | + [五群](https://qm.qq.com/q/JxvHZnxyec) | + [三群](https://qm.qq.com/q/wlH5eT8OmQ)(已满)| + [四群](https://qm.qq.com/q/wGePTl1UyY)(已满) ## 📚 文档 + **部分内容可能更新不够及时,请注意版本对应** -- [📚 核心Wiki文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切 +- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切 ### 设计理念(原始时代的火花) > **千石可乐说:** -> - 这个项目最初只是为了给牛牛bot添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在QQ群聊的"生命体"。可以目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 +> - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 > - 程序的功能设计理念基于一个核心的原则:"最像而不是好" -> - 如果人类真的需要一个AI来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。 -> - 代码会保持开源和开放,但个人希望MaiMbot的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试.我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器. -> - SengokuCola~~纯编程外行,面向cursor编程,很多代码写得不好多多包涵~~已得到大脑升级 - +> - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。 +> - 代码会保持开源和开放,但个人希望 MaiMbot 的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试。我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器。 +> - SengokuCola~~纯编程外行,面向 cursor 编程,很多代码写得不好多多包涵~~已得到大脑升级 ## 📌 注意事项 > [!WARNING] > 使用本项目前必须阅读和同意[用户协议](https://docs.mai-mai.org/manual/other/EULA.html)和[隐私协议](PRIVACY.md)。 -> 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI生成内容不代表本项目团队的观点和立场。 +> 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI 生成内容不代表本项目团队的观点和立场。 ## 贡献和致谢 -MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交bug报告、功能需求还是代码pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) + +MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) ### 贡献 + 感谢各位大佬! @@ -88,6 +102,7 @@ MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献, ### 致谢 + - [略nd](https://space.bilibili.com/1344099355): 麦麦人设 - [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现 From 20ef1cbf029fbcfdc3ea7e68e84d2dcb3df8759d Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 21:28:31 +0800 Subject: [PATCH 36/57] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a6dec73ed..1882e6277 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # 麦麦!MaiCore-MaiBot (编辑中) - MaiBot + MaiBot ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) From 55784153d3b7014168bba3992341e9dc49b84f8e Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 21:54:01 +0800 Subject: [PATCH 37/57] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 尝试修正图片大小,更正链接, --- README.md | 53 ++++++++++++++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 1882e6277..f79ff4a80 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # 麦麦!MaiCore-MaiBot (编辑中) - MaiBot + MaiBot ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) @@ -14,22 +14,22 @@ [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/DrSmoothl/MaiBot) -🌟 案例展示 | -🚀 快速入门 | -📃 教程 | -💬 讨论 | -🙋 常见问题 +🌟 演示视频 | +🚀 快速入门 | +📃 教程 | +💬 讨论 | +🙋 贡献指南 ## 🎉 介绍 **🍔MaiCore 是一个基于大语言模型的可交互智能体** -- 💭 **智能对话系统**:基于 LLM 的自然语言交互 -- 🤔 **实时思维系统**:模拟人类思考过程 -- 💝 **情感表达系统**:丰富的表情包和情绪表达 -- 🧠 **持久记忆系统**:基于 MongoDB 的长期记忆存储 -- 🔄 **动态人格系统**:自适应的性格特征 +- 💭 **智能对话系统**:基于 LLM 的自然语言交互。 +- 🤔 **实时思维系统**:模拟人类思考过程。 +- 💝 **情感表达系统**:丰富的表情包和情绪表达。 +- 🧠 **持久记忆系统**:基于 MongoDB 的长期记忆存储。 +- 🔄 **动态人格系统**:自适应的性格特征。
@@ -51,14 +51,13 @@ ### 最新版本部署教程 (MaiCore 版本) - [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容) -### ⚠️ 重要提示 - -- 从 0.5.x 旧版本升级前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) -- 项目处于活跃开发阶段,功能和 API 可能随时调整。 -- 文档未完善,有问题可以提交 Issue 或者 Discussion。 -- QQ 机器人存在被限制风险,请自行了解,谨慎使用。 -- 由于持续迭代,可能存在一些已知或未知的 bug。 -- 由于程序处于开发中,可能消耗较多 token。 +> [!WARNING]重要提示 +> - 从 0.5.x 旧版本升级前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) +> - 项目处于活跃开发阶段,功能和 API 可能随时调整。 +> - 文档未完善,有问题可以提交 Issue 或者 Discussion。 +> - QQ 机器人存在被限制风险,请自行了解,谨慎使用。 +> - 由于持续迭代,可能存在一些已知或未知的 bug。 +> - 由于程序处于开发中,可能消耗较多 token。 ## 💬 讨论 @@ -72,24 +71,24 @@ **部分内容可能更新不够及时,请注意版本对应** -- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切 +- [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。 ### 设计理念(原始时代的火花) > **千石可乐说:** > - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 -> - 程序的功能设计理念基于一个核心的原则:"最像而不是好" +> - 程序的功能设计理念基于一个核心的原则:"最像而不是好"。 > - 如果人类真的需要一个 AI 来陪伴自己,并不是所有人都需要一个完美的,能解决所有问题的"helpful assistant",而是一个会犯错的,拥有自己感知和想法的"生命形式"。 > - 代码会保持开源和开放,但个人希望 MaiMbot 的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试。我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器。 -> - SengokuCola~~纯编程外行,面向 cursor 编程,很多代码写得不好多多包涵~~已得到大脑升级 +> - SengokuCola~~纯编程外行,面向 cursor 编程,很多代码写得不好多多包涵~~已得到大脑升级。 ## 📌 注意事项 > [!WARNING] -> 使用本项目前必须阅读和同意[用户协议](https://docs.mai-mai.org/manual/other/EULA.html)和[隐私协议](PRIVACY.md)。 +> 使用本项目前必须阅读和同意[用户协议](EULA.md)和[隐私协议](PRIVACY.md)。 > 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI 生成内容不代表本项目团队的观点和立场。 -## 贡献和致谢 +## 🙋 贡献和致谢 MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) @@ -103,14 +102,14 @@ MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献, ### 致谢 -- [略nd](https://space.bilibili.com/1344099355): 麦麦人设 -- [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现 +- [略nd](https://space.bilibili.com/1344099355): 感谢为麦麦绘制人设。 +- [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现。 **也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们** ## 麦麦仓库状态 -![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "Repobeats analytics image") +![Alt](https://repobeats.axiom.co/api/embed/9faca9fccfc467931b87dd357b60c6362b5cfae0.svg "麦麦仓库状态") ### Star 趋势 From fd3411a6eb1675d60e7f0dea87bab8b64ef1493e Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 21:58:46 +0800 Subject: [PATCH 38/57] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f79ff4a80..7c722b67c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # 麦麦!MaiCore-MaiBot (编辑中) - MaiBot + MaiBot ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) @@ -33,7 +33,7 @@
- 麦麦演示视频 + 麦麦演示视频
👆 点击观看麦麦演示视频 👆
From fee34b23e565dd097c06cfaa61c72417a2aa66ff Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 22:00:44 +0800 Subject: [PATCH 39/57] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7c722b67c..a61de9f9f 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # 麦麦!MaiCore-MaiBot (编辑中) - MaiBot + MaiBot ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) @@ -33,7 +33,7 @@
- 麦麦演示视频 + 麦麦演示视频
👆 点击观看麦麦演示视频 👆
From e787958ab955158213fead1dae8ba949b6d4396e Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 22:09:36 +0800 Subject: [PATCH 40/57] Update README.md --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a61de9f9f..5fa145a85 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ # 麦麦!MaiCore-MaiBot (编辑中) - MaiBot + + MaiBot ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) @@ -33,9 +34,12 @@ From b9c1d19ed5fc3afd92549fc861a74869f68e37ed Mon Sep 17 00:00:00 2001 From: Oct-autumn Date: Fri, 16 May 2025 22:29:24 +0800 Subject: [PATCH 41/57] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3template?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E5=AD=97=E6=AE=B5=E5=90=8D=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- template/bot_config_template.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 64e51da77..36cfd4372 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -18,11 +18,11 @@ nickname = "麦麦" alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效 [chat_target] -talk_allowed = [ +talk_allowed_groups = [ 123, 123, ] #可以回复消息的群号码 -talk_frequency_down = [] #降低回复频率的群号码 +talk_frequency_down_groups = [] #降低回复频率的群号码 ban_user_id = [] #禁止回复和读取消息的QQ号 [personality] #未完善 @@ -121,7 +121,7 @@ memory_build_sample_length = 40 # 采样长度,数值越高一段记忆内容 memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多 forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习 -memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时 +memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时 memory_forget_percentage = 0.01 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认 consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低,麦麦整合越频繁,记忆更精简 From 1dfc3e533269894d3fd63e956a274e2d8aa89580 Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 22:31:44 +0800 Subject: [PATCH 42/57] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更正部分错误,添加License --- README.md | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5fa145a85..e5cc57c33 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -# 麦麦!MaiCore-MaiBot (编辑中) - MaiBot +# 麦麦!MaiCore-MaiBot (编辑中) + ![Python Version](https://img.shields.io/badge/Python-3.10+-blue) ![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议) ![Status](https://img.shields.io/badge/状态-开发中-yellow) @@ -55,7 +55,7 @@ ### 最新版本部署教程 (MaiCore 版本) - [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容) -> [!WARNING]重要提示 +> [!WARNING] > - 从 0.5.x 旧版本升级前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) > - 项目处于活跃开发阶段,功能和 API 可能随时调整。 > - 文档未完善,有问题可以提交 Issue 或者 Discussion。 @@ -76,7 +76,7 @@ **部分内容可能更新不够及时,请注意版本对应** - [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。 - +- [📚 开发文档](https://docs.mai-mai.org/develop/) - 为开发者提供了有关MaiBot架构、API和扩展开发的全面指南。 ### 设计理念(原始时代的火花) > **千石可乐说:** @@ -86,17 +86,13 @@ > - 代码会保持开源和开放,但个人希望 MaiMbot 的运行时数据保持封闭,尽量避免以显式命令来对其进行控制和调试。我认为一个你无法完全掌控的个体才更能让你感觉到它的自主性,而视其成为一个对话机器。 > - SengokuCola~~纯编程外行,面向 cursor 编程,很多代码写得不好多多包涵~~已得到大脑升级。 -## 📌 注意事项 - -> [!WARNING] -> 使用本项目前必须阅读和同意[用户协议](EULA.md)和[隐私协议](PRIVACY.md)。 -> 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI 生成内容不代表本项目团队的观点和立场。 - ## 🙋 贡献和致谢 -MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) +MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持! 🎉 +你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦! +但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) -### 贡献 +### 贡献者 感谢各位大佬! @@ -106,10 +102,16 @@ MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献, ### 致谢 -- [略nd](https://space.bilibili.com/1344099355): 感谢为麦麦绘制人设。 +- [略nd](https://space.bilibili.com/1344099355): 为麦麦绘制人设。 - [NapCat](https://github.com/NapNeko/NapCatQQ): 现代化的基于 NTQQ 的 Bot 协议端实现。 -**也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们** +**也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们!** + +## 📌 注意事项 + +> [!WARNING] +> 使用本项目前必须阅读和同意[用户协议](EULA.md)和[隐私协议](PRIVACY.md)。 +> 本应用生成内容来自人工智能模型,由 AI 生成,请仔细甄别,请勿用于违反法律的用途,AI 生成内容不代表本项目团队的观点和立场。 ## 麦麦仓库状态 @@ -118,3 +120,7 @@ MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献, ### Star 趋势 [![Star 趋势](https://starchart.cc/MaiM-with-u/MaiBot.svg?variant=adaptive)](https://starchart.cc/MaiM-with-u/MaiBot) + +## License + +GPL-3.0 From b5d5864a9e1bed31260d5e13187843844b93a8b3 Mon Sep 17 00:00:00 2001 From: Dreamwxz Date: Fri, 16 May 2025 22:55:42 +0800 Subject: [PATCH 43/57] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 调整语句顺序 --- README.md | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index e5cc57c33..b1c271245 100644 --- a/README.md +++ b/README.md @@ -46,14 +46,14 @@ ## 🔥 更新和安装 **最新版本: v0.6.3** ([更新日志](changelogs/changelog.md)) - +可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本 **GitHub 分支说明:** -- `main`: 稳定发布版本 -- `dev`: 开发测试版本(不稳定) -- `classical`: 0.6.0 之前的版本 (停止维护) +- `main`: 稳定发布版本(推荐) +- `dev`: 开发测试版本(不稳定) +- `classical`: 旧版本(停止维护) ### 最新版本部署教程 (MaiCore 版本) -- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容) +- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容) > [!WARNING] > - 从 0.5.x 旧版本升级前请务必阅读:[升级指南](https://docs.mai-mai.org/faq/maibot/backup_update.html) @@ -76,8 +76,8 @@ **部分内容可能更新不够及时,请注意版本对应** - [📚 核心 Wiki 文档](https://docs.mai-mai.org) - 项目最全面的文档中心,你可以了解麦麦有关的一切。 -- [📚 开发文档](https://docs.mai-mai.org/develop/) - 为开发者提供了有关MaiBot架构、API和扩展开发的全面指南。 -### 设计理念(原始时代的火花) + +### 设计理念(原始时代的火花) > **千石可乐说:** > - 这个项目最初只是为了给牛牛 bot 添加一点额外的功能,但是功能越写越多,最后决定重写。其目的是为了创造一个活跃在 QQ 群聊的"生命体"。目的并不是为了写一个功能齐全的机器人,而是一个尽可能让人感知到真实的类人存在。 @@ -87,10 +87,9 @@ > - SengokuCola~~纯编程外行,面向 cursor 编程,很多代码写得不好多多包涵~~已得到大脑升级。 ## 🙋 贡献和致谢 - -MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持! 🎉 你可以阅读[开发文档](https://docs.mai-mai.org/develop/)来更好的了解麦麦! -但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)(待补完) +MaiCore 是一个开源项目,我们非常欢迎你的参与。你的贡献,无论是提交 bug 报告、功能需求还是代码 pr,都对项目非常宝贵。我们非常感谢你的支持!🎉 +但无序的讨论会降低沟通效率,进而影响问题的解决速度,因此在提交任何贡献前,请务必先阅读本项目的[贡献指南](docs/CONTRIBUTE.md)。(待补完) ### 贡献者 From 61e0dbe372eb897903683405ff5e95f19a1cdbac Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 16 May 2025 23:16:47 +0800 Subject: [PATCH 44/57] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8D=E5=90=88?= =?UTF-8?q?=E5=B9=B6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/config_api.py | 4 ++-- src/chat/focus_chat/expressors/default_expressor.py | 4 ++-- src/chat/focus_chat/heartflow_prompt_builder.py | 1 - .../focus_chat/info_processors/chattinginfo_processor.py | 4 ++-- src/chat/focus_chat/info_processors/self_processor.py | 4 ++-- .../focus_chat/info_processors/working_memory_processor.py | 6 +++--- src/chat/focus_chat/planners/planner.py | 4 ++-- src/chat/focus_chat/working_memory/memory_manager.py | 2 +- src/chat/heart_flow/observation/chatting_observation.py | 2 +- src/chat/memory_system/Hippocampus.py | 6 +++--- src/chat/message_receive/bot.py | 7 +++++-- src/chat/person_info/relationship_manager.py | 2 +- src/experimental/PFC/pfc.py | 2 +- 13 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/api/config_api.py b/src/api/config_api.py index 0b23fb993..b81d28da9 100644 --- a/src/api/config_api.py +++ b/src/api/config_api.py @@ -128,7 +128,7 @@ class APIBotConfig: llm_reasoning: Dict[str, Any] # 推理模型配置 llm_normal: Dict[str, Any] # 普通模型配置 llm_topic_judge: Dict[str, Any] # 主题判断模型配置 - llm_summary: Dict[str, Any] # 总结模型配置 + model.summary: Dict[str, Any] # 总结模型配置 vlm: Dict[str, Any] # VLM模型配置 llm_heartflow: Dict[str, Any] # 心流模型配置 llm_observation: Dict[str, Any] # 观察模型配置 @@ -203,7 +203,7 @@ class APIBotConfig: "llm_reasoning", "llm_normal", "llm_topic_judge", - "llm_summary", + "model.summary", "vlm", "llm_heartflow", "llm_observation", diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index ccbc1ca56..81f577b61 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -351,7 +351,7 @@ class DefaultExpressor: grammar_habbits=grammar_habbits_str, chat_target=chat_target_1, chat_info=chat_talking_prompt, - bot_name=global_config.BOT_NICKNAME, + bot_name=global_config.bot.nickname, prompt_personality="", reason=reason, in_mind_reply=in_mind_reply, @@ -363,7 +363,7 @@ class DefaultExpressor: template_name, sender_name=effective_sender_name, # Used in private template chat_talking_prompt=chat_talking_prompt, - bot_name=global_config.BOT_NICKNAME, + bot_name=global_config.bot.nickname, prompt_personality=prompt_personality, reason=reason, moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index af526eb88..d8d2b836f 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -7,7 +7,6 @@ from src.chat.person_info.relationship_manager import relationship_manager from src.chat.utils.utils import get_embedding import time from typing import Union, Optional -from src.common.database import db from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index 8d1eb9793..c9641b9b7 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -27,7 +27,7 @@ class ChattingInfoProcessor(BaseProcessor): """初始化观察处理器""" super().__init__() # TODO: API-Adapter修改标记 - self.llm_summary = LLMRequest( + self.model_summary = LLMRequest( model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" ) @@ -94,7 +94,7 @@ class ChattingInfoProcessor(BaseProcessor): async def chat_compress(self, obs: ChattingObservation): if obs.compressor_prompt: try: - summary_result, _, _ = await self.llm_summary.generate_response(obs.compressor_prompt) + summary_result, _, _ = await self.model_summary.generate_response(obs.compressor_prompt) summary = "没有主题的闲聊" # 默认值 if summary_result: # 确保结果不为空 summary = summary_result diff --git a/src/chat/focus_chat/info_processors/self_processor.py b/src/chat/focus_chat/info_processors/self_processor.py index 19876c93c..5114e49b6 100644 --- a/src/chat/focus_chat/info_processors/self_processor.py +++ b/src/chat/focus_chat/info_processors/self_processor.py @@ -49,8 +49,8 @@ class SelfProcessor(BaseProcessor): self.subheartflow_id = subheartflow_id self.llm_model = LLMRequest( - model=global_config.llm_sub_heartflow, - temperature=global_config.llm_sub_heartflow["temp"], + model=global_config.model.sub_heartflow, + temperature=global_config.model.sub_heartflow["temp"], max_tokens=800, request_type="self_identify", ) diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py index c682da699..c79c8363d 100644 --- a/src/chat/focus_chat/info_processors/working_memory_processor.py +++ b/src/chat/focus_chat/info_processors/working_memory_processor.py @@ -61,8 +61,8 @@ class WorkingMemoryProcessor(BaseProcessor): self.subheartflow_id = subheartflow_id self.llm_model = LLMRequest( - model=global_config.llm_sub_heartflow, - temperature=global_config.llm_sub_heartflow["temp"], + model=global_config.model.sub_heartflow, + temperature=global_config.model.sub_heartflow["temp"], max_tokens=800, request_type="working_memory", ) @@ -118,7 +118,7 @@ class WorkingMemoryProcessor(BaseProcessor): # 使用提示模板进行处理 prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format( - bot_name=global_config.BOT_NICKNAME, + bot_name=global_config.bot.nickname, time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), chat_observe_info=chat_info, memory_str=memory_choose_str, diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index 21ca157f9..116419ee1 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -69,7 +69,7 @@ class ActionPlanner: self.log_prefix = log_prefix # LLM规划器配置 self.planner_llm = LLMRequest( - model=global_config.llm_plan, + model=global_config.model.plan, max_tokens=1000, request_type="action_planning", # 用于动作规划 ) @@ -273,7 +273,7 @@ class ActionPlanner: planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt") prompt = planner_prompt_template.format( - bot_name=global_config.BOT_NICKNAME, + bot_name=global_config.bot.nickname, prompt_personality=personality_block, chat_context_description=chat_context_description, chat_content_block=chat_content_block, diff --git a/src/chat/focus_chat/working_memory/memory_manager.py b/src/chat/focus_chat/working_memory/memory_manager.py index 7154fe48c..7fda40239 100644 --- a/src/chat/focus_chat/working_memory/memory_manager.py +++ b/src/chat/focus_chat/working_memory/memory_manager.py @@ -33,7 +33,7 @@ class MemoryManager: self._id_map: Dict[str, MemoryItem] = {} self.llm_summarizer = LLMRequest( - model=global_config.llm_summary, temperature=0.3, max_tokens=512, request_type="memory_summarization" + model=global_config.model.summary, temperature=0.3, max_tokens=512, request_type="memory_summarization" ) @property diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index 9ea18b471..7e4872014 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -67,7 +67,7 @@ class ChattingObservation(Observation): self.oldest_messages_str = "" self.compressor_prompt = "" # TODO: API-Adapter修改标记 - self.llm_summary = LLMRequest( + self.model_summary = LLMRequest( model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" ) diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 2de769205..aae1721c2 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -193,7 +193,7 @@ class Hippocampus: def __init__(self): self.memory_graph = MemoryGraph() self.llm_topic_judge = None - self.llm_summary = None + self.model_summary = None self.entorhinal_cortex = None self.parahippocampal_gyrus = None @@ -205,7 +205,7 @@ class Hippocampus: self.entorhinal_cortex.sync_memory_from_db() # TODO: API-Adapter修改标记 self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory") - self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory") + self.model_summary = LLMRequest(global_config.model.summary, request_type="memory") def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表""" @@ -1167,7 +1167,7 @@ class ParahippocampalGyrus: # 调用修改后的 topic_what,不再需要 time_info topic_what_prompt = self.hippocampus.topic_what(input_text, topic) try: - task = self.hippocampus.llm_summary.generate_response_async(topic_what_prompt) + task = self.hippocampus.model_summary.generate_response_async(topic_what_prompt) tasks.append((topic.strip(), task)) except Exception as e: logger.error(f"生成话题 '{topic}' 的摘要时发生错误: {e}") diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 0e35f6f6e..cea791de4 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -72,6 +72,7 @@ class ChatBot: message_data["message_info"]["user_info"]["user_id"] = str( message_data["message_info"]["user_info"]["user_id"] ) + # print(message_data) logger.trace(f"处理消息:{str(message_data)[:120]}...") message = MessageRecv(message_data) groupinfo = message.message_info.group_info @@ -86,12 +87,14 @@ class ChatBot: logger.trace("检测到私聊消息,检查") # 好友黑名单拦截 if userinfo.user_id not in global_config.experimental.talk_allowed_private: - logger.debug(f"用户{userinfo.user_id}没有私聊权限") + # logger.debug(f"用户{userinfo.user_id}没有私聊权限") return # 群聊黑名单拦截 + # print(groupinfo.group_id) + # print(global_config.chat_target.talk_allowed_groups) if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups: - logger.trace(f"群{groupinfo.group_id}被禁止回复") + logger.debug(f"群{groupinfo.group_id}被禁止回复") return # 确认从接口发来的message是否有自定义的prompt模板信息 diff --git a/src/chat/person_info/relationship_manager.py b/src/chat/person_info/relationship_manager.py index c8a443857..a23780c0e 100644 --- a/src/chat/person_info/relationship_manager.py +++ b/src/chat/person_info/relationship_manager.py @@ -77,7 +77,7 @@ class RelationshipManager: @staticmethod async def is_known_some_one(platform, user_id): """判断是否认识某人""" - is_known = person_info_manager.is_person_known(platform, user_id) + is_known = await person_info_manager.is_person_known(platform, user_id) return is_known @staticmethod diff --git a/src/experimental/PFC/pfc.py b/src/experimental/PFC/pfc.py index 686d4af49..80e75c5bf 100644 --- a/src/experimental/PFC/pfc.py +++ b/src/experimental/PFC/pfc.py @@ -316,7 +316,7 @@ class GoalAnalyzer: # message_segment = Seg(type="text", data=content) # bot_user_info = UserInfo( # user_id=global_config.BOT_QQ, -# user_nickname=global_config.BOT_NICKNAME, +# user_nickname=global_config.bot.nickname, # platform=chat_stream.platform, # ) From 3520cebc269337854c977a2b80e50be403e7bb1f Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 16 May 2025 23:43:48 +0800 Subject: [PATCH 45/57] =?UTF-8?q?fix:=E5=B0=8F=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/config_api.py | 2 +- src/chat/focus_chat/expressors/default_expressor.py | 2 +- template/bot_config_template.toml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/api/config_api.py b/src/api/config_api.py index b81d28da9..8b99fb93e 100644 --- a/src/api/config_api.py +++ b/src/api/config_api.py @@ -41,7 +41,7 @@ class APIBotConfig: allow_focus_mode: bool # 是否允许专注聊天状态 base_normal_chat_num: int # 最多允许多少个群进行普通聊天 base_focused_chat_num: int # 最多允许多少个群进行专注聊天 - observation_context_size: int # 观察到的最长上下文大小 + chat.observation_context_size: int # 观察到的最长上下文大小 message_buffer: bool # 是否启用消息缓冲 ban_words: List[str] # 禁止词列表 ban_msgs_regex: List[str] # 禁止消息的正则表达式列表 diff --git a/src/chat/focus_chat/expressors/default_expressor.py b/src/chat/focus_chat/expressors/default_expressor.py index 81f577b61..d3d21e074 100644 --- a/src/chat/focus_chat/expressors/default_expressor.py +++ b/src/chat/focus_chat/expressors/default_expressor.py @@ -294,7 +294,7 @@ class DefaultExpressor: message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), - limit=global_config.observation_context_size, + limit=global_config.chat.observation_context_size, ) chat_talking_prompt = await build_readable_messages( message_list_before_now, diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 64e51da77..a778ed09c 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -63,7 +63,7 @@ allow_focus_mode = false # 是否允许专注聊天状态 base_normal_chat_num = 999 # 最多允许多少个群进行普通聊天 base_focused_chat_num = 4 # 最多允许多少个群进行专注聊天 -observation_context_size = 15 # 观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖 +chat.observation_context_size = 15 # 观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖 message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟 # 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息 @@ -99,7 +99,7 @@ default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快, consecutive_no_reply_threshold = 3 # 连续不回复的阈值,越低越容易结束专注聊天 # 以下选项暂时无效 -compressed_length = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 +compressed_length = 5 # 不能大于chat.observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下文会被删除 From d26d69de60b92b020894dfe8a57aad02f4206bf0 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 17 May 2025 01:03:20 +0800 Subject: [PATCH 46/57] =?UTF-8?q?fix=EF=BC=9B=E4=BF=AE=E5=A4=8D=E6=8F=90?= =?UTF-8?q?=E5=8F=96=E6=B6=88=E6=81=AF=E5=92=8C=E8=BF=90=E8=A1=8Cbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../info_processors/chattinginfo_processor.py | 3 + .../observation/chatting_observation.py | 29 ++- src/chat/utils/chat_message_builder.py | 10 + src/chat/utils/utils.py | 2 +- tests/common/test_message_repository.py | 174 ++++++++++++++++++ tests/test_build_readable_messages.py | 171 +++++++++++++++++ tests/test_extract_messages.py | 88 +++++++++ 7 files changed, 472 insertions(+), 5 deletions(-) create mode 100644 tests/common/test_message_repository.py create mode 100644 tests/test_build_readable_messages.py create mode 100644 tests/test_extract_messages.py diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index c9641b9b7..5a72bcd9e 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -63,13 +63,16 @@ class ChattingInfoProcessor(BaseProcessor): # 设置说话消息 if hasattr(obs, "talking_message_str"): + print(f"设置说话消息:obs.talking_message_str: {obs.talking_message_str}") obs_info.set_talking_message(obs.talking_message_str) # 设置截断后的说话消息 if hasattr(obs, "talking_message_str_truncate"): + print(f"设置截断后的说话消息:obs.talking_message_str_truncate: {obs.talking_message_str_truncate}") obs_info.set_talking_message_str_truncate(obs.talking_message_str_truncate) if hasattr(obs, "mid_memory_info"): + print(f"设置之前聊天信息:obs.mid_memory_info: {obs.mid_memory_info}") obs_info.set_previous_chat_info(obs.mid_memory_info) # 设置聊天类型 diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index 7e4872014..415e4b100 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -140,8 +140,23 @@ class ChattingObservation(Observation): return None # logger.debug(f"找到的锚定消息:find_msg: {find_msg}") - group_info = find_msg.get("chat_info", {}).get("group_info") - user_info = find_msg.get("chat_info", {}).get("user_info") + + # 创建所需的user_info字段 + user_info = { + "platform": find_msg.get("user_platform", ""), + "user_id": find_msg.get("user_id", ""), + "user_nickname": find_msg.get("user_nickname", ""), + "user_cardname": find_msg.get("user_cardname", "") + } + + # 创建所需的group_info字段,如果是群聊的话 + group_info = {} + if find_msg.get("chat_info_group_id"): + group_info = { + "platform": find_msg.get("chat_info_group_platform", ""), + "group_id": find_msg.get("chat_info_group_id", ""), + "group_name": find_msg.get("chat_info_group_name", "") + } content_format = "" accept_format = "" @@ -181,6 +196,8 @@ class ChattingObservation(Observation): limit=self.max_now_obs_len, limit_mode="latest", ) + + # print(f"new_messages_list: {new_messages_list}") last_obs_time_mark = self.last_observe_time if new_messages_list: @@ -193,6 +210,7 @@ class ChattingObservation(Observation): oldest_messages = self.talking_message[:messages_to_remove_count] self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的 + # print(f"压缩中:oldest_messages: {oldest_messages}") oldest_messages_str = await build_readable_messages( messages=oldest_messages, timestamp_mode="normal", read_mark=0 ) @@ -235,21 +253,24 @@ class ChattingObservation(Observation): self.oldest_messages = oldest_messages self.oldest_messages_str = oldest_messages_str + # 构建中 + # print(f"构建中:self.talking_message: {self.talking_message}") self.talking_message_str = await build_readable_messages( messages=self.talking_message, timestamp_mode="lite", read_mark=last_obs_time_mark, ) + # print(f"构建中:self.talking_message_str: {self.talking_message_str}") self.talking_message_str_truncate = await build_readable_messages( messages=self.talking_message, timestamp_mode="normal", read_mark=last_obs_time_mark, truncate=True, ) + # print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}") self.person_list = await get_person_id_list(self.talking_message) - - # print(f"self.11111person_list: {self.person_list}") + # print(f"构建中:self.person_list: {self.person_list}") logger.trace( f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}" diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index d3a062680..f81603e13 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -174,6 +174,16 @@ async def _build_readable_messages_internal( # 1 & 2: 获取发送者信息并提取消息组件 for msg in messages: + # 检查并修复缺少的user_info字段 + if 'user_info' not in msg: + # 创建user_info字段 + msg['user_info'] = { + 'platform': msg.get('user_platform', ''), + 'user_id': msg.get('user_id', ''), + 'user_nickname': msg.get('user_nickname', ''), + 'user_cardname': msg.get('user_cardname', '') + } + user_info = msg.get("user_info", {}) platform = user_info.get("platform") user_id = user_info.get("user_id") diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index c400a9948..a5b601c43 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -380,7 +380,7 @@ def process_llm_response(text: str) -> list[str]: # sentences.append(content) # 在所有句子处理完毕后,对包含占位符的列表进行恢复 - if global_config.enable_kaomoji_protection: + if global_config.response_splitter.enable_kaomoji_protection: sentences = recover_kaomoji(sentences, kaomoji_mapping) return sentences diff --git a/tests/common/test_message_repository.py b/tests/common/test_message_repository.py new file mode 100644 index 000000000..43d629761 --- /dev/null +++ b/tests/common/test_message_repository.py @@ -0,0 +1,174 @@ +import unittest +from unittest.mock import patch, MagicMock +import datetime +import sys +import os + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +from peewee import SqliteDatabase +from src.common.database.database_model import Messages, BaseModel +from src.common.message_repository import find_messages + + +class TestMessageRepository(unittest.TestCase): + def setUp(self): + # 创建内存中的SQLite数据库用于测试 + self.test_db = SqliteDatabase(':memory:') + + # 覆盖原有数据库连接 + BaseModel._meta.database = self.test_db + Messages._meta.database = self.test_db + + # 创建表 + self.test_db.create_tables([Messages]) + + # 添加测试数据 + current_time = datetime.datetime.now().timestamp() + self.test_messages = [ + { + 'message_id': 'msg1', + 'time': current_time - 3600, # 1小时前 + 'chat_id': '5ed68437e28644da51f314f37df68d18', + 'chat_info_stream_id': 'stream1', + 'chat_info_platform': 'qq', + 'chat_info_user_platform': 'qq', + 'chat_info_user_id': 'user1', + 'chat_info_user_nickname': '用户1', + 'chat_info_user_cardname': '卡片名1', + 'chat_info_group_platform': 'qq', + 'chat_info_group_id': 'group1', + 'chat_info_group_name': '群组1', + 'chat_info_create_time': current_time - 7200, # 2小时前 + 'chat_info_last_active_time': current_time - 1800, # 30分钟前 + 'user_platform': 'qq', + 'user_id': 'user1', + 'user_nickname': '用户1', + 'user_cardname': '卡片名1', + 'processed_plain_text': '你好', + 'detailed_plain_text': '你好', + 'memorized_times': 1 + }, + { + 'message_id': 'msg2', + 'time': current_time - 1800, # 30分钟前 + 'chat_id': 'chat1', + 'chat_info_stream_id': 'stream1', + 'chat_info_platform': 'qq', + 'chat_info_user_platform': 'qq', + 'chat_info_user_id': 'user1', + 'chat_info_user_nickname': '用户1', + 'chat_info_user_cardname': '卡片名1', + 'chat_info_group_platform': 'qq', + 'chat_info_group_id': 'group1', + 'chat_info_group_name': '群组1', + 'chat_info_create_time': current_time - 7200, + 'chat_info_last_active_time': current_time - 900, # 15分钟前 + 'user_platform': 'qq', + 'user_id': 'user1', + 'user_nickname': '用户1', + 'user_cardname': '卡片名1', + 'processed_plain_text': '世界', + 'detailed_plain_text': '世界', + 'memorized_times': 2 + }, + { + 'message_id': 'msg3', + 'time': current_time - 900, # 15分钟前 + 'chat_id': 'chat2', + 'chat_info_stream_id': 'stream2', + 'chat_info_platform': 'wechat', + 'chat_info_user_platform': 'wechat', + 'chat_info_user_id': 'user2', + 'chat_info_user_nickname': '用户2', + 'chat_info_user_cardname': '卡片名2', + 'chat_info_group_platform': 'wechat', + 'chat_info_group_id': 'group2', + 'chat_info_group_name': '群组2', + 'chat_info_create_time': current_time - 3600, + 'chat_info_last_active_time': current_time - 600, # 10分钟前 + 'user_platform': 'wechat', + 'user_id': 'user2', + 'user_nickname': '用户2', + 'user_cardname': '卡片名2', + 'processed_plain_text': '测试', + 'detailed_plain_text': '测试', + 'memorized_times': 0 + } + ] + + for msg_data in self.test_messages: + Messages.create(**msg_data) + + def tearDown(self): + # 关闭测试数据库连接 + self.test_db.close() + + def test_find_messages_no_filter(self): + """测试不带过滤器的查询""" + results = find_messages({}) + self.assertEqual(len(results), 3) + # 验证结果是否按时间升序排列 + self.assertEqual(results[0]['message_id'], 'msg1') + self.assertEqual(results[1]['message_id'], 'msg2') + self.assertEqual(results[2]['message_id'], 'msg3') + + def test_find_messages_with_filter(self): + """测试带过滤器的查询""" + results = find_messages({'chat_id': 'chat1'}) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]['message_id'], 'msg1') + self.assertEqual(results[1]['message_id'], 'msg2') + + results = find_messages({'user_id': 'user2'}) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['message_id'], 'msg3') + + def test_find_messages_with_operators(self): + """测试带操作符的查询""" + results = find_messages({'memorized_times': {'$gt': 0}}) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]['message_id'], 'msg1') + self.assertEqual(results[1]['message_id'], 'msg2') + + results = find_messages({'memorized_times': {'$gte': 2}}) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['message_id'], 'msg2') + + def test_find_messages_with_sort(self): + """测试带排序的查询""" + results = find_messages({}, sort=[('memorized_times', -1)]) + self.assertEqual(len(results), 3) + # 验证结果是否按memorized_times降序排列 + self.assertEqual(results[0]['message_id'], 'msg2') # memorized_times = 2 + self.assertEqual(results[1]['message_id'], 'msg1') # memorized_times = 1 + self.assertEqual(results[2]['message_id'], 'msg3') # memorized_times = 0 + + def test_find_messages_with_limit(self): + """测试带限制的查询""" + # 默认limit_mode为latest,应返回最新的2条记录 + results = find_messages({}, limit=2) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]['message_id'], 'msg2') + self.assertEqual(results[1]['message_id'], 'msg3') + + # 使用earliest模式,应返回最早的2条记录 + results = find_messages({}, limit=2, limit_mode='earliest') + self.assertEqual(len(results), 2) + self.assertEqual(results[0]['message_id'], 'msg1') + self.assertEqual(results[1]['message_id'], 'msg2') + + def test_find_messages_with_combined_criteria(self): + """测试组合查询条件""" + results = find_messages( + {'chat_info_platform': 'qq', 'memorized_times': {'$gt': 0}}, + sort=[('time', 1)], + limit=1 + ) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['message_id'], 'msg2') + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_build_readable_messages.py b/tests/test_build_readable_messages.py new file mode 100644 index 000000000..76caffb75 --- /dev/null +++ b/tests/test_build_readable_messages.py @@ -0,0 +1,171 @@ +import unittest +import sys +import os +import datetime +import time +import asyncio +import traceback +import json +import copy + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages +from src.common.logger import get_module_logger + +# 创建测试日志记录器 +logger = get_module_logger("test_readable_msg") + +class TestBuildReadableMessages(unittest.TestCase): + def setUp(self): + # 准备测试数据:从真实数据库获取消息 + self.chat_id = '5ed68437e28644da51f314f37df68d18' + self.current_time = time.time() + self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳 + + # 获取最新的10条消息 + try: + self.messages = get_raw_msg_by_timestamp_with_chat( + chat_id=self.chat_id, + timestamp_start=self.thirty_days_ago, + timestamp_end=self.current_time, + limit=10, + limit_mode="latest" + ) + logger.info(f"已获取 {len(self.messages)} 条测试消息") + + # 打印消息样例 + if self.messages: + sample_msg = self.messages[0] + logger.info(f"消息样例: {list(sample_msg.keys())}") + logger.info(f"消息内容: {sample_msg.get('processed_plain_text', '无文本内容')[:50]}...") + except Exception as e: + logger.error(f"获取消息失败: {e}") + logger.error(traceback.format_exc()) + self.messages = [] + + def test_manual_fix_messages(self): + """创建一个手动修复版本的消息进行测试""" + if not self.messages: + self.skipTest("没有测试消息,跳过测试") + return + + logger.info("开始手动修复消息...") + + # 创建修复版本的消息列表 + fixed_messages = [] + + for msg in self.messages: + # 深拷贝以避免修改原始数据 + fixed_msg = copy.deepcopy(msg) + + # 构建 user_info 对象 + if 'user_info' not in fixed_msg: + user_info = { + 'platform': fixed_msg.get('user_platform', 'qq'), + 'user_id': fixed_msg.get('user_id', '10000'), + 'user_nickname': fixed_msg.get('user_nickname', '测试用户'), + 'user_cardname': fixed_msg.get('user_cardname', '') + } + fixed_msg['user_info'] = user_info + logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info") + + fixed_messages.append(fixed_msg) + + logger.info(f"已修复 {len(fixed_messages)} 条消息") + + try: + # 使用修复后的消息尝试格式化 + formatted_text = asyncio.run(build_readable_messages( + messages=fixed_messages, + replace_bot_name=True, + merge_messages=False, + timestamp_mode="absolute", + read_mark=0.0, + truncate=False + )) + + logger.info("使用修复后的消息格式化完成") + logger.info(f"格式化结果长度: {len(formatted_text)}") + if formatted_text: + logger.info(f"格式化结果预览: {formatted_text[:200]}...") + else: + logger.warning("格式化结果为空") + + # 断言 + self.assertNotEqual(formatted_text, "", "有消息时不应返回空字符串") + except Exception as e: + logger.error(f"使用修复后的消息格式化失败: {e}") + logger.error(traceback.format_exc()) + raise + + def test_debug_build_messages_internal(self): + """调试_build_readable_messages_internal函数""" + if not self.messages: + self.skipTest("没有测试消息,跳过测试") + return + + logger.info("开始调试内部构建函数...") + + try: + # 直接导入内部函数进行测试 + from src.chat.utils.chat_message_builder import _build_readable_messages_internal + + # 手动创建一个简单的测试消息列表 + test_msg = self.messages[0].copy() # 使用第一条消息作为模板 + + # 检查消息结构 + logger.info(f"测试消息keys: {list(test_msg.keys())}") + logger.info(f"user_info存在: {'user_info' in test_msg}") + + # 修复缺少的user_info字段 + if 'user_info' not in test_msg: + logger.warning("消息中缺少user_info字段,添加模拟数据") + test_msg['user_info'] = { + 'platform': test_msg.get('user_platform', 'qq'), + 'user_id': test_msg.get('user_id', '10000'), + 'user_nickname': test_msg.get('user_nickname', '测试用户'), + 'user_cardname': test_msg.get('user_cardname', '') + } + logger.info(f"添加的user_info: {test_msg['user_info']}") + + simple_msgs = [test_msg] + + # 运行内部函数 + result_text, result_details = asyncio.run(_build_readable_messages_internal( + simple_msgs, + replace_bot_name=True, + merge_messages=False, + timestamp_mode="absolute", + truncate=False + )) + + logger.info(f"内部函数返回结果: {result_text[:200] if result_text else '空'}") + logger.info(f"详情列表长度: {len(result_details)}") + + # 显示处理过程中的变量 + if not result_text and len(simple_msgs) > 0: + logger.warning("消息处理可能有问题,检查关键步骤") + msg = simple_msgs[0] + + # 打印关键变量的值 + user_info = msg.get("user_info", {}) + platform = user_info.get("platform") + user_id = user_info.get("user_id") + timestamp = msg.get("time") + content = msg.get("processed_plain_text", "") + + logger.warning(f"平台: {platform}, 用户ID: {user_id}, 时间戳: {timestamp}") + logger.warning(f"内容: {content[:50]}...") + + # 检查必要信息是否完整 + logger.warning(f"必要信息完整性检查: {all([platform, user_id, timestamp is not None])}") + + except Exception as e: + logger.error(f"调试内部函数失败: {e}") + logger.error(traceback.format_exc()) + raise + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_extract_messages.py b/tests/test_extract_messages.py new file mode 100644 index 000000000..d32e644b6 --- /dev/null +++ b/tests/test_extract_messages.py @@ -0,0 +1,88 @@ +import unittest +import sys +import os +import datetime +import time + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from src.common.message_repository import find_messages +from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat +from peewee import SqliteDatabase +from src.common.database.database import db # 导入实际的数据库连接 + +class TestExtractMessages(unittest.TestCase): + def setUp(self): + # 这个测试使用真实的数据库,所以不需要创建测试数据 + pass + + def test_extract_latest_messages_direct(self): + """测试直接使用message_repository.find_messages函数""" + chat_id = '5ed68437e28644da51f314f37df68d18' + + # 提取最新的10条消息 + results = find_messages( + {'chat_id': chat_id}, + limit=10 + ) + + # 打印结果数量 + print(f"\n直接使用find_messages,找到 {len(results)} 条消息") + + # 如果有结果,打印一些信息 + if results: + print("\n消息时间顺序:") + for idx, msg in enumerate(results): + msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S') + print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}") + print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") + + # 验证结果按时间排序 + times = [msg['time'] for msg in results] + self.assertEqual(times, sorted(times), "消息应该按时间升序排列") + else: + print(f"未找到chat_id为 {chat_id} 的消息") + + # 最基本的断言,确保测试有效 + self.assertIsInstance(results, list, "结果应该是一个列表") + + def test_extract_latest_messages_via_builder(self): + """使用chat_message_builder中的函数测试从真实数据库提取消息""" + chat_id = '5ed68437e28644da51f314f37df68d18' + + # 设置时间范围为过去30天到现在 + current_time = time.time() + thirty_days_ago = current_time - (30 * 24 * 60 * 60) # 30天前的时间戳 + + # 使用chat_message_builder中的函数 + results = get_raw_msg_by_timestamp_with_chat( + chat_id=chat_id, + timestamp_start=thirty_days_ago, + timestamp_end=current_time, + limit=10, + limit_mode="latest" + ) + + # 打印结果数量 + print(f"\n使用get_raw_msg_by_timestamp_with_chat,找到 {len(results)} 条消息") + + # 如果有结果,打印一些信息 + if results: + print("\n消息时间顺序:") + for idx, msg in enumerate(results): + msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S') + print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}") + print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") + + # 验证结果按时间排序 + times = [msg['time'] for msg in results] + self.assertEqual(times, sorted(times), "消息应该按时间升序排列") + else: + print(f"未找到chat_id为 {chat_id} 的消息") + + # 最基本的断言,确保测试有效 + self.assertIsInstance(results, list, "结果应该是一个列表") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From f1081dfe76510679a8c2c9c40532ea988649e4b0 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 17 May 2025 01:04:12 +0800 Subject: [PATCH 47/57] Update chattinginfo_processor.py --- .../focus_chat/info_processors/chattinginfo_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index 5a72bcd9e..5b46d16bb 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -63,16 +63,16 @@ class ChattingInfoProcessor(BaseProcessor): # 设置说话消息 if hasattr(obs, "talking_message_str"): - print(f"设置说话消息:obs.talking_message_str: {obs.talking_message_str}") + # print(f"设置说话消息:obs.talking_message_str: {obs.talking_message_str}") obs_info.set_talking_message(obs.talking_message_str) # 设置截断后的说话消息 if hasattr(obs, "talking_message_str_truncate"): - print(f"设置截断后的说话消息:obs.talking_message_str_truncate: {obs.talking_message_str_truncate}") + # print(f"设置截断后的说话消息:obs.talking_message_str_truncate: {obs.talking_message_str_truncate}") obs_info.set_talking_message_str_truncate(obs.talking_message_str_truncate) if hasattr(obs, "mid_memory_info"): - print(f"设置之前聊天信息:obs.mid_memory_info: {obs.mid_memory_info}") + # print(f"设置之前聊天信息:obs.mid_memory_info: {obs.mid_memory_info}") obs_info.set_previous_chat_info(obs.mid_memory_info) # 设置聊天类型 From 06a3479c0f2d0586454fbdf2b3ec5c977429c1dd Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 17 May 2025 12:24:00 +0800 Subject: [PATCH 48/57] =?UTF-8?q?fix=EF=BC=9A=E8=AE=B0=E5=BF=86=E6=8A=A5?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/memory_system/Hippocampus.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index aae1721c2..67df2b817 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -22,6 +22,7 @@ from ..utils.utils import translate_timestamp_to_human_readable from rich.traceback import install from ...config.config import global_config +from src.common.database.database_model import Messages # Peewee Messages 模型导入 install(extra_lines=3) @@ -856,11 +857,14 @@ class EntorhinalCortex: if all_valid: # 更新数据库中的记忆次数 for message in messages: - # 确保在更新前获取最新的 memorized_times,以防万一 + # 确保在更新前获取最新的 memorized_times current_memorized_times = message.get("memorized_times", 0) - db.messages.update_one( - {"_id": message["_id"]}, {"$set": {"memorized_times": current_memorized_times + 1}} - ) + # 使用 Peewee 更新记录 + Messages.update( + memorized_times=current_memorized_times + 1 + ).where( + Messages.message_id == message["message_id"] + ).execute() return messages # 直接返回原始的消息列表 # 如果获取失败或消息无效,增加尝试次数 @@ -919,7 +923,7 @@ class EntorhinalCortex: "last_modified": last_modified, } }, - ) + ).execute() # 处理边的信息 db_edges = list(db.graph_data.edges.find()) @@ -965,7 +969,7 @@ class EntorhinalCortex: "last_modified": last_modified, } }, - ) + ).execute() def sync_memory_from_db(self): """从数据库同步数据到内存中的图结构""" @@ -993,7 +997,7 @@ class EntorhinalCortex: if "last_modified" not in node: update_data["last_modified"] = current_time - db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data}) + db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data}).execute() logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") # 获取时间信息(如果不存在则使用当前时间) @@ -1022,7 +1026,7 @@ class EntorhinalCortex: if "last_modified" not in edge: update_data["last_modified"] = current_time - db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data}) + db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data}).execute() logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") # 获取时间信息(如果不存在则使用当前时间) From e6cd2a8e8f222ad03e68cd3018c46b494700e4eb Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 17 May 2025 14:46:01 +0800 Subject: [PATCH 49/57] =?UTF-8?q?feat=EF=BC=9A=E6=B7=BB=E5=8A=A0=E6=B5=B7?= =?UTF-8?q?=E9=A9=AC=E4=BD=93=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/memory_system/Hippocampus.py | 166 ++++++++++++-------------- src/common/database/database_model.py | 35 +++++- 2 files changed, 112 insertions(+), 89 deletions(-) diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 67df2b817..23a296c8d 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -4,6 +4,7 @@ import math import random import time import re +import json from itertools import combinations import jieba @@ -22,7 +23,7 @@ from ..utils.utils import translate_timestamp_to_human_readable from rich.traceback import install from ...config.config import global_config -from src.common.database.database_model import Messages # Peewee Messages 模型导入 +from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 install(extra_lines=3) @@ -877,12 +878,9 @@ class EntorhinalCortex: async def sync_memory_to_db(self): """将记忆图同步到数据库""" # 获取数据库中所有节点和内存中所有节点 - db_nodes = list(db.graph_data.nodes.find()) + db_nodes = {node.concept: node for node in GraphNodes.select()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) - # 转换数据库节点为字典格式,方便查找 - db_nodes_dict = {node["concept"]: node for node in db_nodes} - # 检查并更新节点 for concept, data in memory_nodes: memory_items = data.get("memory_items", []) @@ -896,44 +894,39 @@ class EntorhinalCortex: created_time = data.get("created_time", datetime.datetime.now().timestamp()) last_modified = data.get("last_modified", datetime.datetime.now().timestamp()) - if concept not in db_nodes_dict: + # 将memory_items转换为JSON字符串 + memory_items_json = json.dumps(memory_items, ensure_ascii=False) + + if concept not in db_nodes: # 数据库中缺少的节点,添加 - node_data = { - "concept": concept, - "memory_items": memory_items, - "hash": memory_hash, - "created_time": created_time, - "last_modified": last_modified, - } - db.graph_data.nodes.insert_one(node_data) + GraphNodes.create( + concept=concept, + memory_items=memory_items_json, + hash=memory_hash, + created_time=created_time, + last_modified=last_modified, + ) else: # 获取数据库中节点的特征值 - db_node = db_nodes_dict[concept] - db_hash = db_node.get("hash", None) + db_node = db_nodes[concept] + db_hash = db_node.hash # 如果特征值不同,则更新节点 if db_hash != memory_hash: - db.graph_data.nodes.update_one( - {"concept": concept}, - { - "$set": { - "memory_items": memory_items, - "hash": memory_hash, - "created_time": created_time, - "last_modified": last_modified, - } - }, - ).execute() + db_node.memory_items = memory_items_json + db_node.hash = memory_hash + db_node.last_modified = last_modified + db_node.save() # 处理边的信息 - db_edges = list(db.graph_data.edges.find()) + db_edges = list(GraphEdges.select()) memory_edges = list(self.memory_graph.G.edges(data=True)) # 创建边的哈希值字典 db_edge_dict = {} for edge in db_edges: - edge_hash = self.hippocampus.calculate_edge_hash(edge["source"], edge["target"]) - db_edge_dict[(edge["source"], edge["target"])] = {"hash": edge_hash, "strength": edge.get("strength", 1)} + edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target) + db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength} # 检查并更新边 for source, target, data in memory_edges: @@ -947,29 +940,22 @@ class EntorhinalCortex: if edge_key not in db_edge_dict: # 添加新边 - edge_data = { - "source": source, - "target": target, - "strength": strength, - "hash": edge_hash, - "created_time": created_time, - "last_modified": last_modified, - } - db.graph_data.edges.insert_one(edge_data) + GraphEdges.create( + source=source, + target=target, + strength=strength, + hash=edge_hash, + created_time=created_time, + last_modified=last_modified, + ) else: # 检查边的特征值是否变化 if db_edge_dict[edge_key]["hash"] != edge_hash: - db.graph_data.edges.update_one( - {"source": source, "target": target}, - { - "$set": { - "hash": edge_hash, - "strength": strength, - "created_time": created_time, - "last_modified": last_modified, - } - }, - ).execute() + edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target) + edge.hash = edge_hash + edge.strength = strength + edge.last_modified = last_modified + edge.save() def sync_memory_from_db(self): """从数据库同步数据到内存中的图结构""" @@ -980,29 +966,31 @@ class EntorhinalCortex: self.memory_graph.G.clear() # 从数据库加载所有节点 - nodes = list(db.graph_data.nodes.find()) + nodes = list(GraphNodes.select()) for node in nodes: - concept = node["concept"] - memory_items = node.get("memory_items", []) + concept = node.concept + memory_items = json.loads(node.memory_items) if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] # 检查时间字段是否存在 - if "created_time" not in node or "last_modified" not in node: + if not node.created_time or not node.last_modified: need_update = True # 更新数据库中的节点 update_data = {} - if "created_time" not in node: + if not node.created_time: update_data["created_time"] = current_time - if "last_modified" not in node: + if not node.last_modified: update_data["last_modified"] = current_time - db.graph_data.nodes.update_one({"concept": concept}, {"$set": update_data}).execute() + GraphNodes.update( + **update_data + ).where(GraphNodes.concept == concept).execute() logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") # 获取时间信息(如果不存在则使用当前时间) - created_time = node.get("created_time", current_time) - last_modified = node.get("last_modified", current_time) + created_time = node.created_time or current_time + last_modified = node.last_modified or current_time # 添加节点到图中 self.memory_graph.G.add_node( @@ -1010,28 +998,32 @@ class EntorhinalCortex: ) # 从数据库加载所有边 - edges = list(db.graph_data.edges.find()) + edges = list(GraphEdges.select()) for edge in edges: - source = edge["source"] - target = edge["target"] - strength = edge.get("strength", 1) + source = edge.source + target = edge.target + strength = edge.strength # 检查时间字段是否存在 - if "created_time" not in edge or "last_modified" not in edge: + if not edge.created_time or not edge.last_modified: need_update = True # 更新数据库中的边 update_data = {} - if "created_time" not in edge: + if not edge.created_time: update_data["created_time"] = current_time - if "last_modified" not in edge: + if not edge.last_modified: update_data["last_modified"] = current_time - db.graph_data.edges.update_one({"source": source, "target": target}, {"$set": update_data}).execute() + GraphEdges.update( + **update_data + ).where( + (GraphEdges.source == source) & (GraphEdges.target == target) + ).execute() logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") # 获取时间信息(如果不存在则使用当前时间) - created_time = edge.get("created_time", current_time) - last_modified = edge.get("last_modified", current_time) + created_time = edge.created_time or current_time + last_modified = edge.last_modified or current_time # 只有当源节点和目标节点都存在时才添加边 if source in self.memory_graph.G and target in self.memory_graph.G: @@ -1049,8 +1041,8 @@ class EntorhinalCortex: # 清空数据库 clear_start = time.time() - db.graph_data.nodes.delete_many({}) - db.graph_data.edges.delete_many({}) + GraphNodes.delete().execute() + GraphEdges.delete().execute() clear_end = time.time() logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒") @@ -1065,29 +1057,27 @@ class EntorhinalCortex: if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] - node_data = { - "concept": concept, - "memory_items": memory_items, - "hash": self.hippocampus.calculate_node_hash(concept, memory_items), - "created_time": data.get("created_time", datetime.datetime.now().timestamp()), - "last_modified": data.get("last_modified", datetime.datetime.now().timestamp()), - } - db.graph_data.nodes.insert_one(node_data) + GraphNodes.create( + concept=concept, + memory_items=json.dumps(memory_items), + hash=self.hippocampus.calculate_node_hash(concept, memory_items), + created_time=data.get("created_time", datetime.datetime.now().timestamp()), + last_modified=data.get("last_modified", datetime.datetime.now().timestamp()), + ) node_end = time.time() logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒") # 重新写入边 edge_start = time.time() for source, target, data in memory_edges: - edge_data = { - "source": source, - "target": target, - "strength": data.get("strength", 1), - "hash": self.hippocampus.calculate_edge_hash(source, target), - "created_time": data.get("created_time", datetime.datetime.now().timestamp()), - "last_modified": data.get("last_modified", datetime.datetime.now().timestamp()), - } - db.graph_data.edges.insert_one(edge_data) + GraphEdges.create( + source=source, + target=target, + strength=data.get("strength", 1), + hash=self.hippocampus.calculate_edge_hash(source, target), + created_time=data.get("created_time", datetime.datetime.now().timestamp()), + last_modified=data.get("last_modified", datetime.datetime.now().timestamp()), + ) edge_end = time.time() logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒") diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index bd7a2d319..bf192ca6a 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -275,6 +275,35 @@ class RecalledMessages(BaseModel): table_name = "recalled_messages" +class GraphNodes(BaseModel): + """ + 用于存储记忆图节点的模型 + """ + concept = TextField(unique=True, index=True) # 节点概念 + memory_items = TextField() # JSON格式存储的记忆列表 + hash = TextField() # 节点哈希值 + created_time = FloatField() # 创建时间戳 + last_modified = FloatField() # 最后修改时间戳 + + class Meta: + table_name = "graph_nodes" + + +class GraphEdges(BaseModel): + """ + 用于存储记忆图边的模型 + """ + source = TextField(index=True) # 源节点 + target = TextField(index=True) # 目标节点 + strength = IntegerField() # 连接强度 + hash = TextField() # 边哈希值 + created_time = FloatField() # 创建时间戳 + last_modified = FloatField() # 最后修改时间戳 + + class Meta: + table_name = "graph_edges" + + def create_tables(): """ 创建所有在模型中定义的数据库表。 @@ -293,6 +322,8 @@ def create_tables(): Knowledges, ThinkingLog, RecalledMessages, # 添加新模型 + GraphNodes, # 添加图节点表 + GraphEdges, # 添加图边表 ] ) @@ -315,7 +346,9 @@ def initialize_database(): PersonInfo, Knowledges, ThinkingLog, - RecalledMessages, # 添加新模型 + RecalledMessages, + GraphNodes, # 添加图节点表 + GraphEdges, # 添加图边表 ] needs_creation = False From d70973af4c4d91590c5bf8be919ec0e9653f2ee5 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 17 May 2025 16:33:19 +0800 Subject: [PATCH 50/57] requirements.txt update --- requirements.txt | Bin 824 -> 838 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1e374f4eb463443ac5696797afd2e0760bde01a5..50c24228c5ae7dd75849f7e47df794c322d8e376 100644 GIT binary patch delta 22 ccmdnNc8qO97$avYLn=c#5T|YqWPHvD08Rb|lK=n! delta 12 TcmX@cwu5a$7~|#?#z%|*AF~9F From 877bd9e18880db20737833b35a72973bcc3e8776 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 17 May 2025 16:39:46 +0800 Subject: [PATCH 51/57] Update __init__.py --- src/plugins/test_plugin/actions/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plugins/test_plugin/actions/__init__.py b/src/plugins/test_plugin/actions/__init__.py index 7d96ea8a4..a87c0b523 100644 --- a/src/plugins/test_plugin/actions/__init__.py +++ b/src/plugins/test_plugin/actions/__init__.py @@ -1,7 +1,7 @@ """测试插件动作模块""" # 导入所有动作模块以确保装饰器被执行 -from . import test_action # noqa +# from . import test_action # noqa # from . import online_action # noqa -from . import mute_action # noqa +# from . import mute_action # noqa From 4d7b415589c2aa0ef5b498c6cee20b15100309e3 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 17 May 2025 16:45:07 +0800 Subject: [PATCH 52/57] Update requirements.txt --- requirements.txt | Bin 838 -> 824 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 50c24228c5ae7dd75849f7e47df794c322d8e376..9baaf0bd58d46521800144044eeb1a4e2aa9c384 100644 GIT binary patch delta 17 ZcmX@cwu5cLjE&2B8EhU6oyQO{Ebt2nHaeyFJzhu0DhMUUjP6A From 061fcefeef3667288c83f587d34e3a851f3fc17f Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 17 May 2025 17:34:44 +0800 Subject: [PATCH 53/57] =?UTF-8?q?=E5=8E=BB=E9=99=A4mmc=E7=AB=AF=E7=9A=84?= =?UTF-8?q?=E7=99=BD=E5=90=8D=E5=8D=95=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/bot.py | 96 ++++++++++++++----------- src/chat/message_receive/chat_stream.py | 50 ++++++------- src/config/official_configs.py | 8 +-- template/bot_config_template.toml | 9 +-- 4 files changed, 86 insertions(+), 77 deletions(-) diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index cea791de4..88bf141a1 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -38,10 +38,10 @@ class ChatBot: async def _create_pfc_chat(self, message: MessageRecv): try: - chat_id = str(message.chat_stream.stream_id) - private_name = str(message.message_info.user_info.user_nickname) + if global_config.experimental.pfc_chatting: + chat_id = str(message.chat_stream.stream_id) + private_name = str(message.message_info.user_info.user_nickname) - if global_config.experimental.enable_pfc_chatting: await self.pfc_manager.get_or_create_conversation(chat_id, private_name) except Exception as e: @@ -75,27 +75,27 @@ class ChatBot: # print(message_data) logger.trace(f"处理消息:{str(message_data)[:120]}...") message = MessageRecv(message_data) - groupinfo = message.message_info.group_info - userinfo = message.message_info.user_info + group_info = message.message_info.group_info + user_info = message.message_info.user_info # 用户黑名单拦截 - if userinfo.user_id in global_config.chat_target.ban_user_id: - logger.debug(f"用户{userinfo.user_id}被禁止回复") - return + # if userinfo.user_id in global_config.chat_target.ban_user_id: + # logger.debug(f"用户{userinfo.user_id}被禁止回复") + # return - if groupinfo is None: - logger.trace("检测到私聊消息,检查") - # 好友黑名单拦截 - if userinfo.user_id not in global_config.experimental.talk_allowed_private: - # logger.debug(f"用户{userinfo.user_id}没有私聊权限") - return + # if groupinfo is None: + # logger.trace("检测到私聊消息,检查") + # # 好友黑名单拦截 + # if userinfo.user_id not in global_config.experimental.talk_allowed_private: + # # logger.debug(f"用户{userinfo.user_id}没有私聊权限") + # return # 群聊黑名单拦截 # print(groupinfo.group_id) # print(global_config.chat_target.talk_allowed_groups) - if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups: - logger.debug(f"群{groupinfo.group_id}被禁止回复") - return + # if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups: + # logger.debug(f"群{groupinfo.group_id}被禁止回复") + # return # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: @@ -112,33 +112,49 @@ class ChatBot: async def preprocess(): logger.trace("开始预处理消息...") # 如果在私聊中 - if groupinfo is None: + if group_info is None: logger.trace("检测到私聊消息") # 是否在配置信息中开启私聊模式 - if global_config.experimental.enable_friend_chat: - logger.trace("私聊模式已启用") - # 是否进入PFC - if global_config.enable_pfc_chatting: - logger.trace("进入PFC私聊处理流程") - userinfo = message.message_info.user_info - messageinfo = message.message_info - # 创建聊天流 - logger.trace(f"为{userinfo.user_id}创建/获取聊天流") - chat = await chat_manager.get_or_create_stream( - platform=messageinfo.platform, - user_info=userinfo, - group_info=groupinfo, - ) - message.update_chat_stream(chat) - await self.only_process_chat.process_message(message) - await self._create_pfc_chat(message) - # 禁止PFC,进入普通的心流消息处理逻辑 - else: - logger.trace("进入普通心流私聊处理") - await self.heartflow_processor.process_message(message_data) + # if global_config.experimental.enable_friend_chat: + # logger.trace("私聊模式已启用") + # # 是否进入PFC + # if global_config.enable_pfc_chatting: + # logger.trace("进入PFC私聊处理流程") + # userinfo = message.message_info.user_info + # messageinfo = message.message_info + # # 创建聊天流 + # logger.trace(f"为{userinfo.user_id}创建/获取聊天流") + # chat = await chat_manager.get_or_create_stream( + # platform=messageinfo.platform, + # user_info=userinfo, + # group_info=groupinfo, + # ) + # message.update_chat_stream(chat) + # await self.only_process_chat.process_message(message) + # await self._create_pfc_chat(message) + # # 禁止PFC,进入普通的心流消息处理逻辑 + # else: + # logger.trace("进入普通心流私聊处理") + # await self.heartflow_processor.process_message(message_data) + if global_config.experimental.pfc_chatting: + logger.trace("进入PFC私聊处理流程") + # 创建聊天流 + logger.trace(f"为{user_info.user_id}创建/获取聊天流") + chat = await chat_manager.get_or_create_stream( + platform=message.message_info.platform, + user_info=user_info, + group_info=group_info, + ) + message.update_chat_stream(chat) + await self.only_process_chat.process_message(message) + await self._create_pfc_chat(message) + # 禁止PFC,进入普通的心流消息处理逻辑 + else: + logger.trace("进入普通心流私聊处理") + await self.heartflow_processor.process_message(message_data) # 群聊默认进入心流消息处理逻辑 else: - logger.trace(f"检测到群聊消息,群ID: {groupinfo.group_id}") + logger.trace(f"检测到群聊消息,群ID: {group_info.group_id}") await self.heartflow_processor.process_message(message_data) if template_group_name: diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 723d6da47..e00fc7370 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -39,7 +39,7 @@ class ChatStream: def to_dict(self) -> dict: """转换为字典格式""" - result = { + return { "stream_id": self.stream_id, "platform": self.platform, "user_info": self.user_info.to_dict() if self.user_info else None, @@ -47,7 +47,6 @@ class ChatStream: "create_time": self.create_time, "last_active_time": self.last_active_time, } - return result @classmethod def from_dict(cls, data: dict) -> "ChatStream": @@ -235,33 +234,34 @@ class ChatManager: @staticmethod async def _save_stream(stream: ChatStream): """保存聊天流到数据库""" - if not stream.saved: - stream_data_dict = stream.to_dict() + if stream.saved: + return + stream_data_dict = stream.to_dict() - def _db_save_stream_sync(s_data_dict: dict): - user_info_d = s_data_dict.get("user_info") - group_info_d = s_data_dict.get("group_info") + def _db_save_stream_sync(s_data_dict: dict): + user_info_d = s_data_dict.get("user_info") + group_info_d = s_data_dict.get("group_info") - fields_to_save = { - "platform": s_data_dict["platform"], - "create_time": s_data_dict["create_time"], - "last_active_time": s_data_dict["last_active_time"], - "user_platform": user_info_d["platform"] if user_info_d else "", - "user_id": user_info_d["user_id"] if user_info_d else "", - "user_nickname": user_info_d["user_nickname"] if user_info_d else "", - "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, - "group_platform": group_info_d["platform"] if group_info_d else "", - "group_id": group_info_d["group_id"] if group_info_d else "", - "group_name": group_info_d["group_name"] if group_info_d else "", - } + fields_to_save = { + "platform": s_data_dict["platform"], + "create_time": s_data_dict["create_time"], + "last_active_time": s_data_dict["last_active_time"], + "user_platform": user_info_d["platform"] if user_info_d else "", + "user_id": user_info_d["user_id"] if user_info_d else "", + "user_nickname": user_info_d["user_nickname"] if user_info_d else "", + "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, + "group_platform": group_info_d["platform"] if group_info_d else "", + "group_id": group_info_d["group_id"] if group_info_d else "", + "group_name": group_info_d["group_name"] if group_info_d else "", + } - ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute() + ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute() - try: - await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) - stream.saved = True - except Exception as e: - logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) + try: + await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) + stream.saved = True + except Exception as e: + logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True) async def _save_all_streams(self): """保存所有聊天流""" diff --git a/src/config/official_configs.py b/src/config/official_configs.py index d92d925d6..6ad4648ba 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -340,11 +340,11 @@ class TelemetryConfig(ConfigBase): class ExperimentalConfig(ConfigBase): """实验功能配置类""" - enable_friend_chat: bool = False - """是否启用好友聊天""" + # enable_friend_chat: bool = False + # """是否启用好友聊天""" - talk_allowed_private: set[str] = field(default_factory=lambda: set()) - """允许聊天的私聊列表""" + # talk_allowed_private: set[str] = field(default_factory=lambda: set()) + # """允许聊天的私聊列表""" pfc_chatting: bool = False """是否启用PFC""" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 8ffbcfa92..b66c3b180 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "2.0.0" +version = "2.1.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -18,12 +18,7 @@ nickname = "麦麦" alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效 [chat_target] -talk_allowed_groups = [ - 123, - 123, -] #可以回复消息的群号码 talk_frequency_down_groups = [] #降低回复频率的群号码 -ban_user_id = [] #禁止回复和读取消息的QQ号 [personality] #未完善 personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋 @@ -171,8 +166,6 @@ enable_kaomoji_protection = false # 是否启用颜文字保护 enable = true [experimental] #实验性功能 -enable_friend_chat = false # 是否启用好友聊天 -talk_allowed_private = [] # 可以回复消息的QQ号 pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与回复模式独立 #下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写 From 7973318f4ca5c6db69af94ac7d36ccc2dbce2e12 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sat, 17 May 2025 17:35:00 +0800 Subject: [PATCH 54/57] ruff --- .../observation/chatting_observation.py | 10 +- src/chat/memory_system/Hippocampus.py | 12 +- src/chat/utils/chat_message_builder.py | 14 +- src/common/database/database_model.py | 2 + tests/common/test_message_repository.py | 214 +++++++++--------- tests/test_build_readable_messages.py | 118 +++++----- tests/test_extract_messages.py | 101 ++++----- 7 files changed, 231 insertions(+), 240 deletions(-) diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index 415e4b100..9bd10e511 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -140,22 +140,22 @@ class ChattingObservation(Observation): return None # logger.debug(f"找到的锚定消息:find_msg: {find_msg}") - + # 创建所需的user_info字段 user_info = { "platform": find_msg.get("user_platform", ""), "user_id": find_msg.get("user_id", ""), "user_nickname": find_msg.get("user_nickname", ""), - "user_cardname": find_msg.get("user_cardname", "") + "user_cardname": find_msg.get("user_cardname", ""), } - + # 创建所需的group_info字段,如果是群聊的话 group_info = {} if find_msg.get("chat_info_group_id"): group_info = { "platform": find_msg.get("chat_info_group_platform", ""), "group_id": find_msg.get("chat_info_group_id", ""), - "group_name": find_msg.get("chat_info_group_name", "") + "group_name": find_msg.get("chat_info_group_name", ""), } content_format = "" @@ -196,7 +196,7 @@ class ChattingObservation(Observation): limit=self.max_now_obs_len, limit_mode="latest", ) - + # print(f"new_messages_list: {new_messages_list}") last_obs_time_mark = self.last_observe_time diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 23a296c8d..1695a3948 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -861,9 +861,7 @@ class EntorhinalCortex: # 确保在更新前获取最新的 memorized_times current_memorized_times = message.get("memorized_times", 0) # 使用 Peewee 更新记录 - Messages.update( - memorized_times=current_memorized_times + 1 - ).where( + Messages.update(memorized_times=current_memorized_times + 1).where( Messages.message_id == message["message_id"] ).execute() return messages # 直接返回原始的消息列表 @@ -983,9 +981,7 @@ class EntorhinalCortex: if not node.last_modified: update_data["last_modified"] = current_time - GraphNodes.update( - **update_data - ).where(GraphNodes.concept == concept).execute() + GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute() logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") # 获取时间信息(如果不存在则使用当前时间) @@ -1014,9 +1010,7 @@ class EntorhinalCortex: if not edge.last_modified: update_data["last_modified"] = current_time - GraphEdges.update( - **update_data - ).where( + GraphEdges.update(**update_data).where( (GraphEdges.source == source) & (GraphEdges.target == target) ).execute() logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index f81603e13..d662d8c0a 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -175,15 +175,15 @@ async def _build_readable_messages_internal( # 1 & 2: 获取发送者信息并提取消息组件 for msg in messages: # 检查并修复缺少的user_info字段 - if 'user_info' not in msg: + if "user_info" not in msg: # 创建user_info字段 - msg['user_info'] = { - 'platform': msg.get('user_platform', ''), - 'user_id': msg.get('user_id', ''), - 'user_nickname': msg.get('user_nickname', ''), - 'user_cardname': msg.get('user_cardname', '') + msg["user_info"] = { + "platform": msg.get("user_platform", ""), + "user_id": msg.get("user_id", ""), + "user_nickname": msg.get("user_nickname", ""), + "user_cardname": msg.get("user_cardname", ""), } - + user_info = msg.get("user_info", {}) platform = user_info.get("platform") user_id = user_info.get("user_id") diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index bf192ca6a..3544a8be0 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -279,6 +279,7 @@ class GraphNodes(BaseModel): """ 用于存储记忆图节点的模型 """ + concept = TextField(unique=True, index=True) # 节点概念 memory_items = TextField() # JSON格式存储的记忆列表 hash = TextField() # 节点哈希值 @@ -293,6 +294,7 @@ class GraphEdges(BaseModel): """ 用于存储记忆图边的模型 """ + source = TextField(index=True) # 源节点 target = TextField(index=True) # 目标节点 strength = IntegerField() # 连接强度 diff --git a/tests/common/test_message_repository.py b/tests/common/test_message_repository.py index 43d629761..798fa16b1 100644 --- a/tests/common/test_message_repository.py +++ b/tests/common/test_message_repository.py @@ -5,7 +5,7 @@ import sys import os # 添加项目根目录到Python路径 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) from peewee import SqliteDatabase from src.common.database.database_model import Messages, BaseModel @@ -15,160 +15,158 @@ from src.common.message_repository import find_messages class TestMessageRepository(unittest.TestCase): def setUp(self): # 创建内存中的SQLite数据库用于测试 - self.test_db = SqliteDatabase(':memory:') - + self.test_db = SqliteDatabase(":memory:") + # 覆盖原有数据库连接 BaseModel._meta.database = self.test_db Messages._meta.database = self.test_db - + # 创建表 self.test_db.create_tables([Messages]) - + # 添加测试数据 current_time = datetime.datetime.now().timestamp() self.test_messages = [ { - 'message_id': 'msg1', - 'time': current_time - 3600, # 1小时前 - 'chat_id': '5ed68437e28644da51f314f37df68d18', - 'chat_info_stream_id': 'stream1', - 'chat_info_platform': 'qq', - 'chat_info_user_platform': 'qq', - 'chat_info_user_id': 'user1', - 'chat_info_user_nickname': '用户1', - 'chat_info_user_cardname': '卡片名1', - 'chat_info_group_platform': 'qq', - 'chat_info_group_id': 'group1', - 'chat_info_group_name': '群组1', - 'chat_info_create_time': current_time - 7200, # 2小时前 - 'chat_info_last_active_time': current_time - 1800, # 30分钟前 - 'user_platform': 'qq', - 'user_id': 'user1', - 'user_nickname': '用户1', - 'user_cardname': '卡片名1', - 'processed_plain_text': '你好', - 'detailed_plain_text': '你好', - 'memorized_times': 1 + "message_id": "msg1", + "time": current_time - 3600, # 1小时前 + "chat_id": "5ed68437e28644da51f314f37df68d18", + "chat_info_stream_id": "stream1", + "chat_info_platform": "qq", + "chat_info_user_platform": "qq", + "chat_info_user_id": "user1", + "chat_info_user_nickname": "用户1", + "chat_info_user_cardname": "卡片名1", + "chat_info_group_platform": "qq", + "chat_info_group_id": "group1", + "chat_info_group_name": "群组1", + "chat_info_create_time": current_time - 7200, # 2小时前 + "chat_info_last_active_time": current_time - 1800, # 30分钟前 + "user_platform": "qq", + "user_id": "user1", + "user_nickname": "用户1", + "user_cardname": "卡片名1", + "processed_plain_text": "你好", + "detailed_plain_text": "你好", + "memorized_times": 1, }, { - 'message_id': 'msg2', - 'time': current_time - 1800, # 30分钟前 - 'chat_id': 'chat1', - 'chat_info_stream_id': 'stream1', - 'chat_info_platform': 'qq', - 'chat_info_user_platform': 'qq', - 'chat_info_user_id': 'user1', - 'chat_info_user_nickname': '用户1', - 'chat_info_user_cardname': '卡片名1', - 'chat_info_group_platform': 'qq', - 'chat_info_group_id': 'group1', - 'chat_info_group_name': '群组1', - 'chat_info_create_time': current_time - 7200, - 'chat_info_last_active_time': current_time - 900, # 15分钟前 - 'user_platform': 'qq', - 'user_id': 'user1', - 'user_nickname': '用户1', - 'user_cardname': '卡片名1', - 'processed_plain_text': '世界', - 'detailed_plain_text': '世界', - 'memorized_times': 2 + "message_id": "msg2", + "time": current_time - 1800, # 30分钟前 + "chat_id": "chat1", + "chat_info_stream_id": "stream1", + "chat_info_platform": "qq", + "chat_info_user_platform": "qq", + "chat_info_user_id": "user1", + "chat_info_user_nickname": "用户1", + "chat_info_user_cardname": "卡片名1", + "chat_info_group_platform": "qq", + "chat_info_group_id": "group1", + "chat_info_group_name": "群组1", + "chat_info_create_time": current_time - 7200, + "chat_info_last_active_time": current_time - 900, # 15分钟前 + "user_platform": "qq", + "user_id": "user1", + "user_nickname": "用户1", + "user_cardname": "卡片名1", + "processed_plain_text": "世界", + "detailed_plain_text": "世界", + "memorized_times": 2, }, { - 'message_id': 'msg3', - 'time': current_time - 900, # 15分钟前 - 'chat_id': 'chat2', - 'chat_info_stream_id': 'stream2', - 'chat_info_platform': 'wechat', - 'chat_info_user_platform': 'wechat', - 'chat_info_user_id': 'user2', - 'chat_info_user_nickname': '用户2', - 'chat_info_user_cardname': '卡片名2', - 'chat_info_group_platform': 'wechat', - 'chat_info_group_id': 'group2', - 'chat_info_group_name': '群组2', - 'chat_info_create_time': current_time - 3600, - 'chat_info_last_active_time': current_time - 600, # 10分钟前 - 'user_platform': 'wechat', - 'user_id': 'user2', - 'user_nickname': '用户2', - 'user_cardname': '卡片名2', - 'processed_plain_text': '测试', - 'detailed_plain_text': '测试', - 'memorized_times': 0 - } + "message_id": "msg3", + "time": current_time - 900, # 15分钟前 + "chat_id": "chat2", + "chat_info_stream_id": "stream2", + "chat_info_platform": "wechat", + "chat_info_user_platform": "wechat", + "chat_info_user_id": "user2", + "chat_info_user_nickname": "用户2", + "chat_info_user_cardname": "卡片名2", + "chat_info_group_platform": "wechat", + "chat_info_group_id": "group2", + "chat_info_group_name": "群组2", + "chat_info_create_time": current_time - 3600, + "chat_info_last_active_time": current_time - 600, # 10分钟前 + "user_platform": "wechat", + "user_id": "user2", + "user_nickname": "用户2", + "user_cardname": "卡片名2", + "processed_plain_text": "测试", + "detailed_plain_text": "测试", + "memorized_times": 0, + }, ] - + for msg_data in self.test_messages: Messages.create(**msg_data) - + def tearDown(self): # 关闭测试数据库连接 self.test_db.close() - + def test_find_messages_no_filter(self): """测试不带过滤器的查询""" results = find_messages({}) self.assertEqual(len(results), 3) # 验证结果是否按时间升序排列 - self.assertEqual(results[0]['message_id'], 'msg1') - self.assertEqual(results[1]['message_id'], 'msg2') - self.assertEqual(results[2]['message_id'], 'msg3') - + self.assertEqual(results[0]["message_id"], "msg1") + self.assertEqual(results[1]["message_id"], "msg2") + self.assertEqual(results[2]["message_id"], "msg3") + def test_find_messages_with_filter(self): """测试带过滤器的查询""" - results = find_messages({'chat_id': 'chat1'}) + results = find_messages({"chat_id": "chat1"}) self.assertEqual(len(results), 2) - self.assertEqual(results[0]['message_id'], 'msg1') - self.assertEqual(results[1]['message_id'], 'msg2') - - results = find_messages({'user_id': 'user2'}) + self.assertEqual(results[0]["message_id"], "msg1") + self.assertEqual(results[1]["message_id"], "msg2") + + results = find_messages({"user_id": "user2"}) self.assertEqual(len(results), 1) - self.assertEqual(results[0]['message_id'], 'msg3') - + self.assertEqual(results[0]["message_id"], "msg3") + def test_find_messages_with_operators(self): """测试带操作符的查询""" - results = find_messages({'memorized_times': {'$gt': 0}}) + results = find_messages({"memorized_times": {"$gt": 0}}) self.assertEqual(len(results), 2) - self.assertEqual(results[0]['message_id'], 'msg1') - self.assertEqual(results[1]['message_id'], 'msg2') - - results = find_messages({'memorized_times': {'$gte': 2}}) + self.assertEqual(results[0]["message_id"], "msg1") + self.assertEqual(results[1]["message_id"], "msg2") + + results = find_messages({"memorized_times": {"$gte": 2}}) self.assertEqual(len(results), 1) - self.assertEqual(results[0]['message_id'], 'msg2') - + self.assertEqual(results[0]["message_id"], "msg2") + def test_find_messages_with_sort(self): """测试带排序的查询""" - results = find_messages({}, sort=[('memorized_times', -1)]) + results = find_messages({}, sort=[("memorized_times", -1)]) self.assertEqual(len(results), 3) # 验证结果是否按memorized_times降序排列 - self.assertEqual(results[0]['message_id'], 'msg2') # memorized_times = 2 - self.assertEqual(results[1]['message_id'], 'msg1') # memorized_times = 1 - self.assertEqual(results[2]['message_id'], 'msg3') # memorized_times = 0 - + self.assertEqual(results[0]["message_id"], "msg2") # memorized_times = 2 + self.assertEqual(results[1]["message_id"], "msg1") # memorized_times = 1 + self.assertEqual(results[2]["message_id"], "msg3") # memorized_times = 0 + def test_find_messages_with_limit(self): """测试带限制的查询""" # 默认limit_mode为latest,应返回最新的2条记录 results = find_messages({}, limit=2) self.assertEqual(len(results), 2) - self.assertEqual(results[0]['message_id'], 'msg2') - self.assertEqual(results[1]['message_id'], 'msg3') - + self.assertEqual(results[0]["message_id"], "msg2") + self.assertEqual(results[1]["message_id"], "msg3") + # 使用earliest模式,应返回最早的2条记录 - results = find_messages({}, limit=2, limit_mode='earliest') + results = find_messages({}, limit=2, limit_mode="earliest") self.assertEqual(len(results), 2) - self.assertEqual(results[0]['message_id'], 'msg1') - self.assertEqual(results[1]['message_id'], 'msg2') - + self.assertEqual(results[0]["message_id"], "msg1") + self.assertEqual(results[1]["message_id"], "msg2") + def test_find_messages_with_combined_criteria(self): """测试组合查询条件""" results = find_messages( - {'chat_info_platform': 'qq', 'memorized_times': {'$gt': 0}}, - sort=[('time', 1)], - limit=1 + {"chat_info_platform": "qq", "memorized_times": {"$gt": 0}}, sort=[("time", 1)], limit=1 ) self.assertEqual(len(results), 1) - self.assertEqual(results[0]['message_id'], 'msg2') + self.assertEqual(results[0]["message_id"], "msg2") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_build_readable_messages.py b/tests/test_build_readable_messages.py index 76caffb75..71d91a46d 100644 --- a/tests/test_build_readable_messages.py +++ b/tests/test_build_readable_messages.py @@ -9,7 +9,7 @@ import json import copy # 添加项目根目录到Python路径 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages from src.common.logger import get_module_logger @@ -17,13 +17,14 @@ from src.common.logger import get_module_logger # 创建测试日志记录器 logger = get_module_logger("test_readable_msg") + class TestBuildReadableMessages(unittest.TestCase): def setUp(self): # 准备测试数据:从真实数据库获取消息 - self.chat_id = '5ed68437e28644da51f314f37df68d18' + self.chat_id = "5ed68437e28644da51f314f37df68d18" self.current_time = time.time() self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳 - + # 获取最新的10条消息 try: self.messages = get_raw_msg_by_timestamp_with_chat( @@ -31,10 +32,10 @@ class TestBuildReadableMessages(unittest.TestCase): timestamp_start=self.thirty_days_ago, timestamp_end=self.current_time, limit=10, - limit_mode="latest" + limit_mode="latest", ) logger.info(f"已获取 {len(self.messages)} 条测试消息") - + # 打印消息样例 if self.messages: sample_msg = self.messages[0] @@ -44,128 +45,129 @@ class TestBuildReadableMessages(unittest.TestCase): logger.error(f"获取消息失败: {e}") logger.error(traceback.format_exc()) self.messages = [] - + def test_manual_fix_messages(self): """创建一个手动修复版本的消息进行测试""" if not self.messages: self.skipTest("没有测试消息,跳过测试") return - + logger.info("开始手动修复消息...") - + # 创建修复版本的消息列表 fixed_messages = [] - + for msg in self.messages: # 深拷贝以避免修改原始数据 fixed_msg = copy.deepcopy(msg) - + # 构建 user_info 对象 - if 'user_info' not in fixed_msg: + if "user_info" not in fixed_msg: user_info = { - 'platform': fixed_msg.get('user_platform', 'qq'), - 'user_id': fixed_msg.get('user_id', '10000'), - 'user_nickname': fixed_msg.get('user_nickname', '测试用户'), - 'user_cardname': fixed_msg.get('user_cardname', '') + "platform": fixed_msg.get("user_platform", "qq"), + "user_id": fixed_msg.get("user_id", "10000"), + "user_nickname": fixed_msg.get("user_nickname", "测试用户"), + "user_cardname": fixed_msg.get("user_cardname", ""), } - fixed_msg['user_info'] = user_info + fixed_msg["user_info"] = user_info logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info") - + fixed_messages.append(fixed_msg) - + logger.info(f"已修复 {len(fixed_messages)} 条消息") - + try: # 使用修复后的消息尝试格式化 - formatted_text = asyncio.run(build_readable_messages( - messages=fixed_messages, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="absolute", - read_mark=0.0, - truncate=False - )) - + formatted_text = asyncio.run( + build_readable_messages( + messages=fixed_messages, + replace_bot_name=True, + merge_messages=False, + timestamp_mode="absolute", + read_mark=0.0, + truncate=False, + ) + ) + logger.info("使用修复后的消息格式化完成") logger.info(f"格式化结果长度: {len(formatted_text)}") if formatted_text: logger.info(f"格式化结果预览: {formatted_text[:200]}...") else: logger.warning("格式化结果为空") - + # 断言 self.assertNotEqual(formatted_text, "", "有消息时不应返回空字符串") except Exception as e: logger.error(f"使用修复后的消息格式化失败: {e}") logger.error(traceback.format_exc()) raise - + def test_debug_build_messages_internal(self): """调试_build_readable_messages_internal函数""" if not self.messages: self.skipTest("没有测试消息,跳过测试") return - + logger.info("开始调试内部构建函数...") - + try: # 直接导入内部函数进行测试 from src.chat.utils.chat_message_builder import _build_readable_messages_internal - + # 手动创建一个简单的测试消息列表 test_msg = self.messages[0].copy() # 使用第一条消息作为模板 - + # 检查消息结构 logger.info(f"测试消息keys: {list(test_msg.keys())}") logger.info(f"user_info存在: {'user_info' in test_msg}") - + # 修复缺少的user_info字段 - if 'user_info' not in test_msg: + if "user_info" not in test_msg: logger.warning("消息中缺少user_info字段,添加模拟数据") - test_msg['user_info'] = { - 'platform': test_msg.get('user_platform', 'qq'), - 'user_id': test_msg.get('user_id', '10000'), - 'user_nickname': test_msg.get('user_nickname', '测试用户'), - 'user_cardname': test_msg.get('user_cardname', '') + test_msg["user_info"] = { + "platform": test_msg.get("user_platform", "qq"), + "user_id": test_msg.get("user_id", "10000"), + "user_nickname": test_msg.get("user_nickname", "测试用户"), + "user_cardname": test_msg.get("user_cardname", ""), } logger.info(f"添加的user_info: {test_msg['user_info']}") - + simple_msgs = [test_msg] - + # 运行内部函数 - result_text, result_details = asyncio.run(_build_readable_messages_internal( - simple_msgs, - replace_bot_name=True, - merge_messages=False, - timestamp_mode="absolute", - truncate=False - )) - + result_text, result_details = asyncio.run( + _build_readable_messages_internal( + simple_msgs, replace_bot_name=True, merge_messages=False, timestamp_mode="absolute", truncate=False + ) + ) + logger.info(f"内部函数返回结果: {result_text[:200] if result_text else '空'}") logger.info(f"详情列表长度: {len(result_details)}") - + # 显示处理过程中的变量 if not result_text and len(simple_msgs) > 0: logger.warning("消息处理可能有问题,检查关键步骤") msg = simple_msgs[0] - + # 打印关键变量的值 user_info = msg.get("user_info", {}) platform = user_info.get("platform") user_id = user_info.get("user_id") timestamp = msg.get("time") content = msg.get("processed_plain_text", "") - + logger.warning(f"平台: {platform}, 用户ID: {user_id}, 时间戳: {timestamp}") logger.warning(f"内容: {content[:50]}...") - + # 检查必要信息是否完整 logger.warning(f"必要信息完整性检查: {all([platform, user_id, timestamp is not None])}") - + except Exception as e: logger.error(f"调试内部函数失败: {e}") logger.error(traceback.format_exc()) raise -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_extract_messages.py b/tests/test_extract_messages.py index d32e644b6..95ddb523f 100644 --- a/tests/test_extract_messages.py +++ b/tests/test_extract_messages.py @@ -5,13 +5,14 @@ import datetime import time # 添加项目根目录到Python路径 -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.message_repository import find_messages from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat from peewee import SqliteDatabase from src.common.database.database import db # 导入实际的数据库连接 + class TestExtractMessages(unittest.TestCase): def setUp(self): # 这个测试使用真实的数据库,所以不需要创建测试数据 @@ -19,70 +20,64 @@ class TestExtractMessages(unittest.TestCase): def test_extract_latest_messages_direct(self): """测试直接使用message_repository.find_messages函数""" - chat_id = '5ed68437e28644da51f314f37df68d18' - + chat_id = "5ed68437e28644da51f314f37df68d18" + # 提取最新的10条消息 - results = find_messages( - {'chat_id': chat_id}, - limit=10 - ) - + results = find_messages({"chat_id": chat_id}, limit=10) + # 打印结果数量 print(f"\n直接使用find_messages,找到 {len(results)} 条消息") - + # 如果有结果,打印一些信息 if results: print("\n消息时间顺序:") for idx, msg in enumerate(results): - msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S') - print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}") + msg_time = datetime.datetime.fromtimestamp(msg["time"]).strftime("%Y-%m-%d %H:%M:%S") + print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}") print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") - + # 验证结果按时间排序 - times = [msg['time'] for msg in results] + times = [msg["time"] for msg in results] self.assertEqual(times, sorted(times), "消息应该按时间升序排列") else: print(f"未找到chat_id为 {chat_id} 的消息") - - # 最基本的断言,确保测试有效 - self.assertIsInstance(results, list, "结果应该是一个列表") - - def test_extract_latest_messages_via_builder(self): - """使用chat_message_builder中的函数测试从真实数据库提取消息""" - chat_id = '5ed68437e28644da51f314f37df68d18' - - # 设置时间范围为过去30天到现在 - current_time = time.time() - thirty_days_ago = current_time - (30 * 24 * 60 * 60) # 30天前的时间戳 - - # 使用chat_message_builder中的函数 - results = get_raw_msg_by_timestamp_with_chat( - chat_id=chat_id, - timestamp_start=thirty_days_ago, - timestamp_end=current_time, - limit=10, - limit_mode="latest" - ) - - # 打印结果数量 - print(f"\n使用get_raw_msg_by_timestamp_with_chat,找到 {len(results)} 条消息") - - # 如果有结果,打印一些信息 - if results: - print("\n消息时间顺序:") - for idx, msg in enumerate(results): - msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S') - print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}") - print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") - - # 验证结果按时间排序 - times = [msg['time'] for msg in results] - self.assertEqual(times, sorted(times), "消息应该按时间升序排列") - else: - print(f"未找到chat_id为 {chat_id} 的消息") - + # 最基本的断言,确保测试有效 self.assertIsInstance(results, list, "结果应该是一个列表") -if __name__ == '__main__': - unittest.main() \ No newline at end of file + def test_extract_latest_messages_via_builder(self): + """使用chat_message_builder中的函数测试从真实数据库提取消息""" + chat_id = "5ed68437e28644da51f314f37df68d18" + + # 设置时间范围为过去30天到现在 + current_time = time.time() + thirty_days_ago = current_time - (30 * 24 * 60 * 60) # 30天前的时间戳 + + # 使用chat_message_builder中的函数 + results = get_raw_msg_by_timestamp_with_chat( + chat_id=chat_id, timestamp_start=thirty_days_ago, timestamp_end=current_time, limit=10, limit_mode="latest" + ) + + # 打印结果数量 + print(f"\n使用get_raw_msg_by_timestamp_with_chat,找到 {len(results)} 条消息") + + # 如果有结果,打印一些信息 + if results: + print("\n消息时间顺序:") + for idx, msg in enumerate(results): + msg_time = datetime.datetime.fromtimestamp(msg["time"]).strftime("%Y-%m-%d %H:%M:%S") + print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}") + print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") + + # 验证结果按时间排序 + times = [msg["time"] for msg in results] + self.assertEqual(times, sorted(times), "消息应该按时间升序排列") + else: + print(f"未找到chat_id为 {chat_id} 的消息") + + # 最基本的断言,确保测试有效 + self.assertIsInstance(results, list, "结果应该是一个列表") + + +if __name__ == "__main__": + unittest.main() From a1809d347b0f2ce73fdc344a34c6d752b0fe0c44 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sun, 18 May 2025 17:01:05 +0800 Subject: [PATCH 55/57] =?UTF-8?q?refactor=EF=BC=9A=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E4=B8=BB=E5=BF=83=E6=B5=81=E5=86=97=E4=BD=99=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/heart_flow/heartflow.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/chat/heart_flow/heartflow.py b/src/chat/heart_flow/heartflow.py index 748c8331e..bad0683ce 100644 --- a/src/chat/heart_flow/heartflow.py +++ b/src/chat/heart_flow/heartflow.py @@ -4,10 +4,8 @@ from src.config.config import global_config from src.common.logger_manager import get_logger from typing import Any, Optional from src.tools.tool_use import ToolUser -from src.chat.person_info.relationship_manager import relationship_manager # Module instance from src.chat.heart_flow.mai_state_manager import MaiStateInfo, MaiStateManager from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager -from src.chat.heart_flow.interest_logger import InterestLogger # Import InterestLogger from src.chat.heart_flow.background_tasks import BackgroundTaskManager # Import BackgroundTaskManager logger = get_logger("heartflow") @@ -17,16 +15,10 @@ class Heartflow: """主心流协调器,负责初始化并协调各个子系统: - 状态管理 (MaiState) - 子心流管理 (SubHeartflow) - - 思考过程 (Mind) - - 日志记录 (InterestLogger) - 后台任务 (BackgroundTaskManager) """ def __init__(self): - # 核心状态 - self.current_mind = "什么也没想" # 当前主心流想法 - self.past_mind = [] # 历史想法记录 - # 状态管理相关 self.current_state: MaiStateInfo = MaiStateInfo() # 当前状态信息 self.mai_state_manager: MaiStateManager = MaiStateManager() # 状态决策管理器 @@ -34,24 +26,11 @@ class Heartflow: # 子心流管理 (在初始化时传入 current_state) self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state) - # LLM模型配置 - # TODO: API-Adapter修改标记 - self.llm_model = LLMRequest( - model=global_config.model.heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" - ) - - # 外部依赖模块 - self.tool_user_instance = ToolUser() # 工具使用模块 - self.relationship_manager_instance = relationship_manager # 关系管理模块 - - self.interest_logger: InterestLogger = InterestLogger(self.subheartflow_manager, self) # 兴趣日志记录器 - # 后台任务管理器 (整合所有定时任务) self.background_task_manager: BackgroundTaskManager = BackgroundTaskManager( mai_state_info=self.current_state, mai_state_manager=self.mai_state_manager, subheartflow_manager=self.subheartflow_manager, - interest_logger=self.interest_logger, ) async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]: From 0f788c7abae72779c2df5b486305b2873ab128d0 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 18 May 2025 17:59:47 +0800 Subject: [PATCH 56/57] requirements.txt fix --- requirements.txt | Bin 824 -> 826 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9baaf0bd58d46521800144044eeb1a4e2aa9c384..0e60bc192a2a0d16e882738151b07046307b08ff 100644 GIT binary patch delta 13 VcmdnNwu^1UA|^(z$qSjL0{|j71gii5 delta 11 TcmdnRwu5cMBBsepn5F>$9AyNe From 49c2bc854c02871998b39c12e4d2ce25c65d34f8 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sun, 18 May 2025 18:15:38 +0800 Subject: [PATCH 57/57] =?UTF-8?q?refactor=EF=BC=9A=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E8=81=8A=E5=A4=A9=E7=8A=B6=E6=80=81=E5=88=87=E6=8D=A2=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=EF=BC=8C=E7=A7=BB=E9=99=A4=E9=99=90=E9=A2=9D=EF=BC=8C?= =?UTF-8?q?=E7=B2=BE=E7=AE=80=E5=88=87=E6=8D=A2=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/HeartFC_system.md | 2 +- src/chat/focus_chat/heartFC_chat.py | 9 +- .../focus_chat/heartflow_prompt_builder.py | 6 +- src/chat/focus_chat/info/action_info.py | 83 +++ .../info_processors/action_processor.py | 126 +++++ .../info_processors/self_processor.py | 5 +- .../focus_chat/planners/action_manager.py | 13 - .../actions/exit_focus_chat_action.py | 108 ++++ .../planners/actions/no_reply_action.py | 16 - .../planners/actions/plugin_action.py | 2 - .../planners/actions/reply_action.py | 5 - src/chat/focus_chat/planners/planner.py | 59 +- src/chat/heart_flow/background_tasks.py | 73 +-- src/chat/heart_flow/interest_logger.py | 212 ------- src/chat/heart_flow/mai_state_manager.py | 90 +-- .../observation/hfcloop_observation.py | 5 +- src/chat/heart_flow/sub_heartflow.py | 60 +- src/chat/heart_flow/subheartflow_manager.py | 524 ++---------------- src/chat/message_receive/bot.py | 41 -- src/chat/normal_chat/normal_chat.py | 10 +- src/chat/utils/utils.py | 73 ++- src/common/logger.py | 36 ++ src/common/logger_manager.py | 4 + template/bot_config_template.toml | 8 +- 24 files changed, 541 insertions(+), 1029 deletions(-) create mode 100644 src/chat/focus_chat/info/action_info.py create mode 100644 src/chat/focus_chat/info_processors/action_processor.py create mode 100644 src/chat/focus_chat/planners/actions/exit_focus_chat_action.py delete mode 100644 src/chat/heart_flow/interest_logger.py diff --git a/docs/HeartFC_system.md b/docs/HeartFC_system.md index a55f1c973..e48a7b5d7 100644 --- a/docs/HeartFC_system.md +++ b/docs/HeartFC_system.md @@ -149,7 +149,7 @@ c HeartFChatting工作方式 - **状态及含义**: - `ChatState.ABSENT` (不参与/没在看): 初始或停用状态。子心流不观察新信息,不进行思考,也不回复。 - `ChatState.CHAT` (随便看看/水群): 普通聊天模式。激活 `NormalChatInstance`。 - * `ChatState.FOCUSED` (专注/认真水群): 专注聊天模式。激活 `HeartFlowChatInstance`。 + * `ChatState.FOCUSED` (专注/认真聊天): 专注聊天模式。激活 `HeartFlowChatInstance`。 - **选择**: 子心流可以根据外部指令(来自 `SubHeartflowManager`)或内部逻辑(未来的扩展)选择进入 `ABSENT` 状态(不回复不观察),或进入 `CHAT` / `FOCUSED` 中的一种回复模式。 - **状态转换机制** (由 `SubHeartflowManager` 驱动,更细致的说明): - **初始状态**: 新创建的 `SubHeartflow` 默认为 `ABSENT` 状态。 diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 0f5371a36..4f17f9bdf 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -68,7 +68,6 @@ class HeartFChatting: self, chat_id: str, observations: list[Observation], - on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], ): """ HeartFChatting 初始化函数 @@ -76,12 +75,10 @@ class HeartFChatting: 参数: chat_id: 聊天流唯一标识符(如stream_id) observations: 关联的观察列表 - on_consecutive_no_reply_callback: 连续不回复达到阈值时调用的异步回调函数 """ # 基础属性 self.stream_id: str = chat_id # 聊天流ID self.chat_stream: Optional[ChatStream] = None # 关联的聊天流 - self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback self.log_prefix: str = str(chat_id) # Initial default, will be updated self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id) self.chatting_observation = observations[0] @@ -165,7 +162,7 @@ class HeartFChatting: 启动 HeartFChatting 的主循环。 注意:调用此方法前必须确保已经成功初始化。 """ - logger.info(f"{self.log_prefix} 开始认真水群(HFC)...") + logger.info(f"{self.log_prefix} 开始认真聊天(HFC)...") await self._start_loop_if_needed() async def _start_loop_if_needed(self): @@ -463,11 +460,7 @@ class HeartFChatting: observations=self.all_observations, expressor=self.expressor, chat_stream=self.chat_stream, - current_cycle=self._current_cycle, log_prefix=self.log_prefix, - on_consecutive_no_reply_callback=self.on_consecutive_no_reply_callback, - # total_no_reply_count=self.total_no_reply_count, - # total_waiting_time=self.total_waiting_time, shutting_down=self._shutting_down, ) diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index d8d2b836f..532ceccd1 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -234,7 +234,8 @@ class PromptBuilder: reply_style2=reply_style2_chosen, keywords_reaction_prompt=keywords_reaction_prompt, prompt_ger=prompt_ger, - moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), + # moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), + moderation_prompt="", ) else: template_name = "reasoning_prompt_private_main" @@ -256,7 +257,8 @@ class PromptBuilder: reply_style2=reply_style2_chosen, keywords_reaction_prompt=keywords_reaction_prompt, prompt_ger=prompt_ger, - moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), + # moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), + moderation_prompt="", ) # --- End choosing template --- diff --git a/src/chat/focus_chat/info/action_info.py b/src/chat/focus_chat/info/action_info.py new file mode 100644 index 000000000..1bb6b96a6 --- /dev/null +++ b/src/chat/focus_chat/info/action_info.py @@ -0,0 +1,83 @@ +from typing import Dict, Optional, Any, List +from dataclasses import dataclass +from .info_base import InfoBase + + +@dataclass +class ActionInfo(InfoBase): + """动作信息类 + + 用于管理和记录动作的变更信息,包括需要添加或移除的动作。 + 继承自 InfoBase 类,使用字典存储具体数据。 + + Attributes: + type (str): 信息类型标识符,固定为 "action" + + Data Fields: + add_actions (List[str]): 需要添加的动作列表 + remove_actions (List[str]): 需要移除的动作列表 + reason (str): 变更原因说明 + """ + + type: str = "action" + + def get_type(self) -> str: + """获取信息类型""" + return self.type + + def get_data(self) -> Dict[str, Any]: + """获取信息数据""" + return self.data + + def set_action_changes(self, action_changes: Dict[str, List[str]]) -> None: + """设置动作变更信息 + + Args: + action_changes (Dict[str, List[str]]): 包含要增加和删除的动作列表 + { + "add": ["action1", "action2"], + "remove": ["action3"] + } + """ + self.data["add_actions"] = action_changes.get("add", []) + self.data["remove_actions"] = action_changes.get("remove", []) + + def set_reason(self, reason: str) -> None: + """设置变更原因 + + Args: + reason (str): 动作变更的原因说明 + """ + self.data["reason"] = reason + + def get_add_actions(self) -> List[str]: + """获取需要添加的动作列表 + + Returns: + List[str]: 需要添加的动作列表 + """ + return self.data.get("add_actions", []) + + def get_remove_actions(self) -> List[str]: + """获取需要移除的动作列表 + + Returns: + List[str]: 需要移除的动作列表 + """ + return self.data.get("remove_actions", []) + + def get_reason(self) -> Optional[str]: + """获取变更原因 + + Returns: + Optional[str]: 动作变更的原因说明,如果未设置则返回 None + """ + return self.data.get("reason") + + def has_changes(self) -> bool: + """检查是否有动作变更 + + Returns: + bool: 如果有任何动作需要添加或移除则返回True + """ + return bool(self.get_add_actions() or self.get_remove_actions()) \ No newline at end of file diff --git a/src/chat/focus_chat/info_processors/action_processor.py b/src/chat/focus_chat/info_processors/action_processor.py new file mode 100644 index 000000000..a952b38c8 --- /dev/null +++ b/src/chat/focus_chat/info_processors/action_processor.py @@ -0,0 +1,126 @@ +from typing import List, Optional, Any +from src.chat.focus_chat.info.obs_info import ObsInfo +from src.chat.heart_flow.observation.observation import Observation +from src.chat.focus_chat.info.info_base import InfoBase +from src.chat.focus_chat.info.action_info import ActionInfo +from .base_processor import BaseProcessor +from src.common.logger_manager import get_logger +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation +from src.chat.focus_chat.info.cycle_info import CycleInfo +from datetime import datetime +from typing import Dict +from src.chat.models.utils_model import LLMRequest +from src.config.config import global_config +import random + +logger = get_logger("processor") + + +class ActionProcessor(BaseProcessor): + """动作处理器 + + 用于处理Observation对象,将其转换为ObsInfo对象。 + """ + + log_prefix = "聊天信息处理" + + def __init__(self): + """初始化观察处理器""" + super().__init__() + # TODO: API-Adapter修改标记 + self.model_summary = LLMRequest( + model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation" + ) + + async def process_info( + self, + observations: Optional[List[Observation]] = None, + running_memorys: Optional[List[Dict]] = None, + **kwargs: Any, + ) -> List[InfoBase]: + """处理Observation对象 + + Args: + infos: InfoBase对象列表 + observations: 可选的Observation对象列表 + **kwargs: 其他可选参数 + + Returns: + List[InfoBase]: 处理后的ObsInfo实例列表 + """ + # print(f"observations: {observations}") + processed_infos = [] + + # 处理Observation对象 + if observations: + for obs in observations: + + if isinstance(obs, HFCloopObservation): + + + # 创建动作信息 + action_info = ActionInfo() + action_changes = await self.analyze_loop_actions(obs) + if action_changes["add"] or action_changes["remove"]: + action_info.set_action_changes(action_changes) + # 设置变更原因 + reasons = [] + if action_changes["add"]: + reasons.append(f"添加动作{action_changes['add']}因为检测到大量无回复") + if action_changes["remove"]: + reasons.append(f"移除动作{action_changes['remove']}因为检测到连续回复") + action_info.set_reason(" | ".join(reasons)) + processed_infos.append(action_info) + + return processed_infos + + + async def analyze_loop_actions(self, obs: HFCloopObservation) -> Dict[str, List[str]]: + """分析最近的循环内容并决定动作的增减 + + Returns: + Dict[str, List[str]]: 包含要增加和删除的动作 + { + "add": ["action1", "action2"], + "remove": ["action3"] + } + """ + result = {"add": [], "remove": []} + + # 获取最近10次循环 + recent_cycles = obs.history_loop[-10:] if len(obs.history_loop) > 10 else obs.history_loop + if not recent_cycles: + return result + + # 统计no_reply的数量 + no_reply_count = 0 + reply_sequence = [] # 记录最近的动作序列 + + for cycle in recent_cycles: + action_type = cycle.loop_plan_info["action_result"]["action_type"] + if action_type == "no_reply": + no_reply_count += 1 + reply_sequence.append(action_type == "reply") + + # 检查no_reply比例 + if len(recent_cycles) >= 5 and (no_reply_count / len(recent_cycles)) >= 0.8: + result["add"].append("exit_focus_chat") + + # 获取最近三次的reply状态 + last_three = reply_sequence[-3:] if len(reply_sequence) >= 3 else reply_sequence + + # 根据最近的reply情况决定是否移除reply动作 + if len(last_three) >= 3 and all(last_three): + # 如果最近三次都是reply,直接移除 + result["remove"].append("reply") + elif len(last_three) >= 2 and all(last_three[-2:]): + # 如果最近两次都是reply,40%概率移除 + if random.random() < 0.4: + result["remove"].append("reply") + elif last_three and last_three[-1]: + # 如果最近一次是reply,20%概率移除 + if random.random() < 0.2: + result["remove"].append("reply") + + return result diff --git a/src/chat/focus_chat/info_processors/self_processor.py b/src/chat/focus_chat/info_processors/self_processor.py index 5114e49b6..4e7e8544b 100644 --- a/src/chat/focus_chat/info_processors/self_processor.py +++ b/src/chat/focus_chat/info_processors/self_processor.py @@ -22,7 +22,7 @@ logger = get_logger("processor") def init_prompt(): indentify_prompt = """ 你的名字是{bot_name},你是一个AI,但是你伪装成了一个人类,你的人格是,{prompt_personality}。 -你对外的形象是一只橙色的鱼,头上有绿色的树叶,你用的头像也是这个。 +你的头像形象是一只橙色的鱼,头上有绿色的树叶。 {relation_prompt} {memory_str} @@ -36,6 +36,9 @@ def init_prompt(): 3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同 4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景 +请回复的平淡一些,简短一些,说中文,不要浮夸,平淡一些。 +请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出内容。 + """ Prompt(indentify_prompt, "indentify_prompt") diff --git a/src/chat/focus_chat/planners/action_manager.py b/src/chat/focus_chat/planners/action_manager.py index 2ee7f349d..60ab0babf 100644 --- a/src/chat/focus_chat/planners/action_manager.py +++ b/src/chat/focus_chat/planners/action_manager.py @@ -137,11 +137,7 @@ class ActionManager: observations: List[Observation], expressor: DefaultExpressor, chat_stream: ChatStream, - current_cycle: CycleDetail, log_prefix: str, - on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], - # total_no_reply_count: int = 0, - # total_waiting_time: float = 0.0, shutting_down: bool = False, ) -> Optional[BaseAction]: """ @@ -156,11 +152,7 @@ class ActionManager: observations: 观察列表 expressor: 表达器 chat_stream: 聊天流 - current_cycle: 当前循环信息 log_prefix: 日志前缀 - on_consecutive_no_reply_callback: 连续不回复回调 - total_no_reply_count: 连续不回复计数 - total_waiting_time: 累计等待时间 shutting_down: 是否正在关闭 Returns: @@ -179,7 +171,6 @@ class ActionManager: try: # 创建动作实例 instance = handler_class( - action_name=action_name, action_data=action_data, reasoning=reasoning, cycle_timers=cycle_timers, @@ -187,11 +178,7 @@ class ActionManager: observations=observations, expressor=expressor, chat_stream=chat_stream, - current_cycle=current_cycle, log_prefix=log_prefix, - on_consecutive_no_reply_callback=on_consecutive_no_reply_callback, - # total_no_reply_count=total_no_reply_count, - # total_waiting_time=total_waiting_time, shutting_down=shutting_down, ) diff --git a/src/chat/focus_chat/planners/actions/exit_focus_chat_action.py b/src/chat/focus_chat/planners/actions/exit_focus_chat_action.py new file mode 100644 index 000000000..6aeb68ccd --- /dev/null +++ b/src/chat/focus_chat/planners/actions/exit_focus_chat_action.py @@ -0,0 +1,108 @@ +import asyncio +import traceback +from src.common.logger_manager import get_logger +from src.chat.utils.timer_calculator import Timer +from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action +from typing import Tuple, List, Callable, Coroutine +from src.chat.heart_flow.observation.observation import Observation +from src.chat.heart_flow.observation.chatting_observation import ChattingObservation +from src.chat.heart_flow.sub_heartflow import SubHeartFlow +from src.chat.message_receive.chat_stream import ChatStream +from src.chat.heart_flow.heartflow import heartflow +from src.chat.heart_flow.sub_heartflow import ChatState + +logger = get_logger("action_taken") + + +@register_action +class ExitFocusChatAction(BaseAction): + """退出专注聊天动作处理类 + + 处理决定退出专注聊天的动作。 + 执行后会将所属的sub heartflow转变为normal_chat状态。 + """ + + action_name = "exit_focus_chat" + action_description = "退出专注聊天,转为普通聊天模式" + action_parameters = {} + action_require = [ + "很长时间没有回复,你决定退出专注聊天", + "当前内容不需要持续专注关注,你决定退出专注聊天", + "聊天内容已经完成,你决定退出专注聊天", + ] + default = True + + def __init__( + self, + action_data: dict, + reasoning: str, + cycle_timers: dict, + thinking_id: str, + observations: List[Observation], + log_prefix: str, + chat_stream: ChatStream, + shutting_down: bool = False, + **kwargs, + ): + """初始化退出专注聊天动作处理器 + + Args: + action_data: 动作数据 + reasoning: 执行该动作的理由 + cycle_timers: 计时器字典 + thinking_id: 思考ID + observations: 观察列表 + log_prefix: 日志前缀 + shutting_down: 是否正在关闭 + """ + super().__init__(action_data, reasoning, cycle_timers, thinking_id) + self.observations = observations + self.log_prefix = log_prefix + self._shutting_down = shutting_down + self.chat_id = chat_stream.stream_id + + + + async def handle_action(self) -> Tuple[bool, str]: + """ + 处理退出专注聊天的情况 + + 工作流程: + 1. 将sub heartflow转换为normal_chat状态 + 2. 等待新消息、超时或关闭信号 + 3. 根据等待结果更新连续不回复计数 + 4. 如果达到阈值,触发回调 + + Returns: + Tuple[bool, str]: (是否执行成功, 状态转换消息) + """ + try: + # 转换状态 + status_message = "" + self.sub_heartflow = await heartflow.get_or_create_subheartflow(self.chat_id) + if self.sub_heartflow: + try: + # 转换为normal_chat状态 + await self.sub_heartflow.change_chat_state(ChatState.NORMAL_CHAT) + status_message = "已成功切换到普通聊天模式" + logger.info(f"{self.log_prefix} {status_message}") + except Exception as e: + error_msg = f"切换到普通聊天模式失败: {str(e)}" + logger.error(f"{self.log_prefix} {error_msg}") + return False, error_msg + else: + warning_msg = "未找到有效的sub heartflow实例,无法切换状态" + logger.warning(f"{self.log_prefix} {warning_msg}") + return False, warning_msg + + + return True, status_message + + except asyncio.CancelledError: + logger.info(f"{self.log_prefix} 处理 'exit_focus_chat' 时等待被中断 (CancelledError)") + raise + except Exception as e: + error_msg = f"处理 'exit_focus_chat' 时发生错误: {str(e)}" + logger.error(f"{self.log_prefix} {error_msg}") + logger.error(traceback.format_exc()) + return False, error_msg \ No newline at end of file diff --git a/src/chat/focus_chat/planners/actions/no_reply_action.py b/src/chat/focus_chat/planners/actions/no_reply_action.py index c6852fbe1..6e31d5abb 100644 --- a/src/chat/focus_chat/planners/actions/no_reply_action.py +++ b/src/chat/focus_chat/planners/actions/no_reply_action.py @@ -6,14 +6,12 @@ from src.chat.focus_chat.planners.actions.base_action import BaseAction, registe from typing import Tuple, List, Callable, Coroutine from src.chat.heart_flow.observation.observation import Observation from src.chat.heart_flow.observation.chatting_observation import ChattingObservation -from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp logger = get_logger("action_taken") # 常量定义 WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒 -CONSECUTIVE_NO_REPLY_THRESHOLD = 3 # 连续不回复的阈值 @register_action @@ -40,11 +38,7 @@ class NoReplyAction(BaseAction): cycle_timers: dict, thinking_id: str, observations: List[Observation], - on_consecutive_no_reply_callback: Callable[[], Coroutine[None, None, None]], - current_cycle: CycleDetail, log_prefix: str, - # total_no_reply_count: int = 0, - # total_waiting_time: float = 0.0, shutting_down: bool = False, **kwargs, ): @@ -57,20 +51,12 @@ class NoReplyAction(BaseAction): cycle_timers: 计时器字典 thinking_id: 思考ID observations: 观察列表 - on_consecutive_no_reply_callback: 连续不回复达到阈值时调用的回调函数 - current_cycle: 当前循环信息 log_prefix: 日志前缀 - total_no_reply_count: 连续不回复计数 - total_waiting_time: 累计等待时间 shutting_down: 是否正在关闭 """ super().__init__(action_data, reasoning, cycle_timers, thinking_id) self.observations = observations - self.on_consecutive_no_reply_callback = on_consecutive_no_reply_callback - self._current_cycle = current_cycle self.log_prefix = log_prefix - # self.total_no_reply_count = total_no_reply_count - # self.total_waiting_time = total_waiting_time self._shutting_down = shutting_down async def handle_action(self) -> Tuple[bool, str]: @@ -93,8 +79,6 @@ class NoReplyAction(BaseAction): with Timer("等待新消息", self.cycle_timers): # 等待新消息、超时或关闭信号,并获取结果 await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix) - # 从计时器获取实际等待时间 - _current_waiting = self.cycle_timers.get("等待新消息", 0.0) return True, "" # 不回复动作没有回复文本 diff --git a/src/chat/focus_chat/planners/actions/plugin_action.py b/src/chat/focus_chat/planners/actions/plugin_action.py index 5e8ddd998..94754d021 100644 --- a/src/chat/focus_chat/planners/actions/plugin_action.py +++ b/src/chat/focus_chat/planners/actions/plugin_action.py @@ -30,8 +30,6 @@ class PluginAction(BaseAction): self._services["expressor"] = kwargs["expressor"] if "chat_stream" in kwargs: self._services["chat_stream"] = kwargs["chat_stream"] - if "current_cycle" in kwargs: - self._services["current_cycle"] = kwargs["current_cycle"] self.log_prefix = kwargs.get("log_prefix", "") diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 07e35b458..45a4340d5 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -6,7 +6,6 @@ from typing import Tuple, List from src.chat.heart_flow.observation.observation import Observation from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor from src.chat.message_receive.chat_stream import ChatStream -from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.focus_chat.hfc_utils import create_empty_anchor_message @@ -41,7 +40,6 @@ class ReplyAction(BaseAction): def __init__( self, - action_name: str, action_data: dict, reasoning: str, cycle_timers: dict, @@ -49,7 +47,6 @@ class ReplyAction(BaseAction): observations: List[Observation], expressor: DefaultExpressor, chat_stream: ChatStream, - current_cycle: CycleDetail, log_prefix: str, **kwargs, ): @@ -64,14 +61,12 @@ class ReplyAction(BaseAction): observations: 观察列表 expressor: 表达器 chat_stream: 聊天流 - current_cycle: 当前循环信息 log_prefix: 日志前缀 """ super().__init__(action_data, reasoning, cycle_timers, thinking_id) self.observations = observations self.expressor = expressor self.chat_stream = chat_stream - self._current_cycle = current_cycle self.log_prefix = log_prefix async def handle_action(self) -> Tuple[bool, str]: diff --git a/src/chat/focus_chat/planners/planner.py b/src/chat/focus_chat/planners/planner.py index 116419ee1..ca35d3096 100644 --- a/src/chat/focus_chat/planners/planner.py +++ b/src/chat/focus_chat/planners/planner.py @@ -8,12 +8,12 @@ from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.obs_info import ObsInfo from src.chat.focus_chat.info.cycle_info import CycleInfo from src.chat.focus_chat.info.mind_info import MindInfo +from src.chat.focus_chat.info.action_info import ActionInfo from src.chat.focus_chat.info.structured_info import StructuredInfo from src.common.logger_manager import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.individuality.individuality import Individuality from src.chat.focus_chat.planners.action_manager import ActionManager -from src.chat.focus_chat.planners.action_manager import ActionInfo logger = get_logger("planner") @@ -87,34 +87,68 @@ class ActionPlanner: action = "no_reply" # 默认动作 reasoning = "规划器初始化默认" + action_data = {} try: # 获取观察信息 extra_info: list[str] = [] + + # 首先处理动作变更 + for info in all_plan_info: + if isinstance(info, ActionInfo) and info.has_changes(): + add_actions = info.get_add_actions() + remove_actions = info.get_remove_actions() + reason = info.get_reason() + + # 处理动作的增加 + for action_name in add_actions: + if action_name in self.action_manager.get_registered_actions(): + self.action_manager.add_action_to_using(action_name) + logger.debug(f"{self.log_prefix}添加动作: {action_name}, 原因: {reason}") + + # 处理动作的移除 + for action_name in remove_actions: + self.action_manager.remove_action_from_using(action_name) + logger.debug(f"{self.log_prefix}移除动作: {action_name}, 原因: {reason}") + + # 如果当前选择的动作被移除了,更新为no_reply + if action in remove_actions: + action = "no_reply" + reasoning = f"之前选择的动作{action}已被移除,原因: {reason}" + + # 继续处理其他信息 for info in all_plan_info: if isinstance(info, ObsInfo): - # logger.debug(f"{self.log_prefix} 观察信息: {info}") observed_messages = info.get_talking_message() observed_messages_str = info.get_talking_message_str_truncate() chat_type = info.get_chat_type() - if chat_type == "group": - is_group_chat = True - else: - is_group_chat = False + is_group_chat = (chat_type == "group") elif isinstance(info, MindInfo): - # logger.debug(f"{self.log_prefix} 思维信息: {info}") current_mind = info.get_current_mind() elif isinstance(info, CycleInfo): - # logger.debug(f"{self.log_prefix} 循环信息: {info}") cycle_info = info.get_observe_info() elif isinstance(info, StructuredInfo): - # logger.debug(f"{self.log_prefix} 结构化信息: {info}") _structured_info = info.get_data() - else: - logger.debug(f"{self.log_prefix} 其他信息: {info}") + elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo extra_info.append(info.get_processed_info()) + # 获取当前可用的动作 current_available_actions = self.action_manager.get_using_actions() + + # 如果没有可用动作,直接返回no_reply + if not current_available_actions: + logger.warning(f"{self.log_prefix}没有可用的动作,将使用no_reply") + action = "no_reply" + reasoning = "没有可用的动作" + return { + "action_result": { + "action_type": action, + "action_data": action_data, + "reasoning": reasoning + }, + "current_mind": current_mind, + "observed_messages": observed_messages + } # --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- prompt = await self.build_planner_prompt( @@ -181,7 +215,7 @@ class ActionPlanner: except Exception as outer_e: logger.error(f"{self.log_prefix}Planner 处理过程中发生意外错误,规划失败,将执行 no_reply: {outer_e}") traceback.print_exc() - action = "no_reply" # 发生未知错误,标记为 error 动作 + action = "no_reply" reasoning = f"Planner 内部处理错误: {outer_e}" logger.debug( @@ -202,7 +236,6 @@ class ActionPlanner: "observed_messages": observed_messages, } - # 返回结果字典 return plan_result async def build_planner_prompt( diff --git a/src/chat/heart_flow/background_tasks.py b/src/chat/heart_flow/background_tasks.py index d9fa1c9d3..28b248bdc 100644 --- a/src/chat/heart_flow/background_tasks.py +++ b/src/chat/heart_flow/background_tasks.py @@ -1,13 +1,9 @@ import asyncio import traceback from typing import Optional, Coroutine, Callable, Any, List - from src.common.logger_manager import get_logger - -# Need manager types for dependency injection from src.chat.heart_flow.mai_state_manager import MaiStateManager, MaiStateInfo from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager -from src.chat.heart_flow.interest_logger import InterestLogger logger = get_logger("background_tasks") @@ -62,23 +58,18 @@ class BackgroundTaskManager: mai_state_info: MaiStateInfo, # Needs current state info mai_state_manager: MaiStateManager, subheartflow_manager: SubHeartflowManager, - interest_logger: InterestLogger, ): self.mai_state_info = mai_state_info self.mai_state_manager = mai_state_manager self.subheartflow_manager = subheartflow_manager - self.interest_logger = interest_logger # Task references self._state_update_task: Optional[asyncio.Task] = None self._cleanup_task: Optional[asyncio.Task] = None - self._logging_task: Optional[asyncio.Task] = None - self._normal_chat_timeout_check_task: Optional[asyncio.Task] = None self._hf_judge_state_update_task: Optional[asyncio.Task] = None self._into_focus_task: Optional[asyncio.Task] = None self._private_chat_activation_task: Optional[asyncio.Task] = None # 新增私聊激活任务引用 self._tasks: List[Optional[asyncio.Task]] = [] # Keep track of all tasks - self._detect_command_from_gui_task: Optional[asyncio.Task] = None # 新增GUI命令检测任务引用 async def start_tasks(self): """启动所有后台任务 @@ -97,30 +88,12 @@ class BackgroundTaskManager: f"聊天状态更新任务已启动 间隔:{STATE_UPDATE_INTERVAL_SECONDS}s", "_state_update_task", ), - ( - lambda: self._run_normal_chat_timeout_check_cycle(NORMAL_CHAT_TIMEOUT_CHECK_INTERVAL_SECONDS), - "debug", - f"聊天超时检查任务已启动 间隔:{NORMAL_CHAT_TIMEOUT_CHECK_INTERVAL_SECONDS}s", - "_normal_chat_timeout_check_task", - ), - ( - lambda: self._run_absent_into_chat(HF_JUDGE_STATE_UPDATE_INTERVAL_SECONDS), - "debug", - f"状态评估任务已启动 间隔:{HF_JUDGE_STATE_UPDATE_INTERVAL_SECONDS}s", - "_hf_judge_state_update_task", - ), ( self._run_cleanup_cycle, "info", f"清理任务已启动 间隔:{CLEANUP_INTERVAL_SECONDS}s", "_cleanup_task", ), - ( - self._run_logging_cycle, - "info", - f"日志任务已启动 间隔:{LOG_INTERVAL_SECONDS}s", - "_logging_task", - ), # 新增兴趣评估任务配置 ( self._run_into_focus_cycle, @@ -136,13 +109,6 @@ class BackgroundTaskManager: f"私聊激活检查任务已启动 间隔:{PRIVATE_CHAT_ACTIVATION_CHECK_INTERVAL_SECONDS}s", "_private_chat_activation_task", ), - # 新增GUI命令检测任务配置 - # ( - # lambda: self._run_detect_command_from_gui_cycle(3), - # "debug", - # f"GUI命令检测任务已启动 间隔:{3}s", - # "_detect_command_from_gui_task", - # ), ] # 统一启动所有任务 @@ -207,7 +173,6 @@ class BackgroundTaskManager: if state_changed: current_state = self.mai_state_info.get_current_state() - await self.subheartflow_manager.enforce_subheartflow_limits() # 状态转换处理 @@ -218,15 +183,6 @@ class BackgroundTaskManager: logger.info("检测到离线,停用所有子心流") await self.subheartflow_manager.deactivate_all_subflows() - async def _perform_absent_into_chat(self): - """调用llm检测是否转换ABSENT-CHAT状态""" - logger.debug("[状态评估任务] 开始基于LLM评估子心流状态...") - await self.subheartflow_manager.sbhf_absent_into_chat() - - async def _normal_chat_timeout_check_work(self): - """检查处于CHAT状态的子心流是否因长时间未发言而超时,并将其转为ABSENT""" - logger.debug("[聊天超时检查] 开始检查处于CHAT状态的子心流...") - await self.subheartflow_manager.sbhf_chat_into_absent() async def _perform_cleanup_work(self): """执行子心流清理任务 @@ -253,42 +209,23 @@ class BackgroundTaskManager: # 记录最终清理结果 logger.info(f"[清理任务] 清理完成, 共停止 {stopped_count}/{len(flows_to_stop)} 个子心流") - async def _perform_logging_work(self): - """执行一轮状态日志记录。""" - await self.interest_logger.log_all_states() # --- 新增兴趣评估工作函数 --- async def _perform_into_focus_work(self): """执行一轮子心流兴趣评估与提升检查。""" # 直接调用 subheartflow_manager 的方法,并传递当前状态信息 await self.subheartflow_manager.sbhf_absent_into_focus() - - # --- 结束新增 --- - - # --- 结束新增 --- - - # --- Specific Task Runners --- # + async def _run_state_update_cycle(self, interval: int): await _run_periodic_loop(task_name="State Update", interval=interval, task_func=self._perform_state_update_work) - async def _run_absent_into_chat(self, interval: int): - await _run_periodic_loop(task_name="Into Chat", interval=interval, task_func=self._perform_absent_into_chat) - async def _run_normal_chat_timeout_check_cycle(self, interval: int): - await _run_periodic_loop( - task_name="Normal Chat Timeout Check", interval=interval, task_func=self._normal_chat_timeout_check_work - ) async def _run_cleanup_cycle(self): await _run_periodic_loop( task_name="Subflow Cleanup", interval=CLEANUP_INTERVAL_SECONDS, task_func=self._perform_cleanup_work ) - async def _run_logging_cycle(self): - await _run_periodic_loop( - task_name="State Logging", interval=LOG_INTERVAL_SECONDS, task_func=self._perform_logging_work - ) - # --- 新增兴趣评估任务运行器 --- async def _run_into_focus_cycle(self): await _run_periodic_loop( @@ -304,11 +241,3 @@ class BackgroundTaskManager: interval=interval, task_func=self.subheartflow_manager.sbhf_absent_private_into_focus, ) - - # # 有api之后删除 - # async def _run_detect_command_from_gui_cycle(self, interval: int): - # await _run_periodic_loop( - # task_name="Detect Command from GUI", - # interval=interval, - # task_func=self.subheartflow_manager.detect_command_from_gui, - # ) diff --git a/src/chat/heart_flow/interest_logger.py b/src/chat/heart_flow/interest_logger.py deleted file mode 100644 index b33f449db..000000000 --- a/src/chat/heart_flow/interest_logger.py +++ /dev/null @@ -1,212 +0,0 @@ -import asyncio -import time -import json -import os -import traceback -from typing import TYPE_CHECKING, Dict, List - -from src.common.logger_manager import get_logger - -# Need chat_manager to get stream names -from src.chat.message_receive.chat_stream import chat_manager - -if TYPE_CHECKING: - from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager - from src.chat.heart_flow.sub_heartflow import SubHeartflow - from src.chat.heart_flow.heartflow import Heartflow # 导入 Heartflow 类型 - - -logger = get_logger("interest") - -# Consider moving log directory/filename constants here -LOG_DIRECTORY = "logs/interest" -HISTORY_LOG_FILENAME = "interest_history.log" - - -def _ensure_log_directory(): - """确保日志目录存在。""" - os.makedirs(LOG_DIRECTORY, exist_ok=True) - logger.info(f"已确保日志目录 '{LOG_DIRECTORY}' 存在") - - -def _clear_and_create_log_file(): - """清除日志文件并创建新的日志文件。""" - if os.path.exists(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME)): - os.remove(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME)) - with open(os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME), "w", encoding="utf-8") as f: - f.write("") - - -class InterestLogger: - """负责定期记录主心流和所有子心流的状态到日志文件。""" - - def __init__(self, subheartflow_manager: "SubHeartflowManager", heartflow: "Heartflow"): - """ - 初始化 InterestLogger。 - - Args: - subheartflow_manager: 子心流管理器实例。 - heartflow: 主心流实例,用于获取主心流状态。 - """ - self.subheartflow_manager = subheartflow_manager - self.heartflow = heartflow # 存储 Heartflow 实例 - self._history_log_file_path = os.path.join(LOG_DIRECTORY, HISTORY_LOG_FILENAME) - _ensure_log_directory() - _clear_and_create_log_file() - - async def get_all_subflow_states(self) -> Dict[str, Dict]: - """并发获取所有活跃子心流的当前完整状态。""" - all_flows: List["SubHeartflow"] = self.subheartflow_manager.get_all_subheartflows() - tasks = [] - results = {} - - if not all_flows: - # logger.debug("未找到任何子心流状态") - return results - - for subheartflow in all_flows: - if await self.subheartflow_manager.get_or_create_subheartflow(subheartflow.subheartflow_id): - tasks.append( - asyncio.create_task(subheartflow.get_full_state(), name=f"get_state_{subheartflow.subheartflow_id}") - ) - else: - logger.warning(f"子心流 {subheartflow.subheartflow_id} 在创建任务前已消失") - - if tasks: - done, pending = await asyncio.wait(tasks, timeout=5.0) - - if pending: - logger.warning(f"获取子心流状态超时,有 {len(pending)} 个任务未完成") - for task in pending: - task.cancel() - - for task in done: - stream_id_str = task.get_name().split("get_state_")[-1] - stream_id = stream_id_str - - if task.cancelled(): - logger.warning(f"获取子心流 {stream_id} 状态的任务已取消(超时)", exc_info=False) - elif task.exception(): - exc = task.exception() - logger.warning(f"获取子心流 {stream_id} 状态出错: {exc}") - else: - result = task.result() - results[stream_id] = result - - logger.trace(f"成功获取 {len(results)} 个子心流的完整状态") - return results - - async def log_all_states(self): - """获取主心流状态和所有子心流的完整状态并写入日志文件。""" - try: - current_timestamp = time.time() - - # main_mind = self.heartflow.current_mind - # 获取 Mai 状态名称 - mai_state_name = self.heartflow.current_state.get_current_state().name - - all_subflow_states = await self.get_all_subflow_states() - - log_entry_base = { - "timestamp": round(current_timestamp, 2), - # "main_mind": main_mind, - "mai_state": mai_state_name, - "subflow_count": len(all_subflow_states), - "subflows": [], - } - - if not all_subflow_states: - # logger.debug("没有获取到任何子心流状态,仅记录主心流状态") - with open(self._history_log_file_path, "a", encoding="utf-8") as f: - f.write(json.dumps(log_entry_base, ensure_ascii=False) + "\n") - return - - subflow_details = [] - items_snapshot = list(all_subflow_states.items()) - for stream_id, state in items_snapshot: - group_name = stream_id - try: - chat_stream = chat_manager.get_stream(stream_id) - if chat_stream: - if chat_stream.group_info: - group_name = chat_stream.group_info.group_name - elif chat_stream.user_info: - group_name = f"私聊_{chat_stream.user_info.user_nickname}" - except Exception as e: - logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}") - - interest_state = state.get("interest_state", {}) - - subflow_entry = { - "stream_id": stream_id, - "group_name": group_name, - "sub_mind": state.get("current_mind", "未知"), - "sub_chat_state": state.get("chat_state", "未知"), - "interest_level": interest_state.get("interest_level", 0.0), - "start_hfc_probability": interest_state.get("start_hfc_probability", 0.0), - # "is_above_threshold": interest_state.get("is_above_threshold", False), - } - subflow_details.append(subflow_entry) - - log_entry_base["subflows"] = subflow_details - - with open(self._history_log_file_path, "a", encoding="utf-8") as f: - f.write(json.dumps(log_entry_base, ensure_ascii=False) + "\n") - - except IOError as e: - logger.error(f"写入状态日志到 {self._history_log_file_path} 出错: {e}") - except Exception as e: - logger.error(f"记录状态时发生意外错误: {e}") - logger.error(traceback.format_exc()) - - async def api_get_all_states(self): - """获取主心流和所有子心流的状态。""" - try: - current_timestamp = time.time() - - # main_mind = self.heartflow.current_mind - # 获取 Mai 状态名称 - mai_state_name = self.heartflow.current_state.get_current_state().name - - all_subflow_states = await self.get_all_subflow_states() - - log_entry_base = { - "timestamp": round(current_timestamp, 2), - # "main_mind": main_mind, - "mai_state": mai_state_name, - "subflow_count": len(all_subflow_states), - "subflows": [], - } - - subflow_details = [] - items_snapshot = list(all_subflow_states.items()) - for stream_id, state in items_snapshot: - group_name = stream_id - try: - chat_stream = chat_manager.get_stream(stream_id) - if chat_stream: - if chat_stream.group_info: - group_name = chat_stream.group_info.group_name - elif chat_stream.user_info: - group_name = f"私聊_{chat_stream.user_info.user_nickname}" - except Exception as e: - logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}") - - interest_state = state.get("interest_state", {}) - - subflow_entry = { - "stream_id": stream_id, - "group_name": group_name, - "sub_mind": state.get("current_mind", "未知"), - "sub_chat_state": state.get("chat_state", "未知"), - "interest_level": interest_state.get("interest_level", 0.0), - "start_hfc_probability": interest_state.get("start_hfc_probability", 0.0), - # "is_above_threshold": interest_state.get("is_above_threshold", False), - } - subflow_details.append(subflow_entry) - - log_entry_base["subflows"] = subflow_details - return subflow_details - except Exception as e: - logger.error(f"记录状态时发生意外错误: {e}") - logger.error(traceback.format_exc()) diff --git a/src/chat/heart_flow/mai_state_manager.py b/src/chat/heart_flow/mai_state_manager.py index 017656ad2..c5e272796 100644 --- a/src/chat/heart_flow/mai_state_manager.py +++ b/src/chat/heart_flow/mai_state_manager.py @@ -13,67 +13,24 @@ logger = get_logger("mai_state") # The line `enable_unlimited_hfc_chat = False` is setting a configuration parameter that controls # whether a specific debugging feature is enabled or not. When `enable_unlimited_hfc_chat` is set to # `False`, it means that the debugging feature for unlimited focused chatting is disabled. -enable_unlimited_hfc_chat = True # 调试用:无限专注聊天 -# enable_unlimited_hfc_chat = False +# enable_unlimited_hfc_chat = True # 调试用:无限专注聊天 +enable_unlimited_hfc_chat = False prevent_offline_state = True -# 目前默认不启用OFFLINE状态 - -MAX_NORMAL_CHAT_NUM_PEEKING = int(global_config.chat.base_normal_chat_num / 2) -MAX_NORMAL_CHAT_NUM_NORMAL = global_config.chat.base_normal_chat_num -MAX_NORMAL_CHAT_NUM_FOCUSED = global_config.chat.base_normal_chat_num + 1 - -# 不同状态下专注聊天的最大消息数 -MAX_FOCUSED_CHAT_NUM_PEEKING = int(global_config.chat.base_focused_chat_num / 2) -MAX_FOCUSED_CHAT_NUM_NORMAL = global_config.chat.base_focused_chat_num -MAX_FOCUSED_CHAT_NUM_FOCUSED = global_config.chat.base_focused_chat_num + 2 - -# -- 状态定义 -- +# 目前默认不启用OFFLINE状 class MaiState(enum.Enum): """ 聊天状态: OFFLINE: 不在线:回复概率极低,不会进行任何聊天 - PEEKING: 看一眼手机:回复概率较低,会进行一些普通聊天 NORMAL_CHAT: 正常看手机:回复概率较高,会进行一些普通聊天和少量的专注聊天 FOCUSED_CHAT: 专注聊天:回复概率极高,会进行专注聊天和少量的普通聊天 """ OFFLINE = "不在线" - PEEKING = "看一眼手机" NORMAL_CHAT = "正常看手机" FOCUSED_CHAT = "专心看手机" - def get_normal_chat_max_num(self): - # 调试用 - if enable_unlimited_hfc_chat: - return 1000 - - if self == MaiState.OFFLINE: - return 0 - elif self == MaiState.PEEKING: - return MAX_NORMAL_CHAT_NUM_PEEKING - elif self == MaiState.NORMAL_CHAT: - return MAX_NORMAL_CHAT_NUM_NORMAL - elif self == MaiState.FOCUSED_CHAT: - return MAX_NORMAL_CHAT_NUM_FOCUSED - return None - - def get_focused_chat_max_num(self): - # 调试用 - if enable_unlimited_hfc_chat: - return 1000 - - if self == MaiState.OFFLINE: - return 0 - elif self == MaiState.PEEKING: - return MAX_FOCUSED_CHAT_NUM_PEEKING - elif self == MaiState.NORMAL_CHAT: - return MAX_FOCUSED_CHAT_NUM_NORMAL - elif self == MaiState.FOCUSED_CHAT: - return MAX_FOCUSED_CHAT_NUM_FOCUSED - return None - class MaiStateInfo: def __init__(self): @@ -143,34 +100,18 @@ class MaiStateManager: _time_since_last_min_check = current_time - current_state_info.last_min_check_time next_state: Optional[MaiState] = None - # 辅助函数:根据 prevent_offline_state 标志调整目标状态 def _resolve_offline(candidate_state: MaiState) -> MaiState: - # 现在不再切换到OFFLINE,直接返回当前状态 if candidate_state == MaiState.OFFLINE: return current_status return candidate_state if current_status == MaiState.OFFLINE: logger.info("当前[离线],没看手机,思考要不要上线看看......") - elif current_status == MaiState.PEEKING: - logger.info("当前[看一眼手机],思考要不要继续聊下去......") elif current_status == MaiState.NORMAL_CHAT: logger.info("当前在[正常看手机]思考要不要继续聊下去......") elif current_status == MaiState.FOCUSED_CHAT: logger.info("当前在[专心看手机]思考要不要继续聊下去......") - # 1. 移除每分钟概率切换到OFFLINE的逻辑 - # if time_since_last_min_check >= 60: - # if current_status != MaiState.OFFLINE: - # if random.random() < 0.03: # 3% 概率切换到 OFFLINE - # potential_next = MaiState.OFFLINE - # resolved_next = _resolve_offline(potential_next) - # logger.debug(f"概率触发下线,resolve 为 {resolved_next.value}") - # # 只有当解析后的状态与当前状态不同时才设置 next_state - # if resolved_next != current_status: - # next_state = resolved_next - - # 2. 状态持续时间规则 (只有在规则1没有触发状态改变时才检查) if next_state is None: time_limit_exceeded = False choices_list = [] @@ -178,44 +119,33 @@ class MaiStateManager: rule_id = "" if current_status == MaiState.OFFLINE: - # OFFLINE 状态不再自动切换,直接返回 None return None - elif current_status == MaiState.PEEKING: - if time_in_current_status >= 600: # PEEKING 最多持续 600 秒 - time_limit_exceeded = True - rule_id = "2.2 (From PEEKING)" - weights = [50, 50] - choices_list = [MaiState.NORMAL_CHAT, MaiState.FOCUSED_CHAT] elif current_status == MaiState.NORMAL_CHAT: if time_in_current_status >= 300: # NORMAL_CHAT 最多持续 300 秒 time_limit_exceeded = True rule_id = "2.3 (From NORMAL_CHAT)" - weights = [50, 50] - choices_list = [MaiState.PEEKING, MaiState.FOCUSED_CHAT] + weights = [100] + choices_list = [MaiState.FOCUSED_CHAT] elif current_status == MaiState.FOCUSED_CHAT: if time_in_current_status >= 600: # FOCUSED_CHAT 最多持续 600 秒 time_limit_exceeded = True rule_id = "2.4 (From FOCUSED_CHAT)" - weights = [50, 50] - choices_list = [MaiState.NORMAL_CHAT, MaiState.PEEKING] + weights = [100] + choices_list = [MaiState.NORMAL_CHAT] if time_limit_exceeded: next_state_candidate = random.choices(choices_list, weights=weights, k=1)[0] resolved_candidate = _resolve_offline(next_state_candidate) logger.debug( - f"规则{rule_id}:时间到,随机选择 {next_state_candidate.value},resolve 为 {resolved_candidate.value}" + f"规则{rule_id}:时间到,切换到 {next_state_candidate.value},resolve 为 {resolved_candidate.value}" ) - next_state = resolved_candidate # 直接使用解析后的状态 + next_state = resolved_candidate - # 注意:enable_unlimited_hfc_chat 优先级高于 prevent_offline_state - # 如果触发了这个,它会覆盖上面规则2设置的 next_state if enable_unlimited_hfc_chat: logger.debug("调试用:开挂了,强制切换到专注聊天") next_state = MaiState.FOCUSED_CHAT - # --- 最终决策 --- # - # 如果决定了下一个状态,且这个状态与当前状态不同,则返回下一个状态 if next_state is not None and next_state != current_status: return next_state else: - return None # 没有状态转换发生或无需重置计时器 + return None diff --git a/src/chat/heart_flow/observation/hfcloop_observation.py b/src/chat/heart_flow/observation/hfcloop_observation.py index 82c9c879a..d712b83be 100644 --- a/src/chat/heart_flow/observation/hfcloop_observation.py +++ b/src/chat/heart_flow/observation/hfcloop_observation.py @@ -17,7 +17,9 @@ class HFCloopObservation: self.observe_id = observe_id self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 self.history_loop: List[CycleDetail] = [] - self.action_manager = ActionManager() + self.action_manager: ActionManager = None + + self.all_actions = {} def get_observe_info(self): return self.observe_info @@ -27,6 +29,7 @@ class HFCloopObservation: def set_action_manager(self, action_manager: ActionManager): self.action_manager = action_manager + self.all_actions = self.action_manager.get_registered_actions() async def observe(self): recent_active_cycles: List[CycleDetail] = [] diff --git a/src/chat/heart_flow/sub_heartflow.py b/src/chat/heart_flow/sub_heartflow.py index 157c1c957..c440f8cfd 100644 --- a/src/chat/heart_flow/sub_heartflow.py +++ b/src/chat/heart_flow/sub_heartflow.py @@ -89,6 +89,14 @@ class SubHeartflow: await self.interest_chatting.initialize() logger.debug(f"{self.log_prefix} InterestChatting 实例已初始化。") + # 创建并初始化 normal_chat_instance + chat_stream = chat_manager.get_stream(self.chat_id) + if chat_stream: + self.normal_chat_instance = NormalChat(chat_stream=chat_stream,interest_dict=self.get_interest_dict()) + await self.normal_chat_instance.initialize() + await self.normal_chat_instance.start_chat() + logger.info(f"{self.log_prefix} NormalChat 实例已创建并启动。") + def update_last_chat_state_time(self): self.chat_state_last_time = time.time() - self.chat_state_changed_time @@ -181,8 +189,7 @@ class SubHeartflow: # 创建 HeartFChatting 实例,并传递 从构造函数传入的 回调函数 self.heart_fc_instance = HeartFChatting( chat_id=self.subheartflow_id, - observations=self.observations, # 传递所有观察者 - on_consecutive_no_reply_callback=self.hfc_no_reply_callback, # <-- Use stored callback + observations=self.observations, ) # 初始化并启动 HeartFChatting @@ -200,55 +207,41 @@ class SubHeartflow: self.heart_fc_instance = None # 创建或初始化异常,清理实例 return False - async def change_chat_state(self, new_state: "ChatState"): - """更新sub_heartflow的聊天状态,并管理 HeartFChatting 和 NormalChat 实例及任务""" + async def change_chat_state(self, new_state: ChatState) -> None: + """ + 改变聊天状态。 + 如果转换到CHAT或FOCUSED状态时超过限制,会保持当前状态。 + """ current_state = self.chat_state.chat_status + state_changed = False + log_prefix = f"[{self.log_prefix}]" - if current_state == new_state: - return - - log_prefix = self.log_prefix - state_changed = False # 标记状态是否实际发生改变 - - # --- 状态转换逻辑 --- if new_state == ChatState.CHAT: - # 移除限额检查逻辑 - logger.debug(f"{log_prefix} 准备进入或保持 聊天 状态") - if current_state == ChatState.FOCUSED: - if await self._start_normal_chat(rewind=False): - # logger.info(f"{log_prefix} 成功进入或保持 NormalChat 状态。") - state_changed = True - else: - logger.error(f"{log_prefix} 从FOCUSED状态启动 NormalChat 失败,无法进入 CHAT 状态。") - # 考虑是否需要回滚状态或采取其他措施 - return # 启动失败,不改变状态 + logger.debug(f"{log_prefix} 准备进入或保持 普通聊天 状态") + if await self._start_normal_chat(): + logger.debug(f"{log_prefix} 成功进入或保持 NormalChat 状态。") + state_changed = True else: - if await self._start_normal_chat(rewind=True): - # logger.info(f"{log_prefix} 成功进入或保持 NormalChat 状态。") - state_changed = True - else: - logger.error(f"{log_prefix} 从ABSENT状态启动 NormalChat 失败,无法进入 CHAT 状态。") - # 考虑是否需要回滚状态或采取其他措施 - return # 启动失败,不改变状态 + logger.error(f"{log_prefix} 启动 NormalChat 失败,无法进入 CHAT 状态。") + # 启动失败时,保持当前状态 + return elif new_state == ChatState.FOCUSED: - # 移除限额检查逻辑 logger.debug(f"{log_prefix} 准备进入或保持 专注聊天 状态") if await self._start_heart_fc_chat(): logger.debug(f"{log_prefix} 成功进入或保持 HeartFChatting 状态。") state_changed = True else: logger.error(f"{log_prefix} 启动 HeartFChatting 失败,无法进入 FOCUSED 状态。") - # 启动失败,状态回滚到之前的状态或ABSENT?这里保持不改变 - return # 启动失败,不改变状态 + # 启动失败时,保持当前状态 + return elif new_state == ChatState.ABSENT: logger.info(f"{log_prefix} 进入 ABSENT 状态,停止所有聊天活动...") self.clear_interest_dict() - await self._stop_normal_chat() await self._stop_heart_fc_chat() - state_changed = True # 总是可以成功转换到 ABSENT + state_changed = True # --- 更新状态和最后活动时间 --- if state_changed: @@ -263,7 +256,6 @@ class SubHeartflow: self.chat_state_last_time = 0 self.chat_state_changed_time = time.time() else: - # 如果因为某些原因(如启动失败)没有成功改变状态,记录一下 logger.debug( f"{log_prefix} 尝试将状态从 {current_state.value} 变为 {new_state.value},但未成功或未执行更改。" ) diff --git a/src/chat/heart_flow/subheartflow_manager.py b/src/chat/heart_flow/subheartflow_manager.py index bf4ddf7e1..22bab6a40 100644 --- a/src/chat/heart_flow/subheartflow_manager.py +++ b/src/chat/heart_flow/subheartflow_manager.py @@ -1,26 +1,14 @@ import asyncio import time import random -from typing import Dict, Any, Optional, List, Tuple -import json # 导入 json 模块 -import functools # <-- 新增导入 - -# 导入日志模块 +from typing import Dict, Any, Optional, List +import functools from src.common.logger_manager import get_logger - -# 导入聊天流管理模块 from src.chat.message_receive.chat_stream import chat_manager - -# 导入心流相关类 from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState from src.chat.heart_flow.mai_state_manager import MaiStateInfo from src.chat.heart_flow.observation.chatting_observation import ChattingObservation - -# 导入LLM请求工具 -from src.chat.models.utils_model import LLMRequest from src.config.config import global_config -from src.individuality.individuality import Individuality -import traceback # 初始化日志记录器 @@ -74,15 +62,6 @@ class SubHeartflowManager: self._lock = asyncio.Lock() # 用于保护 self.subheartflows 的访问 self.mai_state_info: MaiStateInfo = mai_state_info # 存储传入的 MaiStateInfo 实例 - # 为 LLM 状态评估创建一个 LLMRequest 实例 - # 使用与 Heartflow 相同的模型和参数 - # TODO: API-Adapter修改标记 - self.llm_state_evaluator = LLMRequest( - model=global_config.model.heartflow, # 与 Heartflow 一致 - temperature=0.6, # 与 Heartflow 一致 - max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多) - request_type="subheartflow_state_eval", # 保留特定的请求类型 - ) async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool: """强制改变指定子心流的状态""" @@ -156,10 +135,6 @@ class SubHeartflowManager: logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True) return None - # --- 新增:内部方法,用于尝试将单个子心流设置为 ABSENT --- - - # --- 结束新增 --- - async def sleep_subheartflow(self, subheartflow_id: Any, reason: str) -> bool: """停止指定的子心流并将其状态设置为 ABSENT""" log_prefix = "[子心流管理]" @@ -190,54 +165,6 @@ class SubHeartflowManager: return flows_to_stop - async def enforce_subheartflow_limits(self): - """根据主状态限制停止超额子心流(优先停不活跃的)""" - # 使用 self.mai_state_info 获取当前状态和限制 - current_mai_state = self.mai_state_info.get_current_state() - normal_limit = current_mai_state.get_normal_chat_max_num() - focused_limit = current_mai_state.get_focused_chat_max_num() - logger.debug(f"[限制] 状态:{current_mai_state.value}, 普通限:{normal_limit}, 专注限:{focused_limit}") - - # 分类统计当前子心流 - normal_flows = [] - focused_flows = [] - for flow_id, flow in list(self.subheartflows.items()): - if flow.chat_state.chat_status == ChatState.CHAT: - normal_flows.append((flow_id, getattr(flow, "last_active_time", 0))) - elif flow.chat_state.chat_status == ChatState.FOCUSED: - focused_flows.append((flow_id, getattr(flow, "last_active_time", 0))) - - logger.debug(f"[限制] 当前数量 - 普通:{len(normal_flows)}, 专注:{len(focused_flows)}") - stopped = 0 - - # 处理普通聊天超额 - if len(normal_flows) > normal_limit: - excess = len(normal_flows) - normal_limit - logger.info(f"[限制] 普通聊天超额({len(normal_flows)}>{normal_limit}), 停止{excess}个") - normal_flows.sort(key=lambda x: x[1]) - for flow_id, _ in normal_flows[:excess]: - if await self.sleep_subheartflow(flow_id, f"普通聊天超额(限{normal_limit})"): - stopped += 1 - - # 处理专注聊天超额(需重新统计) - focused_flows = [ - (fid, t) - for fid, f in list(self.subheartflows.items()) - if (t := getattr(f, "last_active_time", 0)) and f.chat_state.chat_status == ChatState.FOCUSED - ] - if len(focused_flows) > focused_limit: - excess = len(focused_flows) - focused_limit - logger.info(f"[限制] 专注聊天超额({len(focused_flows)}>{focused_limit}), 停止{excess}个") - focused_flows.sort(key=lambda x: x[1]) - for flow_id, _ in focused_flows[:excess]: - if await self.sleep_subheartflow(flow_id, f"专注聊天超额(限{focused_limit})"): - stopped += 1 - - if stopped: - logger.info(f"[限制] 已停止{stopped}个子心流, 剩余:{len(self.subheartflows)}") - else: - logger.debug(f"[限制] 无需停止, 当前总数:{len(self.subheartflows)}") - async def deactivate_all_subflows(self): """将所有子心流的状态更改为 ABSENT (例如主状态变为OFFLINE时调用)""" log_prefix = "[停用]" @@ -273,27 +200,14 @@ class SubHeartflowManager: ) async def sbhf_absent_into_focus(self): - """评估子心流兴趣度,满足条件且未达上限则提升到FOCUSED状态(基于start_hfc_probability)""" + """评估子心流兴趣度,满足条件则提升到FOCUSED状态(基于start_hfc_probability)""" try: current_state = self.mai_state_info.get_current_state() - focused_limit = current_state.get_focused_chat_max_num() - # --- 新增:检查是否允许进入 FOCUS 模式 --- # + # 检查是否允许进入 FOCUS 模式 if not global_config.chat.allow_focus_mode: if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏 logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)") - return # 如果不允许,直接返回 - # --- 结束新增 --- - - logger.info(f"当前状态 ({current_state.value}) 可以在{focused_limit}个群 专注聊天") - - if focused_limit <= 0: - # logger.debug(f"{log_prefix} 当前状态 ({current_state.value}) 不允许 FOCUSED 子心流") - return - - current_focused_count = self.count_subflows_by_state(ChatState.FOCUSED) - if current_focused_count >= focused_limit: - logger.debug(f"已达专注上限 ({current_focused_count}/{focused_limit})") return for sub_hf in list(self.subheartflows.values()): @@ -321,11 +235,6 @@ class SubHeartflowManager: if random.random() >= sub_hf.interest_chatting.start_hfc_probability: continue - # 再次检查是否达到上限 - if current_focused_count >= focused_limit: - logger.debug(f"{stream_name} 已达专注上限") - break - # 获取最新状态并执行提升 current_subflow = self.subheartflows.get(flow_id) if not current_subflow: @@ -338,283 +247,57 @@ class SubHeartflowManager: # 执行状态提升 await current_subflow.change_chat_state(ChatState.FOCUSED) - # 验证提升结果 - if ( - final_subflow := self.subheartflows.get(flow_id) - ) and final_subflow.chat_state.chat_status == ChatState.FOCUSED: - current_focused_count += 1 except Exception as e: logger.error(f"启动HFC 兴趣评估失败: {e}", exc_info=True) - async def sbhf_absent_into_chat(self): + + async def sbhf_focus_into_absent_or_chat(self, subflow_id: Any): """ - 随机选一个 ABSENT 状态的 *群聊* 子心流,评估是否应转换为 CHAT 状态。 - 每次调用最多转换一个。 - 私聊会被忽略。 - """ - current_mai_state = self.mai_state_info.get_current_state() - chat_limit = current_mai_state.get_normal_chat_max_num() - - async with self._lock: - # 1. 筛选出所有 ABSENT 状态的 *群聊* 子心流 - absent_group_subflows = [ - hf - for hf in self.subheartflows.values() - if hf.chat_state.chat_status == ChatState.ABSENT and hf.is_group_chat - ] - - if not absent_group_subflows: - # logger.debug("没有摸鱼的群聊子心流可以评估。") # 日志太频繁 - return # 没有目标,直接返回 - - # 2. 随机选一个幸运儿 - sub_hf_to_evaluate = random.choice(absent_group_subflows) - flow_id = sub_hf_to_evaluate.subheartflow_id - stream_name = chat_manager.get_stream_name(flow_id) or flow_id - log_prefix = f"[{stream_name}]" - - # 3. 检查 CHAT 上限 - current_chat_count = self.count_subflows_by_state_nolock(ChatState.CHAT) - if current_chat_count >= chat_limit: - logger.info(f"{log_prefix} 想看看能不能聊,但是聊天太多了, ({current_chat_count}/{chat_limit}) 满了。") - return # 满了,这次就算了 - - # --- 获取 FOCUSED 计数 --- - current_focused_count = self.count_subflows_by_state_nolock(ChatState.FOCUSED) - focused_limit = current_mai_state.get_focused_chat_max_num() - - # --- 新增:获取聊天和专注群名 --- - chatting_group_names = [] - focused_group_names = [] - for flow_id, hf in self.subheartflows.items(): - stream_name = chat_manager.get_stream_name(flow_id) or str(flow_id) # 保证有名字 - if hf.chat_state.chat_status == ChatState.CHAT: - chatting_group_names.append(stream_name) - elif hf.chat_state.chat_status == ChatState.FOCUSED: - focused_group_names.append(stream_name) - # --- 结束新增 --- - - # --- 获取观察信息和构建 Prompt --- - first_observation = sub_hf_to_evaluate.observations[0] # 喵~第一个观察者肯定存在的说 - await first_observation.observe() - current_chat_log = first_observation.talking_message_str or "当前没啥聊天内容。" - _observation_summary = f"在[{stream_name}]这个群中,你最近看群友聊了这些:\n{current_chat_log}" - - _mai_state_description = f"你当前状态: {current_mai_state.value}。" - individuality = Individuality.get_instance() - personality_prompt = individuality.get_prompt(x_person=2, level=2) - prompt_personality = f"你正在扮演名为{individuality.name}的人类,{personality_prompt}" - - # --- 修改:在 prompt 中加入当前聊天计数和群名信息 (条件显示) --- - chat_status_lines = [] - if chatting_group_names: - chat_status_lines.append( - f"正在这些群闲聊 ({current_chat_count}/{chat_limit}): {', '.join(chatting_group_names)}" - ) - if focused_group_names: - chat_status_lines.append( - f"正在这些群专注的聊天 ({current_focused_count}/{focused_limit}): {', '.join(focused_group_names)}" - ) - - chat_status_prompt = "当前没有在任何群聊中。" # 默认消息喵~ - if chat_status_lines: - chat_status_prompt = "当前聊天情况,你已经参与了下面这几个群的聊天:\n" + "\n".join( - chat_status_lines - ) # 拼接状态信息 - - prompt = ( - f"{prompt_personality}\n" - f"{chat_status_prompt}\n" # <-- 喵!用了新的状态信息~ - f"你当前尚未加入 [{stream_name}] 群聊天。\n" - f"{_observation_summary}\n---\n" - f"基于以上信息,你想不想开始在这个群闲聊?\n" - f"请说明理由,并以 JSON 格式回答,包含 'decision' (布尔值) 和 'reason' (字符串)。\n" - f'例如:{{"decision": true, "reason": "看起来挺热闹的,插个话"}}\n' - f'例如:{{"decision": false, "reason": "已经聊了好多,休息一下"}}\n' - f"请只输出有效的 JSON 对象。" - ) - # --- 结束修改 --- - - # --- 4. LLM 评估是否想聊 --- - yao_kai_shi_liao_ma, reason = await self._llm_evaluate_state_transition(prompt) - - if reason: - if yao_kai_shi_liao_ma: - logger.info(f"{log_prefix} 打算开始聊,原因是: {reason}") - else: - logger.info(f"{log_prefix} 不打算聊,原因是: {reason}") - else: - logger.info(f"{log_prefix} 结果: {yao_kai_shi_liao_ma}") - - if yao_kai_shi_liao_ma is None: - logger.debug(f"{log_prefix} 问AI想不想聊失败了,这次算了。") - return # 评估失败,结束 - - if not yao_kai_shi_liao_ma: - # logger.info(f"{log_prefix} 现在不想聊这个群。") - return # 不想聊,结束 - - # --- 5. AI想聊,再次检查额度并尝试转换 --- - # 再次检查以防万一 - current_chat_count_before_change = self.count_subflows_by_state_nolock(ChatState.CHAT) - if current_chat_count_before_change < chat_limit: - logger.info( - f"{log_prefix} 想聊,而且还有精力 ({current_chat_count_before_change}/{chat_limit}),这就去聊!" - ) - await sub_hf_to_evaluate.change_chat_state(ChatState.CHAT) - # 确认转换成功 - if sub_hf_to_evaluate.chat_state.chat_status == ChatState.CHAT: - logger.debug(f"{log_prefix} 成功进入聊天状态!本次评估圆满结束。") - else: - logger.warning( - f"{log_prefix} 奇怪,尝试进入聊天状态失败了。当前状态: {sub_hf_to_evaluate.chat_state.chat_status.value}" - ) - else: - logger.warning( - f"{log_prefix} AI说想聊,但是刚问完就没空位了 ({current_chat_count_before_change}/{chat_limit})。真不巧,下次再说吧。" - ) - # 无论转换成功与否,本次评估都结束了 - - # 锁在这里自动释放 - - # --- 新增:单独检查 CHAT 状态超时的任务 --- - async def sbhf_chat_into_absent(self): - """定期检查处于 CHAT 状态的子心流是否因长时间未发言而超时,并将其转为 ABSENT。""" - log_prefix_task = "[聊天超时检查]" - transitioned_to_absent = 0 - checked_count = 0 - - async with self._lock: - subflows_snapshot = list(self.subheartflows.values()) - checked_count = len(subflows_snapshot) - - if not subflows_snapshot: - return - - for sub_hf in subflows_snapshot: - # 只检查 CHAT 状态的子心流 - if sub_hf.chat_state.chat_status != ChatState.CHAT: - continue - - flow_id = sub_hf.subheartflow_id - stream_name = chat_manager.get_stream_name(flow_id) or flow_id - log_prefix = f"[{stream_name}]({log_prefix_task})" - - should_deactivate = False - reason = "" - - try: - last_bot_dong_zuo_time = sub_hf.get_normal_chat_last_speak_time() - - if last_bot_dong_zuo_time > 0: - current_time = time.time() - time_since_last_bb = current_time - last_bot_dong_zuo_time - minutes_since_last_bb = time_since_last_bb / 60 - - # 60分钟强制退出 - if minutes_since_last_bb >= 60: - should_deactivate = True - reason = "超过60分钟未发言,强制退出" - else: - # 根据时间区间确定退出概率 - exit_probability = 0 - if minutes_since_last_bb < 5: - exit_probability = 0.01 # 1% - elif minutes_since_last_bb < 15: - exit_probability = 0.02 # 2% - elif minutes_since_last_bb < 30: - exit_probability = 0.04 # 4% - else: - exit_probability = 0.08 # 8% - - # 随机判断是否退出 - if random.random() < exit_probability: - should_deactivate = True - reason = f"已{minutes_since_last_bb:.1f}分钟未发言,触发{exit_probability * 100:.0f}%退出概率" - - except AttributeError: - logger.error( - f"{log_prefix} 无法获取 Bot 最后 BB 时间,请确保 SubHeartflow 相关实现正确。跳过超时检查。" - ) - except Exception as e: - logger.error(f"{log_prefix} 检查 Bot 超时状态时出错: {e}", exc_info=True) - - # 执行状态转换(如果超时) - if should_deactivate: - logger.debug(f"{log_prefix} 因超时 ({reason}),尝试转换为 ABSENT 状态。") - await sub_hf.change_chat_state(ChatState.ABSENT) - # 再次检查确保状态已改变 - if sub_hf.chat_state.chat_status == ChatState.ABSENT: - transitioned_to_absent += 1 - logger.info(f"{log_prefix} 不看了。") - else: - logger.warning(f"{log_prefix} 尝试因超时转换为 ABSENT 失败。") - - if transitioned_to_absent > 0: - logger.debug( - f"{log_prefix_task} 完成,共检查 {checked_count} 个子心流,{transitioned_to_absent} 个因超时转为 ABSENT。" - ) - - # --- 结束新增 --- - - async def _llm_evaluate_state_transition(self, prompt: str) -> Tuple[Optional[bool], Optional[str]]: - """ - 使用 LLM 评估是否应进行状态转换,期望 LLM 返回 JSON 格式。 + 接收来自 HeartFChatting 的请求,将特定子心流的状态转换为 CHAT。 + 通常在连续多次 "no_reply" 后被调用。 + 对于私聊和群聊,都转换为 CHAT。 Args: - prompt: 提供给 LLM 的提示信息,要求返回 {"decision": true/false}。 - - Returns: - Optional[bool]: 如果成功解析 LLM 的 JSON 响应并提取了 'decision' 键的值,则返回该布尔值。 - 如果 LLM 调用失败、返回无效 JSON 或 JSON 中缺少 'decision' 键或其值不是布尔型,则返回 None。 + subflow_id: 需要转换状态的子心流 ID。 """ - log_prefix = "[LLM状态评估]" - try: - # --- 真实的 LLM 调用 --- - response_text, _ = await self.llm_state_evaluator.generate_response_async(prompt) - # logger.debug(f"{log_prefix} 使用模型 {self.llm_state_evaluator.model_name} 评估") - logger.debug(f"{log_prefix} 原始输入: {prompt}") - logger.debug(f"{log_prefix} 原始评估结果: {response_text}") + async with self._lock: + subflow = self.subheartflows.get(subflow_id) + if not subflow: + logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id} 到 CHAT") + return - # --- 解析 JSON 响应 --- - try: - # 尝试去除可能的Markdown代码块标记 - cleaned_response = response_text.strip().strip("`").strip() - if cleaned_response.startswith("json"): - cleaned_response = cleaned_response[4:].strip() + stream_name = chat_manager.get_stream_name(subflow_id) or subflow_id + current_state = subflow.chat_state.chat_status - data = json.loads(cleaned_response) - decision = data.get("decision") # 使用 .get() 避免 KeyError - reason = data.get("reason") + if current_state == ChatState.FOCUSED: + target_state = ChatState.CHAT + log_reason = "转为CHAT" - if isinstance(decision, bool): - logger.debug(f"{log_prefix} LLM评估结果 (来自JSON): {'建议转换' if decision else '建议不转换'}") - - return decision, reason - else: - logger.warning( - f"{log_prefix} LLM 返回的 JSON 中 'decision' 键的值不是布尔型: {decision}。响应: {response_text}" + logger.info( + f"[状态转换请求] 接收到请求,将 {stream_name} (当前: {current_state.value}) 尝试转换为 {target_state.value} ({log_reason})" + ) + try: + # 从HFC到CHAT时,清空兴趣字典 + subflow.clear_interest_dict() + await subflow.change_chat_state(target_state) + final_state = subflow.chat_state.chat_status + if final_state == target_state: + logger.debug(f"[状态转换请求] {stream_name} 状态已成功转换为 {final_state.value}") + else: + logger.warning( + f"[状态转换请求] 尝试将 {stream_name} 转换为 {target_state.value} 后,状态实际为 {final_state.value}" + ) + except Exception as e: + logger.error( + f"[状态转换请求] 转换 {stream_name} 到 {target_state.value} 时出错: {e}", exc_info=True ) - return None, None # 值类型不正确 - - except json.JSONDecodeError as json_err: - logger.warning(f"{log_prefix} LLM 返回的响应不是有效的 JSON: {json_err}。响应: {response_text}") - # 尝试在非JSON响应中查找关键词作为后备方案 (可选) - if "true" in response_text.lower(): - logger.debug(f"{log_prefix} 在非JSON响应中找到 'true',解释为建议转换") - return True, None - if "false" in response_text.lower(): - logger.debug(f"{log_prefix} 在非JSON响应中找到 'false',解释为建议不转换") - return False, None - return None, None # JSON 解析失败,也未找到关键词 - except Exception as parse_err: # 捕获其他可能的解析错误 - logger.warning(f"{log_prefix} 解析 LLM JSON 响应时发生意外错误: {parse_err}。响应: {response_text}") - return None, None - - except Exception as e: - logger.error(f"{log_prefix} 调用 LLM 或处理其响应时出错: {e}", exc_info=True) - traceback.print_exc() - return None, None # LLM 调用或处理失败 + elif current_state == ChatState.ABSENT: + logger.debug(f"[状态转换请求] {stream_name} 处于 ABSENT 状态,尝试转为 CHAT") + await subflow.change_chat_state(ChatState.CHAT) + else: + logger.debug( + f"[状态转换请求] {stream_name} 当前状态为 {current_state.value},无需转换" + ) def count_subflows_by_state(self, state: ChatState) -> int: """统计指定状态的子心流数量""" @@ -637,23 +320,6 @@ class SubHeartflowManager: count += 1 return count - def get_active_subflow_minds(self) -> List[str]: - """获取所有活跃(非ABSENT)子心流的当前想法""" - minds = [] - for subheartflow in self.subheartflows.values(): - # 检查子心流是否活跃(非ABSENT状态) - if subheartflow.chat_state.chat_status != ChatState.ABSENT: - minds.append(subheartflow.sub_mind.current_mind) - return minds - - def update_main_mind_in_subflows(self, main_mind: str): - """更新所有子心流的主心流想法""" - updated_count = sum( - 1 - for _, subheartflow in list(self.subheartflows.items()) - if subheartflow.subheartflow_id in self.subheartflows - ) - logger.debug(f"[子心流管理器] 更新了{updated_count}个子心流的主想法") async def delete_subflow(self, subheartflow_id: Any): """删除指定的子心流。""" @@ -670,91 +336,13 @@ class SubHeartflowManager: else: logger.warning(f"尝试删除不存在的 SubHeartflow: {subheartflow_id}") - # --- 新增:处理 HFC 无回复回调的专用方法 --- # + async def _handle_hfc_no_reply(self, subheartflow_id: Any): """处理来自 HeartFChatting 的连续无回复信号 (通过 partial 绑定 ID)""" - # 注意:这里不需要再获取锁,因为 sbhf_focus_into_absent 内部会处理锁 + # 注意:这里不需要再获取锁,因为 sbhf_focus_into_absent_or_chat 内部会处理锁 logger.debug(f"[管理器 HFC 处理器] 接收到来自 {subheartflow_id} 的 HFC 无回复信号") await self.sbhf_focus_into_absent_or_chat(subheartflow_id) - # --- 结束新增 --- # - - # --- 新增:处理来自 HeartFChatting 的状态转换请求 --- # - async def sbhf_focus_into_absent_or_chat(self, subflow_id: Any): - """ - 接收来自 HeartFChatting 的请求,将特定子心流的状态转换为 ABSENT 或 CHAT。 - 通常在连续多次 "no_reply" 后被调用。 - 对于私聊,总是转换为 ABSENT。 - 对于群聊,随机决定转换为 ABSENT 或 CHAT (如果 CHAT 未达上限)。 - - Args: - subflow_id: 需要转换状态的子心流 ID。 - """ - async with self._lock: - subflow = self.subheartflows.get(subflow_id) - if not subflow: - logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id} 到 ABSENT/CHAT") - return - - stream_name = chat_manager.get_stream_name(subflow_id) or subflow_id - current_state = subflow.chat_state.chat_status - - if current_state == ChatState.FOCUSED: - target_state = ChatState.ABSENT # Default target - log_reason = "默认转换 (私聊或群聊)" - - # --- Modify logic based on chat type --- # - if subflow.is_group_chat: - # Group chat: Decide between ABSENT or CHAT - if random.random() < 0.5: # 50% chance to try CHAT - current_mai_state = self.mai_state_info.get_current_state() - chat_limit = current_mai_state.get_normal_chat_max_num() - current_chat_count = self.count_subflows_by_state_nolock(ChatState.CHAT) - - if current_chat_count < chat_limit: - target_state = ChatState.CHAT - log_reason = f"群聊随机选择 CHAT (当前 {current_chat_count}/{chat_limit})" - else: - target_state = ChatState.ABSENT # Fallback to ABSENT if CHAT limit reached - log_reason = ( - f"群聊随机选择 CHAT 但已达上限 ({current_chat_count}/{chat_limit}),转为 ABSENT" - ) - else: # 50% chance to go directly to ABSENT - target_state = ChatState.ABSENT - log_reason = "群聊随机选择 ABSENT" - else: - # Private chat: Always go to ABSENT - target_state = ChatState.ABSENT - log_reason = "私聊退出 FOCUSED,转为 ABSENT" - # --- End modification --- # - - logger.info( - f"[状态转换请求] 接收到请求,将 {stream_name} (当前: {current_state.value}) 尝试转换为 {target_state.value} ({log_reason})" - ) - try: - # 从HFC到CHAT时,清空兴趣字典 - subflow.clear_interest_dict() - await subflow.change_chat_state(target_state) - final_state = subflow.chat_state.chat_status - if final_state == target_state: - logger.debug(f"[状态转换请求] {stream_name} 状态已成功转换为 {final_state.value}") - else: - logger.warning( - f"[状态转换请求] 尝试将 {stream_name} 转换为 {target_state.value} 后,状态实际为 {final_state.value}" - ) - except Exception as e: - logger.error( - f"[状态转换请求] 转换 {stream_name} 到 {target_state.value} 时出错: {e}", exc_info=True - ) - elif current_state == ChatState.ABSENT: - logger.debug(f"[状态转换请求] {stream_name} 已处于 ABSENT 状态,无需转换") - else: - logger.warning( - f"[状态转换请求] 收到对 {stream_name} 的请求,但其状态为 {current_state.value} (非 FOCUSED),不执行转换" - ) - - # --- 结束新增 --- # - # --- 新增:处理私聊从 ABSENT 直接到 FOCUSED 的逻辑 --- # async def sbhf_absent_private_into_focus(self): """检查 ABSENT 状态的私聊子心流是否有新活动,若有且未达 FOCUSED 上限,则直接转换为 FOCUSED。""" @@ -762,19 +350,8 @@ class SubHeartflowManager: transitioned_count = 0 checked_count = 0 - # --- 获取当前状态和 FOCUSED 上限 --- # - current_mai_state = self.mai_state_info.get_current_state() - focused_limit = current_mai_state.get_focused_chat_max_num() - # --- 检查是否允许 FOCUS 模式 --- # if not global_config.chat.allow_focus_mode: - # Log less frequently to avoid spam - # if int(time.time()) % 60 == 0: - # logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态") - return - - if focused_limit <= 0: - # logger.debug(f"{log_prefix_task} 当前状态 ({current_mai_state.value}) 不允许 FOCUSED 子心流") return async with self._lock: @@ -795,12 +372,6 @@ class SubHeartflowManager: # --- 遍历评估每个符合条件的私聊 --- # for sub_hf in eligible_subflows: - # --- 再次检查 FOCUSED 上限,因为可能有多个同时激活 --- # - if current_focused_count >= focused_limit: - logger.debug( - f"{log_prefix_task} 已达专注上限 ({current_focused_count}/{focused_limit}),停止检查后续私聊。" - ) - break # 已满,无需再检查其他私聊 flow_id = sub_hf.subheartflow_id stream_name = chat_manager.get_stream_name(flow_id) or flow_id @@ -824,9 +395,6 @@ class SubHeartflowManager: # --- 如果活跃且未达上限,则尝试转换 --- # if is_active: - logger.info( - f"{log_prefix} 检测到活跃且未达专注上限 ({current_focused_count}/{focused_limit}),尝试转换为 FOCUSED。" - ) await sub_hf.change_chat_state(ChatState.FOCUSED) # 确认转换成功 if sub_hf.chat_state.chat_status == ChatState.FOCUSED: diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 88bf141a1..3b9a6f929 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -78,25 +78,6 @@ class ChatBot: group_info = message.message_info.group_info user_info = message.message_info.user_info - # 用户黑名单拦截 - # if userinfo.user_id in global_config.chat_target.ban_user_id: - # logger.debug(f"用户{userinfo.user_id}被禁止回复") - # return - - # if groupinfo is None: - # logger.trace("检测到私聊消息,检查") - # # 好友黑名单拦截 - # if userinfo.user_id not in global_config.experimental.talk_allowed_private: - # # logger.debug(f"用户{userinfo.user_id}没有私聊权限") - # return - - # 群聊黑名单拦截 - # print(groupinfo.group_id) - # print(global_config.chat_target.talk_allowed_groups) - # if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups: - # logger.debug(f"群{groupinfo.group_id}被禁止回复") - # return - # 确认从接口发来的message是否有自定义的prompt模板信息 if message.message_info.template_info and not message.message_info.template_info.template_default: template_group_name = message.message_info.template_info.template_name @@ -114,28 +95,6 @@ class ChatBot: # 如果在私聊中 if group_info is None: logger.trace("检测到私聊消息") - # 是否在配置信息中开启私聊模式 - # if global_config.experimental.enable_friend_chat: - # logger.trace("私聊模式已启用") - # # 是否进入PFC - # if global_config.enable_pfc_chatting: - # logger.trace("进入PFC私聊处理流程") - # userinfo = message.message_info.user_info - # messageinfo = message.message_info - # # 创建聊天流 - # logger.trace(f"为{userinfo.user_id}创建/获取聊天流") - # chat = await chat_manager.get_or_create_stream( - # platform=messageinfo.platform, - # user_info=userinfo, - # group_info=groupinfo, - # ) - # message.update_chat_stream(chat) - # await self.only_process_chat.process_message(message) - # await self._create_pfc_chat(message) - # # 禁止PFC,进入普通的心流消息处理逻辑 - # else: - # logger.trace("进入普通心流私聊处理") - # await self.heartflow_processor.process_message(message_data) if global_config.experimental.pfc_chatting: logger.trace("进入PFC私聊处理流程") # 创建聊天流 diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 96cc2b8cb..bd5322137 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -22,11 +22,11 @@ from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.normal_chat.willing.willing_manager import willing_manager from src.config.config import global_config -logger = get_logger("chat") +logger = get_logger("normal_chat") class NormalChat: - def __init__(self, chat_stream: ChatStream, interest_dict: dict = None): + def __init__(self, chat_stream: ChatStream, interest_dict: dict = {}): """初始化 NormalChat 实例。只进行同步操作。""" # Basic info from chat_stream (sync) @@ -200,7 +200,7 @@ class NormalChat: logger.info(f"[{self.stream_name}] 兴趣监控任务被取消或置空,退出") break - # 获取待处理消息列表 + items_to_process = list(self.interest_dict.items()) if not items_to_process: continue @@ -481,7 +481,7 @@ class NormalChat: try: if exc := task.exception(): logger.error(f"[{self.stream_name}] 任务异常: {exc}") - logger.error(traceback.format_exc()) + traceback.print_exc() except asyncio.CancelledError: logger.debug(f"[{self.stream_name}] 任务已取消") except Exception as e: @@ -522,4 +522,4 @@ class NormalChat: logger.info(f"[{self.stream_name}] 清理了 {len(thinking_messages)} 条未处理的思考消息。") except Exception as e: logger.error(f"[{self.stream_name}] 清理思考消息时出错: {e}") - logger.error(traceback.format_exc()) + traceback.print_exc() diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index a5b601c43..6d9ce0719 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -15,6 +15,8 @@ from ..models.utils_model import LLMRequest from .typo_generator import ChineseTypoGenerator from ...common.database.database import db from ...config.config import global_config +from ...common.database.database_model import Messages +from ...common.message_repository import find_messages, count_messages logger = get_module_logger("chat_utils") @@ -108,20 +110,12 @@ async def get_embedding(text, request_type="embedding"): def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False): - recent_messages = list( - db.messages.find( - {"chat_id": chat_stream_id}, - { - "time": 1, # 返回时间字段 - "chat_id": 1, - "chat_info": 1, - "user_info": 1, - "message_id": 1, # 返回消息ID字段 - "detailed_plain_text": 1, # 返回处理后的文本字段 - }, - ) - .sort("time", -1) - .limit(limit) + filter_query = {"chat_id": chat_stream_id} + sort_order = [("time", -1)] + recent_messages = find_messages( + message_filter=filter_query, + sort=sort_order, + limit=limit ) if not recent_messages: @@ -143,17 +137,14 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c return message_detailed_plain_text_list -def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> list: +def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: # 获取当前群聊记录内发言的人 - recent_messages = list( - db.messages.find( - {"chat_id": chat_stream_id}, - { - "user_info": 1, - }, - ) - .sort("time", -1) - .limit(limit) + filter_query = {"chat_id": chat_stream_id} + sort_order = [("time", -1)] + recent_messages = find_messages( + message_filter=filter_query, + sort=sort_order, + limit=limit ) if not recent_messages: @@ -161,7 +152,12 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li who_chat_in_group = [] for msg_db_data in recent_messages: - user_info = UserInfo.from_dict(msg_db_data["user_info"]) + user_info = UserInfo.from_dict({ + "platform": msg_db_data["user_platform"], + "user_id": msg_db_data["user_id"], + "user_nickname": msg_db_data["user_nickname"], + "user_cardname": msg_db_data.get("user_cardname", "") + }) if ( (user_info.platform, user_info.user_id) != sender and user_info.user_id != global_config.bot.qq_account @@ -581,26 +577,23 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) - logger.error("stream_id 不能为空") return 0, 0 - # 直接查询时间范围内的消息 - # time > start_time AND time <= end_time - query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}} + # 使用message_repository中的count_messages和find_messages函数 + + + # 构建查询条件 + filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}} try: - # 执行查询 - messages_cursor = db.messages.find(query) + # 先获取消息数量 + count = count_messages(filter_query) + + # 获取消息内容计算总长度 + messages = find_messages(message_filter=filter_query) + total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages) - # 遍历结果计算数量和长度 - for msg in messages_cursor: - count += 1 - total_length += len(msg.get("processed_plain_text", "")) - - # logger.debug(f"查询范围 ({start_time}, {end_time}] 内找到 {count} 条消息,总长度 {total_length}") return count, total_length - except PyMongoError as e: - logger.error(f"查询 stream_id={stream_id} 在 ({start_time}, {end_time}] 范围内的消息时出错: {e}") - return 0, 0 - except Exception as e: # 保留一个通用异常捕获以防万一 + except Exception as e: logger.error(f"计算消息数量时发生意外错误: {e}") return 0, 0 diff --git a/src/common/logger.py b/src/common/logger.py index adc15fe71..394d9de90 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -276,6 +276,40 @@ CHAT_STYLE_CONFIG = { }, } +# Topic日志样式配置 +NORMAL_CHAT_STYLE_CONFIG = { + "advanced": { + "console_format": ( + "{time:YYYY-MM-DD HH:mm:ss} | " + "{level: <8} | " + "一般水群 | " + "{message}" + ), + "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}", + }, + "simple": { + "console_format": "{time:HH:mm:ss} | 一般水群 | {message}", # noqa: E501 + "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 一般水群 | {message}", + }, +} + +# Topic日志样式配置 +FOCUS_CHAT_STYLE_CONFIG = { + "advanced": { + "console_format": ( + "{time:YYYY-MM-DD HH:mm:ss} | " + "{level: <8} | " + "专注水群 | " + "{message}" + ), + "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 专注水群 | {message}", + }, + "simple": { + "console_format": "{time:HH:mm:ss} | 专注水群 | {message}", # noqa: E501 + "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 专注水群 | {message}", + }, +} + REMOTE_STYLE_CONFIG = { "advanced": { "console_format": ( @@ -915,6 +949,8 @@ API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT els INTEREST_CHAT_STYLE_CONFIG = ( INTEREST_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INTEREST_CHAT_STYLE_CONFIG["advanced"] ) +NORMAL_CHAT_STYLE_CONFIG = NORMAL_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else NORMAL_CHAT_STYLE_CONFIG["advanced"] +FOCUS_CHAT_STYLE_CONFIG = FOCUS_CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else FOCUS_CHAT_STYLE_CONFIG["advanced"] def is_registered_module(record: dict) -> bool: diff --git a/src/common/logger_manager.py b/src/common/logger_manager.py index 48d415bd9..523059313 100644 --- a/src/common/logger_manager.py +++ b/src/common/logger_manager.py @@ -21,6 +21,8 @@ from src.common.logger import ( WILLING_STYLE_CONFIG, PFC_ACTION_PLANNER_STYLE_CONFIG, MAI_STATE_CONFIG, + NORMAL_CHAT_STYLE_CONFIG, + FOCUS_CHAT_STYLE_CONFIG, LPMM_STYLE_CONFIG, HFC_STYLE_CONFIG, OBSERVATION_STYLE_CONFIG, @@ -94,6 +96,8 @@ MODULE_LOGGER_CONFIGS = { "init": INIT_STYLE_CONFIG, # 初始化 "interest_chat": INTEREST_CHAT_STYLE_CONFIG, # 兴趣 "api": API_SERVER_STYLE_CONFIG, # API服务器 + "normal_chat": NORMAL_CHAT_STYLE_CONFIG, # 一般水群 + "focus_chat": FOCUS_CHAT_STYLE_CONFIG, # 专注水群 # ...如有更多模块,继续添加... } diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index b66c3b180..943422029 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "2.1.0" +version = "2.2.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -55,8 +55,6 @@ qq="http://127.0.0.1:18002/api/message" allow_focus_mode = false # 是否允许专注聊天状态 # 是否启用heart_flowC(HFC)模式 # 启用后麦麦会自主选择进入heart_flowC模式(持续一段时间),进行主动的观察和回复,并给出回复,比较消耗token -base_normal_chat_num = 999 # 最多允许多少个群进行普通聊天 -base_focused_chat_num = 4 # 最多允许多少个群进行专注聊天 chat.observation_context_size = 15 # 观察到的最长上下文大小,建议15,太短太长都会导致脑袋尖尖 message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟 @@ -226,14 +224,14 @@ provider = "SILICONFLOW" pri_in = 0 pri_out = 0 -[model.sub_heartflow] #心流:认真水群时,生成麦麦的内心想法,必须使用具有工具调用能力的模型 +[model.sub_heartflow] #心流:认真聊天时,生成麦麦的内心想法,必须使用具有工具调用能力的模型 name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2 pri_out = 8 temp = 0.3 #模型的温度,新V3建议0.1-0.3 -[model.plan] #决策:认真水群时,负责决定麦麦该做什么 +[model.plan] #决策:认真聊天时,负责决定麦麦该做什么 name = "Pro/deepseek-ai/DeepSeek-V3" provider = "SILICONFLOW" pri_in = 2