Merge remote-tracking branch 'upstream/debug' into debug
This commit is contained in:
@@ -104,6 +104,12 @@ class BotConfig:
|
||||
memory_ban_words: list = field(
|
||||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
|
||||
) # 添加新的配置项默认值
|
||||
|
||||
# 是否优先使用当前群组的记忆
|
||||
memory_group_priority: bool = True # 默认开启群组记忆优先
|
||||
|
||||
# 群组记忆私有化
|
||||
memory_private_groups: dict = field(default_factory=dict) # 群组私有记忆配置
|
||||
|
||||
@staticmethod
|
||||
def get_config_dir() -> str:
|
||||
@@ -304,6 +310,12 @@ class BotConfig:
|
||||
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
|
||||
config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage)
|
||||
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
|
||||
# 添加对memory_group_priority配置项的加载
|
||||
config.memory_group_priority = memory_config.get("memory_group_priority", config.memory_group_priority)
|
||||
|
||||
if config.INNER_VERSION in SpecifierSet(">=0.0.9"):
|
||||
# 添加群组记忆私有化配置项的加载
|
||||
config.memory_private_groups = memory_config.get("memory_private_groups", {})
|
||||
|
||||
def mood(parent: dict):
|
||||
mood_config = parent["mood"]
|
||||
|
||||
@@ -6,6 +6,8 @@ import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
@@ -192,11 +194,11 @@ class EmojiManager:
|
||||
logger.error(f"获取标签失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _check_emoji(self, image_base64: str) -> str:
|
||||
async def _check_emoji(self, image_base64: str, image_format: str) -> str:
|
||||
try:
|
||||
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
|
||||
|
||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
|
||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
logger.debug(f"输出描述: {content}")
|
||||
return content
|
||||
|
||||
@@ -237,7 +239,7 @@ class EmojiManager:
|
||||
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
# 检查是否已经注册过
|
||||
existing_emoji = self.db['emoji'].find_one({'filename': filename})
|
||||
description = None
|
||||
@@ -278,7 +280,7 @@ class EmojiManager:
|
||||
|
||||
|
||||
if global_config.EMOJI_CHECK:
|
||||
check = await self._check_emoji(image_base64)
|
||||
check = await self._check_emoji(image_base64, image_format)
|
||||
if '是' not in check:
|
||||
os.remove(image_path)
|
||||
logger.info(f"描述: {description}")
|
||||
|
||||
@@ -91,12 +91,20 @@ class PromptBuilder:
|
||||
memory_prompt = ''
|
||||
start_time = time.time()
|
||||
|
||||
# 获取群组ID
|
||||
group_id = None
|
||||
if stream_id:
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if chat_stream and chat_stream.group_info:
|
||||
group_id = chat_stream.group_info.group_id
|
||||
|
||||
# 调用 hippocampus 的 get_relevant_memories 方法
|
||||
relevant_memories = await hippocampus.get_relevant_memories(
|
||||
text=message_txt,
|
||||
max_topics=5,
|
||||
similarity_threshold=0.4,
|
||||
max_memory_num=5
|
||||
max_memory_num=5,
|
||||
group_id=group_id # 传递群组ID
|
||||
)
|
||||
|
||||
if relevant_memories:
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
import aiohttp
|
||||
import hashlib
|
||||
from typing import Optional, Union
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
@@ -119,6 +121,7 @@ class ImageManager:
|
||||
|
||||
# 计算哈希值
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 查重
|
||||
existing = self.db.images.find_one({'hash': image_hash})
|
||||
@@ -127,7 +130,7 @@ class ImageManager:
|
||||
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.jpg"
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
file_path = os.path.join(self.IMAGE_DIR, filename)
|
||||
|
||||
# 保存文件
|
||||
@@ -238,7 +241,8 @@ class ImageManager:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, 'emoji')
|
||||
if cached_description:
|
||||
@@ -247,13 +251,13 @@ class ImageManager:
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.EMOJI_SAVE:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.jpg"
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
|
||||
|
||||
try:
|
||||
@@ -293,7 +297,8 @@ class ImageManager:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, 'image')
|
||||
if cached_description:
|
||||
@@ -302,7 +307,7 @@ class ImageManager:
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
print(f"描述是{description}")
|
||||
|
||||
@@ -314,7 +319,7 @@ class ImageManager:
|
||||
if global_config.EMOJI_SAVE:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.jpg"
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
file_path = os.path.join(self.IMAGE_DIR,'image', filename)
|
||||
|
||||
try:
|
||||
|
||||
@@ -4,6 +4,7 @@ import math
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import jieba
|
||||
import networkx as nx
|
||||
@@ -209,15 +210,31 @@ class Hippocampus:
|
||||
|
||||
return chat_samples
|
||||
|
||||
async def memory_compress(self, messages: list, compress_rate=0.1):
|
||||
async def memory_compress(self, messages: list, compress_rate=0.1, group_id=None):
|
||||
"""压缩消息记录为记忆
|
||||
|
||||
Args:
|
||||
messages: 消息记录列表
|
||||
compress_rate: 压缩率
|
||||
group_id: 群组ID,用于标记记忆来源
|
||||
|
||||
Returns:
|
||||
tuple: (压缩记忆集合, 相似主题字典)
|
||||
"""
|
||||
from ..chat.config import global_config # 导入配置
|
||||
|
||||
if not messages:
|
||||
return set(), {}
|
||||
|
||||
# 确定记忆所属的群组
|
||||
memory_group_tag = None
|
||||
if group_id is not None:
|
||||
# 查找群聊所属的群组
|
||||
for group_name, group_ids in global_config.memory_private_groups.items():
|
||||
if str(group_id) in group_ids:
|
||||
memory_group_tag = f"[群组:{group_name}]"
|
||||
break
|
||||
|
||||
# 合并消息文本,同时保留时间信息
|
||||
input_text = ""
|
||||
time_info = ""
|
||||
@@ -267,7 +284,17 @@ class Hippocampus:
|
||||
for topic, task in tasks:
|
||||
response = await task
|
||||
if response:
|
||||
compressed_memory.add((topic, response[0]))
|
||||
memory_content = response[0]
|
||||
|
||||
# 添加标记
|
||||
# 优先使用群组标记
|
||||
if memory_group_tag:
|
||||
memory_content = f"{memory_group_tag}{memory_content}"
|
||||
# 如果没有群组标记但有群组ID,添加简单的群组ID标记
|
||||
elif group_id is not None:
|
||||
memory_content = f"[群组:{group_id}]{memory_content}"
|
||||
|
||||
compressed_memory.add((topic, memory_content))
|
||||
# 为每个话题查找相似的已存在主题
|
||||
existing_topics = list(self.memory_graph.G.nodes())
|
||||
similar_topics = []
|
||||
@@ -316,7 +343,17 @@ class Hippocampus:
|
||||
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
|
||||
|
||||
compress_rate = global_config.memory_compress_rate
|
||||
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
||||
|
||||
# 尝试从消息中提取群组ID
|
||||
group_id = None
|
||||
if messages and len(messages) > 0:
|
||||
first_msg = messages[0]
|
||||
if 'group_id' in first_msg:
|
||||
group_id = first_msg['group_id']
|
||||
logger.info(f"检测到消息来自群组: {group_id}")
|
||||
|
||||
# 传递群组ID到memory_compress
|
||||
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate, group_id)
|
||||
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
@@ -841,8 +878,21 @@ class Hippocampus:
|
||||
return activation
|
||||
|
||||
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
|
||||
max_memory_num: int = 5) -> list:
|
||||
"""根据输入文本获取相关的记忆内容"""
|
||||
max_memory_num: int = 5, group_id: Optional[int] = None) -> list:
|
||||
"""根据输入文本获取相关的记忆内容
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
max_topics: 最大主题数
|
||||
similarity_threshold: 相似度阈值
|
||||
max_memory_num: 最大记忆数量
|
||||
group_id: 群组ID,用于优先匹配当前群组的记忆
|
||||
|
||||
Returns:
|
||||
list: 相关记忆列表
|
||||
"""
|
||||
from ..chat.config import global_config # 导入配置
|
||||
|
||||
# 识别主题
|
||||
identified_topics = await self._identify_topics(text)
|
||||
|
||||
@@ -855,30 +905,134 @@ class Hippocampus:
|
||||
|
||||
# 获取最相关的主题
|
||||
relevant_topics = self._get_top_topics(all_similar_topics, max_topics)
|
||||
|
||||
|
||||
# 确定记忆所属的群组
|
||||
current_group_name = None
|
||||
if group_id is not None:
|
||||
# 查找群聊所属的群组
|
||||
for group_name, group_ids in global_config.memory_private_groups.items():
|
||||
if str(group_id) in group_ids:
|
||||
current_group_name = group_name
|
||||
break
|
||||
|
||||
has_private_groups = len(global_config.memory_private_groups) > 0
|
||||
|
||||
# 获取相关记忆内容
|
||||
relevant_memories = []
|
||||
group_related_memories = [] # 当前群聊的记忆
|
||||
group_definition_memories = [] # 当前群组的记忆
|
||||
public_memories = [] # 公共记忆
|
||||
|
||||
for topic, score in relevant_topics:
|
||||
# 获取该主题的记忆内容
|
||||
first_layer, _ = self.memory_graph.get_related_item(topic, depth=1)
|
||||
if first_layer:
|
||||
# 如果记忆条数超过限制,随机选择指定数量的记忆
|
||||
if len(first_layer) > max_memory_num / 2:
|
||||
first_layer = random.sample(first_layer, max_memory_num // 2)
|
||||
if len(first_layer) > max_memory_num:
|
||||
first_layer = random.sample(first_layer, max_memory_num)
|
||||
|
||||
# 为每条记忆添加来源主题和相似度信息
|
||||
for memory in first_layer:
|
||||
relevant_memories.append({
|
||||
memory_info = {
|
||||
'topic': topic,
|
||||
'similarity': score,
|
||||
'content': memory
|
||||
})
|
||||
}
|
||||
|
||||
memory_text = str(memory)
|
||||
|
||||
# 分类处理记忆
|
||||
if has_private_groups and group_id is not None:
|
||||
# 如果配置了私有群组且当前在群聊中
|
||||
if current_group_name:
|
||||
# 当前群聊属于某个群组
|
||||
if f"[群组:{current_group_name}]" in memory_text:
|
||||
# 当前群组的记忆
|
||||
group_definition_memories.append(memory_info)
|
||||
elif not any(f"[群组:" in memory_text for _ in range(1)):
|
||||
# 公共记忆
|
||||
public_memories.append(memory_info)
|
||||
else:
|
||||
# 当前群聊不属于任何群组
|
||||
if f"[群组:{group_id}]" in memory_text:
|
||||
# 当前群聊的特定记忆
|
||||
group_related_memories.append(memory_info)
|
||||
elif not any(f"[群组:" in memory_text for _ in range(1)):
|
||||
# 公共记忆
|
||||
public_memories.append(memory_info)
|
||||
elif global_config.memory_group_priority and group_id is not None:
|
||||
# 如果只启用了群组记忆优先
|
||||
if f"[群组:{group_id}]" in memory_text:
|
||||
# 当前群聊的记忆,放入群组相关记忆列表
|
||||
group_related_memories.append(memory_info)
|
||||
else:
|
||||
# 其他记忆,放入公共记忆列表
|
||||
public_memories.append(memory_info)
|
||||
else:
|
||||
# 如果没有特殊配置,所有记忆都放入相关记忆列表
|
||||
relevant_memories.append(memory_info)
|
||||
|
||||
# 如果记忆数量超过5个,随机选择5个
|
||||
# 按相似度排序
|
||||
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
|
||||
if len(relevant_memories) > max_memory_num:
|
||||
relevant_memories = random.sample(relevant_memories, max_memory_num)
|
||||
# 根据配置决定如何组合记忆
|
||||
if has_private_groups and group_id is not None:
|
||||
# 配置了私有群组且当前在群聊中
|
||||
if current_group_name:
|
||||
# 当前群聊属于某个群组
|
||||
# 优先使用当前群组的记忆,如果不足再使用公共记忆
|
||||
if len(group_definition_memories) >= max_memory_num:
|
||||
# 如果群组记忆足够,只使用群组记忆
|
||||
group_definition_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
relevant_memories = group_definition_memories[:max_memory_num]
|
||||
else:
|
||||
# 如果群组记忆不足,添加公共记忆
|
||||
group_definition_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
public_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
|
||||
relevant_memories = group_definition_memories.copy()
|
||||
remaining_count = max_memory_num - len(relevant_memories)
|
||||
if remaining_count > 0 and public_memories:
|
||||
selected_other = public_memories[:remaining_count]
|
||||
relevant_memories.extend(selected_other)
|
||||
else:
|
||||
# 当前群聊不属于任何群组
|
||||
# 优先使用当前群聊的记忆,然后使用公共记忆
|
||||
if len(group_related_memories) >= max_memory_num:
|
||||
# 如果当前群聊记忆足够,只使用当前群聊记忆
|
||||
group_related_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
relevant_memories = group_related_memories[:max_memory_num]
|
||||
else:
|
||||
# 如果当前群聊记忆不足,添加公共记忆
|
||||
group_related_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
public_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
|
||||
relevant_memories = group_related_memories.copy()
|
||||
remaining_count = max_memory_num - len(relevant_memories)
|
||||
if remaining_count > 0 and public_memories:
|
||||
selected_other = public_memories[:remaining_count]
|
||||
relevant_memories.extend(selected_other)
|
||||
elif global_config.memory_group_priority and group_id is not None:
|
||||
# 如果只启用了群组记忆优先
|
||||
# 按相似度排序
|
||||
group_related_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
public_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
|
||||
# 优先使用群组相关记忆,如果不足再使用其他记忆
|
||||
if len(group_related_memories) >= max_memory_num:
|
||||
# 如果群组相关记忆足够,只使用群组相关记忆
|
||||
relevant_memories = group_related_memories[:max_memory_num]
|
||||
else:
|
||||
# 使用所有群组相关记忆
|
||||
relevant_memories = group_related_memories.copy()
|
||||
# 如果群组相关记忆不足,添加其他记忆
|
||||
remaining_count = max_memory_num - len(relevant_memories)
|
||||
if remaining_count > 0 and public_memories:
|
||||
# 从其他记忆中选择剩余需要的数量
|
||||
selected_other = public_memories[:remaining_count]
|
||||
relevant_memories.extend(selected_other)
|
||||
else:
|
||||
# 如果没有特殊配置,按相似度排序
|
||||
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
if len(relevant_memories) > max_memory_num:
|
||||
relevant_memories = relevant_memories[:max_memory_num]
|
||||
|
||||
return relevant_memories
|
||||
|
||||
|
||||
@@ -104,6 +104,7 @@ class LLM_request:
|
||||
endpoint: str,
|
||||
prompt: str = None,
|
||||
image_base64: str = None,
|
||||
image_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
@@ -115,6 +116,7 @@ class LLM_request:
|
||||
endpoint: API端点路径 (如 "chat/completions")
|
||||
prompt: prompt文本
|
||||
image_base64: 图片的base64编码
|
||||
image_format: 图片格式
|
||||
payload: 请求体数据
|
||||
retry_policy: 自定义重试策略
|
||||
response_handler: 自定义响应处理器
|
||||
@@ -151,7 +153,7 @@ class LLM_request:
|
||||
|
||||
# 构建请求体
|
||||
if image_base64:
|
||||
payload = await self._build_payload(prompt, image_base64)
|
||||
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||
elif payload is None:
|
||||
payload = await self._build_payload(prompt)
|
||||
|
||||
@@ -172,7 +174,7 @@ class LLM_request:
|
||||
if response.status == 413:
|
||||
logger.warning("请求体过大,尝试压缩...")
|
||||
image_base64 = compress_base64_image_by_scale(image_base64)
|
||||
payload = await self._build_payload(prompt, image_base64)
|
||||
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||
elif response.status in [500, 503]:
|
||||
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
||||
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
||||
@@ -294,7 +296,7 @@ class LLM_request:
|
||||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||||
return new_params
|
||||
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None) -> dict:
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||||
"""构建请求体"""
|
||||
# 复制一份参数,避免直接修改 self.params
|
||||
params_copy = await self._transform_parameters(self.params)
|
||||
@@ -306,7 +308,7 @@ class LLM_request:
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}}
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -391,13 +393,14 @@ class LLM_request:
|
||||
)
|
||||
return content, reasoning_content
|
||||
|
||||
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
||||
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]:
|
||||
"""根据输入的提示和图片生成模型的异步响应"""
|
||||
|
||||
content, reasoning_content = await self._execute_request(
|
||||
endpoint="/chat/completions",
|
||||
prompt=prompt,
|
||||
image_base64=image_base64
|
||||
image_base64=image_base64,
|
||||
image_format=image_format
|
||||
)
|
||||
return content, reasoning_content
|
||||
|
||||
|
||||
Reference in New Issue
Block a user