From ef630cd6c36093846f5c55baf1c7bd2e80032451 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Wed, 27 Aug 2025 21:24:12 +0800 Subject: [PATCH] =?UTF-8?q?feat(expression):=20=E9=87=8D=E6=9E=84=E8=A1=A8?= =?UTF-8?q?=E8=BE=BE=E5=AD=A6=E4=B9=A0=E9=85=8D=E7=BD=AE=EF=BC=8C=E5=BC=95?= =?UTF-8?q?=E5=85=A5=E5=9F=BA=E4=BA=8E=E8=A7=84=E5=88=99=E7=9A=84=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E5=8C=96=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次更新旨在提高表达学习配置的可读性和灵活性。旧的基于列表的 `expression_learning` 和 `expression_groups` 配置方式难以理解且容易出错。 通过引入新的 `ExpressionRule` Pydantic模型,我们将所有相关设置(如聊天流ID、是否学习、学习强度、共享组等)整合到一个统一的、自描述的结构中。现在,所有规则都在一个新的 `[[expression.rules]]` 表中进行配置,使得逻辑更加清晰和易于维护。 相关模块,如 `ExpressionSelector`,已更新以适配新的配置结构。同时,数据库中的 `Expression` 模型也已更新为现代的 SQLAlchemy 2.0 风格。 BREAKING CHANGE: 表达学习的配置文件格式已完全改变。旧的 `expression_learning` 和 `expression_groups` 配置不再受支持,用户需要根据新的 `bot_config_template.toml` 文件迁移到 `[[expression.rules]]` 格式。 --- src/chat/express/expression_selector.py | 31 +++++--- src/common/database/sqlalchemy_models.py | 26 ++++--- src/config/official_configs.py | 98 ++++++------------------ template/bot_config_template.toml | 50 +++++++----- 4 files changed, 88 insertions(+), 117 deletions(-) diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 497f43eba..0b360a85b 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -114,16 +114,27 @@ class ExpressionSelector: return None def get_related_chat_ids(self, chat_id: str) -> List[str]: - """根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)""" - groups = global_config.expression.expression_groups - for group in groups: - group_chat_ids = [] - for stream_config_str in group: - if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str): - group_chat_ids.append(chat_id_candidate) - if chat_id in group_chat_ids: - return group_chat_ids - return [chat_id] + """根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)""" + rules = global_config.expression.rules + current_group = None + + # 找到当前chat_id所在的组 + for rule in rules: + if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_id: + current_group = rule.group + break + + if not current_group: + return [chat_id] + + # 找出同一组的所有chat_id + related_chat_ids = [] + for rule in rules: + if rule.group == current_group and rule.chat_stream_id: + if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id): + related_chat_ids.append(chat_id_candidate) + + return related_chat_ids if related_chat_ids else [chat_id] def get_random_expressions( self, chat_id: str, total_num: int diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index b63b65661..779179ff9 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -5,12 +5,12 @@ from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column from sqlalchemy.pool import QueuePool import os import datetime import time -from typing import Iterator, Optional +from typing import Iterator, Optional, Any, Dict from src.common.logger import get_logger from contextlib import contextmanager @@ -306,14 +306,14 @@ class Expression(Base): """表达风格模型""" __tablename__ = 'expression' - id = Column(Integer, primary_key=True, autoincrement=True) - situation = Column(Text, nullable=False) - style = Column(Text, nullable=False) - count = Column(Float, nullable=False) - last_active_time = Column(Float, nullable=False) - chat_id = Column(get_string_field(64), nullable=False, index=True) - type = Column(Text, nullable=False) - create_date = Column(Float, nullable=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + situation: Mapped[str] = mapped_column(Text, nullable=False) + style: Mapped[str] = mapped_column(Text, nullable=False) + count: Mapped[float] = mapped_column(Float, nullable=False) + last_active_time: Mapped[float] = mapped_column(Float, nullable=False) + chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + type: Mapped[str] = mapped_column(Text, nullable=False) + create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True) __table_args__ = ( Index('idx_expression_chat_id', 'chat_id'), @@ -589,7 +589,7 @@ def initialize_database(): config = global_config.database # 配置引擎参数 - engine_kwargs = { + engine_kwargs: Dict[str, Any] = { 'echo': False, # 生产环境关闭SQL日志 'future': True, } @@ -642,7 +642,9 @@ def get_db_session() -> Iterator[Session]: """数据库会话上下文管理器 - 推荐使用这个而不是get_session()""" session: Optional[Session] = None try: - _, SessionLocal = initialize_database() + engine, SessionLocal = initialize_database() + if not SessionLocal: + raise RuntimeError("Database session not initialized") session = SessionLocal() yield session #session.commit() diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 7a46688dc..8c7aae355 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -263,11 +263,20 @@ class NormalChatConfig(ValidatedConfigBase): +class ExpressionRule(ValidatedConfigBase): + """表达学习规则""" + + chat_stream_id: str = Field(..., description="聊天流ID,空字符串表示全局") + use_expression: bool = Field(default=True, description="是否使用学到的表达") + learn_expression: bool = Field(default=True, description="是否学习表达") + learning_strength: float = Field(default=1.0, description="学习强度") + group: Optional[str] = Field(default=None, description="表达共享组") + + class ExpressionConfig(ValidatedConfigBase): """表达配置类""" - expression_learning: list[list] = Field(default_factory=lambda: [], description="表达学习") - expression_groups: list[list[str]] = Field(default_factory=list, description="表达组") + rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则") def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: """ @@ -314,86 +323,23 @@ class ExpressionConfig(ValidatedConfigBase): Returns: tuple: (是否使用表达, 是否学习表达, 学习间隔) """ - if not self.expression_learning: - # 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔 - return True, True, 300 + if not self.rules: + # 如果没有配置,使用默认值:启用表达,启用学习,强度1.0 + return True, True, 1.0 # 优先检查聊天流特定的配置 if chat_stream_id: - specific_config = self._get_stream_specific_config(chat_stream_id) - if specific_config is not None: - return specific_config + for rule in self.rules: + if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_stream_id: + return rule.use_expression, rule.learn_expression, rule.learning_strength - # 检查全局配置(第一个元素为空字符串的配置) - global_config = self._get_global_config() - if global_config is not None: - return global_config + # 检查全局配置(chat_stream_id为空字符串的配置) + for rule in self.rules: + if rule.chat_stream_id == "": + return rule.use_expression, rule.learn_expression, rule.learning_strength # 如果都没有匹配,返回默认值 - return True, True, 300 - - def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, float]]: - """ - 获取特定聊天流的表达配置 - - Args: - chat_stream_id: 聊天流ID(哈希值) - - Returns: - tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None - """ - for config_item in self.expression_learning: - if not config_item or len(config_item) < 4: - continue - - stream_config_str = config_item[0] # 例如 "qq:1026294844:group" - - # 如果是空字符串,跳过(这是全局配置) - if stream_config_str == "": - continue - - # 解析配置字符串并生成对应的 chat_id - config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str) - if config_chat_id is None: - continue - - # 比较生成的 chat_id - if config_chat_id != chat_stream_id: - continue - - # 解析配置 - try: - use_expression = config_item[1].lower() == "enable" - enable_learning = config_item[2].lower() == "enable" - learning_intensity = float(config_item[3]) - return use_expression, enable_learning, learning_intensity - except (ValueError, IndexError): - continue - - return None - - def _get_global_config(self) -> Optional[tuple[bool, bool, float]]: - """ - 获取全局表达配置 - - Returns: - tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None - """ - for config_item in self.expression_learning: - if not config_item or len(config_item) < 4: - continue - - # 检查是否为全局配置(第一个元素为空字符串) - if config_item[0] == "": - try: - use_expression = config_item[1].lower() == "enable" - enable_learning = config_item[2].lower() == "enable" - learning_intensity = float(config_item[3]) - return use_expression, enable_learning, learning_intensity - except (ValueError, IndexError): - continue - - return None + return True, True, 1.0 class ToolHistoryConfig(ValidatedConfigBase): diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 04acd3f5d..a81872f12 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -44,9 +44,7 @@ connection_timeout = 10 # 连接超时时间(秒) # Master用户配置(拥有最高权限,无视所有权限节点) # 格式:[[platform, user_id], ...] # 示例:[["qq", "123456"], ["telegram", "user789"]] -master_users = [ - # ["qq", "123456789"], # 示例:QQ平台的Master用户 -] +master_users = []# ["qq", "123456789"], # 示例:QQ平台的Master用户 [bot] platform = "qq" @@ -74,23 +72,37 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息 [expression] # 表达学习配置 -expression_learning = [ # 表达学习配置列表,支持按聊天流配置 - ["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0 - ["qq:1919810:group", "enable", "enable", 1.5], # 特定群聊配置:使用表达,启用学习,学习强度1.5 - ["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5 - # 格式说明: - # 第一位: chat_stream_id,空字符串表示全局配置 - # 第二位: 是否使用学到的表达 ("enable"/"disable") - # 第三位: 是否学习表达 ("enable"/"disable") - # 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒) - # 学习强度越高,学习越频繁;学习强度越低,学习越少 -] +# rules是一个列表,每个元素都是一个学习规则 +# chat_stream_id: 聊天流ID,格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置 +# use_expression: 是否使用学到的表达 (true/false) +# learn_expression: 是否学习新的表达 (true/false) +# learning_strength: 学习强度(浮点数),影响学习频率 +# group: 表达共享组的名称(字符串),相同组的聊天会共享学习到的表达方式 +[[expression.rules]] +chat_stream_id = "" +use_expression = true +learn_expression = true +learning_strength = 1.0 -expression_groups = [ - ["qq:1919810:private","qq:114514:private","qq:1111111:group"], # 在这里设置互通组,相同组的chat_id会共享学习到的表达方式 - # 格式:["qq:123456:private","qq:654321:group"] - # 注意:如果为群聊,则需要设置为group,如果设置为私聊,则需要设置为private -] +[[expression.rules]] +chat_stream_id = "qq:1919810:group" +use_expression = true +learn_expression = true +learning_strength = 1.5 + +[[expression.rules]] +chat_stream_id = "qq:114514:private" +group = "group_A" +use_expression = true +learn_expression = false +learning_strength = 0.5 + +[[expression.rules]] +chat_stream_id = "qq:1919810:private" +group = "group_A" +use_expression = true +learn_expression = true +learning_strength = 1.0