feat(expression): 重构表达学习配置,引入基于规则的结构化定义
本次更新旨在提高表达学习配置的可读性和灵活性。旧的基于列表的 `expression_learning` 和 `expression_groups` 配置方式难以理解且容易出错。 通过引入新的 `ExpressionRule` Pydantic模型,我们将所有相关设置(如聊天流ID、是否学习、学习强度、共享组等)整合到一个统一的、自描述的结构中。现在,所有规则都在一个新的 `[[expression.rules]]` 表中进行配置,使得逻辑更加清晰和易于维护。 相关模块,如 `ExpressionSelector`,已更新以适配新的配置结构。同时,数据库中的 `Expression` 模型也已更新为现代的 SQLAlchemy 2.0 风格。 BREAKING CHANGE: 表达学习的配置文件格式已完全改变。旧的 `expression_learning` 和 `expression_groups` 配置不再受支持,用户需要根据新的 `bot_config_template.toml` 文件迁移到 `[[expression.rules]]` 格式。
This commit is contained in:
@@ -114,16 +114,27 @@ class ExpressionSelector:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
"""根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
rules = global_config.expression.rules
|
||||
current_group = None
|
||||
|
||||
# 找到当前chat_id所在的组
|
||||
for rule in rules:
|
||||
if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_id:
|
||||
current_group = rule.group
|
||||
break
|
||||
|
||||
if not current_group:
|
||||
return [chat_id]
|
||||
|
||||
# 找出同一组的所有chat_id
|
||||
related_chat_ids = []
|
||||
for rule in rules:
|
||||
if rule.group == current_group and rule.chat_stream_id:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
|
||||
related_chat_ids.append(chat_id_candidate)
|
||||
|
||||
return related_chat_ids if related_chat_ids else [chat_id]
|
||||
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.orm import sessionmaker, Session, Mapped, mapped_column
|
||||
from sqlalchemy.pool import QueuePool
|
||||
import os
|
||||
import datetime
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterator, Optional, Any, Dict
|
||||
from src.common.logger import get_logger
|
||||
from contextlib import contextmanager
|
||||
|
||||
@@ -306,14 +306,14 @@ class Expression(Base):
|
||||
"""表达风格模型"""
|
||||
__tablename__ = 'expression'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
situation = Column(Text, nullable=False)
|
||||
style = Column(Text, nullable=False)
|
||||
count = Column(Float, nullable=False)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
type = Column(Text, nullable=False)
|
||||
create_date = Column(Float, nullable=True)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
situation: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
style: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
count: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_expression_chat_id', 'chat_id'),
|
||||
@@ -589,7 +589,7 @@ def initialize_database():
|
||||
config = global_config.database
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs = {
|
||||
engine_kwargs: Dict[str, Any] = {
|
||||
'echo': False, # 生产环境关闭SQL日志
|
||||
'future': True,
|
||||
}
|
||||
@@ -642,7 +642,9 @@ def get_db_session() -> Iterator[Session]:
|
||||
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
|
||||
session: Optional[Session] = None
|
||||
try:
|
||||
_, SessionLocal = initialize_database()
|
||||
engine, SessionLocal = initialize_database()
|
||||
if not SessionLocal:
|
||||
raise RuntimeError("Database session not initialized")
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
#session.commit()
|
||||
|
||||
@@ -263,11 +263,20 @@ class NormalChatConfig(ValidatedConfigBase):
|
||||
|
||||
|
||||
|
||||
class ExpressionRule(ValidatedConfigBase):
|
||||
"""表达学习规则"""
|
||||
|
||||
chat_stream_id: str = Field(..., description="聊天流ID,空字符串表示全局")
|
||||
use_expression: bool = Field(default=True, description="是否使用学到的表达")
|
||||
learn_expression: bool = Field(default=True, description="是否学习表达")
|
||||
learning_strength: float = Field(default=1.0, description="学习强度")
|
||||
group: Optional[str] = Field(default=None, description="表达共享组")
|
||||
|
||||
|
||||
class ExpressionConfig(ValidatedConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
expression_learning: list[list] = Field(default_factory=lambda: [], description="表达学习")
|
||||
expression_groups: list[list[str]] = Field(default_factory=list, description="表达组")
|
||||
rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
|
||||
|
||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||
"""
|
||||
@@ -314,86 +323,23 @@ class ExpressionConfig(ValidatedConfigBase):
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔)
|
||||
"""
|
||||
if not self.expression_learning:
|
||||
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
|
||||
return True, True, 300
|
||||
if not self.rules:
|
||||
# 如果没有配置,使用默认值:启用表达,启用学习,强度1.0
|
||||
return True, True, 1.0
|
||||
|
||||
# 优先检查聊天流特定的配置
|
||||
if chat_stream_id:
|
||||
specific_config = self._get_stream_specific_config(chat_stream_id)
|
||||
if specific_config is not None:
|
||||
return specific_config
|
||||
for rule in self.rules:
|
||||
if rule.chat_stream_id and self._parse_stream_config_to_chat_id(rule.chat_stream_id) == chat_stream_id:
|
||||
return rule.use_expression, rule.learn_expression, rule.learning_strength
|
||||
|
||||
# 检查全局配置(第一个元素为空字符串的配置)
|
||||
global_config = self._get_global_config()
|
||||
if global_config is not None:
|
||||
return global_config
|
||||
# 检查全局配置(chat_stream_id为空字符串的配置)
|
||||
for rule in self.rules:
|
||||
if rule.chat_stream_id == "":
|
||||
return rule.use_expression, rule.learn_expression, rule.learning_strength
|
||||
|
||||
# 如果都没有匹配,返回默认值
|
||||
return True, True, 300
|
||||
|
||||
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, float]]:
|
||||
"""
|
||||
获取特定聊天流的表达配置
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.expression_learning:
|
||||
if not config_item or len(config_item) < 4:
|
||||
continue
|
||||
|
||||
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||
|
||||
# 如果是空字符串,跳过(这是全局配置)
|
||||
if stream_config_str == "":
|
||||
continue
|
||||
|
||||
# 解析配置字符串并生成对应的 chat_id
|
||||
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
|
||||
if config_chat_id is None:
|
||||
continue
|
||||
|
||||
# 比较生成的 chat_id
|
||||
if config_chat_id != chat_stream_id:
|
||||
continue
|
||||
|
||||
# 解析配置
|
||||
try:
|
||||
use_expression = config_item[1].lower() == "enable"
|
||||
enable_learning = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _get_global_config(self) -> Optional[tuple[bool, bool, float]]:
|
||||
"""
|
||||
获取全局表达配置
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.expression_learning:
|
||||
if not config_item or len(config_item) < 4:
|
||||
continue
|
||||
|
||||
# 检查是否为全局配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
try:
|
||||
use_expression = config_item[1].lower() == "enable"
|
||||
enable_learning = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
return True, True, 1.0
|
||||
|
||||
|
||||
class ToolHistoryConfig(ValidatedConfigBase):
|
||||
|
||||
@@ -44,9 +44,7 @@ connection_timeout = 10 # 连接超时时间(秒)
|
||||
# Master用户配置(拥有最高权限,无视所有权限节点)
|
||||
# 格式:[[platform, user_id], ...]
|
||||
# 示例:[["qq", "123456"], ["telegram", "user789"]]
|
||||
master_users = [
|
||||
# ["qq", "123456789"], # 示例:QQ平台的Master用户
|
||||
]
|
||||
master_users = []# ["qq", "123456789"], # 示例:QQ平台的Master用户
|
||||
|
||||
[bot]
|
||||
platform = "qq"
|
||||
@@ -74,23 +72,37 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
|
||||
|
||||
[expression]
|
||||
# 表达学习配置
|
||||
expression_learning = [ # 表达学习配置列表,支持按聊天流配置
|
||||
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
["qq:1919810:group", "enable", "enable", 1.5], # 特定群聊配置:使用表达,启用学习,学习强度1.5
|
||||
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||
# 格式说明:
|
||||
# 第一位: chat_stream_id,空字符串表示全局配置
|
||||
# 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||
# 第三位: 是否学习表达 ("enable"/"disable")
|
||||
# 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒)
|
||||
# 学习强度越高,学习越频繁;学习强度越低,学习越少
|
||||
]
|
||||
# rules是一个列表,每个元素都是一个学习规则
|
||||
# chat_stream_id: 聊天流ID,格式为 "platform:id:type",例如 "qq:123456:private"。空字符串""表示全局配置
|
||||
# use_expression: 是否使用学到的表达 (true/false)
|
||||
# learn_expression: 是否学习新的表达 (true/false)
|
||||
# learning_strength: 学习强度(浮点数),影响学习频率
|
||||
# group: 表达共享组的名称(字符串),相同组的聊天会共享学习到的表达方式
|
||||
[[expression.rules]]
|
||||
chat_stream_id = ""
|
||||
use_expression = true
|
||||
learn_expression = true
|
||||
learning_strength = 1.0
|
||||
|
||||
expression_groups = [
|
||||
["qq:1919810:private","qq:114514:private","qq:1111111:group"], # 在这里设置互通组,相同组的chat_id会共享学习到的表达方式
|
||||
# 格式:["qq:123456:private","qq:654321:group"]
|
||||
# 注意:如果为群聊,则需要设置为group,如果设置为私聊,则需要设置为private
|
||||
]
|
||||
[[expression.rules]]
|
||||
chat_stream_id = "qq:1919810:group"
|
||||
use_expression = true
|
||||
learn_expression = true
|
||||
learning_strength = 1.5
|
||||
|
||||
[[expression.rules]]
|
||||
chat_stream_id = "qq:114514:private"
|
||||
group = "group_A"
|
||||
use_expression = true
|
||||
learn_expression = false
|
||||
learning_strength = 0.5
|
||||
|
||||
[[expression.rules]]
|
||||
chat_stream_id = "qq:1919810:private"
|
||||
group = "group_A"
|
||||
use_expression = true
|
||||
learn_expression = true
|
||||
learning_strength = 1.0
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user