diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index 88cb31ed5..891c4e939 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -104,6 +104,12 @@ class BotConfig: memory_ban_words: list = field( default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"] ) # 添加新的配置项默认值 + + # 是否优先使用当前群组的记忆 + memory_group_priority: bool = True # 默认开启群组记忆优先 + + # 群组记忆私有化 + memory_private_groups: dict = field(default_factory=dict) # 群组私有记忆配置 @staticmethod def get_config_dir() -> str: @@ -304,6 +310,12 @@ class BotConfig: config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time) config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage) config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate) + # 添加对memory_group_priority配置项的加载 + config.memory_group_priority = memory_config.get("memory_group_priority", config.memory_group_priority) + + if config.INNER_VERSION in SpecifierSet(">=0.0.9"): + # 添加群组记忆私有化配置项的加载 + config.memory_private_groups = memory_config.get("memory_private_groups", {}) def mood(parent: dict): mood_config = parent["mood"] diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 9532db4f0..1c8a07699 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -6,6 +6,8 @@ import random import time import traceback from typing import Optional, Tuple +from PIL import Image +import io from loguru import logger from nonebot import get_driver @@ -192,11 +194,11 @@ class EmojiManager: logger.error(f"获取标签失败: {str(e)}") return None - async def _check_emoji(self, image_base64: str) -> str: + async def _check_emoji(self, image_base64: str, image_format: str) -> str: try: prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容' - content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) + content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) logger.debug(f"输出描述: {content}") return content @@ -237,7 +239,7 @@ class EmojiManager: image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # 检查是否已经注册过 existing_emoji = self.db['emoji'].find_one({'filename': filename}) description = None @@ -278,7 +280,7 @@ class EmojiManager: if global_config.EMOJI_CHECK: - check = await self._check_emoji(image_base64) + check = await self._check_emoji(image_base64, image_format) if '是' not in check: os.remove(image_path) logger.info(f"描述: {description}") diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index c89bf3e07..ac31234c8 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -91,12 +91,20 @@ class PromptBuilder: memory_prompt = '' start_time = time.time() + # 获取群组ID + group_id = None + if stream_id: + chat_stream = chat_manager.get_stream(stream_id) + if chat_stream and chat_stream.group_info: + group_id = chat_stream.group_info.group_id + # 调用 hippocampus 的 get_relevant_memories 方法 relevant_memories = await hippocampus.get_relevant_memories( text=message_txt, max_topics=5, similarity_threshold=0.4, - max_memory_num=5 + max_memory_num=5, + group_id=group_id # 传递群组ID ) if relevant_memories: diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index fb2428870..94014b5b4 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -4,6 +4,8 @@ import time import aiohttp import hashlib from typing import Optional, Union +from PIL import Image +import io from loguru import logger from nonebot import get_driver @@ -119,6 +121,7 @@ class ImageManager: # 计算哈希值 image_hash = hashlib.md5(image_bytes).hexdigest() + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # 查重 existing = self.db.images.find_one({'hash': image_hash}) @@ -127,7 +130,7 @@ class ImageManager: # 生成文件名和路径 timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.jpg" + filename = f"{timestamp}_{image_hash[:8]}.{image_format}" file_path = os.path.join(self.IMAGE_DIR, filename) # 保存文件 @@ -238,7 +241,8 @@ class ImageManager: # 计算图片哈希 image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, 'emoji') if cached_description: @@ -247,13 +251,13 @@ class ImageManager: # 调用AI获取描述 prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感" - description, _ = await self._llm.generate_response_for_image(prompt, image_base64) + description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) # 根据配置决定是否保存图片 if global_config.EMOJI_SAVE: # 生成文件名和路径 timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.jpg" + filename = f"{timestamp}_{image_hash[:8]}.{image_format}" file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename) try: @@ -293,7 +297,8 @@ class ImageManager: # 计算图片哈希 image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - + image_format = Image.open(io.BytesIO(image_bytes)).format.lower() + # 查询缓存的描述 cached_description = self._get_description_from_db(image_hash, 'image') if cached_description: @@ -302,7 +307,7 @@ class ImageManager: # 调用AI获取描述 prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" - description, _ = await self._llm.generate_response_for_image(prompt, image_base64) + description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) print(f"描述是{description}") @@ -314,7 +319,7 @@ class ImageManager: if global_config.EMOJI_SAVE: # 生成文件名和路径 timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.jpg" + filename = f"{timestamp}_{image_hash[:8]}.{image_format}" file_path = os.path.join(self.IMAGE_DIR,'image', filename) try: diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index d9e867e63..a3a0d068a 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -4,6 +4,7 @@ import math import random import time import os +from typing import Optional import jieba import networkx as nx @@ -209,15 +210,31 @@ class Hippocampus: return chat_samples - async def memory_compress(self, messages: list, compress_rate=0.1): + async def memory_compress(self, messages: list, compress_rate=0.1, group_id=None): """压缩消息记录为记忆 + Args: + messages: 消息记录列表 + compress_rate: 压缩率 + group_id: 群组ID,用于标记记忆来源 + Returns: tuple: (压缩记忆集合, 相似主题字典) """ + from ..chat.config import global_config # 导入配置 + if not messages: return set(), {} + # 确定记忆所属的群组 + memory_group_tag = None + if group_id is not None: + # 查找群聊所属的群组 + for group_name, group_ids in global_config.memory_private_groups.items(): + if str(group_id) in group_ids: + memory_group_tag = f"[群组:{group_name}]" + break + # 合并消息文本,同时保留时间信息 input_text = "" time_info = "" @@ -267,7 +284,17 @@ class Hippocampus: for topic, task in tasks: response = await task if response: - compressed_memory.add((topic, response[0])) + memory_content = response[0] + + # 添加标记 + # 优先使用群组标记 + if memory_group_tag: + memory_content = f"{memory_group_tag}{memory_content}" + # 如果没有群组标记但有群组ID,添加简单的群组ID标记 + elif group_id is not None: + memory_content = f"[群组:{group_id}]{memory_content}" + + compressed_memory.add((topic, memory_content)) # 为每个话题查找相似的已存在主题 existing_topics = list(self.memory_graph.G.nodes()) similar_topics = [] @@ -316,7 +343,17 @@ class Hippocampus: logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") compress_rate = global_config.memory_compress_rate - compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) + + # 尝试从消息中提取群组ID + group_id = None + if messages and len(messages) > 0: + first_msg = messages[0] + if 'group_id' in first_msg: + group_id = first_msg['group_id'] + logger.info(f"检测到消息来自群组: {group_id}") + + # 传递群组ID到memory_compress + compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate, group_id) logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}") current_time = datetime.datetime.now().timestamp() @@ -841,8 +878,21 @@ class Hippocampus: return activation async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, - max_memory_num: int = 5) -> list: - """根据输入文本获取相关的记忆内容""" + max_memory_num: int = 5, group_id: Optional[int] = None) -> list: + """根据输入文本获取相关的记忆内容 + + Args: + text: 输入文本 + max_topics: 最大主题数 + similarity_threshold: 相似度阈值 + max_memory_num: 最大记忆数量 + group_id: 群组ID,用于优先匹配当前群组的记忆 + + Returns: + list: 相关记忆列表 + """ + from ..chat.config import global_config # 导入配置 + # 识别主题 identified_topics = await self._identify_topics(text) @@ -855,30 +905,134 @@ class Hippocampus: # 获取最相关的主题 relevant_topics = self._get_top_topics(all_similar_topics, max_topics) - + + # 确定记忆所属的群组 + current_group_name = None + if group_id is not None: + # 查找群聊所属的群组 + for group_name, group_ids in global_config.memory_private_groups.items(): + if str(group_id) in group_ids: + current_group_name = group_name + break + + has_private_groups = len(global_config.memory_private_groups) > 0 + # 获取相关记忆内容 relevant_memories = [] + group_related_memories = [] # 当前群聊的记忆 + group_definition_memories = [] # 当前群组的记忆 + public_memories = [] # 公共记忆 + for topic, score in relevant_topics: # 获取该主题的记忆内容 first_layer, _ = self.memory_graph.get_related_item(topic, depth=1) if first_layer: # 如果记忆条数超过限制,随机选择指定数量的记忆 - if len(first_layer) > max_memory_num / 2: - first_layer = random.sample(first_layer, max_memory_num // 2) + if len(first_layer) > max_memory_num: + first_layer = random.sample(first_layer, max_memory_num) + # 为每条记忆添加来源主题和相似度信息 for memory in first_layer: - relevant_memories.append({ + memory_info = { 'topic': topic, 'similarity': score, 'content': memory - }) + } + + memory_text = str(memory) + + # 分类处理记忆 + if has_private_groups and group_id is not None: + # 如果配置了私有群组且当前在群聊中 + if current_group_name: + # 当前群聊属于某个群组 + if f"[群组:{current_group_name}]" in memory_text: + # 当前群组的记忆 + group_definition_memories.append(memory_info) + elif not any(f"[群组:" in memory_text for _ in range(1)): + # 公共记忆 + public_memories.append(memory_info) + else: + # 当前群聊不属于任何群组 + if f"[群组:{group_id}]" in memory_text: + # 当前群聊的特定记忆 + group_related_memories.append(memory_info) + elif not any(f"[群组:" in memory_text for _ in range(1)): + # 公共记忆 + public_memories.append(memory_info) + elif global_config.memory_group_priority and group_id is not None: + # 如果只启用了群组记忆优先 + if f"[群组:{group_id}]" in memory_text: + # 当前群聊的记忆,放入群组相关记忆列表 + group_related_memories.append(memory_info) + else: + # 其他记忆,放入公共记忆列表 + public_memories.append(memory_info) + else: + # 如果没有特殊配置,所有记忆都放入相关记忆列表 + relevant_memories.append(memory_info) - # 如果记忆数量超过5个,随机选择5个 - # 按相似度排序 - relevant_memories.sort(key=lambda x: x['similarity'], reverse=True) - - if len(relevant_memories) > max_memory_num: - relevant_memories = random.sample(relevant_memories, max_memory_num) + # 根据配置决定如何组合记忆 + if has_private_groups and group_id is not None: + # 配置了私有群组且当前在群聊中 + if current_group_name: + # 当前群聊属于某个群组 + # 优先使用当前群组的记忆,如果不足再使用公共记忆 + if len(group_definition_memories) >= max_memory_num: + # 如果群组记忆足够,只使用群组记忆 + group_definition_memories.sort(key=lambda x: x['similarity'], reverse=True) + relevant_memories = group_definition_memories[:max_memory_num] + else: + # 如果群组记忆不足,添加公共记忆 + group_definition_memories.sort(key=lambda x: x['similarity'], reverse=True) + public_memories.sort(key=lambda x: x['similarity'], reverse=True) + + relevant_memories = group_definition_memories.copy() + remaining_count = max_memory_num - len(relevant_memories) + if remaining_count > 0 and public_memories: + selected_other = public_memories[:remaining_count] + relevant_memories.extend(selected_other) + else: + # 当前群聊不属于任何群组 + # 优先使用当前群聊的记忆,然后使用公共记忆 + if len(group_related_memories) >= max_memory_num: + # 如果当前群聊记忆足够,只使用当前群聊记忆 + group_related_memories.sort(key=lambda x: x['similarity'], reverse=True) + relevant_memories = group_related_memories[:max_memory_num] + else: + # 如果当前群聊记忆不足,添加公共记忆 + group_related_memories.sort(key=lambda x: x['similarity'], reverse=True) + public_memories.sort(key=lambda x: x['similarity'], reverse=True) + + relevant_memories = group_related_memories.copy() + remaining_count = max_memory_num - len(relevant_memories) + if remaining_count > 0 and public_memories: + selected_other = public_memories[:remaining_count] + relevant_memories.extend(selected_other) + elif global_config.memory_group_priority and group_id is not None: + # 如果只启用了群组记忆优先 + # 按相似度排序 + group_related_memories.sort(key=lambda x: x['similarity'], reverse=True) + public_memories.sort(key=lambda x: x['similarity'], reverse=True) + + # 优先使用群组相关记忆,如果不足再使用其他记忆 + if len(group_related_memories) >= max_memory_num: + # 如果群组相关记忆足够,只使用群组相关记忆 + relevant_memories = group_related_memories[:max_memory_num] + else: + # 使用所有群组相关记忆 + relevant_memories = group_related_memories.copy() + # 如果群组相关记忆不足,添加其他记忆 + remaining_count = max_memory_num - len(relevant_memories) + if remaining_count > 0 and public_memories: + # 从其他记忆中选择剩余需要的数量 + selected_other = public_memories[:remaining_count] + relevant_memories.extend(selected_other) + else: + # 如果没有特殊配置,按相似度排序 + relevant_memories.sort(key=lambda x: x['similarity'], reverse=True) + if len(relevant_memories) > max_memory_num: + relevant_memories = relevant_memories[:max_memory_num] return relevant_memories diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 5335e3d65..aa07bb55d 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -104,6 +104,7 @@ class LLM_request: endpoint: str, prompt: str = None, image_base64: str = None, + image_format: str = None, payload: dict = None, retry_policy: dict = None, response_handler: callable = None, @@ -115,6 +116,7 @@ class LLM_request: endpoint: API端点路径 (如 "chat/completions") prompt: prompt文本 image_base64: 图片的base64编码 + image_format: 图片格式 payload: 请求体数据 retry_policy: 自定义重试策略 response_handler: 自定义响应处理器 @@ -151,7 +153,7 @@ class LLM_request: # 构建请求体 if image_base64: - payload = await self._build_payload(prompt, image_base64) + payload = await self._build_payload(prompt, image_base64, image_format) elif payload is None: payload = await self._build_payload(prompt) @@ -172,7 +174,7 @@ class LLM_request: if response.status == 413: logger.warning("请求体过大,尝试压缩...") image_base64 = compress_base64_image_by_scale(image_base64) - payload = await self._build_payload(prompt, image_base64) + payload = await self._build_payload(prompt, image_base64, image_format) elif response.status in [500, 503]: logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") raise RuntimeError("服务器负载过高,模型恢复失败QAQ") @@ -294,7 +296,7 @@ class LLM_request: new_params["max_completion_tokens"] = new_params.pop("max_tokens") return new_params - async def _build_payload(self, prompt: str, image_base64: str = None) -> dict: + async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: """构建请求体""" # 复制一份参数,避免直接修改 self.params params_copy = await self._transform_parameters(self.params) @@ -306,7 +308,7 @@ class LLM_request: "role": "user", "content": [ {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}} + {"type": "image_url", "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}} ] } ], @@ -391,13 +393,14 @@ class LLM_request: ) return content, reasoning_content - async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: + async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]: """根据输入的提示和图片生成模型的异步响应""" content, reasoning_content = await self._execute_request( endpoint="/chat/completions", prompt=prompt, - image_base64=image_base64 + image_base64=image_base64, + image_format=image_format ) return content, reasoning_content diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 089be69b0..49f3a1919 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "0.0.8" +version = "0.0.9" #如果你想要修改配置文件,请在修改后将version的值进行变更 #如果新增项目,请在BotConfig类下新增相应的变量 @@ -72,6 +72,17 @@ forget_memory_interval = 600 # 记忆遗忘间隔 单位秒 间隔越低,麦 memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时 memory_forget_percentage = 0.01 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认 +memory_group_priority = true # 是否优先使用当前群组的记忆,开启后将优先使用当前群组的记忆内容,避免不同群组讨论相同话题时的记忆混淆 + +# 群组私有记忆配置 - 同一群组内的群聊共享记忆,但不与其他群组共享 +# 格式为 { 群组名称 = [群聊ID列表] } +# 未配置在任何群组中的群聊记忆可以与所有群聊共享(群组内群数量过少 聊天记录过少的情况下 建议修改其他记忆参数 加强回复概率等) +# 例如: +# memory_private_groups = { +# "游戏群组" = ["123456", "234567"], +# "工作群组" = ["345678", "456789"] +# } +memory_private_groups = { } memory_ban_words = [ #不希望记忆的词 # "403","张三"