This commit is contained in:
SengokuCola
2025-03-27 00:23:13 +08:00
parent 5886b1c849
commit 8da4729c17
5 changed files with 49 additions and 24 deletions

View File

@@ -84,7 +84,11 @@ async def start_background_tasks():
@driver.on_startup @driver.on_startup
async def init_schedule(): async def init_schedule():
"""在 NoneBot2 启动时初始化日程系统""" """在 NoneBot2 启动时初始化日程系统"""
bot_schedule.initialize(name=global_config.BOT_NICKNAME, personality=global_config.PROMPT_PERSONALITY, behavior=global_config.PROMPT_SCHEDULE_GEN, interval=global_config.SCHEDULE_DOING_UPDATE_INTERVAL) bot_schedule.initialize(
name=global_config.BOT_NICKNAME,
personality=global_config.PROMPT_PERSONALITY,
behavior=global_config.PROMPT_SCHEDULE_GEN,
interval=global_config.SCHEDULE_DOING_UPDATE_INTERVAL)
asyncio.create_task(bot_schedule.mai_schedule_start()) asyncio.create_task(bot_schedule.mai_schedule_start())
@driver.on_startup @driver.on_startup

View File

@@ -57,7 +57,7 @@ class PromptBuilder:
mood_prompt = mood_manager.get_prompt() mood_prompt = mood_manager.get_prompt()
# 日程构建 # 日程构建
schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}''' # schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}'''
# 获取聊天上下文 # 获取聊天上下文
chat_in_group = True chat_in_group = True

View File

@@ -1,10 +1,7 @@
import asyncio import asyncio
import os import os
import time
from typing import Tuple, Union
import aiohttp import aiohttp
import requests
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")

View File

@@ -24,7 +24,7 @@ logger = get_module_logger("scheduler", config=schedule_config)
class ScheduleGenerator: class ScheduleGenerator:
# enable_output: bool = True # enable_output: bool = True
def __init__(self, ): def __init__(self):
# 使用离线LLM模型 # 使用离线LLM模型
self.llm_scheduler_all = LLM_request( self.llm_scheduler_all = LLM_request(
model= global_config.llm_reasoning, temperature=0.9, max_tokens=7000,request_type="schedule") model= global_config.llm_reasoning, temperature=0.9, max_tokens=7000,request_type="schedule")
@@ -45,7 +45,11 @@ class ScheduleGenerator:
self.schedule_doing_update_interval = 300 #最好大于60 self.schedule_doing_update_interval = 300 #最好大于60
def initialize(self,name: str = "bot_name", personality: str = "你是一个爱国爱党的新时代青年", behavior: str = "你非常外向,喜欢尝试新事物和人交流",interval: int = 60): def initialize(
self,name: str = "bot_name",
personality: str = "你是一个爱国爱党的新时代青年",
behavior: str = "你非常外向,喜欢尝试新事物和人交流",
interval: int = 60):
"""初始化日程系统""" """初始化日程系统"""
self.name = name self.name = name
self.behavior = behavior self.behavior = behavior
@@ -117,7 +121,7 @@ class ScheduleGenerator:
prompt = f"你是{self.name}{self.personality}{self.behavior}" prompt = f"你是{self.name}{self.personality}{self.behavior}"
prompt += f"你昨天的日程是:{self.yesterday_schedule_text}\n" prompt += f"你昨天的日程是:{self.yesterday_schedule_text}\n"
prompt += f"请为你生成{date_str}{weekday})的日程安排,结合你的个人特点和行为习惯\n" prompt += f"请为你生成{date_str}{weekday})的日程安排,结合你的个人特点和行为习惯\n"
prompt += "推测你的日程安排包括你一天都在做什么从起床到睡眠有什么发现和思考具体一些详细一些需要1500字以上精确到每半个小时记得写明时间\n" prompt += "推测你的日程安排包括你一天都在做什么从起床到睡眠有什么发现和思考具体一些详细一些需要1500字以上精确到每半个小时记得写明时间\n" #noqa: E501
prompt += "直接返回你的日程,从起床到睡觉,不要输出其他内容:" prompt += "直接返回你的日程,从起床到睡觉,不要输出其他内容:"
return prompt return prompt
@@ -132,7 +136,7 @@ class ScheduleGenerator:
prompt = f"你是{self.name}{self.personality}{self.behavior}" prompt = f"你是{self.name}{self.personality}{self.behavior}"
prompt += f"你今天的日程是:{self.today_schedule_text}\n" prompt += f"你今天的日程是:{self.today_schedule_text}\n"
prompt += f"你之前做了的事情是:{previous_doings},从之前到现在已经过去了{self.schedule_doing_update_interval/60}分钟了\n" prompt += f"你之前做了的事情是:{previous_doings},从之前到现在已经过去了{self.schedule_doing_update_interval/60}分钟了\n" #noqa: E501
if mind_thinking: if mind_thinking:
prompt += f"你脑子里在想:{mind_thinking}\n" prompt += f"你脑子里在想:{mind_thinking}\n"
prompt += f"现在是{now_time},结合你的个人特点和行为习惯," prompt += f"现在是{now_time},结合你的个人特点和行为习惯,"
@@ -284,7 +288,11 @@ class ScheduleGenerator:
async def main(): async def main():
# 使用示例 # 使用示例
scheduler = ScheduleGenerator() scheduler = ScheduleGenerator()
scheduler.initialize(name="麦麦", personality="你叫麦麦你19岁是一个大二的女大学生你有一头黑色短发你会刷贴吧你现在在学习心理学", behavior="你比较内向一般熬夜比较晚然后第二天早上10点起床吃早午饭",interval=60) scheduler.initialize(
name="麦麦",
personality="你叫麦麦你19岁是一个大二的女大学生你有一头黑色短发你会刷贴吧你现在在学习心理学",
behavior="你比较内向一般熬夜比较晚然后第二天早上10点起床吃早午饭",
interval=60)
await scheduler.mai_schedule_start() await scheduler.mai_schedule_start()

View File

@@ -1,8 +1,7 @@
import tomli import tomli
import sys import sys
import re
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Set, Tuple from typing import Dict, Any, List, Tuple
def load_toml_file(file_path: str) -> Dict[str, Any]: def load_toml_file(file_path: str) -> Dict[str, Any]:
"""加载TOML文件""" """加载TOML文件"""
@@ -184,10 +183,15 @@ def check_model_configurations(config: Dict[str, Any], env_vars: Dict[str, str])
provider = model_config["provider"].upper() provider = model_config["provider"].upper()
# 检查拼写错误 # 检查拼写错误
for known_provider, correct_provider in reverse_mapping.items(): for known_provider, _correct_provider in reverse_mapping.items():
# 使用模糊匹配检测拼写错误 # 使用模糊匹配检测拼写错误
if provider != known_provider and _similar_strings(provider, known_provider) and provider not in reverse_mapping: if (provider != known_provider and
errors.append(f"[model.{model_name}]的provider '{model_config['provider']}' 可能拼写错误,应为 '{known_provider}'") _similar_strings(provider, known_provider) and
provider not in reverse_mapping):
errors.append(
f"[model.{model_name}]的provider '{model_config['provider']}' "
f"可能拼写错误,应为 '{known_provider}'"
)
break break
return errors return errors
@@ -223,7 +227,7 @@ def check_api_providers(config: Dict[str, Any], env_vars: Dict[str, str]) -> Lis
# 检查配置文件中使用的所有提供商 # 检查配置文件中使用的所有提供商
used_providers = set() used_providers = set()
for model_category, model_config in config["model"].items(): for _model_category, model_config in config["model"].items():
if "provider" in model_config: if "provider" in model_config:
provider = model_config["provider"] provider = model_config["provider"]
used_providers.add(provider) used_providers.add(provider)
@@ -247,7 +251,7 @@ def check_api_providers(config: Dict[str, Any], env_vars: Dict[str, str]) -> Lis
# 特别检查常见的拼写错误 # 特别检查常见的拼写错误
for provider in used_providers: for provider in used_providers:
if provider.upper() == "SILICONFOLW": if provider.upper() == "SILICONFOLW":
errors.append(f"提供商 'SILICONFOLW' 存在拼写错误,应为 'SILICONFLOW'") errors.append("提供商 'SILICONFOLW' 存在拼写错误,应为 'SILICONFLOW'")
return errors return errors
@@ -272,7 +276,7 @@ def check_groups_configuration(config: Dict[str, Any]) -> List[str]:
"main": "groups.talk_allowed中存在默认示例值'123',请修改为真实的群号", "main": "groups.talk_allowed中存在默认示例值'123',请修改为真实的群号",
"details": [ "details": [
f" 当前值: {groups['talk_allowed']}", f" 当前值: {groups['talk_allowed']}",
f" '123'为示例值,需要替换为真实群号" " '123'为示例值,需要替换为真实群号"
] ]
}) })
@@ -371,7 +375,8 @@ def check_memory_config(config: Dict[str, Any]) -> List[str]:
if "memory_compress_rate" in memory and (memory["memory_compress_rate"] <= 0 or memory["memory_compress_rate"] > 1): if "memory_compress_rate" in memory and (memory["memory_compress_rate"] <= 0 or memory["memory_compress_rate"] > 1):
errors.append(f"memory.memory_compress_rate值无效: {memory['memory_compress_rate']}, 应在0-1之间") errors.append(f"memory.memory_compress_rate值无效: {memory['memory_compress_rate']}, 应在0-1之间")
if "memory_forget_percentage" in memory and (memory["memory_forget_percentage"] <= 0 or memory["memory_forget_percentage"] > 1): if ("memory_forget_percentage" in memory
and (memory["memory_forget_percentage"] <= 0 or memory["memory_forget_percentage"] > 1)):
errors.append(f"memory.memory_forget_percentage值无效: {memory['memory_forget_percentage']}, 应在0-1之间") errors.append(f"memory.memory_forget_percentage值无效: {memory['memory_forget_percentage']}, 应在0-1之间")
return errors return errors
@@ -393,7 +398,10 @@ def check_personality_config(config: Dict[str, Any]) -> List[str]:
else: else:
# 检查数组长度 # 检查数组长度
if len(personality["prompt_personality"]) < 1: if len(personality["prompt_personality"]) < 1:
errors.append(f"personality.prompt_personality数组长度不足当前长度: {len(personality['prompt_personality'])}, 需要至少1项") errors.append(
f"personality.prompt_personality至少需要1项"
f"当前长度: {len(personality['prompt_personality'])}"
)
else: else:
# 模板默认值 # 模板默认值
template_values = [ template_values = [
@@ -452,10 +460,13 @@ def check_bot_config(config: Dict[str, Any]) -> List[str]:
def format_results(all_errors): def format_results(all_errors):
"""格式化检查结果""" """格式化检查结果"""
sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results = all_errors sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results = all_errors # noqa: E501, F821
bot_errors, bot_infos = bot_results bot_errors, bot_infos = bot_results
if not any([sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_errors]): if not any([
sections_errors, prob_sum_errors,
prob_range_errors, model_errors, api_errors, groups_errors,
kr_errors, willing_errors, memory_errors, personality_errors, bot_errors]):
result = "✅ 配置文件检查通过,未发现问题。" result = "✅ 配置文件检查通过,未发现问题。"
# 添加机器人信息 # 添加机器人信息
@@ -574,7 +585,10 @@ def main():
bot_results = check_bot_config(config) bot_results = check_bot_config(config)
# 格式化并打印结果 # 格式化并打印结果
all_errors = (sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results) all_errors = (
sections_errors, prob_sum_errors,
prob_range_errors, model_errors, api_errors, groups_errors,
kr_errors, willing_errors, memory_errors, personality_errors, bot_results)
result = format_results(all_errors) result = format_results(all_errors)
print("📋 机器人配置检查结果:") print("📋 机器人配置检查结果:")
print(result) print(result)
@@ -586,7 +600,9 @@ def main():
bot_errors, _ = bot_results bot_errors, _ = bot_results
# 计算普通错误列表的长度 # 计算普通错误列表的长度
for errors in [sections_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, bot_errors]: for errors in [
sections_errors, model_errors, api_errors,
groups_errors, kr_errors, willing_errors, memory_errors, bot_errors]:
total_errors += len(errors) total_errors += len(errors)
# 计算元组列表的长度(概率相关错误) # 计算元组列表的长度(概率相关错误)