diff --git a/bot.py b/bot.py index b8f154cd3..5342be7ce 100644 --- a/bot.py +++ b/bot.py @@ -20,11 +20,13 @@ from rich.traceback import install # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 from src.common.logger import initialize_logging, get_logger, shutdown_logging -from src.main import MainSystem -from src.manager.async_task_manager import async_task_manager - initialize_logging() +from src.main import MainSystem #noqa +from src.manager.async_task_manager import async_task_manager #noqa + + + logger = get_logger("main") diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index bf3110336..7ef3894ad 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -16,6 +16,7 @@ from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager from src.chat.chat_loop.hfc_utils import CycleDetail from src.person_info.relationship_builder_manager import relationship_builder_manager +from src.chat.express.expression_learner import expression_learner_manager from src.person_info.person_info import get_person_info_manager from src.plugin_system.base.component_types import ActionInfo, ChatMode, EventType from src.plugin_system.core import events_manager @@ -87,6 +88,7 @@ class HeartFChatting: self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id) + self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id) self.loop_mode = ChatMode.NORMAL # 初始循环模式为普通模式 @@ -325,6 +327,7 @@ class HeartFChatting: async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): loop_start_time = time.time() await self.relationship_builder.build_relation() + await self.expression_learner.trigger_learning_for_chat() available_actions = {} diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index a98085038..383279c70 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -9,7 +9,7 @@ from typing import List, Dict, Optional, Any, Tuple from src.common.logger import get_logger from src.common.database.database_model import Expression from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config +from src.config.config import model_config, global_config from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager @@ -79,15 +79,410 @@ def init_prompt() -> None: class ExpressionLearner: - def __init__(self) -> None: + def __init__(self, chat_id: str) -> None: self.express_learn_model: LLMRequest = LLMRequest( model_set=model_config.model_task_config.replyer_1, request_type="expressor.learner" ) - self.llm_model = None + self.chat_id = chat_id + self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id + + + # 维护每个chat的上次学习时间 + self.last_learning_time: float = time.time() + + # 学习参数 + self.min_messages_for_learning = 25 # 触发学习所需的最少消息数 + self.min_learning_interval = 300 # 最短学习时间间隔(秒) + + + + + def can_learn_for_chat(self) -> bool: + """ + 检查指定聊天流是否允许学习表达 + + Args: + chat_id: 聊天流ID + + Returns: + bool: 是否允许学习 + """ + try: + use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id) + return enable_learning + except Exception as e: + logger.error(f"检查学习权限失败: {e}") + return False + + def should_trigger_learning(self) -> bool: + """ + 检查是否应该触发学习 + + Args: + chat_id: 聊天流ID + + Returns: + bool: 是否应该触发学习 + """ + current_time = time.time() + + # 获取该聊天流的学习强度 + try: + use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id) + except Exception as e: + logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}") + return False + + # 检查是否允许学习 + if not enable_learning: + return False + + # 根据学习强度计算最短学习时间间隔 + min_interval = self.min_learning_interval / learning_intensity + + # 检查时间间隔 + time_diff = current_time - self.last_learning_time + if time_diff < min_interval: + return False + + # 检查消息数量(只检查指定聊天流的消息) + recent_messages = get_raw_msg_by_timestamp_random( + self.last_learning_time, current_time, limit=self.min_messages_for_learning + 1, chat_id=self.chat_id + ) + + if not recent_messages or len(recent_messages) < self.min_messages_for_learning: + return False + + return True + + async def trigger_learning_for_chat(self) -> bool: + """ + 为指定聊天流触发学习 + + Args: + chat_id: 聊天流ID + + Returns: + bool: 是否成功触发学习 + """ + if not self.should_trigger_learning(): + return False + + try: + logger.info(f"为聊天流 {self.chat_name} 触发表达学习") + + # 学习语言风格 + learnt_style = await self.learn_and_store(type="style", num=25) + + # 学习句法特点 + learnt_grammar = await self.learn_and_store(type="grammar", num=10) + + # 更新学习时间 + self.last_learning_time = time.time() + + if learnt_style or learnt_grammar: + logger.info(f"聊天流 {self.chat_name} 表达学习完成") + return True + else: + logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果") + return False + + except Exception as e: + logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") + return False + + def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: + """ + 获取指定chat_id的style和grammar表达方式 + 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 + """ + learnt_style_expressions = [] + learnt_grammar_expressions = [] + + # 直接从数据库查询 + style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style")) + for expr in style_query: + # 确保create_date存在,如果不存在则使用last_active_time + create_date = expr.create_date if expr.create_date is not None else expr.last_active_time + learnt_style_expressions.append( + { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": self.chat_id, + "type": "style", + "create_date": create_date, + } + ) + grammar_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")) + for expr in grammar_query: + # 确保create_date存在,如果不存在则使用last_active_time + create_date = expr.create_date if expr.create_date is not None else expr.last_active_time + learnt_grammar_expressions.append( + { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": self.chat_id, + "type": "grammar", + "create_date": create_date, + } + ) + return learnt_style_expressions, learnt_grammar_expressions + + + + + + + + def _apply_global_decay_to_database(self, current_time: float) -> None: + """ + 对数据库中的所有表达方式应用全局衰减 + """ + try: + # 获取所有表达方式 + all_expressions = Expression.select() + + updated_count = 0 + deleted_count = 0 + + for expr in all_expressions: + # 计算时间差 + last_active = expr.last_active_time + time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 + + # 计算衰减值 + decay_value = self.calculate_decay_factor(time_diff_days) + new_count = max(0.01, expr.count - decay_value) + + if new_count <= 0.01: + # 如果count太小,删除这个表达方式 + expr.delete_instance() + deleted_count += 1 + else: + # 更新count + expr.count = new_count + expr.save() + updated_count += 1 + + if updated_count > 0 or deleted_count > 0: + logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") + + except Exception as e: + logger.error(f"数据库全局衰减失败: {e}") + + def calculate_decay_factor(self, time_diff_days: float) -> float: + """ + 计算衰减值 + 当时间差为0天时,衰减值为0(最近活跃的不衰减) + 当时间差为7天时,衰减值为0.002(中等衰减) + 当时间差为30天或更长时,衰减值为0.01(高衰减) + 使用二次函数进行曲线插值 + """ + if time_diff_days <= 0: + return 0.0 # 刚激活的表达式不衰减 + + if time_diff_days >= DECAY_DAYS: + return 0.01 # 长时间未活跃的表达式大幅衰减 + + # 使用二次函数插值:在0-30天之间从0衰减到0.01 + # 使用简单的二次函数:y = a * x^2 + # 当x=30时,y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900 + a = 0.01 / (DECAY_DAYS**2) + decay = a * (time_diff_days**2) + + return min(0.01, decay) + + async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: + # sourcery skip: use-join + """ + 学习并存储表达方式 + type: "style" or "grammar" + """ + if type == "style": + type_str = "语言风格" + elif type == "grammar": + type_str = "句法特点" + else: + raise ValueError(f"Invalid type: {type}") + + # 检查是否允许在此聊天流中学习(在函数最前面检查) + if not self.can_learn_for_chat(): + logger.debug(f"聊天流 {self.chat_name} 不允许学习表达,跳过学习") + return [] + + res = await self.learn_expression(type, num) + + if res is None: + return [] + learnt_expressions, chat_id = res + + chat_stream = get_chat_manager().get_stream(chat_id) + if chat_stream is None: + group_name = f"聊天流 {chat_id}" + elif chat_stream.group_info: + group_name = chat_stream.group_info.group_name + else: + group_name = f"{chat_stream.user_info.user_nickname}的私聊" + learnt_expressions_str = "" + for _chat_id, situation, style in learnt_expressions: + learnt_expressions_str += f"{situation}->{style}\n" + logger.info(f"在 {group_name} 学习到{type_str}:\n{learnt_expressions_str}") + + if not learnt_expressions: + logger.info(f"没有学习到{type_str}") + return [] + + # 按chat_id分组 + chat_dict: Dict[str, List[Dict[str, Any]]] = {} + for chat_id, situation, style in learnt_expressions: + if chat_id not in chat_dict: + chat_dict[chat_id] = [] + chat_dict[chat_id].append({"situation": situation, "style": style}) + + current_time = time.time() + + # 存储到数据库 Expression 表 + for chat_id, expr_list in chat_dict.items(): + for new_expr in expr_list: + # 查找是否已存在相似表达方式 + query = Expression.select().where( + (Expression.chat_id == chat_id) + & (Expression.type == type) + & (Expression.situation == new_expr["situation"]) + & (Expression.style == new_expr["style"]) + ) + if query.exists(): + expr_obj = query.get() + # 50%概率替换内容 + if random.random() < 0.5: + expr_obj.situation = new_expr["situation"] + expr_obj.style = new_expr["style"] + expr_obj.count = expr_obj.count + 1 + expr_obj.last_active_time = current_time + expr_obj.save() + else: + Expression.create( + situation=new_expr["situation"], + style=new_expr["style"], + count=1, + last_active_time=current_time, + chat_id=chat_id, + type=type, + create_date=current_time, # 手动设置创建日期 + ) + # 限制最大数量 + exprs = list( + Expression.select() + .where((Expression.chat_id == chat_id) & (Expression.type == type)) + .order_by(Expression.count.asc()) + ) + if len(exprs) > MAX_EXPRESSION_COUNT: + # 删除count最小的多余表达方式 + for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: + expr.delete_instance() + return learnt_expressions + + async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: + """从指定聊天流学习表达方式 + + Args: + type: "style" or "grammar" + """ + if type == "style": + type_str = "语言风格" + prompt = "learn_style_prompt" + elif type == "grammar": + type_str = "句法特点" + prompt = "learn_grammar_prompt" + else: + raise ValueError(f"Invalid type: {type}") + + current_time = time.time() + + # 获取上次学习时间 + last_time = self.last_learning_time.get(self.chat_id, current_time - 3600 * 24) + random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_random( + last_time, current_time, limit=num, chat_id=self.chat_id + ) + + # print(random_msg) + if not random_msg or random_msg == []: + return None + # 转化成str + chat_id: str = random_msg[0]["chat_id"] + # random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal") + random_msg_str: str = await build_anonymous_messages(random_msg) + # print(f"random_msg_str:{random_msg_str}") + + prompt: str = await global_prompt_manager.format_prompt( + prompt, + chat_str=random_msg_str, + ) + + logger.debug(f"学习{type_str}的prompt: {prompt}") + + try: + response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) + except Exception as e: + logger.error(f"学习{type_str}失败: {e}") + return None + + logger.debug(f"学习{type_str}的response: {response}") + + expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id) + + return expressions, chat_id + + def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]: + """ + 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 + """ + expressions: List[Tuple[str, str, str]] = [] + for line in response.splitlines(): + line = line.strip() + if not line: + continue + # 查找"当"和下一个引号 + idx_when = line.find('当"') + if idx_when == -1: + continue + idx_quote1 = idx_when + 1 + idx_quote2 = line.find('"', idx_quote1 + 1) + if idx_quote2 == -1: + continue + situation = line[idx_quote1 + 1 : idx_quote2] + # 查找"使用" + idx_use = line.find('使用"', idx_quote2) + if idx_use == -1: + continue + idx_quote3 = idx_use + 2 + idx_quote4 = line.find('"', idx_quote3 + 1) + if idx_quote4 == -1: + continue + style = line[idx_quote3 + 1 : idx_quote4] + expressions.append((chat_id, situation, style)) + return expressions + + +init_prompt() + +class ExpressionLearnerManager: + def __init__(self): + self.expression_learners = {} + self._ensure_expression_directories() self._auto_migrate_json_to_db() self._migrate_old_data_create_date() - + + def get_expression_learner(self, chat_id: str) -> ExpressionLearner: + if chat_id not in self.expression_learners: + self.expression_learners[chat_id] = ExpressionLearner(chat_id) + return self.expression_learners[chat_id] + def _ensure_expression_directories(self): """ 确保表达方式相关的目录结构存在 @@ -106,6 +501,7 @@ class ExpressionLearner: except Exception as e: logger.error(f"创建目录失败 {directory}: {e}") + def _auto_migrate_json_to_db(self): """ 自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。 @@ -238,346 +634,5 @@ class ExpressionLearner: except Exception as e: logger.error(f"迁移老数据创建日期失败: {e}") - def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]: - """ - 获取指定chat_id的style和grammar表达方式 - 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作 - """ - learnt_style_expressions = [] - learnt_grammar_expressions = [] - # 直接从数据库查询 - style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style")) - for expr in style_query: - # 确保create_date存在,如果不存在则使用last_active_time - create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - learnt_style_expressions.append( - { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": chat_id, - "type": "style", - "create_date": create_date, - } - ) - grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar")) - for expr in grammar_query: - # 确保create_date存在,如果不存在则使用last_active_time - create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - learnt_grammar_expressions.append( - { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": chat_id, - "type": "grammar", - "create_date": create_date, - } - ) - return learnt_style_expressions, learnt_grammar_expressions - - def get_expression_create_info(self, chat_id: str, limit: int = 10) -> List[Dict[str, Any]]: - """ - 获取指定chat_id的表达方式创建信息,按创建日期排序 - """ - try: - expressions = ( - Expression.select() - .where(Expression.chat_id == chat_id) - .order_by(Expression.create_date.desc()) - .limit(limit) - ) - - result = [] - for expr in expressions: - create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - result.append( - { - "situation": expr.situation, - "style": expr.style, - "type": expr.type, - "count": expr.count, - "create_date": create_date, - "create_date_formatted": format_create_date(create_date), - "last_active_time": expr.last_active_time, - "last_active_formatted": format_create_date(expr.last_active_time), - } - ) - - return result - except Exception as e: - logger.error(f"获取表达方式创建信息失败: {e}") - return [] - - def is_similar(self, s1: str, s2: str) -> bool: - """ - 判断两个字符串是否相似(只考虑长度大于5且有80%以上重合,不考虑子串) - """ - if not s1 or not s2: - return False - min_len = min(len(s1), len(s2)) - if min_len < 5: - return False - same = sum(a == b for a, b in zip(s1, s2, strict=False)) - return same / min_len > 0.8 - - async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]: - """ - 学习并存储表达方式,分别学习语言风格和句法特点 - 同时对所有已存储的表达方式进行全局衰减 - """ - current_time = time.time() - - # 全局衰减所有已存储的表达方式(直接操作数据库) - self._apply_global_decay_to_database(current_time) - - learnt_style: Optional[List[Tuple[str, str, str]]] = [] - learnt_grammar: Optional[List[Tuple[str, str, str]]] = [] - # 学习新的表达方式(这里会进行局部衰减) - for _ in range(3): - learnt_style = await self.learn_and_store(type="style", num=25) - if not learnt_style: - return [], [] - - for _ in range(1): - learnt_grammar = await self.learn_and_store(type="grammar", num=10) - if not learnt_grammar: - return [], [] - - return learnt_style, learnt_grammar - - def _apply_global_decay_to_database(self, current_time: float) -> None: - """ - 对数据库中的所有表达方式应用全局衰减 - """ - try: - # 获取所有表达方式 - all_expressions = Expression.select() - - updated_count = 0 - deleted_count = 0 - - for expr in all_expressions: - # 计算时间差 - last_active = expr.last_active_time - time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - - # 计算衰减值 - decay_value = self.calculate_decay_factor(time_diff_days) - new_count = max(0.01, expr.count - decay_value) - - if new_count <= 0.01: - # 如果count太小,删除这个表达方式 - expr.delete_instance() - deleted_count += 1 - else: - # 更新count - expr.count = new_count - expr.save() - updated_count += 1 - - if updated_count > 0 or deleted_count > 0: - logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") - - except Exception as e: - logger.error(f"数据库全局衰减失败: {e}") - - def calculate_decay_factor(self, time_diff_days: float) -> float: - """ - 计算衰减值 - 当时间差为0天时,衰减值为0(最近活跃的不衰减) - 当时间差为7天时,衰减值为0.002(中等衰减) - 当时间差为30天或更长时,衰减值为0.01(高衰减) - 使用二次函数进行曲线插值 - """ - if time_diff_days <= 0: - return 0.0 # 刚激活的表达式不衰减 - - if time_diff_days >= DECAY_DAYS: - return 0.01 # 长时间未活跃的表达式大幅衰减 - - # 使用二次函数插值:在0-30天之间从0衰减到0.01 - # 使用简单的二次函数:y = a * x^2 - # 当x=30时,y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900 - a = 0.01 / (DECAY_DAYS**2) - decay = a * (time_diff_days**2) - - return min(0.01, decay) - - async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: - # sourcery skip: use-join - """ - 选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式 - type: "style" or "grammar" - """ - if type == "style": - type_str = "语言风格" - elif type == "grammar": - type_str = "句法特点" - else: - raise ValueError(f"Invalid type: {type}") - - res = await self.learn_expression(type, num) - - if res is None: - return [] - learnt_expressions, chat_id = res - - chat_stream = get_chat_manager().get_stream(chat_id) - if chat_stream is None: - group_name = f"聊天流 {chat_id}" - elif chat_stream.group_info: - group_name = chat_stream.group_info.group_name - else: - group_name = f"{chat_stream.user_info.user_nickname}的私聊" - learnt_expressions_str = "" - for _chat_id, situation, style in learnt_expressions: - learnt_expressions_str += f"{situation}->{style}\n" - logger.info(f"在 {group_name} 学习到{type_str}:\n{learnt_expressions_str}") - - if not learnt_expressions: - logger.info(f"没有学习到{type_str}") - return [] - - # 按chat_id分组 - chat_dict: Dict[str, List[Dict[str, Any]]] = {} - for chat_id, situation, style in learnt_expressions: - if chat_id not in chat_dict: - chat_dict[chat_id] = [] - chat_dict[chat_id].append({"situation": situation, "style": style}) - - current_time = time.time() - - # 存储到数据库 Expression 表 - for chat_id, expr_list in chat_dict.items(): - for new_expr in expr_list: - # 查找是否已存在相似表达方式 - query = Expression.select().where( - (Expression.chat_id == chat_id) - & (Expression.type == type) - & (Expression.situation == new_expr["situation"]) - & (Expression.style == new_expr["style"]) - ) - if query.exists(): - expr_obj = query.get() - # 50%概率替换内容 - if random.random() < 0.5: - expr_obj.situation = new_expr["situation"] - expr_obj.style = new_expr["style"] - expr_obj.count = expr_obj.count + 1 - expr_obj.last_active_time = current_time - expr_obj.save() - else: - Expression.create( - situation=new_expr["situation"], - style=new_expr["style"], - count=1, - last_active_time=current_time, - chat_id=chat_id, - type=type, - create_date=current_time, # 手动设置创建日期 - ) - # 限制最大数量 - exprs = list( - Expression.select() - .where((Expression.chat_id == chat_id) & (Expression.type == type)) - .order_by(Expression.count.asc()) - ) - if len(exprs) > MAX_EXPRESSION_COUNT: - # 删除count最小的多余表达方式 - for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: - expr.delete_instance() - return learnt_expressions - - async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: - """选择从当前到最近1小时内的随机num条消息,然后学习这些消息的表达方式 - - Args: - type: "style" or "grammar" - """ - if type == "style": - type_str = "语言风格" - prompt = "learn_style_prompt" - elif type == "grammar": - type_str = "句法特点" - prompt = "learn_grammar_prompt" - else: - raise ValueError(f"Invalid type: {type}") - - current_time = time.time() - random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_random( - current_time - 3600 * 24, current_time, limit=num - ) - # print(random_msg) - if not random_msg or random_msg == []: - return None - # 转化成str - chat_id: str = random_msg[0]["chat_id"] - # random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal") - random_msg_str: str = await build_anonymous_messages(random_msg) - # print(f"random_msg_str:{random_msg_str}") - - prompt: str = await global_prompt_manager.format_prompt( - prompt, - chat_str=random_msg_str, - ) - - logger.debug(f"学习{type_str}的prompt: {prompt}") - - try: - response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) - except Exception as e: - logger.error(f"学习{type_str}失败: {e}") - return None - - logger.debug(f"学习{type_str}的response: {response}") - - expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id) - - return expressions, chat_id - - def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]: - """ - 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 - """ - expressions: List[Tuple[str, str, str]] = [] - for line in response.splitlines(): - line = line.strip() - if not line: - continue - # 查找"当"和下一个引号 - idx_when = line.find('当"') - if idx_when == -1: - continue - idx_quote1 = idx_when + 1 - idx_quote2 = line.find('"', idx_quote1 + 1) - if idx_quote2 == -1: - continue - situation = line[idx_quote1 + 1 : idx_quote2] - # 查找"使用" - idx_use = line.find('使用"', idx_quote2) - if idx_use == -1: - continue - idx_quote3 = idx_use + 2 - idx_quote4 = line.find('"', idx_quote3 + 1) - if idx_quote4 == -1: - continue - style = line[idx_quote3 + 1 : idx_quote4] - expressions.append((chat_id, situation, style)) - return expressions - - -init_prompt() - - -expression_learner = None - - -def get_expression_learner(): - global expression_learner - if expression_learner is None: - expression_learner = ExpressionLearner() - return expression_learner +expression_learner_manager = ExpressionLearnerManager() diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index d623ba876..652c3aa67 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -11,7 +11,6 @@ from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.database_model import Expression from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from .expression_learner import get_expression_learner logger = get_logger("expression_selector") @@ -71,11 +70,27 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis class ExpressionSelector: def __init__(self): - self.expression_learner = get_expression_learner() self.llm_model = LLMRequest( model_set=model_config.model_task_config.utils_small, request_type="expression.selector" ) + def can_use_expression_for_chat(self, chat_id: str) -> bool: + """ + 检查指定聊天流是否允许使用表达 + + Args: + chat_id: 聊天流ID + + Returns: + bool: 是否允许使用表达 + """ + try: + use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id) + return use_expression + except Exception as e: + logger.error(f"检查表达使用权限失败: {e}") + return False + @staticmethod def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: """解析'platform:id:type'为chat_id(与get_stream_id一致)""" @@ -208,6 +223,11 @@ class ExpressionSelector: ) -> List[Dict[str, Any]]: # sourcery skip: inline-variable, list-comprehension """使用LLM选择适合的表达方式""" + + # 检查是否允许在此聊天流中使用表达 + if not self.can_use_expression_for_chat(chat_id): + logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") + return [] # 1. 获取35个随机表达方式(现在按权重抽取) style_exprs, grammar_exprs = self.get_random_expressions(chat_id, 30, 0.5, 0.5) @@ -305,6 +325,7 @@ class ExpressionSelector: except Exception as e: logger.error(f"LLM处理表达方式选择时出错: {e}") return [] + init_prompt() diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 406d0e6d0..934cc327a 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -42,25 +42,25 @@ async def _process_relationship(message: MessageRecv) -> None: await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore -async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: +async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[str]]: """计算消息的兴趣度 Args: message: 待处理的消息对象 Returns: - Tuple[float, bool]: (兴趣度, 是否被提及) + Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词) """ is_mentioned, _ = is_mentioned_bot_in_message(message) interested_rate = 0.0 with Timer("记忆激活"): - interested_rate = await hippocampus_manager.get_activate_from_text( + interested_rate, keywords = await hippocampus_manager.get_activate_from_text( message.processed_plain_text, max_depth= 5, fast_retrieval=False, ) - logger.debug(f"记忆激活率: {interested_rate:.2f}") + logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}") text_len = len(message.processed_plain_text) # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 @@ -99,7 +99,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: interest_increase_on_mention = 1 interested_rate += interest_increase_on_mention - return interested_rate, is_mentioned + return interested_rate, is_mentioned, keywords class HeartFCMessageReceiver: @@ -128,7 +128,7 @@ class HeartFCMessageReceiver: chat = message.chat_stream # 2. 兴趣度计算与更新 - interested_rate, is_mentioned = await _calculate_interest(message) + interested_rate, is_mentioned, keywords = await _calculate_interest(message) message.interest_value = interested_rate message.is_mentioned = is_mentioned @@ -157,7 +157,10 @@ class HeartFCMessageReceiver: replace_bot_name=True ) - logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]") # type: ignore + if keywords: + logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}][关键词:{keywords}]") # type: ignore + else: + logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]") # type: ignore logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]") diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 9e4005b97..d56686927 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -327,7 +327,7 @@ class Hippocampus: keywords = [word for word in words if len(word) > 1] keywords = list(set(keywords))[:3] # 限制最多3个关键词 if keywords: - logger.info(f"提取关键词: {keywords}") + logger.debug(f"提取关键词: {keywords}") return keywords elif text_length <= 10: topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本) @@ -354,7 +354,7 @@ class Hippocampus: ] if keywords: - logger.info(f"提取关键词: {keywords}") + logger.debug(f"提取关键词: {keywords}") return keywords @@ -391,7 +391,7 @@ class Hippocampus: logger.debug("没有找到有效的关键词节点") return [] - logger.debug(f"有效的关键词: {', '.join(valid_keywords)}") + logger.info(f"有效的关键词: {', '.join(valid_keywords)}") # 从每个关键词获取记忆 activate_map = {} # 存储每个词的累计激活值 @@ -692,7 +692,7 @@ class Hippocampus: return result - async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float: + async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]: """从文本中提取关键词并获取相关记忆。 Args: @@ -704,6 +704,7 @@ class Hippocampus: Returns: float: 激活节点数与总节点数的比值 + list[str]: 有效的关键词 """ keywords = await self.get_keywords_from_text(text) @@ -711,7 +712,7 @@ class Hippocampus: valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] if not valid_keywords: # logger.info("没有找到有效的关键词节点") - return 0 + return 0, [] logger.debug(f"有效的关键词: {', '.join(valid_keywords)}") @@ -778,7 +779,7 @@ class Hippocampus: activation_ratio = activation_ratio * 60 logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}") - return activation_ratio + return activation_ratio, keywords # 负责海马体与其他部分的交互 @@ -1738,16 +1739,16 @@ class HippocampusManager: response = [] return response - async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float: + async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]: """从文本中获取激活值的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") try: - response = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval) + response, keywords = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval) except Exception as e: logger.error(f"文本产生激活值失败: {e}") response = 0.0 - return response + return response, keywords def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list: """从关键词获取相关记忆的公共接口""" diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index c2b6e1cb9..9ae9e5814 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -55,7 +55,7 @@ def init_prompt(): 对这句话,你想表达,原句:{raw_reply},原因是:{reason}。你现在要思考怎么组织回复 你现在的心情是:{mood_state} 你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯 -{config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。 +{reply_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。 {keywords_reaction_prompt} {moderation_prompt} 不要浮夸,不要夸张修辞,平淡且不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 ),只输出一条回复就好。 @@ -91,7 +91,7 @@ def init_prompt(): 你现在的心情是:{mood_state} -{config_expression_style} +{reply_style} 注意不要复读你说过的话 {keywords_reaction_prompt} 请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。 @@ -310,7 +310,9 @@ class DefaultReplyer: Returns: str: 表达习惯信息字符串 """ - if not global_config.expression.enable_expression: + # 检查是否允许在此聊天流中使用表达 + use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id) + if not use_expression: return "" style_habits = [] @@ -854,7 +856,7 @@ class DefaultReplyer: core_dialogue_prompt=core_dialogue_prompt, reply_target_block=reply_target_block, message_txt=target, - config_expression_style=global_config.expression.expression_style, + reply_style=global_config.personality.reply_style, keywords_reaction_prompt=keywords_reaction_prompt, moderation_prompt=moderation_prompt_block, ) @@ -959,7 +961,7 @@ class DefaultReplyer: raw_reply=raw_reply, reason=reason, mood_state=mood_prompt, # 添加情绪状态参数 - config_expression_style=global_config.expression.expression_style, + reply_style=global_config.personality.reply_style, keywords_reaction_prompt=keywords_reaction_prompt, moderation_prompt=moderation_prompt_block, ) diff --git a/src/common/logger.py b/src/common/logger.py index e27fcb4ef..5db58d7d1 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -5,7 +5,7 @@ import json import threading import time import structlog -import toml +import tomlkit from pathlib import Path from typing import Callable, Optional @@ -188,22 +188,23 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress """从配置文件加载日志设置""" config_path = Path("config/bot_config.toml") default_config = { - "date_style": "Y-m-d H:i:s", + "date_style": "m-d H:i:s", "log_level_style": "lite", - "color_text": "title", + "color_text": "full", "log_level": "INFO", # 全局日志级别(向下兼容) "console_log_level": "INFO", # 控制台日志级别 "file_log_level": "DEBUG", # 文件日志级别 - "suppress_libraries": [], - "library_log_levels": {}, + "suppress_libraries": ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"], + "library_log_levels": { "aiohttp": "WARNING"}, } try: if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: - config = toml.load(f) + config = tomlkit.load(f) return config.get("log", default_config) - except Exception: + except Exception as e: + print(f"[日志系统] 加载日志配置失败: {e}") pass return default_config @@ -706,181 +707,6 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger: return logger -def configure_logging( - level: str = "INFO", - console_level: Optional[str] = None, - file_level: Optional[str] = None, - max_bytes: int = 5 * 1024 * 1024, - backup_count: int = 30, - log_dir: str = "logs", -): - """动态配置日志参数""" - log_path = Path(log_dir) - log_path.mkdir(exist_ok=True) - - # 更新文件handler配置 - file_handler = get_file_handler() - if file_handler and isinstance(file_handler, TimestampedFileHandler): - file_handler.max_bytes = max_bytes - file_handler.backup_count = backup_count - file_handler.log_dir = Path(log_dir) - - # 更新文件handler日志级别 - if file_level: - file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO)) - - # 更新控制台handler日志级别 - console_handler = get_console_handler() - if console_handler and console_level: - console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO)) - - # 设置根logger日志级别为最低级别 - if console_level or file_level: - console_level_num = getattr(logging, (console_level or level).upper(), logging.INFO) - file_level_num = getattr(logging, (file_level or level).upper(), logging.INFO) - min_level = min(console_level_num, file_level_num) - root_logger = logging.getLogger() - root_logger.setLevel(min_level) - else: - root_logger = logging.getLogger() - root_logger.setLevel(getattr(logging, level.upper())) - - - - - -def reload_log_config(): - """重新加载日志配置""" - global LOG_CONFIG - LOG_CONFIG = load_log_config() - - if file_handler := get_file_handler(): - file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO")) - file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO)) - - if console_handler := get_console_handler(): - console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO")) - console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO)) - - # 重新配置console渲染器 - root_logger = logging.getLogger() - for handler in root_logger.handlers: - if isinstance(handler, logging.StreamHandler): - # 这是控制台处理器,更新其格式化器 - handler.setFormatter( - structlog.stdlib.ProcessorFormatter( - processor=ModuleColoredConsoleRenderer(colors=True), - foreign_pre_chain=[ - structlog.stdlib.add_logger_name, - structlog.stdlib.add_log_level, - structlog.stdlib.PositionalArgumentsFormatter(), - structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False), - structlog.processors.StackInfoRenderer(), - structlog.processors.format_exc_info, - ], - ) - ) - - # 重新配置第三方库日志 - configure_third_party_loggers() - - # 重新配置所有已存在的logger - reconfigure_existing_loggers() - - -def get_log_config(): - """获取当前日志配置""" - return LOG_CONFIG.copy() - - -def set_console_log_level(level: str): - """设置控制台日志级别 - - Args: - level: 日志级别 ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") - """ - global LOG_CONFIG - LOG_CONFIG["console_log_level"] = level.upper() - - if console_handler := get_console_handler(): - console_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) - - # 重新设置root logger级别 - configure_third_party_loggers() - - logger = get_logger("logger") - logger.info(f"控制台日志级别已设置为: {level.upper()}") - - -def set_file_log_level(level: str): - """设置文件日志级别 - - Args: - level: 日志级别 ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") - """ - global LOG_CONFIG - LOG_CONFIG["file_log_level"] = level.upper() - - if file_handler := get_file_handler(): - file_handler.setLevel(getattr(logging, level.upper(), logging.INFO)) - - # 重新设置root logger级别 - configure_third_party_loggers() - - logger = get_logger("logger") - logger.info(f"文件日志级别已设置为: {level.upper()}") - - -def get_current_log_levels(): - """获取当前的日志级别设置""" - file_handler = get_file_handler() - console_handler = get_console_handler() - - file_level = logging.getLevelName(file_handler.level) if file_handler else "UNKNOWN" - console_level = logging.getLevelName(console_handler.level) if console_handler else "UNKNOWN" - - return { - "console_level": console_level, - "file_level": file_level, - "root_level": logging.getLevelName(logging.getLogger().level), - } - - -def force_reset_all_loggers(): - """强制重置所有logger,解决格式不一致问题""" - # 先关闭现有的handler - close_handlers() - - # 清除所有现有的logger配置 - logging.getLogger().manager.loggerDict.clear() - - # 重新配置根logger - root_logger = logging.getLogger() - root_logger.handlers.clear() - - # 使用单例handler避免重复创建 - file_handler = get_file_handler() - console_handler = get_console_handler() - - # 重新添加我们的handler - root_logger.addHandler(file_handler) - root_logger.addHandler(console_handler) - - # 设置格式化器 - file_handler.setFormatter(file_formatter) - console_handler.setFormatter(console_formatter) - - # 设置根logger级别为所有handler中最低的级别 - console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO")) - file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO")) - - console_level_num = getattr(logging, console_level.upper(), logging.INFO) - file_level_num = getattr(logging, file_level.upper(), logging.INFO) - min_level = min(console_level_num, file_level_num) - - root_logger.setLevel(min_level) - - def initialize_logging(): """手动初始化日志系统,确保所有logger都使用正确的配置 @@ -888,6 +714,7 @@ def initialize_logging(): """ global LOG_CONFIG LOG_CONFIG = load_log_config() + # print(LOG_CONFIG) configure_third_party_loggers() reconfigure_existing_loggers() @@ -899,77 +726,10 @@ def initialize_logging(): console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO")) file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO")) - logger.info("日志系统已重新初始化:") + logger.info("日志系统已初始化:") logger.info(f" - 控制台级别: {console_level}") logger.info(f" - 文件级别: {file_level}") - logger.info(" - 轮转份数: 30个文件") - logger.info(" - 自动清理: 30天前的日志") - - -def force_initialize_logging(): - """强制重新初始化整个日志系统,解决格式不一致问题""" - global LOG_CONFIG - LOG_CONFIG = load_log_config() - - # 强制重置所有logger - force_reset_all_loggers() - - # 重新配置structlog - configure_structlog() - - # 配置第三方库 - configure_third_party_loggers() - - # 输出初始化信息 - logger = get_logger("logger") - console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO")) - file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO")) - logger.info( - f"日志系统已强制重新初始化,控制台级别: {console_level},文件级别: {file_level},轮转份数: 30个文件,所有logger格式已统一" - ) - - -def show_module_colors(): - """显示所有模块的颜色效果""" - get_logger("demo") - print("\n=== 模块颜色展示 ===") - - for module_name, _color_code in MODULE_COLORS.items(): - # 临时创建一个该模块的logger来展示颜色 - demo_logger = structlog.get_logger(module_name).bind(logger_name=module_name) - alias = MODULE_ALIASES.get(module_name, module_name) - if alias != module_name: - demo_logger.info(f"这是 {module_name} 模块的颜色效果 (显示为: {alias})") - else: - demo_logger.info(f"这是 {module_name} 模块的颜色效果") - - print("=== 颜色展示结束 ===\n") - - # 显示别名映射表 - if MODULE_ALIASES: - print("=== 当前别名映射 ===") - for module_name, alias in MODULE_ALIASES.items(): - print(f" {module_name} -> {alias}") - print("=== 别名映射结束 ===\n") - - -def format_json_for_logging(data, indent=2, ensure_ascii=False): - """将JSON数据格式化为可读字符串 - - Args: - data: 要格式化的数据(字典、列表等) - indent: 缩进空格数 - ensure_ascii: 是否确保ASCII编码 - - Returns: - str: 格式化后的JSON字符串 - """ - if not isinstance(data, str): - # 如果是对象,直接格式化 - return json.dumps(data, indent=indent, ensure_ascii=ensure_ascii) - # 如果是JSON字符串,先解析再格式化 - parsed_data = json.loads(data) - return json.dumps(parsed_data, indent=indent, ensure_ascii=ensure_ascii) + logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志") def cleanup_old_logs(): @@ -1017,35 +777,6 @@ def start_log_cleanup_task(): logger.info("已启动日志清理任务,将自动清理30天前的日志文件(轮转份数限制: 30个文件)") -def get_log_stats(): - """获取日志文件统计信息""" - stats = {"total_files": 0, "total_size": 0, "files": []} - - try: - if not LOG_DIR.exists(): - return stats - - for log_file in LOG_DIR.glob("*.log*"): - file_info = { - "name": log_file.name, - "size": log_file.stat().st_size, - "modified": datetime.fromtimestamp(log_file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S"), - } - - stats["files"].append(file_info) - stats["total_files"] += 1 - stats["total_size"] += file_info["size"] - - # 按修改时间排序 - stats["files"].sort(key=lambda x: x["modified"], reverse=True) - - except Exception as e: - logger = get_logger("logger") - logger.error(f"获取日志统计信息时出错: {e}") - - return stats - - def shutdown_logging(): """优雅关闭日志系统,释放所有文件句柄""" logger = get_logger("logger") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index dfad134cc..7c8786bea 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -43,6 +43,9 @@ class PersonalityConfig(ConfigBase): identity: str = "" """身份特征""" + + reply_style: str = "" + """表达风格""" compress_personality: bool = True """是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭""" @@ -295,17 +298,24 @@ class NormalChatConfig(ConfigBase): class ExpressionConfig(ConfigBase): """表达配置类""" - enable_expression: bool = True - """是否启用表达方式""" - - expression_style: str = "" - """表达风格""" - - learning_interval: int = 300 - """学习间隔(秒)""" - - enable_expression_learning: bool = True - """是否启用表达学习""" + expression_learning: list[list] = field(default_factory=lambda: []) + """ + 表达学习配置列表,支持按聊天流配置 + 格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...] + + 示例: + [ + ["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0 + ["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5 + ["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5 + ] + + 说明: + - 第一位: chat_stream_id,空字符串表示全局配置 + - 第二位: 是否使用学到的表达 ("enable"/"disable") + - 第三位: 是否学习表达 ("enable"/"disable") + - 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒) + """ expression_groups: list[list[str]] = field(default_factory=list) """ @@ -313,6 +323,132 @@ class ExpressionConfig(ConfigBase): 格式: [["qq:12345:group", "qq:67890:private"]] """ + def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: + """ + 解析流配置字符串并生成对应的 chat_id + + Args: + stream_config_str: 格式为 "platform:id:type" 的字符串 + + Returns: + str: 生成的 chat_id,如果解析失败则返回 None + """ + try: + parts = stream_config_str.split(":") + if len(parts) != 3: + return None + + platform = parts[0] + id_str = parts[1] + stream_type = parts[2] + + # 判断是否为群聊 + is_group = stream_type == "group" + + # 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id + import hashlib + + if is_group: + components = [platform, str(id_str)] + else: + components = [platform, str(id_str), "private"] + key = "_".join(components) + return hashlib.md5(key.encode()).hexdigest() + + except (ValueError, IndexError): + return None + + def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]: + """ + 根据聊天流ID获取表达配置 + + Args: + chat_stream_id: 聊天流ID,格式为哈希值 + + Returns: + tuple: (是否使用表达, 是否学习表达, 学习间隔) + """ + if not self.expression_learning: + # 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔 + return True, True, 300 + + # 优先检查聊天流特定的配置 + if chat_stream_id: + specific_config = self._get_stream_specific_config(chat_stream_id) + if specific_config is not None: + return specific_config + + # 检查全局配置(第一个元素为空字符串的配置) + global_config = self._get_global_config() + if global_config is not None: + return global_config + + # 如果都没有匹配,返回默认值 + return True, True, 300 + + def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]: + """ + 获取特定聊天流的表达配置 + + Args: + chat_stream_id: 聊天流ID(哈希值) + + Returns: + tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None + """ + for config_item in self.expression_learning: + if not config_item or len(config_item) < 4: + continue + + stream_config_str = config_item[0] # 例如 "qq:1026294844:group" + + # 如果是空字符串,跳过(这是全局配置) + if stream_config_str == "": + continue + + # 解析配置字符串并生成对应的 chat_id + config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str) + if config_chat_id is None: + continue + + # 比较生成的 chat_id + if config_chat_id != chat_stream_id: + continue + + # 解析配置 + try: + use_expression = config_item[1].lower() == "enable" + enable_learning = config_item[2].lower() == "enable" + learning_intensity = float(config_item[3]) + return use_expression, enable_learning, learning_intensity + except (ValueError, IndexError): + continue + + return None + + def _get_global_config(self) -> Optional[tuple[bool, bool, int]]: + """ + 获取全局表达配置 + + Returns: + tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None + """ + for config_item in self.expression_learning: + if not config_item or len(config_item) < 4: + continue + + # 检查是否为全局配置(第一个元素为空字符串) + if config_item[0] == "": + try: + use_expression = config_item[1].lower() == "enable" + enable_learning = config_item[2].lower() == "enable" + learning_intensity = float(config_item[3]) + return use_expression, enable_learning, learning_intensity + except (ValueError, IndexError): + continue + + return None + @dataclass class ToolConfig(ConfigBase): diff --git a/src/main.py b/src/main.py index aed9a2bf1..ef673fd13 100644 --- a/src/main.py +++ b/src/main.py @@ -2,7 +2,6 @@ import asyncio import time from maim_message import MessageServer -from src.chat.express.expression_learner import get_expression_learner from src.common.remote import TelemetryHeartBeatTask from src.manager.async_task_manager import async_task_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask @@ -142,8 +141,6 @@ class MainSystem: ] ) - tasks.append(self.learn_and_store_expression_task()) - await asyncio.gather(*tasks) async def build_memory_task(self): @@ -169,17 +166,6 @@ class MainSystem: await self.hippocampus_manager.consolidate_memory() # type: ignore logger.info("[记忆整合] 记忆整合完成") - @staticmethod - async def learn_and_store_expression_task(): - """学习并存储表达方式任务""" - expression_learner = get_expression_learner() - while True: - await asyncio.sleep(global_config.expression.learning_interval) - if global_config.expression.enable_expression_learning and global_config.expression.enable_expression: - logger.info("[表达方式学习] 开始学习表达方式...") - await expression_learner.learn_and_store_expression() - logger.info("[表达方式学习] 表达方式学习完成") - async def main(): """主函数""" diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index c5ad9ca1f..1bef53051 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -40,7 +40,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: if global_config.memory.enable_memory: with Timer("记忆激活"): - interested_rate = await hippocampus_manager.get_activate_from_text( + interested_rate,_ = await hippocampus_manager.get_activate_from_text( message.processed_plain_text, fast_retrieval=True, ) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 8a285086f..574f23b29 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.1.0" +version = "6.2.1" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -26,22 +26,25 @@ personality_side = "用一句话或几句话描述人格的侧面特质" # 可以描述外貌,性别,身高,职业,属性等等描述 identity = "年龄为19岁,是女孩子,身高为160cm,有黑色的短发" +# 描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容 +reply_style = "回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。" + compress_personality = false # 是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭 compress_identity = true # 是否压缩身份,压缩后会精简身份信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果不长,可以关闭 [expression] -# 表达方式 -enable_expression = true # 是否启用表达方式 -# 描述麦麦说话的表达风格,表达习惯,例如:(请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景。) -expression_style = "回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。" - -enable_expression_learning = false # 是否启用表达学习,麦麦会学习不同群里人类说话风格(群之间不互通) -expression_learning = [ # 允许表达学习的聊天流列表,留空为全部允许 - # "qq:1919810:private", - # "qq:114514:private", - # "qq:1111111:group", +# 表达学习配置 +expression_learning = [ # 表达学习配置列表,支持按聊天流配置 + ["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0 + ["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5 + ["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5 + # 格式说明: + # 第一位: chat_stream_id,空字符串表示全局配置 + # 第二位: 是否使用学到的表达 ("enable"/"disable") + # 第三位: 是否学习表达 ("enable"/"disable") + # 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒) + # 学习强度越高,学习越频繁;学习强度越低,学习越少 ] -learning_interval = 350 # 学习间隔 单位秒 expression_groups = [ ["qq:1919810:private","qq:114514:private","qq:1111111:group"], # 在这里设置互通组,相同组的chat_id会共享学习到的表达方式 @@ -202,7 +205,7 @@ max_sentence_num = 8 # 回复允许的最大句子数 enable_kaomoji_protection = false # 是否启用颜文字保护 [log] -date_style = "Y-m-d H:i:s" # 日期格式 +date_style = "m-d H:i:s" # 日期格式 log_level_style = "lite" # 日志级别样式,可选FULL,compact,lite color_text = "full" # 日志文本颜色,可选none,title,full log_level = "INFO" # 全局日志级别(向下兼容,优先级低于下面的分别设置)