feat(expression): 重构表达学习配置,引入基于规则的结构化定义

本次更新旨在提高表达学习配置的可读性和灵活性。旧的基于列表的 `expression_learning` 和 `expression_groups` 配置方式难以理解且容易出错。

通过引入新的 `ExpressionRule` Pydantic模型,我们将所有相关设置(如聊天流ID、是否学习、学习强度、共享组等)整合到一个统一的、自描述的结构中。现在,所有规则都在一个新的 `[[expression.rules]]` 表中进行配置,使得逻辑更加清晰和易于维护。

相关模块,如 `ExpressionSelector`,已更新以适配新的配置结构。同时,数据库中的 `Expression` 模型也已更新为现代的 SQLAlchemy 2.0 风格。

BREAKING CHANGE: 表达学习的配置文件格式已完全改变。旧的 `expression_learning` 和 `expression_groups` 配置不再受支持,用户需要根据新的 `bot_config_template.toml` 文件迁移到 `[[expression.rules]]` 格式。
This commit is contained in:
minecraft1024a
2025-08-27 21:24:12 +08:00
committed by Windpicker-owo
parent d0bb520869
commit ef630cd6c3
4 changed files with 88 additions and 117 deletions

View File

@@ -114,16 +114,27 @@ class ExpressionSelector:
return None return None
def get_related_chat_ids(self, chat_id: str) -> List[str]: def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身""" """根据expression.rules配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups rules = global_config.expression.rules
for group in groups: current_group = None
group_chat_ids = []
for stream_config_str in group: # 找到当前chat_id所在的组
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str): for rule in rules:
group_chat_ids.append(chat_id_candidate) if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_id:
if chat_id in group_chat_ids: current_group = rule.group
return group_chat_ids break
return [chat_id]
if not current_group:
return [chat_id]
# 找出同一组的所有chat_id
related_chat_ids = []
for rule in rules:
if rule.group == current_group and rule.chat_stream_id:
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
related_chat_ids.append(chat_id_candidate)
return related_chat_ids if related_chat_ids else [chat_id]
def get_random_expressions( def get_random_expressions(
self, chat_id: str, total_num: int self, chat_id: str, total_num: int

View File

@@ -5,12 +5,12 @@
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column
from sqlalchemy.pool import QueuePool from sqlalchemy.pool import QueuePool
import os import os
import datetime import datetime
import time import time
from typing import Iterator, Optional from typing import Iterator, Optional, Any, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from contextlib import contextmanager from contextlib import contextmanager
@@ -306,14 +306,14 @@ class Expression(Base):
"""表达风格模型""" """表达风格模型"""
__tablename__ = 'expression' __tablename__ = 'expression'
id = Column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
situation = Column(Text, nullable=False) situation: Mapped[str] = mapped_column(Text, nullable=False)
style = Column(Text, nullable=False) style: Mapped[str] = mapped_column(Text, nullable=False)
count = Column(Float, nullable=False) count: Mapped[float] = mapped_column(Float, nullable=False)
last_active_time = Column(Float, nullable=False) last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True) chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
type = Column(Text, nullable=False) type: Mapped[str] = mapped_column(Text, nullable=False)
create_date = Column(Float, nullable=True) create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
__table_args__ = ( __table_args__ = (
Index('idx_expression_chat_id', 'chat_id'), Index('idx_expression_chat_id', 'chat_id'),
@@ -589,7 +589,7 @@ def initialize_database():
config = global_config.database config = global_config.database
# 配置引擎参数 # 配置引擎参数
engine_kwargs = { engine_kwargs: Dict[str, Any] = {
'echo': False, # 生产环境关闭SQL日志 'echo': False, # 生产环境关闭SQL日志
'future': True, 'future': True,
} }
@@ -642,7 +642,9 @@ def get_db_session() -> Iterator[Session]:
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()""" """数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
session: Optional[Session] = None session: Optional[Session] = None
try: try:
_, SessionLocal = initialize_database() engine, SessionLocal = initialize_database()
if not SessionLocal:
raise RuntimeError("Database session not initialized")
session = SessionLocal() session = SessionLocal()
yield session yield session
#session.commit() #session.commit()

View File

@@ -263,11 +263,20 @@ class NormalChatConfig(ValidatedConfigBase):
class ExpressionRule(ValidatedConfigBase):
"""表达学习规则"""
chat_stream_id: str = Field(..., description="聊天流ID空字符串表示全局")
use_expression: bool = Field(default=True, description="是否使用学到的表达")
learn_expression: bool = Field(default=True, description="是否学习表达")
learning_strength: float = Field(default=1.0, description="学习强度")
group: Optional[str] = Field(default=None, description="表达共享组")
class ExpressionConfig(ValidatedConfigBase): class ExpressionConfig(ValidatedConfigBase):
"""表达配置类""" """表达配置类"""
expression_learning: list[list] = Field(default_factory=lambda: [], description="表达学习") rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
expression_groups: list[list[str]] = Field(default_factory=list, description="表达组")
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
""" """
@@ -314,86 +323,23 @@ class ExpressionConfig(ValidatedConfigBase):
Returns: Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔) tuple: (是否使用表达, 是否学习表达, 学习间隔)
""" """
if not self.expression_learning: if not self.rules:
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔 # 如果没有配置,使用默认值:启用表达,启用学习,强度1.0
return True, True, 300 return True, True, 1.0
# 优先检查聊天流特定的配置 # 优先检查聊天流特定的配置
if chat_stream_id: if chat_stream_id:
specific_config = self._get_stream_specific_config(chat_stream_id) for rule in self.rules:
if specific_config is not None: if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_stream_id:
return specific_config return rule.use_expression, rule.learn_expression, rule.learning_strength
# 检查全局配置(第一个元素为空字符串的配置) # 检查全局配置(chat_stream_id为空字符串的配置)
global_config = self._get_global_config() for rule in self.rules:
if global_config is not None: if rule.chat_stream_id == "":
return global_config return rule.use_expression, rule.learn_expression, rule.learning_strength
# 如果都没有匹配,返回默认值 # 如果都没有匹配,返回默认值
return True, True, 300 return True, True, 1.0
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, float]]:
"""
获取特定聊天流的表达配置
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
"""
for config_item in self.expression_learning:
if not config_item or len(config_item) < 4:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 如果是空字符串,跳过(这是全局配置)
if stream_config_str == "":
continue
# 解析配置字符串并生成对应的 chat_id
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_stream_id:
continue
# 解析配置
try:
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
return None
def _get_global_config(self) -> Optional[tuple[bool, bool, float]]:
"""
获取全局表达配置
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
"""
for config_item in self.expression_learning:
if not config_item or len(config_item) < 4:
continue
# 检查是否为全局配置(第一个元素为空字符串)
if config_item[0] == "":
try:
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
return None
class ToolHistoryConfig(ValidatedConfigBase): class ToolHistoryConfig(ValidatedConfigBase):

View File

@@ -44,9 +44,7 @@ connection_timeout = 10 # 连接超时时间(秒)
# Master用户配置拥有最高权限无视所有权限节点 # Master用户配置拥有最高权限无视所有权限节点
# 格式:[[platform, user_id], ...] # 格式:[[platform, user_id], ...]
# 示例:[["qq", "123456"], ["telegram", "user789"]] # 示例:[["qq", "123456"], ["telegram", "user789"]]
master_users = [ master_users = []# ["qq", "123456789"], # 示例QQ平台的Master用户
# ["qq", "123456789"], # 示例QQ平台的Master用户
]
[bot] [bot]
platform = "qq" platform = "qq"
@@ -74,23 +72,37 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
[expression] [expression]
# 表达学习配置 # 表达学习配置
expression_learning = [ # 表达学习配置列表,支持按聊天流配置 # rules是一个列表每个元素都是一个学习规则
["", "enable", "enable", 1.0], # 全局配置使用表达启用学习学习强度1.0 # chat_stream_id: 聊天流ID格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置
["qq:1919810:group", "enable", "enable", 1.5], # 特定群聊配置使用表达启用学习学习强度1.5 # use_expression: 是否使用学到的表达 (true/false)
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置使用表达禁用学习学习强度0.5 # learn_expression: 是否学习新的表达 (true/false)
# 格式说明: # learning_strength: 学习强度(浮点数),影响学习频率
# 第一位: chat_stream_id空字符串表示全局配置 # group: 表达共享组的名称(字符串),相同组的聊天会共享学习到的表达方式
# 第二位: 是否使用学到的表达 ("enable"/"disable") [[expression.rules]]
# 第三位: 是否学习表达 ("enable"/"disable") chat_stream_id = ""
# 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒) use_expression = true
# 学习强度越高,学习越频繁;学习强度越低,学习越少 learn_expression = true
] learning_strength = 1.0
expression_groups = [ [[expression.rules]]
["qq:1919810:private","qq:114514:private","qq:1111111:group"], # 在这里设置互通组相同组的chat_id会共享学习到的表达方式 chat_stream_id = "qq:1919810:group"
# 格式:["qq:123456:private","qq:654321:group"] use_expression = true
# 注意如果为群聊则需要设置为group如果设置为私聊则需要设置为private learn_expression = true
] learning_strength = 1.5
[[expression.rules]]
chat_stream_id = "qq:114514:private"
group = "group_A"
use_expression = true
learn_expression = false
learning_strength = 0.5
[[expression.rules]]
chat_stream_id = "qq:1919810:private"
group = "group_A"
use_expression = true
learn_expression = true
learning_strength = 1.0