v0.3.2 更改了.env config的逻辑和memory优化
v0.3.2 更改了.env config的逻辑 memory优化 读空气优化
This commit is contained in:
@@ -11,16 +11,18 @@ from .relationship_manager import relationship_manager
|
||||
from ..schedule.schedule_generator import bot_schedule
|
||||
from .willing_manager import willing_manager
|
||||
|
||||
|
||||
# 获取驱动器
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
Database.initialize(
|
||||
host= os.getenv("MONGODB_HOST"),
|
||||
port= int(os.getenv("MONGODB_PORT")),
|
||||
db_name= os.getenv("DATABASE_NAME"),
|
||||
username= os.getenv("MONGODB_USERNAME"),
|
||||
password= os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
auth_source= config.mongodb_auth_source
|
||||
)
|
||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||
|
||||
@@ -37,7 +39,7 @@ emoji_manager.initialize()
|
||||
|
||||
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
||||
# 创建机器人实例
|
||||
chat_bot = ChatBot(global_config)
|
||||
chat_bot = ChatBot()
|
||||
# 注册消息处理器
|
||||
group_msg = on_message()
|
||||
# 创建定时任务
|
||||
|
||||
@@ -18,10 +18,9 @@ from .utils import is_mentioned_bot_in_txt, calculate_typing_time
|
||||
from ..memory_system.memory import memory_graph
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self, config: BotConfig):
|
||||
self.config = config
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
self.gpt = LLMResponseGenerator(config)
|
||||
self.gpt = LLMResponseGenerator()
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
|
||||
@@ -39,11 +38,11 @@ class ChatBot:
|
||||
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
||||
"""处理收到的群消息"""
|
||||
|
||||
if event.group_id not in self.config.talk_allowed_groups:
|
||||
if event.group_id not in global_config.talk_allowed_groups:
|
||||
return
|
||||
self.bot = bot # 更新 bot 实例
|
||||
|
||||
if event.user_id in self.config.ban_user_id:
|
||||
if event.user_id in global_config.ban_user_id:
|
||||
return
|
||||
|
||||
# 打印原始消息内容
|
||||
@@ -120,7 +119,7 @@ class ChatBot:
|
||||
event.group_id,
|
||||
topic[0] if topic else None,
|
||||
is_mentioned,
|
||||
self.config,
|
||||
global_config,
|
||||
event.user_id,
|
||||
message.is_emoji,
|
||||
interested_rate
|
||||
@@ -143,10 +142,14 @@ class ChatBot:
|
||||
response, emotion = await self.gpt.generate_response(message)
|
||||
|
||||
# 如果生成了回复,发送并记录
|
||||
|
||||
|
||||
'''
|
||||
生成回复后的内容
|
||||
|
||||
'''
|
||||
|
||||
if response:
|
||||
message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id)
|
||||
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id)
|
||||
accu_typing_time = 0
|
||||
for msg in response:
|
||||
print(f"当前消息: {msg}")
|
||||
@@ -157,7 +160,7 @@ class ChatBot:
|
||||
|
||||
bot_message = Message(
|
||||
group_id=event.group_id,
|
||||
user_id=self.config.BOT_QQ,
|
||||
user_id=global_config.BOT_QQ,
|
||||
message_id=think_id,
|
||||
message_based_id=event.message_id,
|
||||
raw_message=msg,
|
||||
@@ -174,7 +177,7 @@ class ChatBot:
|
||||
|
||||
|
||||
bot_response_time = tinking_time_point
|
||||
if random() < self.config.emoji_chance:
|
||||
if random() < global_config.emoji_chance:
|
||||
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
|
||||
if emoji_path:
|
||||
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
||||
@@ -186,7 +189,7 @@ class ChatBot:
|
||||
|
||||
bot_message = Message(
|
||||
group_id=event.group_id,
|
||||
user_id=self.config.BOT_QQ,
|
||||
user_id=global_config.BOT_QQ,
|
||||
message_id=0,
|
||||
raw_message=emoji_cq,
|
||||
plain_text=emoji_cq,
|
||||
|
||||
@@ -7,6 +7,7 @@ import configparser
|
||||
import tomli
|
||||
import sys
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
|
||||
|
||||
|
||||
@@ -111,7 +112,6 @@ class BotConfig:
|
||||
# 获取配置文件路径
|
||||
bot_config_path = BotConfig.get_default_config_path()
|
||||
config_dir = os.path.dirname(bot_config_path)
|
||||
env_path = os.path.join(config_dir, '.env')
|
||||
|
||||
logger.info(f"尝试从 {bot_config_path} 加载机器人配置")
|
||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||
@@ -126,10 +126,11 @@ class LLMConfig:
|
||||
DEEP_SEEK_BASE_URL: str = None
|
||||
|
||||
llm_config = LLMConfig()
|
||||
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
|
||||
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
|
||||
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
|
||||
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
|
||||
config = get_driver().config
|
||||
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
|
||||
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
|
||||
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
|
||||
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
||||
|
||||
|
||||
if not global_config.enable_advance_output:
|
||||
|
||||
@@ -7,7 +7,7 @@ from PIL import Image
|
||||
import os
|
||||
from random import random
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from .config import global_config, llm_config
|
||||
from .config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
from .utils_image import storage_image,storage_emoji
|
||||
@@ -16,6 +16,10 @@ from .utils_user import get_user_nickname
|
||||
#包含CQ码类
|
||||
import urllib3
|
||||
from urllib3.util import create_urllib3_context
|
||||
from nonebot import get_driver
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
|
||||
ctx = create_urllib3_context()
|
||||
@@ -179,7 +183,7 @@ class CQCode:
|
||||
"""调用AI接口获取表情包描述"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
"Authorization": f"Bearer {config.siliconflow_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
@@ -206,7 +210,7 @@ class CQCode:
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
f"{config.siliconflow_base_url}chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30
|
||||
@@ -224,7 +228,7 @@ class CQCode:
|
||||
"""调用AI接口获取普通图片描述"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
"Authorization": f"Bearer {config.siliconflow_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
@@ -251,7 +255,7 @@ class CQCode:
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
f"{config.siliconflow_base_url}chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30
|
||||
|
||||
@@ -10,10 +10,14 @@ import hashlib
|
||||
from datetime import datetime
|
||||
import base64
|
||||
import shutil
|
||||
from .config import global_config, llm_config
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from nonebot import get_driver
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
_instance = None
|
||||
@@ -93,7 +97,7 @@ class EmojiManager:
|
||||
# 准备请求数据
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
"Authorization": f"Bearer {config.siliconflow_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
@@ -115,7 +119,7 @@ class EmojiManager:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
f"{config.siliconflow_base_url}chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
@@ -249,7 +253,7 @@ class EmojiManager:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
"Authorization": f"Bearer {config.siliconflow_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
@@ -276,7 +280,7 @@ class EmojiManager:
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
f"{config.siliconflow_base_url}chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
|
||||
@@ -1,34 +1,34 @@
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||
from openai import OpenAI
|
||||
import asyncio
|
||||
import requests
|
||||
from functools import partial
|
||||
from .message import Message
|
||||
from .config import BotConfig, global_config
|
||||
from .config import global_config
|
||||
from ...common.database import Database
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
from .relationship_manager import relationship_manager
|
||||
from ..schedule.schedule_generator import bot_schedule
|
||||
from .prompt_builder import prompt_builder
|
||||
from .config import llm_config, global_config
|
||||
from .config import global_config
|
||||
from .utils import process_llm_response
|
||||
from nonebot import get_driver
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
class LLMResponseGenerator:
|
||||
def __init__(self, config: BotConfig):
|
||||
self.config = config
|
||||
if self.config.API_USING == "siliconflow":
|
||||
def __init__(self):
|
||||
if global_config.API_USING == "siliconflow":
|
||||
self.client = OpenAI(
|
||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
||||
api_key=config.siliconflow_key,
|
||||
base_url=config.siliconflow_base_url
|
||||
)
|
||||
elif self.config.API_USING == "deepseek":
|
||||
elif global_config.API_USING == "deepseek":
|
||||
self.client = OpenAI(
|
||||
api_key=llm_config.DEEP_SEEK_API_KEY,
|
||||
base_url=llm_config.DEEP_SEEK_BASE_URL
|
||||
api_key=config.deep_seek_key,
|
||||
base_url=config.deep_seek_base_url
|
||||
)
|
||||
|
||||
self.db = Database.get_instance()
|
||||
@@ -52,6 +52,7 @@ class LLMResponseGenerator:
|
||||
else:
|
||||
self.current_model_type = 'r1_distill' # 默认使用 R1-Distill
|
||||
|
||||
|
||||
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
||||
if self.current_model_type == 'r1':
|
||||
model_response = await self._generate_r1_response(message)
|
||||
@@ -83,8 +84,9 @@ class LLMResponseGenerator:
|
||||
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
||||
else:
|
||||
relationship_value = 0.0
|
||||
|
||||
|
||||
# 构建prompt
|
||||
''' 构建prompt '''
|
||||
prompt,prompt_check = prompt_builder._build_prompt(
|
||||
message_txt=message.processed_plain_text,
|
||||
sender_name=sender_name,
|
||||
@@ -92,6 +94,7 @@ class LLMResponseGenerator:
|
||||
group_id=message.group_id
|
||||
)
|
||||
|
||||
|
||||
# 设置默认参数
|
||||
default_params = {
|
||||
"model": model_name,
|
||||
@@ -113,6 +116,7 @@ class LLMResponseGenerator:
|
||||
if model_params:
|
||||
default_params.update(model_params)
|
||||
|
||||
|
||||
def create_completion():
|
||||
return self.client.chat.completions.create(**default_params)
|
||||
|
||||
@@ -122,6 +126,7 @@ class LLMResponseGenerator:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# 读空气模块
|
||||
air = 0
|
||||
reasoning_content_check=''
|
||||
content_check=''
|
||||
if global_config.enable_kuuki_read:
|
||||
@@ -135,21 +140,26 @@ class LLMResponseGenerator:
|
||||
content_check = response_check.choices[0].message.content
|
||||
print(f"\033[1;32m[读空气]\033[0m 读空气结果为{content_check}")
|
||||
if 'yes' not in content_check.lower():
|
||||
self.db.db.reasoning_logs.insert_one({
|
||||
'time': time.time(),
|
||||
'group_id': message.group_id,
|
||||
'user': sender_name,
|
||||
'message': message.processed_plain_text,
|
||||
'model': model_name,
|
||||
'reasoning_check': reasoning_content_check,
|
||||
'response_check': content_check,
|
||||
'reasoning': "",
|
||||
'response': "",
|
||||
'prompt': prompt,
|
||||
'prompt_check': prompt_check,
|
||||
'model_params': default_params
|
||||
})
|
||||
return None
|
||||
air = 1
|
||||
#稀释读空气的判定
|
||||
if air == 1 and random.random() < 0.3:
|
||||
self.db.db.reasoning_logs.insert_one({
|
||||
'time': time.time(),
|
||||
'group_id': message.group_id,
|
||||
'user': sender_name,
|
||||
'message': message.processed_plain_text,
|
||||
'model': model_name,
|
||||
'reasoning_check': reasoning_content_check,
|
||||
'response_check': content_check,
|
||||
'reasoning': "",
|
||||
'response': "",
|
||||
'prompt': prompt,
|
||||
'prompt_check': prompt_check,
|
||||
'model_params': default_params
|
||||
})
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -193,7 +203,7 @@ class LLMResponseGenerator:
|
||||
|
||||
async def _generate_r1_response(self, message: Message) -> Optional[str]:
|
||||
"""使用 DeepSeek-R1 模型生成回复"""
|
||||
if self.config.API_USING == "deepseek":
|
||||
if global_config.API_USING == "deepseek":
|
||||
return await self._generate_base_response(
|
||||
message,
|
||||
"deepseek-reasoner",
|
||||
@@ -208,7 +218,7 @@ class LLMResponseGenerator:
|
||||
|
||||
async def _generate_v3_response(self, message: Message) -> Optional[str]:
|
||||
"""使用 DeepSeek-V3 模型生成回复"""
|
||||
if self.config.API_USING == "deepseek":
|
||||
if global_config.API_USING == "deepseek":
|
||||
return await self._generate_base_response(
|
||||
message,
|
||||
"deepseek-chat",
|
||||
@@ -259,7 +269,7 @@ class LLMResponseGenerator:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
if self.config.API_USING == "deepseek":
|
||||
if global_config.API_USING == "deepseek":
|
||||
model = "deepseek-chat"
|
||||
else:
|
||||
model = "Pro/deepseek-ai/DeepSeek-V3"
|
||||
@@ -296,4 +306,4 @@ class LLMResponseGenerator:
|
||||
return processed_response, emotion_tags
|
||||
|
||||
# 创建全局实例
|
||||
llm_response = LLMResponseGenerator(global_config)
|
||||
llm_response = LLMResponseGenerator()
|
||||
@@ -66,12 +66,15 @@ class PromptBuilder:
|
||||
overlapping_second_layer.update(overlap)
|
||||
|
||||
# 合并所有需要的记忆
|
||||
if all_first_layer_items:
|
||||
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
||||
if overlapping_second_layer:
|
||||
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
||||
# if all_first_layer_items:
|
||||
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆1: {all_first_layer_items}")
|
||||
# if overlapping_second_layer:
|
||||
# print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆2: {list(overlapping_second_layer)}")
|
||||
|
||||
all_memories = all_first_layer_items + list(overlapping_second_layer)
|
||||
# 使用集合去重
|
||||
all_memories = list(set(all_first_layer_items) | set(overlapping_second_layer))
|
||||
if all_memories:
|
||||
print(f"\033[1;32m[前额叶]\033[0m 合并所有需要的记忆: {all_memories}")
|
||||
|
||||
if all_memories: # 只在列表非空时选择随机项
|
||||
random_item = choice(all_memories)
|
||||
@@ -179,7 +182,11 @@ class PromptBuilder:
|
||||
# prompt += f"{activate_prompt}\n"
|
||||
prompt += f"{prompt_personality}\n"
|
||||
prompt += f"{prompt_ger}\n"
|
||||
prompt += f"{extra_info}\n"
|
||||
prompt += f"{extra_info}\n"
|
||||
|
||||
|
||||
|
||||
'''读空气prompt处理'''
|
||||
|
||||
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
|
||||
prompt_personality_check = ''
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from typing import Optional, Dict, List
|
||||
from openai import OpenAI
|
||||
from .message import Message
|
||||
from .config import global_config, llm_config
|
||||
import jieba
|
||||
from nonebot import get_driver
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
class TopicIdentifier:
|
||||
def __init__(self):
|
||||
self.client = OpenAI(
|
||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
||||
api_key=config.siliconflow_key,
|
||||
base_url=config.siliconflow_base_url
|
||||
)
|
||||
|
||||
def identify_topic_llm(self, text: str) -> Optional[str]:
|
||||
|
||||
@@ -4,11 +4,15 @@ from typing import List
|
||||
from .message import Message
|
||||
import requests
|
||||
import numpy as np
|
||||
from .config import llm_config, global_config
|
||||
from .config import global_config
|
||||
import re
|
||||
from typing import Dict
|
||||
from collections import Counter
|
||||
import math
|
||||
from nonebot import get_driver
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
def combine_messages(messages: List[Message]) -> str:
|
||||
@@ -64,7 +68,7 @@ def get_embedding(text):
|
||||
"encoding_format": "float"
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
|
||||
"Authorization": f"Bearer {config.siliconflow_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,10 @@ from ...common.database import Database
|
||||
import zlib # 用于 CRC32
|
||||
import base64
|
||||
from .config import global_config
|
||||
from nonebot import get_driver
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
||||
@@ -37,12 +41,12 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
||||
|
||||
# 连接数据库
|
||||
db = Database(
|
||||
host= os.getenv("MONGODB_HOST"),
|
||||
port= int(os.getenv("MONGODB_PORT")),
|
||||
db_name= os.getenv("DATABASE_NAME"),
|
||||
username= os.getenv("MONGODB_USERNAME"),
|
||||
password= os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
auth_source=config.mongodb_auth_source
|
||||
)
|
||||
|
||||
# 检查是否已存在相同哈希值的图片
|
||||
|
||||
@@ -58,8 +58,8 @@ class WillingManager:
|
||||
if group_id in config.talk_frequency_down_groups:
|
||||
reply_probability = reply_probability / 3.5
|
||||
|
||||
if is_mentioned_bot and user_id == int(964959351):
|
||||
reply_probability = 1
|
||||
# if is_mentioned_bot and user_id == int(1026294844):
|
||||
# reply_probability = 1
|
||||
|
||||
return reply_probability
|
||||
|
||||
|
||||
Reference in New Issue
Block a user