Merge remote-tracking branch 'upstream/debug' into debug

This commit is contained in:
tcmofashi
2025-03-12 17:14:49 +08:00
7 changed files with 230 additions and 35 deletions

View File

@@ -104,6 +104,12 @@ class BotConfig:
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
# 是否优先使用当前群组的记忆
memory_group_priority: bool = True # 默认开启群组记忆优先
# 群组记忆私有化
memory_private_groups: dict = field(default_factory=dict) # 群组私有记忆配置
@staticmethod
def get_config_dir() -> str:
@@ -304,6 +310,12 @@ class BotConfig:
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage)
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
# 添加对memory_group_priority配置项的加载
config.memory_group_priority = memory_config.get("memory_group_priority", config.memory_group_priority)
if config.INNER_VERSION in SpecifierSet(">=0.0.9"):
# 添加群组记忆私有化配置项的加载
config.memory_private_groups = memory_config.get("memory_private_groups", {})
def mood(parent: dict):
mood_config = parent["mood"]

View File

@@ -6,6 +6,8 @@ import random
import time
import traceback
from typing import Optional, Tuple
from PIL import Image
import io
from loguru import logger
from nonebot import get_driver
@@ -192,11 +194,11 @@ class EmojiManager:
logger.error(f"获取标签失败: {str(e)}")
return None
async def _check_emoji(self, image_base64: str) -> str:
async def _check_emoji(self, image_base64: str, image_format: str) -> str:
try:
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
logger.debug(f"输出描述: {content}")
return content
@@ -237,7 +239,7 @@ class EmojiManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过
existing_emoji = self.db['emoji'].find_one({'filename': filename})
description = None
@@ -278,7 +280,7 @@ class EmojiManager:
if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64)
check = await self._check_emoji(image_base64, image_format)
if '' not in check:
os.remove(image_path)
logger.info(f"描述: {description}")

View File

@@ -91,12 +91,20 @@ class PromptBuilder:
memory_prompt = ''
start_time = time.time()
# 获取群组ID
group_id = None
if stream_id:
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream and chat_stream.group_info:
group_id = chat_stream.group_info.group_id
# 调用 hippocampus 的 get_relevant_memories 方法
relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt,
max_topics=5,
similarity_threshold=0.4,
max_memory_num=5
max_memory_num=5,
group_id=group_id # 传递群组ID
)
if relevant_memories:

View File

@@ -4,6 +4,8 @@ import time
import aiohttp
import hashlib
from typing import Optional, Union
from PIL import Image
import io
from loguru import logger
from nonebot import get_driver
@@ -119,6 +121,7 @@ class ImageManager:
# 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查重
existing = self.db.images.find_one({'hash': image_hash})
@@ -127,7 +130,7 @@ class ImageManager:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件
@@ -238,7 +241,8 @@ class ImageManager:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji')
if cached_description:
@@ -247,13 +251,13 @@ class ImageManager:
# 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
try:
@@ -293,7 +297,8 @@ class ImageManager:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image')
if cached_description:
@@ -302,7 +307,7 @@ class ImageManager:
# 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
print(f"描述是{description}")
@@ -314,7 +319,7 @@ class ImageManager:
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR,'image', filename)
try:

View File

@@ -4,6 +4,7 @@ import math
import random
import time
import os
from typing import Optional
import jieba
import networkx as nx
@@ -209,15 +210,31 @@ class Hippocampus:
return chat_samples
async def memory_compress(self, messages: list, compress_rate=0.1):
async def memory_compress(self, messages: list, compress_rate=0.1, group_id=None):
"""压缩消息记录为记忆
Args:
messages: 消息记录列表
compress_rate: 压缩率
group_id: 群组ID用于标记记忆来源
Returns:
tuple: (压缩记忆集合, 相似主题字典)
"""
from ..chat.config import global_config # 导入配置
if not messages:
return set(), {}
# 确定记忆所属的群组
memory_group_tag = None
if group_id is not None:
# 查找群聊所属的群组
for group_name, group_ids in global_config.memory_private_groups.items():
if str(group_id) in group_ids:
memory_group_tag = f"[群组:{group_name}]"
break
# 合并消息文本,同时保留时间信息
input_text = ""
time_info = ""
@@ -267,7 +284,17 @@ class Hippocampus:
for topic, task in tasks:
response = await task
if response:
compressed_memory.add((topic, response[0]))
memory_content = response[0]
# 添加标记
# 优先使用群组标记
if memory_group_tag:
memory_content = f"{memory_group_tag}{memory_content}"
# 如果没有群组标记但有群组ID添加简单的群组ID标记
elif group_id is not None:
memory_content = f"[群组:{group_id}]{memory_content}"
compressed_memory.add((topic, memory_content))
# 为每个话题查找相似的已存在主题
existing_topics = list(self.memory_graph.G.nodes())
similar_topics = []
@@ -316,7 +343,17 @@ class Hippocampus:
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
compress_rate = global_config.memory_compress_rate
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
# 尝试从消息中提取群组ID
group_id = None
if messages and len(messages) > 0:
first_msg = messages[0]
if 'group_id' in first_msg:
group_id = first_msg['group_id']
logger.info(f"检测到消息来自群组: {group_id}")
# 传递群组ID到memory_compress
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate, group_id)
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
current_time = datetime.datetime.now().timestamp()
@@ -841,8 +878,21 @@ class Hippocampus:
return activation
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
max_memory_num: int = 5) -> list:
"""根据输入文本获取相关的记忆内容"""
max_memory_num: int = 5, group_id: Optional[int] = None) -> list:
"""根据输入文本获取相关的记忆内容
Args:
text: 输入文本
max_topics: 最大主题数
similarity_threshold: 相似度阈值
max_memory_num: 最大记忆数量
group_id: 群组ID用于优先匹配当前群组的记忆
Returns:
list: 相关记忆列表
"""
from ..chat.config import global_config # 导入配置
# 识别主题
identified_topics = await self._identify_topics(text)
@@ -855,30 +905,134 @@ class Hippocampus:
# 获取最相关的主题
relevant_topics = self._get_top_topics(all_similar_topics, max_topics)
# 确定记忆所属的群组
current_group_name = None
if group_id is not None:
# 查找群聊所属的群组
for group_name, group_ids in global_config.memory_private_groups.items():
if str(group_id) in group_ids:
current_group_name = group_name
break
has_private_groups = len(global_config.memory_private_groups) > 0
# 获取相关记忆内容
relevant_memories = []
group_related_memories = [] # 当前群聊的记忆
group_definition_memories = [] # 当前群组的记忆
public_memories = [] # 公共记忆
for topic, score in relevant_topics:
# 获取该主题的记忆内容
first_layer, _ = self.memory_graph.get_related_item(topic, depth=1)
if first_layer:
# 如果记忆条数超过限制,随机选择指定数量的记忆
if len(first_layer) > max_memory_num / 2:
first_layer = random.sample(first_layer, max_memory_num // 2)
if len(first_layer) > max_memory_num:
first_layer = random.sample(first_layer, max_memory_num)
# 为每条记忆添加来源主题和相似度信息
for memory in first_layer:
relevant_memories.append({
memory_info = {
'topic': topic,
'similarity': score,
'content': memory
})
}
memory_text = str(memory)
# 分类处理记忆
if has_private_groups and group_id is not None:
# 如果配置了私有群组且当前在群聊中
if current_group_name:
# 当前群聊属于某个群组
if f"[群组:{current_group_name}]" in memory_text:
# 当前群组的记忆
group_definition_memories.append(memory_info)
elif not any(f"[群组:" in memory_text for _ in range(1)):
# 公共记忆
public_memories.append(memory_info)
else:
# 当前群聊不属于任何群组
if f"[群组:{group_id}]" in memory_text:
# 当前群聊的特定记忆
group_related_memories.append(memory_info)
elif not any(f"[群组:" in memory_text for _ in range(1)):
# 公共记忆
public_memories.append(memory_info)
elif global_config.memory_group_priority and group_id is not None:
# 如果只启用了群组记忆优先
if f"[群组:{group_id}]" in memory_text:
# 当前群聊的记忆,放入群组相关记忆列表
group_related_memories.append(memory_info)
else:
# 其他记忆,放入公共记忆列表
public_memories.append(memory_info)
else:
# 如果没有特殊配置,所有记忆都放入相关记忆列表
relevant_memories.append(memory_info)
# 如果记忆数量超过5个,随机选择5个
# 按相似度排序
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
if len(relevant_memories) > max_memory_num:
relevant_memories = random.sample(relevant_memories, max_memory_num)
# 根据配置决定如何组合记忆
if has_private_groups and group_id is not None:
# 配置了私有群组且当前在群聊中
if current_group_name:
# 当前群聊属于某个群组
# 优先使用当前群组的记忆,如果不足再使用公共记忆
if len(group_definition_memories) >= max_memory_num:
# 如果群组记忆足够,只使用群组记忆
group_definition_memories.sort(key=lambda x: x['similarity'], reverse=True)
relevant_memories = group_definition_memories[:max_memory_num]
else:
# 如果群组记忆不足,添加公共记忆
group_definition_memories.sort(key=lambda x: x['similarity'], reverse=True)
public_memories.sort(key=lambda x: x['similarity'], reverse=True)
relevant_memories = group_definition_memories.copy()
remaining_count = max_memory_num - len(relevant_memories)
if remaining_count > 0 and public_memories:
selected_other = public_memories[:remaining_count]
relevant_memories.extend(selected_other)
else:
# 当前群聊不属于任何群组
# 优先使用当前群聊的记忆,然后使用公共记忆
if len(group_related_memories) >= max_memory_num:
# 如果当前群聊记忆足够,只使用当前群聊记忆
group_related_memories.sort(key=lambda x: x['similarity'], reverse=True)
relevant_memories = group_related_memories[:max_memory_num]
else:
# 如果当前群聊记忆不足,添加公共记忆
group_related_memories.sort(key=lambda x: x['similarity'], reverse=True)
public_memories.sort(key=lambda x: x['similarity'], reverse=True)
relevant_memories = group_related_memories.copy()
remaining_count = max_memory_num - len(relevant_memories)
if remaining_count > 0 and public_memories:
selected_other = public_memories[:remaining_count]
relevant_memories.extend(selected_other)
elif global_config.memory_group_priority and group_id is not None:
# 如果只启用了群组记忆优先
# 按相似度排序
group_related_memories.sort(key=lambda x: x['similarity'], reverse=True)
public_memories.sort(key=lambda x: x['similarity'], reverse=True)
# 优先使用群组相关记忆,如果不足再使用其他记忆
if len(group_related_memories) >= max_memory_num:
# 如果群组相关记忆足够,只使用群组相关记忆
relevant_memories = group_related_memories[:max_memory_num]
else:
# 使用所有群组相关记忆
relevant_memories = group_related_memories.copy()
# 如果群组相关记忆不足,添加其他记忆
remaining_count = max_memory_num - len(relevant_memories)
if remaining_count > 0 and public_memories:
# 从其他记忆中选择剩余需要的数量
selected_other = public_memories[:remaining_count]
relevant_memories.extend(selected_other)
else:
# 如果没有特殊配置,按相似度排序
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
if len(relevant_memories) > max_memory_num:
relevant_memories = relevant_memories[:max_memory_num]
return relevant_memories

View File

@@ -104,6 +104,7 @@ class LLM_request:
endpoint: str,
prompt: str = None,
image_base64: str = None,
image_format: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
@@ -115,6 +116,7 @@ class LLM_request:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
image_format: 图片格式
payload: 请求体数据
retry_policy: 自定义重试策略
response_handler: 自定义响应处理器
@@ -151,7 +153,7 @@ class LLM_request:
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64)
payload = await self._build_payload(prompt, image_base64, image_format)
elif payload is None:
payload = await self._build_payload(prompt)
@@ -172,7 +174,7 @@ class LLM_request:
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64)
payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
@@ -294,7 +296,7 @@ class LLM_request:
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
return new_params
async def _build_payload(self, prompt: str, image_base64: str = None) -> dict:
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
"""构建请求体"""
# 复制一份参数,避免直接修改 self.params
params_copy = await self._transform_parameters(self.params)
@@ -306,7 +308,7 @@ class LLM_request:
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
{"type": "image_url", "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}}
]
}
],
@@ -391,13 +393,14 @@ class LLM_request:
)
return content, reasoning_content
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple[str, str]:
"""根据输入的提示和图片生成模型的异步响应"""
content, reasoning_content = await self._execute_request(
endpoint="/chat/completions",
prompt=prompt,
image_base64=image_base64
image_base64=image_base64,
image_format=image_format
)
return content, reasoning_content