This commit is contained in:
SengokuCola
2025-03-27 00:23:13 +08:00
parent 5886b1c849
commit 8da4729c17
5 changed files with 49 additions and 24 deletions

View File

@@ -84,7 +84,11 @@ async def start_background_tasks():
@driver.on_startup
async def init_schedule():
"""在 NoneBot2 启动时初始化日程系统"""
bot_schedule.initialize(name=global_config.BOT_NICKNAME, personality=global_config.PROMPT_PERSONALITY, behavior=global_config.PROMPT_SCHEDULE_GEN, interval=global_config.SCHEDULE_DOING_UPDATE_INTERVAL)
bot_schedule.initialize(
name=global_config.BOT_NICKNAME,
personality=global_config.PROMPT_PERSONALITY,
behavior=global_config.PROMPT_SCHEDULE_GEN,
interval=global_config.SCHEDULE_DOING_UPDATE_INTERVAL)
asyncio.create_task(bot_schedule.mai_schedule_start())
@driver.on_startup

View File

@@ -57,7 +57,7 @@ class PromptBuilder:
mood_prompt = mood_manager.get_prompt()
# 日程构建
schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}'''
# schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}'''
# 获取聊天上下文
chat_in_group = True

View File

@@ -1,10 +1,7 @@
import asyncio
import os
import time
from typing import Tuple, Union
import aiohttp
import requests
from src.common.logger import get_module_logger
logger = get_module_logger("offline_llm")

View File

@@ -24,7 +24,7 @@ logger = get_module_logger("scheduler", config=schedule_config)
class ScheduleGenerator:
# enable_output: bool = True
def __init__(self, ):
def __init__(self):
# 使用离线LLM模型
self.llm_scheduler_all = LLM_request(
model= global_config.llm_reasoning, temperature=0.9, max_tokens=7000,request_type="schedule")
@@ -45,7 +45,11 @@ class ScheduleGenerator:
self.schedule_doing_update_interval = 300 #最好大于60
def initialize(self,name: str = "bot_name", personality: str = "你是一个爱国爱党的新时代青年", behavior: str = "你非常外向,喜欢尝试新事物和人交流",interval: int = 60):
def initialize(
self,name: str = "bot_name",
personality: str = "你是一个爱国爱党的新时代青年",
behavior: str = "你非常外向,喜欢尝试新事物和人交流",
interval: int = 60):
"""初始化日程系统"""
self.name = name
self.behavior = behavior
@@ -117,7 +121,7 @@ class ScheduleGenerator:
prompt = f"你是{self.name}{self.personality}{self.behavior}"
prompt += f"你昨天的日程是:{self.yesterday_schedule_text}\n"
prompt += f"请为你生成{date_str}{weekday})的日程安排,结合你的个人特点和行为习惯\n"
prompt += "推测你的日程安排包括你一天都在做什么从起床到睡眠有什么发现和思考具体一些详细一些需要1500字以上精确到每半个小时记得写明时间\n"
prompt += "推测你的日程安排包括你一天都在做什么从起床到睡眠有什么发现和思考具体一些详细一些需要1500字以上精确到每半个小时记得写明时间\n" #noqa: E501
prompt += "直接返回你的日程,从起床到睡觉,不要输出其他内容:"
return prompt
@@ -132,7 +136,7 @@ class ScheduleGenerator:
prompt = f"你是{self.name}{self.personality}{self.behavior}"
prompt += f"你今天的日程是:{self.today_schedule_text}\n"
prompt += f"你之前做了的事情是:{previous_doings},从之前到现在已经过去了{self.schedule_doing_update_interval/60}分钟了\n"
prompt += f"你之前做了的事情是:{previous_doings},从之前到现在已经过去了{self.schedule_doing_update_interval/60}分钟了\n" #noqa: E501
if mind_thinking:
prompt += f"你脑子里在想:{mind_thinking}\n"
prompt += f"现在是{now_time},结合你的个人特点和行为习惯,"
@@ -284,7 +288,11 @@ class ScheduleGenerator:
async def main():
# 使用示例
scheduler = ScheduleGenerator()
scheduler.initialize(name="麦麦", personality="你叫麦麦你19岁是一个大二的女大学生你有一头黑色短发你会刷贴吧你现在在学习心理学", behavior="你比较内向一般熬夜比较晚然后第二天早上10点起床吃早午饭",interval=60)
scheduler.initialize(
name="麦麦",
personality="你叫麦麦你19岁是一个大二的女大学生你有一头黑色短发你会刷贴吧你现在在学习心理学",
behavior="你比较内向一般熬夜比较晚然后第二天早上10点起床吃早午饭",
interval=60)
await scheduler.mai_schedule_start()

View File

@@ -1,8 +1,7 @@
import tomli
import sys
import re
from pathlib import Path
from typing import Dict, Any, List, Set, Tuple
from typing import Dict, Any, List, Tuple
def load_toml_file(file_path: str) -> Dict[str, Any]:
"""加载TOML文件"""
@@ -184,10 +183,15 @@ def check_model_configurations(config: Dict[str, Any], env_vars: Dict[str, str])
provider = model_config["provider"].upper()
# 检查拼写错误
for known_provider, correct_provider in reverse_mapping.items():
for known_provider, _correct_provider in reverse_mapping.items():
# 使用模糊匹配检测拼写错误
if provider != known_provider and _similar_strings(provider, known_provider) and provider not in reverse_mapping:
errors.append(f"[model.{model_name}]的provider '{model_config['provider']}' 可能拼写错误,应为 '{known_provider}'")
if (provider != known_provider and
_similar_strings(provider, known_provider) and
provider not in reverse_mapping):
errors.append(
f"[model.{model_name}]的provider '{model_config['provider']}' "
f"可能拼写错误,应为 '{known_provider}'"
)
break
return errors
@@ -223,7 +227,7 @@ def check_api_providers(config: Dict[str, Any], env_vars: Dict[str, str]) -> Lis
# 检查配置文件中使用的所有提供商
used_providers = set()
for model_category, model_config in config["model"].items():
for _model_category, model_config in config["model"].items():
if "provider" in model_config:
provider = model_config["provider"]
used_providers.add(provider)
@@ -247,7 +251,7 @@ def check_api_providers(config: Dict[str, Any], env_vars: Dict[str, str]) -> Lis
# 特别检查常见的拼写错误
for provider in used_providers:
if provider.upper() == "SILICONFOLW":
errors.append(f"提供商 'SILICONFOLW' 存在拼写错误,应为 'SILICONFLOW'")
errors.append("提供商 'SILICONFOLW' 存在拼写错误,应为 'SILICONFLOW'")
return errors
@@ -272,7 +276,7 @@ def check_groups_configuration(config: Dict[str, Any]) -> List[str]:
"main": "groups.talk_allowed中存在默认示例值'123',请修改为真实的群号",
"details": [
f" 当前值: {groups['talk_allowed']}",
f" '123'为示例值,需要替换为真实群号"
" '123'为示例值,需要替换为真实群号"
]
})
@@ -371,7 +375,8 @@ def check_memory_config(config: Dict[str, Any]) -> List[str]:
if "memory_compress_rate" in memory and (memory["memory_compress_rate"] <= 0 or memory["memory_compress_rate"] > 1):
errors.append(f"memory.memory_compress_rate值无效: {memory['memory_compress_rate']}, 应在0-1之间")
if "memory_forget_percentage" in memory and (memory["memory_forget_percentage"] <= 0 or memory["memory_forget_percentage"] > 1):
if ("memory_forget_percentage" in memory
and (memory["memory_forget_percentage"] <= 0 or memory["memory_forget_percentage"] > 1)):
errors.append(f"memory.memory_forget_percentage值无效: {memory['memory_forget_percentage']}, 应在0-1之间")
return errors
@@ -393,7 +398,10 @@ def check_personality_config(config: Dict[str, Any]) -> List[str]:
else:
# 检查数组长度
if len(personality["prompt_personality"]) < 1:
errors.append(f"personality.prompt_personality数组长度不足当前长度: {len(personality['prompt_personality'])}, 需要至少1项")
errors.append(
f"personality.prompt_personality至少需要1项"
f"当前长度: {len(personality['prompt_personality'])}"
)
else:
# 模板默认值
template_values = [
@@ -452,10 +460,13 @@ def check_bot_config(config: Dict[str, Any]) -> List[str]:
def format_results(all_errors):
"""格式化检查结果"""
sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results = all_errors
sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results = all_errors # noqa: E501, F821
bot_errors, bot_infos = bot_results
if not any([sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_errors]):
if not any([
sections_errors, prob_sum_errors,
prob_range_errors, model_errors, api_errors, groups_errors,
kr_errors, willing_errors, memory_errors, personality_errors, bot_errors]):
result = "✅ 配置文件检查通过,未发现问题。"
# 添加机器人信息
@@ -574,7 +585,10 @@ def main():
bot_results = check_bot_config(config)
# 格式化并打印结果
all_errors = (sections_errors, prob_sum_errors, prob_range_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, personality_errors, bot_results)
all_errors = (
sections_errors, prob_sum_errors,
prob_range_errors, model_errors, api_errors, groups_errors,
kr_errors, willing_errors, memory_errors, personality_errors, bot_results)
result = format_results(all_errors)
print("📋 机器人配置检查结果:")
print(result)
@@ -586,7 +600,9 @@ def main():
bot_errors, _ = bot_results
# 计算普通错误列表的长度
for errors in [sections_errors, model_errors, api_errors, groups_errors, kr_errors, willing_errors, memory_errors, bot_errors]:
for errors in [
sections_errors, model_errors, api_errors,
groups_errors, kr_errors, willing_errors, memory_errors, bot_errors]:
total_errors += len(errors)
# 计算元组列表的长度(概率相关错误)