Merge branch 'SengokuCola:debug' into debug
This commit is contained in:
@@ -104,6 +104,12 @@ class BotConfig:
|
||||
memory_ban_words: list = field(
|
||||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
|
||||
) # 添加新的配置项默认值
|
||||
|
||||
# 是否优先使用当前群组的记忆
|
||||
memory_group_priority: bool = True # 默认开启群组记忆优先
|
||||
|
||||
# 群组记忆私有化
|
||||
memory_private_groups: dict = field(default_factory=dict) # 群组私有记忆配置
|
||||
|
||||
@staticmethod
|
||||
def get_config_dir() -> str:
|
||||
@@ -304,6 +310,12 @@ class BotConfig:
|
||||
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
|
||||
config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage)
|
||||
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
|
||||
# 添加对memory_group_priority配置项的加载
|
||||
config.memory_group_priority = memory_config.get("memory_group_priority", config.memory_group_priority)
|
||||
|
||||
if config.INNER_VERSION in SpecifierSet(">=0.0.9"):
|
||||
# 添加群组记忆私有化配置项的加载
|
||||
config.memory_private_groups = memory_config.get("memory_private_groups", {})
|
||||
|
||||
def mood(parent: dict):
|
||||
mood_config = parent["mood"]
|
||||
|
||||
@@ -6,6 +6,8 @@ import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
@@ -192,11 +194,11 @@ class EmojiManager:
|
||||
logger.error(f"获取标签失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _check_emoji(self, image_base64: str) -> str:
|
||||
async def _check_emoji(self, image_base64: str, image_format: str) -> str:
|
||||
try:
|
||||
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
|
||||
|
||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
|
||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
logger.debug(f"输出描述: {content}")
|
||||
return content
|
||||
|
||||
@@ -237,7 +239,7 @@ class EmojiManager:
|
||||
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
# 检查是否已经注册过
|
||||
existing_emoji = self.db['emoji'].find_one({'filename': filename})
|
||||
description = None
|
||||
@@ -278,7 +280,7 @@ class EmojiManager:
|
||||
|
||||
|
||||
if global_config.EMOJI_CHECK:
|
||||
check = await self._check_emoji(image_base64)
|
||||
check = await self._check_emoji(image_base64, image_format)
|
||||
if '是' not in check:
|
||||
os.remove(image_path)
|
||||
logger.info(f"描述: {description}")
|
||||
|
||||
@@ -91,12 +91,20 @@ class PromptBuilder:
|
||||
memory_prompt = ''
|
||||
start_time = time.time()
|
||||
|
||||
# 获取群组ID
|
||||
group_id = None
|
||||
if stream_id:
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if chat_stream and chat_stream.group_info:
|
||||
group_id = chat_stream.group_info.group_id
|
||||
|
||||
# 调用 hippocampus 的 get_relevant_memories 方法
|
||||
relevant_memories = await hippocampus.get_relevant_memories(
|
||||
text=message_txt,
|
||||
max_topics=5,
|
||||
similarity_threshold=0.4,
|
||||
max_memory_num=5
|
||||
max_memory_num=5,
|
||||
group_id=group_id # 传递群组ID
|
||||
)
|
||||
|
||||
if relevant_memories:
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
import aiohttp
|
||||
import hashlib
|
||||
from typing import Optional, Union
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
@@ -119,6 +121,7 @@ class ImageManager:
|
||||
|
||||
# 计算哈希值
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 查重
|
||||
existing = self.db.images.find_one({'hash': image_hash})
|
||||
@@ -127,7 +130,7 @@ class ImageManager:
|
||||
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.jpg"
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
file_path = os.path.join(self.IMAGE_DIR, filename)
|
||||
|
||||
# 保存文件
|
||||
@@ -238,7 +241,8 @@ class ImageManager:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, 'emoji')
|
||||
if cached_description:
|
||||
@@ -247,13 +251,13 @@ class ImageManager:
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.EMOJI_SAVE:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.jpg"
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
|
||||
|
||||
try:
|
||||
@@ -293,7 +297,8 @@ class ImageManager:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, 'image')
|
||||
if cached_description:
|
||||
@@ -302,7 +307,7 @@ class ImageManager:
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
print(f"描述是{description}")
|
||||
|
||||
@@ -314,7 +319,7 @@ class ImageManager:
|
||||
if global_config.EMOJI_SAVE:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.jpg"
|
||||
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
|
||||
file_path = os.path.join(self.IMAGE_DIR,'image', filename)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user