refactor: 重构配置模块

This commit is contained in:
Oct-autumn
2025-05-16 16:50:53 +08:00
parent 5d5033452d
commit 021e7f1a97
52 changed files with 902 additions and 1102 deletions

View File

@@ -369,14 +369,15 @@ class EmojiManager:
def __init__(self):
self._initialized = None
self._scan_task = None
self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.llm_emotion_judge = LLMRequest(
model=global_config.llm_normal, max_tokens=600, request_type="emoji"
model=global_config.model.normal, max_tokens=600, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
self.emoji_num_max = global_config.emoji.max_reg_num
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表使用类型注解明确列表元素类型
logger.info("启动表情包管理器")
@@ -613,18 +614,18 @@ class EmojiManager:
logger.warning(f"[警告] 表情包目录不存在: {EMOJI_DIR}")
os.makedirs(EMOJI_DIR, exist_ok=True)
logger.info(f"[创建] 已创建表情包目录: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
continue
# 检查目录是否为空
files = os.listdir(EMOJI_DIR)
if not files:
logger.warning(f"[警告] 表情包目录为空: {EMOJI_DIR}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
continue
# 检查是否需要处理表情包(数量超过最大值或不足)
if (self.emoji_num > self.emoji_num_max and global_config.max_reach_deletion) or (
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
self.emoji_num < self.emoji_num_max
):
try:
@@ -651,7 +652,7 @@ class EmojiManager:
except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
await asyncio.sleep(global_config.emoji.check_interval * 60)
async def get_all_emoji_from_db(self):
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
@@ -788,7 +789,7 @@ class EmojiManager:
# 构建提示词
prompt = (
f"{global_config.BOT_NICKNAME}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
f"需要决定是否删除一个旧表情包来为新表情包腾出空间。\n\n"
f"新表情包信息:\n"
f"描述: {new_emoji.description}\n\n"
@@ -871,10 +872,10 @@ class EmojiManager:
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
# 审核表情包
if global_config.EMOJI_CHECK:
if global_config.emoji.content_filtration:
prompt = f'''
这是一个表情包,请对这个表情包进行审核,标准如下:
1. 必须符合"{global_config.EMOJI_CHECK_PROMPT}"的要求
1. 必须符合"{global_config.emoji.filtration_prompt}"的要求
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
3. 不能是任何形式的截图,聊天记录或视频截图
4. 不要出现5个以上文字

View File

@@ -25,9 +25,10 @@ logger = get_logger("expressor")
class DefaultExpressor:
def __init__(self, chat_id: str):
self.log_prefix = "expressor"
# TODO: API-Adapter修改标记
self.express_model = LLMRequest(
model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"],
model=global_config.model.normal,
temperature=global_config.model.normal["temp"],
max_tokens=256,
request_type="response_heartflow",
)
@@ -51,8 +52,8 @@ class DefaultExpressor:
messageinfo = anchor_message.message_info
thinking_time_point = parse_thinking_id_to_timestamp(thinking_id)
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
# logger.debug(f"创建思考消息:{anchor_message}")
@@ -141,7 +142,7 @@ class DefaultExpressor:
try:
# 1. 获取情绪影响因子并调整模型温度
arousal_multiplier = mood_manager.get_arousal_multiplier()
current_temp = float(global_config.llm_normal["temp"]) * arousal_multiplier
current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
self.express_model.params["temperature"] = current_temp # 动态调整温度
# 2. 获取信息捕捉器
@@ -183,6 +184,7 @@ class DefaultExpressor:
try:
with Timer("LLM生成", {}): # 内部计时器,可选保留
# TODO: API-Adapter修改标记
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
content, reasoning_content, model_name = await self.express_model.generate_response(prompt)
@@ -330,8 +332,8 @@ class DefaultExpressor:
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(self.chat_id, thinking_id)
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
)

View File

@@ -77,8 +77,9 @@ def init_prompt() -> None:
class ExpressionLearner:
def __init__(self) -> None:
# TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest(
model=global_config.llm_normal,
model=global_config.model.normal,
temperature=0.1,
max_tokens=256,
request_type="response_heartflow",
@@ -289,7 +290,7 @@ class ExpressionLearner:
# 构建prompt
prompt = await global_prompt_manager.format_prompt(
"personality_expression_prompt",
personality=global_config.expression_style,
personality=global_config.personality.expression_style,
)
# logger.info(f"个性表达方式提取prompt: {prompt}")

View File

@@ -112,7 +112,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否包含过滤词
"""
for word in global_config.ban_words:
for word in global_config.chat.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
@@ -132,7 +132,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否匹配过滤正则
"""
for pattern in global_config.ban_msgs_regex:
for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")

View File

@@ -6,14 +6,13 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw
from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding
import time
from typing import Union, Optional, Dict, Any
from typing import Union, Optional
from src.common.database import db
from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
import traceback
import random
@@ -142,7 +141,7 @@ async def _build_prompt_focus(
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.observation_context_size,
limit=global_config.chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
@@ -209,7 +208,7 @@ async def _build_prompt_focus(
chat_target=chat_target_1, # Used in group template
# chat_talking_prompt=chat_talking_prompt,
chat_info=chat_talking_prompt,
bot_name=global_config.BOT_NICKNAME,
bot_name=global_config.bot.nickname,
# prompt_personality=prompt_personality,
prompt_personality="",
reason=reason,
@@ -225,7 +224,7 @@ async def _build_prompt_focus(
info_from_tools=structured_info_prompt,
sender_name=effective_sender_name, # Used in private template
chat_talking_prompt=chat_talking_prompt,
bot_name=global_config.BOT_NICKNAME,
bot_name=global_config.bot.nickname,
prompt_personality=prompt_personality,
# chat_target and chat_target_2 are not used in private template
current_mind_info=current_mind_info,
@@ -280,7 +279,7 @@ class PromptBuilder:
who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
limit=global_config.observation_context_size,
limit=global_config.chat.observation_context_size,
)
elif chat_stream.user_info:
who_chat_in_group.append(
@@ -328,7 +327,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.observation_context_size,
limit=global_config.chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
@@ -340,18 +339,15 @@ class PromptBuilder:
# 关键词检测与反应
keywords_reaction_prompt = ""
for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
logger.info(
f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}"
)
keywords_reaction_prompt += rule.get("reaction", "") + ""
for rule in global_config.keyword_reaction.rules:
if rule.enable:
if any(keyword in message_txt for keyword in rule.keywords):
logger.info(f"检测到以下关键词之一:{rule.keywords},触发反应:{rule.reaction}")
keywords_reaction_prompt += f"{rule.reaction}"
else:
for pattern in rule.get("regex", []):
result = pattern.search(message_txt)
if result:
reaction = rule.get("reaction", "")
for pattern in rule.regex:
if result := pattern.search(message_txt):
reaction = rule.reaction
for name, content in result.groupdict().items():
reaction = reaction.replace(f"[{name}]", content)
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
@@ -397,8 +393,8 @@ class PromptBuilder:
chat_target_2=chat_target_2,
chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME,
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
bot_name=global_config.bot.nickname,
bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen,
@@ -419,8 +415,8 @@ class PromptBuilder:
prompt_info=prompt_info,
chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME,
bot_other_names="/".join(global_config.BOT_ALIAS_NAMES),
bot_name=global_config.bot.nickname,
bot_other_names="/".join(global_config.bot.alias_names),
prompt_personality=prompt_personality,
mood_prompt=mood_prompt,
reply_style1=reply_style1_chosen,

View File

@@ -26,8 +26,9 @@ class ChattingInfoProcessor(BaseProcessor):
def __init__(self):
"""初始化观察处理器"""
super().__init__()
# TODO: API-Adapter修改标记
self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def process_info(
@@ -108,12 +109,12 @@ class ChattingInfoProcessor(BaseProcessor):
"created_at": datetime.now().timestamp(),
}
obs.mid_memorys.append(mid_memory)
if len(obs.mid_memorys) > obs.max_mid_memory_len:
obs.mid_memorys.pop(0) # 移除最旧的
obs.mid_memories.append(mid_memory)
if len(obs.mid_memories) > obs.max_mid_memory_len:
obs.mid_memories.pop(0) # 移除最旧的
mid_memory_str = "之前聊天的内容概述是:\n"
for mid_memory_item in obs.mid_memorys: # 重命名循环变量以示区分
for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分
time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60)
mid_memory_str += (
f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}){mid_memory_item['theme']}\n"

View File

@@ -81,8 +81,8 @@ class MindProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
model=global_config.llm_sub_heartflow,
temperature=global_config.llm_sub_heartflow["temp"],
model=global_config.model.sub_heartflow,
temperature=global_config.model.sub_heartflow["temp"],
max_tokens=800,
request_type="sub_heart_flow",
)

View File

@@ -52,7 +52,7 @@ class ToolProcessor(BaseProcessor):
self.subheartflow_id = subheartflow_id
self.log_prefix = f"[{subheartflow_id}:ToolExecutor] "
self.llm_model = LLMRequest(
model=global_config.llm_tool_use,
model=global_config.model.tool_use,
max_tokens=500,
request_type="tool_execution",
)

View File

@@ -34,8 +34,9 @@ def init_prompt():
class MemoryActivator:
def __init__(self):
# TODO: API-Adapter修改标记
self.summary_model = LLMRequest(
model=global_config.llm_summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
model=global_config.model.summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
)
self.running_memory = []

View File

@@ -35,8 +35,9 @@ class Heartflow:
self.subheartflow_manager: SubHeartflowManager = SubHeartflowManager(self.current_state)
# LLM模型配置
# TODO: API-Adapter修改标记
self.llm_model = LLMRequest(
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
model=global_config.model.heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
)
# 外部依赖模块

View File

@@ -20,9 +20,9 @@ MAX_REPLY_PROBABILITY = 1
class InterestChatting:
def __init__(
self,
decay_rate=global_config.default_decay_rate_per_second,
decay_rate=global_config.focus_chat.default_decay_rate_per_second,
max_interest=MAX_INTEREST,
trigger_threshold=global_config.reply_trigger_threshold,
trigger_threshold=global_config.focus_chat.reply_trigger_threshold,
max_probability=MAX_REPLY_PROBABILITY,
):
# 基础属性初始化

View File

@@ -18,19 +18,14 @@ enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
prevent_offline_state = True
# 目前默认不启用OFFLINE状态
# 不同状态下普通聊天的最大消息数
base_normal_chat_num = global_config.base_normal_chat_num
base_focused_chat_num = global_config.base_focused_chat_num
MAX_NORMAL_CHAT_NUM_PEEKING = int(base_normal_chat_num / 2)
MAX_NORMAL_CHAT_NUM_NORMAL = base_normal_chat_num
MAX_NORMAL_CHAT_NUM_FOCUSED = base_normal_chat_num + 1
MAX_NORMAL_CHAT_NUM_PEEKING = int(global_config.chat.base_normal_chat_num / 2)
MAX_NORMAL_CHAT_NUM_NORMAL = global_config.chat.base_normal_chat_num
MAX_NORMAL_CHAT_NUM_FOCUSED = global_config.chat.base_normal_chat_num + 1
# 不同状态下专注聊天的最大消息数
MAX_FOCUSED_CHAT_NUM_PEEKING = int(base_focused_chat_num / 2)
MAX_FOCUSED_CHAT_NUM_NORMAL = base_focused_chat_num
MAX_FOCUSED_CHAT_NUM_FOCUSED = base_focused_chat_num + 2
MAX_FOCUSED_CHAT_NUM_PEEKING = int(global_config.chat.base_focused_chat_num / 2)
MAX_FOCUSED_CHAT_NUM_NORMAL = global_config.chat.base_focused_chat_num
MAX_FOCUSED_CHAT_NUM_FOCUSED = global_config.chat.base_focused_chat_num + 2
# -- 状态定义 --

View File

@@ -53,19 +53,20 @@ class ChattingObservation(Observation):
self.talking_message = []
self.talking_message_str = ""
self.talking_message_str_truncate = ""
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.max_now_obs_len = global_config.observation_context_size
self.overlap_len = global_config.compressed_length
self.mid_memorys = []
self.max_mid_memory_len = global_config.compress_length_limit
self.name = global_config.bot.nickname
self.nick_name = global_config.bot.alias_names
self.max_now_obs_len = global_config.chat.observation_context_size
self.overlap_len = global_config.focus_chat.compressed_length
self.mid_memories = []
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
self.mid_memory_info = ""
self.person_list = []
self.oldest_messages = []
self.oldest_messages_str = ""
self.compressor_prompt = ""
# TODO: API-Adapter修改标记
self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
model=global_config.model.observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def initialize(self):
@@ -83,7 +84,7 @@ class ChattingObservation(Observation):
for id in ids:
print(f"id{id}")
try:
for mid_memory in self.mid_memorys:
for mid_memory in self.mid_memories:
if mid_memory["id"] == id:
mid_memory_by_id = mid_memory
msg_str = ""
@@ -101,7 +102,7 @@ class ChattingObservation(Observation):
else:
mid_memory_str = "之前的聊天内容:\n"
for mid_memory in self.mid_memorys:
for mid_memory in self.mid_memories:
mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str

View File

@@ -76,8 +76,9 @@ class SubHeartflowManager:
# 为 LLM 状态评估创建一个 LLMRequest 实例
# 使用与 Heartflow 相同的模型和参数
# TODO: API-Adapter修改标记
self.llm_state_evaluator = LLMRequest(
model=global_config.llm_heartflow, # 与 Heartflow 一致
model=global_config.model.heartflow, # 与 Heartflow 一致
temperature=0.6, # 与 Heartflow 一致
max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多)
request_type="subheartflow_state_eval", # 保留特定的请求类型
@@ -278,7 +279,7 @@ class SubHeartflowManager:
focused_limit = current_state.get_focused_chat_max_num()
# --- 新增:检查是否允许进入 FOCUS 模式 --- #
if not global_config.allow_focus_mode:
if not global_config.chat.allow_focus_mode:
if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏
logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)")
return # 如果不允许,直接返回
@@ -766,7 +767,7 @@ class SubHeartflowManager:
focused_limit = current_mai_state.get_focused_chat_max_num()
# --- 检查是否允许 FOCUS 模式 --- #
if not global_config.allow_focus_mode:
if not global_config.chat.allow_focus_mode:
# Log less frequently to avoid spam
# if int(time.time()) % 60 == 0:
# logger.debug(f"{log_prefix_task} 配置不允许进入 FOCUSED 状态")

View File

@@ -19,9 +19,10 @@ from ..utils.chat_message_builder import (
build_readable_messages,
) # 导入 build_readable_messages
from ..utils.utils import translate_timestamp_to_human_readable
from .memory_config import MemoryConfig
from rich.traceback import install
from ...config.config import global_config
install(extra_lines=3)
@@ -195,18 +196,16 @@ class Hippocampus:
self.llm_summary = None
self.entorhinal_cortex = None
self.parahippocampal_gyrus = None
self.config = None
def initialize(self, global_config):
# 使用导入的 MemoryConfig dataclass 和其 from_global_config 方法
self.config = MemoryConfig.from_global_config(global_config)
def initialize(self):
# 初始化子组件
self.entorhinal_cortex = EntorhinalCortex(self)
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory")
self.llm_summary = LLMRequest(self.config.llm_summary, request_type="memory")
# TODO: API-Adapter修改标记
self.llm_topic_judge = LLMRequest(global_config.model.topic_judge, request_type="memory")
self.llm_summary = LLMRequest(global_config.model.summary, request_type="memory")
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
@@ -792,7 +791,6 @@ class EntorhinalCortex:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.config = hippocampus.config
def get_memory_sample(self):
"""从数据库获取记忆样本"""
@@ -801,13 +799,13 @@ class EntorhinalCortex:
# 创建双峰分布的记忆调度器
sample_scheduler = MemoryBuildScheduler(
n_hours1=self.config.memory_build_distribution[0],
std_hours1=self.config.memory_build_distribution[1],
weight1=self.config.memory_build_distribution[2],
n_hours2=self.config.memory_build_distribution[3],
std_hours2=self.config.memory_build_distribution[4],
weight2=self.config.memory_build_distribution[5],
total_samples=self.config.build_memory_sample_num,
n_hours1=global_config.memory.memory_build_distribution[0],
std_hours1=global_config.memory.memory_build_distribution[1],
weight1=global_config.memory.memory_build_distribution[2],
n_hours2=global_config.memory.memory_build_distribution[3],
std_hours2=global_config.memory.memory_build_distribution[4],
weight2=global_config.memory.memory_build_distribution[5],
total_samples=global_config.memory.memory_build_sample_num,
)
timestamps = sample_scheduler.get_timestamp_array()
@@ -818,7 +816,7 @@ class EntorhinalCortex:
for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet
messages = self.random_get_msg_snippet(
timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
)
if messages:
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
@@ -1099,7 +1097,6 @@ class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.config = hippocampus.config
async def memory_compress(self, messages: list, compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。
@@ -1159,7 +1156,7 @@ class ParahippocampalGyrus:
# 3. 过滤掉包含禁用关键词的topic
filtered_topics = [
topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words)
topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
]
logger.debug(f"过滤后话题: {filtered_topics}")
@@ -1222,7 +1219,7 @@ class ParahippocampalGyrus:
bar = "" * filled_length + "-" * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
compress_rate = self.config.memory_compress_rate
compress_rate = global_config.memory.memory_compress_rate
try:
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
except Exception as e:
@@ -1322,7 +1319,7 @@ class ParahippocampalGyrus:
edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get("last_modified")
if current_time - last_modified > 3600 * self.config.memory_forget_time:
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
current_strength = edge_data.get("strength", 1)
new_strength = current_strength - 1
@@ -1430,8 +1427,8 @@ class ParahippocampalGyrus:
async def operation_consolidate_memory(self):
"""整合记忆:合并节点内相似的记忆项"""
start_time = time.time()
percentage = self.config.consolidate_memory_percentage
similarity_threshold = self.config.consolidation_similarity_threshold
percentage = global_config.memory.consolidate_memory_percentage
similarity_threshold = global_config.memory.consolidation_similarity_threshold
logger.info(f"[整合] 开始检查记忆节点... 检查比例: {percentage:.2%}, 合并阈值: {similarity_threshold}")
# 获取所有至少有2条记忆项的节点
@@ -1544,7 +1541,6 @@ class ParahippocampalGyrus:
class HippocampusManager:
_instance = None
_hippocampus = None
_global_config = None
_initialized = False
@classmethod
@@ -1559,19 +1555,15 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return cls._hippocampus
def initialize(self, global_config):
def initialize(self):
"""初始化海马体实例"""
if self._initialized:
return self._hippocampus
self._global_config = global_config
self._hippocampus = Hippocampus()
self._hippocampus.initialize(global_config)
self._hippocampus.initialize()
self._initialized = True
# 输出记忆系统参数信息
config = self._hippocampus.config
# 输出记忆图统计信息
memory_graph = self._hippocampus.memory_graph.G
node_count = len(memory_graph.nodes())
@@ -1579,9 +1571,9 @@ class HippocampusManager:
logger.success(f"""--------------------------------
记忆系统参数配置:
构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate}
记忆构建分布: {config.memory_build_distribution}
遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后
构建间隔: {global_config.memory.memory_build_interval}秒|样本数: {global_config.memory.memory_build_sample_num},长度: {global_config.memory.memory_build_sample_length}|压缩率: {global_config.memory.memory_compress_rate}
记忆构建分布: {global_config.memory.memory_build_distribution}
遗忘间隔: {global_config.memory.forget_memory_interval}秒|遗忘比例: {global_config.memory.memory_forget_percentage}|遗忘: {global_config.memory.memory_forget_time}小时之后
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
--------------------------------""") # noqa: E501

View File

@@ -7,7 +7,6 @@ import os
# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.config.config import global_config
from rich.traceback import install
install(extra_lines=3)
@@ -19,7 +18,7 @@ async def test_memory_system():
# 初始化记忆系统
print("开始初始化记忆系统...")
hippocampus_manager = HippocampusManager.get_instance()
hippocampus_manager.initialize(global_config=global_config)
hippocampus_manager.initialize()
print("记忆系统初始化完成")
# 测试记忆构建

View File

@@ -1,48 +0,0 @@
from dataclasses import dataclass
from typing import List
@dataclass
class MemoryConfig:
"""记忆系统配置类"""
# 记忆构建相关配置
memory_build_distribution: List[float] # 记忆构建的时间分布参数
build_memory_sample_num: int # 每次构建记忆的样本数量
build_memory_sample_length: int # 每个样本的消息长度
memory_compress_rate: float # 记忆压缩率
# 记忆遗忘相关配置
memory_forget_time: int # 记忆遗忘时间(小时)
# 记忆过滤相关配置
memory_ban_words: List[str] # 记忆过滤词列表
# 新增:记忆整合相关配置
consolidation_similarity_threshold: float # 相似度阈值
consolidate_memory_percentage: float # 检查节点比例
consolidate_memory_interval: int # 记忆整合间隔
llm_topic_judge: str # 话题判断模型
llm_summary: str # 话题总结模型
@classmethod
def from_global_config(cls, global_config):
"""从全局配置创建记忆系统配置"""
# 使用 getattr 提供默认值,防止全局配置缺少这些项
return cls(
memory_build_distribution=getattr(
global_config, "memory_build_distribution", (24, 12, 0.5, 168, 72, 0.5)
), # 添加默认值
build_memory_sample_num=getattr(global_config, "build_memory_sample_num", 5),
build_memory_sample_length=getattr(global_config, "build_memory_sample_length", 30),
memory_compress_rate=getattr(global_config, "memory_compress_rate", 0.1),
memory_forget_time=getattr(global_config, "memory_forget_time", 24 * 7),
memory_ban_words=getattr(global_config, "memory_ban_words", []),
# 新增加载整合配置,并提供默认值
consolidation_similarity_threshold=getattr(global_config, "consolidation_similarity_threshold", 0.7),
consolidate_memory_percentage=getattr(global_config, "consolidate_memory_percentage", 0.01),
consolidate_memory_interval=getattr(global_config, "consolidate_memory_interval", 1000),
llm_topic_judge=getattr(global_config, "llm_topic_judge", "default_judge_model"), # 添加默认模型名
llm_summary=getattr(global_config, "llm_summary", "default_summary_model"), # 添加默认模型名
)

View File

@@ -41,7 +41,7 @@ class ChatBot:
chat_id = str(message.chat_stream.stream_id)
private_name = str(message.message_info.user_info.user_nickname)
if global_config.enable_pfc_chatting:
if global_config.experimental.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
except Exception as e:
@@ -78,19 +78,19 @@ class ChatBot:
userinfo = message.message_info.user_info
# 用户黑名单拦截
if userinfo.user_id in global_config.ban_user_id:
if userinfo.user_id in global_config.chat_target.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复")
return
if groupinfo is None:
logger.trace("检测到私聊消息,检查")
# 好友黑名单拦截
if userinfo.user_id not in global_config.talk_allowed_private:
if userinfo.user_id not in global_config.experimental.talk_allowed_private:
logger.debug(f"用户{userinfo.user_id}没有私聊权限")
return
# 群聊黑名单拦截
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups:
logger.trace(f"{groupinfo.group_id}被禁止回复")
return
@@ -112,7 +112,7 @@ class ChatBot:
if groupinfo is None:
logger.trace("检测到私聊消息")
# 是否在配置信息中开启私聊模式
if global_config.enable_friend_chat:
if global_config.experimental.enable_friend_chat:
logger.trace("私聊模式已启用")
# 是否进入PFC
if global_config.enable_pfc_chatting:

View File

@@ -38,7 +38,7 @@ class MessageBuffer:
async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲"""
if not global_config.message_buffer:
if not global_config.chat.message_buffer:
person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id
)
@@ -107,7 +107,7 @@ class MessageBuffer:
async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理"""
if not global_config.message_buffer:
if not global_config.chat.message_buffer:
return True
person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info

View File

@@ -279,7 +279,7 @@ class MessageManager:
)
# 检查是否超时
if thinking_time > global_config.thinking_timeout:
if thinking_time > global_config.normal_chat.thinking_timeout:
logger.warning(
f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}"
)

View File

@@ -111,8 +111,8 @@ class LLMRequest:
def __init__(self, model: dict, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = os.environ[model["key"]]
self.base_url = os.environ[model["base_url"]]
self.api_key = os.environ[f"{model['provider']}_KEY"]
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
@@ -500,11 +500,11 @@ class LLMRequest:
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
global_config.llm_normal["name"] = self.model_name
if global_config.model.normal.get("name") == old_model_name:
global_config.model.normal["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
if global_config.model.reasoning.get("name") == old_model_name:
global_config.model.reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
if payload and "model" in payload:
@@ -636,7 +636,7 @@ class LLMRequest:
**params_copy,
}
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
payload["max_tokens"] = global_config.model_max_output_length
payload["max_tokens"] = global_config.model.model_max_output_length
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")

View File

@@ -73,8 +73,8 @@ class NormalChat:
messageinfo = message.message_info
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=messageinfo.platform,
)
@@ -121,8 +121,8 @@ class NormalChat:
message_id=thinking_id,
chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=message.message_info.platform,
),
sender_info=message.message_info.user_info,
@@ -147,7 +147,7 @@ class NormalChat:
# 改为实例方法
async def _handle_emoji(self, message: MessageRecv, response: str):
"""处理表情包"""
if random() < global_config.emoji_chance:
if random() < global_config.normal_chat.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response)
if emoji_raw:
emoji_path, description = emoji_raw
@@ -160,8 +160,8 @@ class NormalChat:
message_id="mt" + str(thinking_time_point),
chat_stream=self.chat_stream, # 使用 self.chat_stream
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=message.message_info.platform,
),
sender_info=message.message_info.user_info,
@@ -186,7 +186,7 @@ class NormalChat:
label=emotion,
stance=stance, # 使用 self.chat_stream
)
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood.mood_intensity_factor)
async def _reply_interested_message(self) -> None:
"""
@@ -430,7 +430,7 @@ class NormalChat:
def _check_ban_words(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息中是否包含过滤词"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
for word in global_config.ban_words:
for word in global_config.chat.ban_words:
if word in text:
logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"
@@ -445,7 +445,7 @@ class NormalChat:
def _check_ban_regex(text: str, chat: ChatStream, userinfo: UserInfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
stream_name = chat_manager.get_stream_name(chat.stream_id) or chat.stream_id
for pattern in global_config.ban_msgs_regex:
for pattern in global_config.chat.ban_msgs_regex:
if pattern.search(text):
logger.info(
f"[{stream_name}][{chat.group_info.group_name if chat.group_info else '私聊'}]"

View File

@@ -15,21 +15,22 @@ logger = get_logger("llm")
class NormalChatGenerator:
def __init__(self):
# TODO: API-Adapter修改标记
self.model_reasoning = LLMRequest(
model=global_config.llm_reasoning,
model=global_config.model.reasoning,
temperature=0.7,
max_tokens=3000,
request_type="response_reasoning",
)
self.model_normal = LLMRequest(
model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"],
model=global_config.model.normal,
temperature=global_config.model.normal["temp"],
max_tokens=256,
request_type="response_reasoning",
)
self.model_sum = LLMRequest(
model=global_config.llm_summary, temperature=0.7, max_tokens=3000, request_type="relation"
model=global_config.model.summary, temperature=0.7, max_tokens=3000, request_type="relation"
)
self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model"
@@ -37,7 +38,7 @@ class NormalChatGenerator:
async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
if random.random() < global_config.model_reasoning_probability:
if random.random() < global_config.normal_chat.reasoning_model_probability:
self.current_model_type = "深深地"
current_model = self.model_reasoning
else:
@@ -51,7 +52,7 @@ class NormalChatGenerator:
model_response = await self._generate_response_with_model(message, current_model, thinking_id)
if model_response:
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
logger.info(f"{global_config.bot.nickname}的回复是:{model_response}")
model_response = await self._process_response(model_response)
return model_response
@@ -113,7 +114,7 @@ class NormalChatGenerator:
- "中立":不表达明确立场或无关回应
2. 从"开心,愤怒,悲伤,惊讶,平静,害羞,恐惧,厌恶,困惑"中选出最匹配的1个情感标签
3. 按照"立场-情绪"的格式直接输出结果,例如:"反对-愤怒"
4. 考虑回复者的人格设定为{global_config.personality_core}
4. 考虑回复者的人格设定为{global_config.personality.personality_core}
对话示例:
被回复「A就是笨」

View File

@@ -1,18 +1,20 @@
import asyncio
from src.config.config import global_config
from .willing_manager import BaseWillingManager
class ClassicalWillingManager(BaseWillingManager):
def __init__(self):
super().__init__()
self._decay_task: asyncio.Task = None
self._decay_task: asyncio.Task | None = None
async def _decay_reply_willing(self):
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(1)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.9)
self.chat_reply_willing[chat_id] = max(0.0, self.chat_reply_willing[chat_id] * 0.9)
async def async_task_starter(self):
if self._decay_task is None:
@@ -23,35 +25,33 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = willing_info.chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
interested_rate = willing_info.interested_rate * self.global_config.response_interested_rate_amplifier
interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
if interested_rate > 0.4:
current_willing += interested_rate - 0.3
if willing_info.is_mentioned_bot and current_willing < 1.0:
current_willing += 1
elif willing_info.is_mentioned_bot:
current_willing += 0.05
if willing_info.is_mentioned_bot:
current_willing += 1 if current_willing < 1.0 else 0.05
is_emoji_not_reply = False
if willing_info.is_emoji:
if self.global_config.emoji_response_penalty != 0:
current_willing *= self.global_config.emoji_response_penalty
if global_config.normal_chat.emoji_response_penalty != 0:
current_willing *= global_config.normal_chat.emoji_response_penalty
else:
is_emoji_not_reply = True
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(
max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1
max((current_willing - 0.5), 0.01) * global_config.normal_chat.response_willing_amplifier * 2, 1
)
# 检查群组权限(如果是群聊)
if (
willing_info.group_info
and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups
and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups
):
reply_probability = reply_probability / self.global_config.down_frequency_rate
reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate
if is_emoji_not_reply:
reply_probability = 0
@@ -61,7 +61,7 @@ class ClassicalWillingManager(BaseWillingManager):
async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
self.chat_reply_willing[chat_id] = max(0, current_willing - 1.8)
self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8)
async def after_generate_reply_handle(self, message_id):
if message_id not in self.ongoing_messages:
@@ -70,7 +70,7 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
if current_willing < 1:
self.chat_reply_willing[chat_id] = min(1, current_willing + 0.4)
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4)
async def bombing_buffer_message_handle(self, message_id):
return await super().bombing_buffer_message_handle(message_id)

View File

@@ -19,6 +19,7 @@ Mxp 模式:梦溪畔独家赞助
下下策是询问一个菜鸟(@梦溪畔)
"""
from src.config.config import global_config
from .willing_manager import BaseWillingManager
from typing import Dict
import asyncio
@@ -50,8 +51,6 @@ class MxpWillingManager(BaseWillingManager):
self.mention_willing_gain = 0.6 # 提及意愿增益
self.interest_willing_gain = 0.3 # 兴趣意愿增益
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
self.single_chat_gain = 0.12 # 单聊增益
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
@@ -179,10 +178,10 @@ class MxpWillingManager(BaseWillingManager):
probability = self._willing_to_probability(current_willing)
if w_info.is_emoji:
probability *= self.emoji_response_penalty
probability *= global_config.normal_chat.emoji_response_penalty
if w_info.group_info and w_info.group_info.group_id in self.global_config.talk_frequency_down_groups:
probability /= self.down_frequency_rate
if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups:
probability /= global_config.normal_chat.down_frequency_rate
self.temporary_willing = current_willing

View File

@@ -1,6 +1,6 @@
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
from dataclasses import dataclass
from src.config.config import global_config, BotConfig
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
@@ -93,7 +93,6 @@ class BaseWillingManager(ABC):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
self.lock = asyncio.Lock()
self.global_config: BotConfig = global_config
self.logger: LoguruLogger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
@@ -173,7 +172,7 @@ def init_willing_manager() -> BaseWillingManager:
Returns:
对应mode的WillingManager实例
"""
mode = global_config.willing_mode.lower()
mode = global_config.normal_chat.willing_mode.lower()
return BaseWillingManager.create(mode)

View File

@@ -59,8 +59,9 @@ person_info_default = {
class PersonInfoManager:
def __init__(self):
self.person_name_list = {}
# TODO: API-Adapter修改标记
self.qv_name_llm = LLMRequest(
model=global_config.llm_normal,
model=global_config.model.normal,
max_tokens=256,
request_type="qv_name",
)

View File

@@ -190,8 +190,8 @@ async def _build_readable_messages_internal(
person_id = person_info_manager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
if replace_bot_name and user_id == global_config.BOT_QQ:
person_name = f"{global_config.BOT_NICKNAME}(你)"
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -427,7 +427,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
output_lines = []
def get_anon_name(platform, user_id):
if user_id == global_config.BOT_QQ:
if user_id == global_config.bot.qq_account:
return "SELF"
person_id = person_info_manager.get_person_id(platform, user_id)
if person_id not in person_map:
@@ -501,7 +501,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
user_id = user_info.get("user_id")
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.BOT_QQ:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
person_id = person_info_manager.get_person_id(platform, user_id)

View File

@@ -9,7 +9,6 @@ from typing import List
class InfoCatcher:
def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文喵~
self.context_length = global_config.observation_context_size
self.chat_history_in_thinking = [] # 思考期间的聊天内容喵~
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文喵~
@@ -143,7 +142,7 @@ class InfoCatcher:
messages_before = (
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
.sort("time", -1)
.limit(self.context_length * 3)
.limit(global_config.chat.observation_context_size * 3)
) # 获取更多历史信息
return list(messages_before)

View File

@@ -43,8 +43,8 @@ def db_message_to_str(message_dict: dict) -> str:
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
@@ -64,18 +64,18 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
)
# 判断是否被@
if re.search(f"@[\s\S]*?id:{global_config.BOT_QQ}", message.processed_plain_text):
if re.search(f"@[\s\S]*?id:{global_config.bot.qq_account}", message.processed_plain_text):
is_at = True
is_mentioned = True
if is_at and global_config.at_bot_inevitable_reply:
if is_at and global_config.normal_chat.at_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\)[\s\S]*?],说:", message.processed_plain_text
f"\[回复 [\s\S]*?\({str(global_config.bot.qq_account)}\)[\s\S]*?],说:", message.processed_plain_text
):
is_mentioned = True
else:
@@ -88,7 +88,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
for nickname in nicknames:
if nickname in message_content:
is_mentioned = True
if is_mentioned and global_config.mentioned_bot_inevitable_reply:
if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
reply_probability = 1.0
logger.info("被提及回复概率设置为100%")
return is_mentioned, reply_probability
@@ -96,7 +96,8 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding"):
"""获取文本的embedding向量"""
llm = LLMRequest(model=global_config.embedding, request_type=request_type)
# TODO: API-Adapter修改标记
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
try:
embedding = await llm.get_embedding(text)
@@ -163,7 +164,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
user_info = UserInfo.from_dict(msg_db_data["user_info"])
if (
(user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.BOT_QQ
and user_info.user_id != global_config.bot.qq_account
and (user_info.platform, user_info.user_id, user_info.user_nickname) not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
@@ -321,7 +322,7 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> list[str]:
# 先保护颜文字
if global_config.enable_kaomoji_protection:
if global_config.response_splitter.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text)
logger.trace(f"保护颜文字后的文本: {protected_text}")
else:
@@ -340,8 +341,8 @@ def process_llm_response(text: str) -> list[str]:
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
# 对清理后的文本进行进一步处理
max_length = global_config.response_max_length * 2
max_sentence_num = global_config.response_max_sentence_num
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
@@ -349,20 +350,20 @@ def process_llm_response(text: str) -> list[str]:
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo_error_rate,
min_freq=global_config.chinese_typo_min_freq,
tone_error_rate=global_config.chinese_typo_tone_error_rate,
word_replace_rate=global_config.chinese_typo_word_replace_rate,
error_rate=global_config.chinese_typo.error_rate,
min_freq=global_config.chinese_typo.min_freq,
tone_error_rate=global_config.chinese_typo.tone_error_rate,
word_replace_rate=global_config.chinese_typo.word_replace_rate,
)
if global_config.enable_response_splitter:
if global_config.response_splitter.enable:
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else:
split_sentences = [cleaned_text]
sentences = []
for sentence in split_sentences:
if global_config.chinese_typo_enable:
if global_config.chinese_typo.enable:
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
sentences.append(typoed_text)
if typo_corrections:
@@ -372,7 +373,7 @@ def process_llm_response(text: str) -> list[str]:
if len(sentences) > max_sentence_num:
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f"{global_config.BOT_NICKNAME}不知道哦"]
return [f"{global_config.bot.nickname}不知道哦"]
# if extracted_contents:
# for content in extracted_contents:

View File

@@ -36,7 +36,7 @@ class ImageManager:
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
@@ -134,7 +134,7 @@ class ImageManager:
return f"[表情包,含义看起来是:{cached_description}]"
# 根据配置决定是否保存图片
if global_config.save_emoji:
if global_config.emoji.save_emoji:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
@@ -200,7 +200,7 @@ class ImageManager:
return "[图片]"
# 根据配置决定是否保存图片
if global_config.save_pic:
if global_config.emoji.save_pic:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"