feat(expression): 重构表达学习配置,引入基于规则的结构化定义

本次更新旨在提高表达学习配置的可读性和灵活性。旧的基于列表的 `expression_learning` 和 `expression_groups` 配置方式难以理解且容易出错。

通过引入新的 `ExpressionRule` Pydantic模型,我们将所有相关设置(如聊天流ID、是否学习、学习强度、共享组等)整合到一个统一的、自描述的结构中。现在,所有规则都在一个新的 `[[expression.rules]]` 表中进行配置,使得逻辑更加清晰和易于维护。

相关模块,如 `ExpressionSelector`,已更新以适配新的配置结构。同时,数据库中的 `Expression` 模型也已更新为现代的 SQLAlchemy 2.0 风格。

BREAKING CHANGE: 表达学习的配置文件格式已完全改变。旧的 `expression_learning` 和 `expression_groups` 配置不再受支持,用户需要根据新的 `bot_config_template.toml` 文件迁移到 `[[expression.rules]]` 格式。
This commit is contained in:
minecraft1024a
2025-08-27 21:24:12 +08:00
parent 17e8755e66
commit eb469240d4
4 changed files with 88 additions and 117 deletions

View File

@@ -114,16 +114,27 @@ class ExpressionSelector:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
"""根据expression.rules配置获取与当前chat_id相关的所有chat_id包括自身"""
rules = global_config.expression.rules
current_group = None
# 找到当前chat_id所在的组
for rule in rules:
if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_id:
current_group = rule.group
break
if not current_group:
return [chat_id]
# 找出同一组的所有chat_id
related_chat_ids = []
for rule in rules:
if rule.group == current_group and rule.chat_stream_id:
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
related_chat_ids.append(chat_id_candidate)
return related_chat_ids if related_chat_ids else [chat_id]
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float

View File

@@ -5,12 +5,12 @@
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column
from sqlalchemy.pool import QueuePool
import os
import datetime
import time
from typing import Iterator, Optional
from typing import Iterator, Optional, Any, Dict
from src.common.logger import get_logger
from contextlib import contextmanager
@@ -306,14 +306,14 @@ class Expression(Base):
"""表达风格模型"""
__tablename__ = 'expression'
id = Column(Integer, primary_key=True, autoincrement=True)
situation = Column(Text, nullable=False)
style = Column(Text, nullable=False)
count = Column(Float, nullable=False)
last_active_time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
type = Column(Text, nullable=False)
create_date = Column(Float, nullable=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
situation: Mapped[str] = mapped_column(Text, nullable=False)
style: Mapped[str] = mapped_column(Text, nullable=False)
count: Mapped[float] = mapped_column(Float, nullable=False)
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
type: Mapped[str] = mapped_column(Text, nullable=False)
create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
__table_args__ = (
Index('idx_expression_chat_id', 'chat_id'),
@@ -589,7 +589,7 @@ def initialize_database():
config = global_config.database
# 配置引擎参数
engine_kwargs = {
engine_kwargs: Dict[str, Any] = {
'echo': False, # 生产环境关闭SQL日志
'future': True,
}
@@ -642,7 +642,9 @@ def get_db_session() -> Iterator[Session]:
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
session: Optional[Session] = None
try:
_, SessionLocal = initialize_database()
engine, SessionLocal = initialize_database()
if not SessionLocal:
raise RuntimeError("Database session not initialized")
session = SessionLocal()
yield session
#session.commit()

View File

@@ -263,11 +263,20 @@ class NormalChatConfig(ValidatedConfigBase):
class ExpressionRule(ValidatedConfigBase):
"""表达学习规则"""
chat_stream_id: str = Field(..., description="聊天流ID空字符串表示全局")
use_expression: bool = Field(default=True, description="是否使用学到的表达")
learn_expression: bool = Field(default=True, description="是否学习表达")
learning_strength: float = Field(default=1.0, description="学习强度")
group: Optional[str] = Field(default=None, description="表达共享组")
class ExpressionConfig(ValidatedConfigBase):
"""表达配置类"""
expression_learning: list[list] = Field(default_factory=lambda: [], description="表达学习")
expression_groups: list[list[str]] = Field(default_factory=list, description="表达组")
rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""
@@ -314,86 +323,23 @@ class ExpressionConfig(ValidatedConfigBase):
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)
"""
if not self.expression_learning:
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
return True, True, 300
if not self.rules:
# 如果没有配置,使用默认值:启用表达,启用学习,强度1.0
return True, True, 1.0
# 优先检查聊天流特定的配置
if chat_stream_id:
specific_config = self._get_stream_specific_config(chat_stream_id)
if specific_config is not None:
return specific_config
for rule in self.rules:
if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_stream_id:
return rule.use_expression, rule.learn_expression, rule.learning_strength
# 检查全局配置(第一个元素为空字符串的配置)
global_config = self._get_global_config()
if global_config is not None:
return global_config
# 检查全局配置(chat_stream_id为空字符串的配置)
for rule in self.rules:
if rule.chat_stream_id == "":
return rule.use_expression, rule.learn_expression, rule.learning_strength
# 如果都没有匹配,返回默认值
return True, True, 300
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, float]]:
"""
获取特定聊天流的表达配置
Args:
chat_stream_id: 聊天流ID哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
"""
for config_item in self.expression_learning:
if not config_item or len(config_item) < 4:
continue
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
# 如果是空字符串,跳过(这是全局配置)
if stream_config_str == "":
continue
# 解析配置字符串并生成对应的 chat_id
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
if config_chat_id is None:
continue
# 比较生成的 chat_id
if config_chat_id != chat_stream_id:
continue
# 解析配置
try:
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
return None
def _get_global_config(self) -> Optional[tuple[bool, bool, float]]:
"""
获取全局表达配置
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
"""
for config_item in self.expression_learning:
if not config_item or len(config_item) < 4:
continue
# 检查是否为全局配置(第一个元素为空字符串)
if config_item[0] == "":
try:
use_expression = config_item[1].lower() == "enable"
enable_learning = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity
except (ValueError, IndexError):
continue
return None
return True, True, 1.0
class ToolHistoryConfig(ValidatedConfigBase):

View File

@@ -44,9 +44,7 @@ connection_timeout = 10 # 连接超时时间(秒)
# Master用户配置拥有最高权限无视所有权限节点
# 格式:[[platform, user_id], ...]
# 示例:[["qq", "123456"], ["telegram", "user789"]]
master_users = [
# ["qq", "123456789"], # 示例QQ平台的Master用户
]
master_users = []# ["qq", "123456789"], # 示例QQ平台的Master用户
[bot]
platform = "qq"
@@ -74,23 +72,37 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
[expression]
# 表达学习配置
expression_learning = [ # 表达学习配置列表,支持按聊天流配置
["", "enable", "enable", 1.0], # 全局配置使用表达启用学习学习强度1.0
["qq:1919810:group", "enable", "enable", 1.5], # 特定群聊配置使用表达启用学习学习强度1.5
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置使用表达禁用学习学习强度0.5
# 格式说明:
# 第一位: chat_stream_id空字符串表示全局配置
# 第二位: 是否使用学到的表达 ("enable"/"disable")
# 第三位: 是否学习表达 ("enable"/"disable")
# 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒)
# 学习强度越高,学习越频繁;学习强度越低,学习越少
]
# rules是一个列表每个元素都是一个学习规则
# chat_stream_id: 聊天流ID格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置
# use_expression: 是否使用学到的表达 (true/false)
# learn_expression: 是否学习新的表达 (true/false)
# learning_strength: 学习强度(浮点数),影响学习频率
# group: 表达共享组的名称(字符串),相同组的聊天会共享学习到的表达方式
[[expression.rules]]
chat_stream_id = ""
use_expression = true
learn_expression = true
learning_strength = 1.0
expression_groups = [
["qq:1919810:private","qq:114514:private","qq:1111111:group"], # 在这里设置互通组相同组的chat_id会共享学习到的表达方式
# 格式:["qq:123456:private","qq:654321:group"]
# 注意如果为群聊则需要设置为group如果设置为私聊则需要设置为private
]
[[expression.rules]]
chat_stream_id = "qq:1919810:group"
use_expression = true
learn_expression = true
learning_strength = 1.5
[[expression.rules]]
chat_stream_id = "qq:114514:private"
group = "group_A"
use_expression = true
learn_expression = false
learning_strength = 0.5
[[expression.rules]]
chat_stream_id = "qq:1919810:private"
group = "group_A"
use_expression = true
learn_expression = true
learning_strength = 1.0