typing
This commit is contained in:
@@ -17,7 +17,7 @@ from src.config.config_base import ConfigBase
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BotConfig(ConfigBase):
|
class BotConfig(ConfigBase):
|
||||||
"""QQ机器人配置类"""
|
"""QQ机器人配置类"""
|
||||||
|
|
||||||
platform: str
|
platform: str
|
||||||
"""平台"""
|
"""平台"""
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ class PersonalityConfig(ConfigBase):
|
|||||||
|
|
||||||
identity: str = ""
|
identity: str = ""
|
||||||
"""身份特征"""
|
"""身份特征"""
|
||||||
|
|
||||||
reply_style: str = ""
|
reply_style: str = ""
|
||||||
"""表达风格"""
|
"""表达风格"""
|
||||||
|
|
||||||
@@ -71,7 +71,6 @@ class ChatConfig(ConfigBase):
|
|||||||
|
|
||||||
max_context_size: int = 18
|
max_context_size: int = 18
|
||||||
"""上下文长度"""
|
"""上下文长度"""
|
||||||
|
|
||||||
|
|
||||||
replyer_random_probability: float = 0.5
|
replyer_random_probability: float = 0.5
|
||||||
"""
|
"""
|
||||||
@@ -129,7 +128,7 @@ class ChatConfig(ConfigBase):
|
|||||||
"""
|
"""
|
||||||
if not self.talk_frequency_adjust:
|
if not self.talk_frequency_adjust:
|
||||||
return self.talk_frequency
|
return self.talk_frequency
|
||||||
|
|
||||||
# 优先检查聊天流特定的配置
|
# 优先检查聊天流特定的配置
|
||||||
if chat_stream_id:
|
if chat_stream_id:
|
||||||
stream_frequency = self._get_stream_specific_frequency(chat_stream_id)
|
stream_frequency = self._get_stream_specific_frequency(chat_stream_id)
|
||||||
@@ -138,11 +137,7 @@ class ChatConfig(ConfigBase):
|
|||||||
|
|
||||||
# 检查全局时段配置(第一个元素为空字符串的配置)
|
# 检查全局时段配置(第一个元素为空字符串的配置)
|
||||||
global_frequency = self._get_global_frequency()
|
global_frequency = self._get_global_frequency()
|
||||||
if global_frequency is not None:
|
return self.talk_frequency if global_frequency is None else global_frequency
|
||||||
return global_frequency
|
|
||||||
|
|
||||||
# 如果都没有匹配,返回默认值
|
|
||||||
return self.talk_frequency
|
|
||||||
|
|
||||||
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
|
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
@@ -294,6 +289,7 @@ class NormalChatConfig(ConfigBase):
|
|||||||
willing_mode: str = "classical"
|
willing_mode: str = "classical"
|
||||||
"""意愿模式"""
|
"""意愿模式"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExpressionConfig(ConfigBase):
|
class ExpressionConfig(ConfigBase):
|
||||||
"""表达配置类"""
|
"""表达配置类"""
|
||||||
@@ -326,10 +322,10 @@ class ExpressionConfig(ConfigBase):
|
|||||||
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
解析流配置字符串并生成对应的 chat_id
|
解析流配置字符串并生成对应的 chat_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_config_str: 格式为 "platform:id:type" 的字符串
|
stream_config_str: 格式为 "platform:id:type" 的字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 生成的 chat_id,如果解析失败则返回 None
|
str: 生成的 chat_id,如果解析失败则返回 None
|
||||||
"""
|
"""
|
||||||
@@ -337,116 +333,116 @@ class ExpressionConfig(ConfigBase):
|
|||||||
parts = stream_config_str.split(":")
|
parts = stream_config_str.split(":")
|
||||||
if len(parts) != 3:
|
if len(parts) != 3:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
platform = parts[0]
|
platform = parts[0]
|
||||||
id_str = parts[1]
|
id_str = parts[1]
|
||||||
stream_type = parts[2]
|
stream_type = parts[2]
|
||||||
|
|
||||||
# 判断是否为群聊
|
# 判断是否为群聊
|
||||||
is_group = stream_type == "group"
|
is_group = stream_type == "group"
|
||||||
|
|
||||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
# 使用与 ChatStream.get_stream_id 相同的逻辑生成 chat_id
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
if is_group:
|
if is_group:
|
||||||
components = [platform, str(id_str)]
|
components = [platform, str(id_str)]
|
||||||
else:
|
else:
|
||||||
components = [platform, str(id_str), "private"]
|
components = [platform, str(id_str), "private"]
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
|
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
|
||||||
"""
|
"""
|
||||||
根据聊天流ID获取表达配置
|
根据聊天流ID获取表达配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_stream_id: 聊天流ID,格式为哈希值
|
chat_stream_id: 聊天流ID,格式为哈希值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (是否使用表达, 是否学习表达, 学习间隔)
|
tuple: (是否使用表达, 是否学习表达, 学习间隔)
|
||||||
"""
|
"""
|
||||||
if not self.expression_learning:
|
if not self.expression_learning:
|
||||||
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
|
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
|
||||||
return True, True, 300
|
return True, True, 300
|
||||||
|
|
||||||
# 优先检查聊天流特定的配置
|
# 优先检查聊天流特定的配置
|
||||||
if chat_stream_id:
|
if chat_stream_id:
|
||||||
specific_config = self._get_stream_specific_config(chat_stream_id)
|
specific_expression_config = self._get_stream_specific_config(chat_stream_id)
|
||||||
if specific_config is not None:
|
if specific_expression_config is not None:
|
||||||
return specific_config
|
return specific_expression_config
|
||||||
|
|
||||||
# 检查全局配置(第一个元素为空字符串的配置)
|
# 检查全局配置(第一个元素为空字符串的配置)
|
||||||
global_config = self._get_global_config()
|
global_expression_config = self._get_global_config()
|
||||||
if global_config is not None:
|
if global_expression_config is not None:
|
||||||
return global_config
|
return global_expression_config
|
||||||
|
|
||||||
# 如果都没有匹配,返回默认值
|
# 如果都没有匹配,返回默认值
|
||||||
return True, True, 300
|
return True, True, 300
|
||||||
|
|
||||||
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
|
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
|
||||||
"""
|
"""
|
||||||
获取特定聊天流的表达配置
|
获取特定聊天流的表达配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_stream_id: 聊天流ID(哈希值)
|
chat_stream_id: 聊天流ID(哈希值)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||||
"""
|
"""
|
||||||
for config_item in self.expression_learning:
|
for config_item in self.expression_learning:
|
||||||
if not config_item or len(config_item) < 4:
|
if not config_item or len(config_item) < 4:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
stream_config_str = config_item[0] # 例如 "qq:1026294844:group"
|
||||||
|
|
||||||
# 如果是空字符串,跳过(这是全局配置)
|
# 如果是空字符串,跳过(这是全局配置)
|
||||||
if stream_config_str == "":
|
if stream_config_str == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 解析配置字符串并生成对应的 chat_id
|
# 解析配置字符串并生成对应的 chat_id
|
||||||
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
|
config_chat_id = self._parse_stream_config_to_chat_id(stream_config_str)
|
||||||
if config_chat_id is None:
|
if config_chat_id is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 比较生成的 chat_id
|
# 比较生成的 chat_id
|
||||||
if config_chat_id != chat_stream_id:
|
if config_chat_id != chat_stream_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 解析配置
|
# 解析配置
|
||||||
try:
|
try:
|
||||||
use_expression = config_item[1].lower() == "enable"
|
use_expression: bool = config_item[1].lower() == "enable"
|
||||||
enable_learning = config_item[2].lower() == "enable"
|
enable_learning: bool = config_item[2].lower() == "enable"
|
||||||
learning_intensity = float(config_item[3])
|
learning_intensity: float = float(config_item[3])
|
||||||
return use_expression, enable_learning, learning_intensity
|
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
|
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
|
||||||
"""
|
"""
|
||||||
获取全局表达配置
|
获取全局表达配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None
|
||||||
"""
|
"""
|
||||||
for config_item in self.expression_learning:
|
for config_item in self.expression_learning:
|
||||||
if not config_item or len(config_item) < 4:
|
if not config_item or len(config_item) < 4:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查是否为全局配置(第一个元素为空字符串)
|
# 检查是否为全局配置(第一个元素为空字符串)
|
||||||
if config_item[0] == "":
|
if config_item[0] == "":
|
||||||
try:
|
try:
|
||||||
use_expression = config_item[1].lower() == "enable"
|
use_expression: bool = config_item[1].lower() == "enable"
|
||||||
enable_learning = config_item[2].lower() == "enable"
|
enable_learning: bool = config_item[2].lower() == "enable"
|
||||||
learning_intensity = float(config_item[3])
|
learning_intensity = float(config_item[3])
|
||||||
return use_expression, enable_learning, learning_intensity
|
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -456,7 +452,8 @@ class ToolConfig(ConfigBase):
|
|||||||
|
|
||||||
enable_tool: bool = False
|
enable_tool: bool = False
|
||||||
"""是否在聊天中启用工具"""
|
"""是否在聊天中启用工具"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VoiceConfig(ConfigBase):
|
class VoiceConfig(ConfigBase):
|
||||||
"""语音识别配置类"""
|
"""语音识别配置类"""
|
||||||
@@ -542,7 +539,7 @@ class MemoryConfig(ConfigBase):
|
|||||||
|
|
||||||
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
memory_ban_words: list[str] = field(default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"])
|
||||||
"""不允许记忆的词列表"""
|
"""不允许记忆的词列表"""
|
||||||
|
|
||||||
enable_instant_memory: bool = True
|
enable_instant_memory: bool = True
|
||||||
"""是否启用即时记忆"""
|
"""是否启用即时记忆"""
|
||||||
|
|
||||||
@@ -553,7 +550,7 @@ class MoodConfig(ConfigBase):
|
|||||||
|
|
||||||
enable_mood: bool = False
|
enable_mood: bool = False
|
||||||
"""是否启用情绪系统"""
|
"""是否启用情绪系统"""
|
||||||
|
|
||||||
mood_update_threshold: float = 1.0
|
mood_update_threshold: float = 1.0
|
||||||
"""情绪更新阈值,越高,更新越慢"""
|
"""情绪更新阈值,越高,更新越慢"""
|
||||||
|
|
||||||
@@ -604,6 +601,7 @@ class KeywordReactionConfig(ConfigBase):
|
|||||||
if not isinstance(rule, KeywordRuleConfig):
|
if not isinstance(rule, KeywordRuleConfig):
|
||||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CustomPromptConfig(ConfigBase):
|
class CustomPromptConfig(ConfigBase):
|
||||||
"""自定义提示词配置类"""
|
"""自定义提示词配置类"""
|
||||||
@@ -752,4 +750,3 @@ class LPMMKnowledgeConfig(ConfigBase):
|
|||||||
|
|
||||||
embedding_dimension: int = 1024
|
embedding_dimension: int = 1024
|
||||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user