Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
254
src/chat/express/express_utils.py
Normal file
254
src/chat/express/express_utils.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
表达系统工具函数
|
||||
提供消息过滤、文本相似度计算、加权随机抽样等功能
|
||||
"""
|
||||
import difflib
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("express_utils")
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
Returns:
|
||||
过滤后的纯文本内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[图片:...]格式的图片ID
|
||||
content = re.sub(r"\[图片:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
相似度值 (0-1)
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
加权随机抽样函数
|
||||
|
||||
Args:
|
||||
population: 待抽样的数据列表
|
||||
k: 抽样数量
|
||||
weight_key: 权重字段名,如果为None则等概率抽样
|
||||
|
||||
Returns:
|
||||
抽样结果列表
|
||||
"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 如果指定了权重字段
|
||||
if weight_key and all(weight_key in item for item in population):
|
||||
try:
|
||||
# 获取权重
|
||||
weights = [float(item.get(weight_key, 1.0)) for item in population]
|
||||
# 使用random.choices进行加权抽样
|
||||
return random.choices(population, weights=weights, k=k)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"加权抽样失败,使用等概率抽样: {e}")
|
||||
|
||||
# 等概率抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
# 随机选择一个元素
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
标准化文本,移除多余空白字符
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
标准化后的文本
|
||||
"""
|
||||
# 替换多个连续空白字符为单个空格
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
|
||||
"""
|
||||
简单的关键词提取(基于词频)
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
max_keywords: 最大关键词数量
|
||||
|
||||
Returns:
|
||||
关键词列表
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
try:
|
||||
import jieba.analyse
|
||||
|
||||
# 使用TF-IDF提取关键词
|
||||
keywords = jieba.analyse.extract_tags(text, topK=max_keywords)
|
||||
return keywords
|
||||
except ImportError:
|
||||
logger.warning("jieba未安装,无法提取关键词")
|
||||
# 简单分词
|
||||
words = text.split()
|
||||
return words[:max_keywords]
|
||||
|
||||
|
||||
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
|
||||
"""
|
||||
格式化表达方式对
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
index: 序号(可选)
|
||||
|
||||
Returns:
|
||||
格式化后的字符串
|
||||
"""
|
||||
if index is not None:
|
||||
return f'{index}. 当"{situation}"时,使用"{style}"'
|
||||
else:
|
||||
return f'当"{situation}"时,使用"{style}"'
|
||||
|
||||
|
||||
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
解析表达方式对文本
|
||||
|
||||
Args:
|
||||
text: 格式化的表达方式对文本
|
||||
|
||||
Returns:
|
||||
(situation, style) 或 None
|
||||
"""
|
||||
# 匹配格式:当"..."时,使用"..."
|
||||
match = re.search(r'当"(.+?)"时,使用"(.+?)"', text)
|
||||
if match:
|
||||
return match.group(1), match.group(2)
|
||||
return None
|
||||
|
||||
|
||||
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量去重表达方式
|
||||
|
||||
Args:
|
||||
expressions: 表达方式列表
|
||||
key_fields: 用于去重的字段名列表
|
||||
|
||||
Returns:
|
||||
去重后的表达方式列表
|
||||
"""
|
||||
seen = set()
|
||||
unique_expressions = []
|
||||
|
||||
for expr in expressions:
|
||||
# 构建去重key
|
||||
key_values = tuple(expr.get(field, "") for field in key_fields)
|
||||
|
||||
if key_values not in seen:
|
||||
seen.add(key_values)
|
||||
unique_expressions.append(expr)
|
||||
|
||||
return unique_expressions
|
||||
|
||||
|
||||
def calculate_time_weight(last_active_time: float, current_time: float, half_life_days: int = 30) -> float:
|
||||
"""
|
||||
根据时间计算权重(时间衰减)
|
||||
|
||||
Args:
|
||||
last_active_time: 最后活跃时间戳
|
||||
current_time: 当前时间戳
|
||||
half_life_days: 半衰期天数
|
||||
|
||||
Returns:
|
||||
权重值 (0-1)
|
||||
"""
|
||||
time_diff_days = (current_time - last_active_time) / 86400 # 转换为天数
|
||||
if time_diff_days < 0:
|
||||
return 1.0
|
||||
|
||||
# 使用指数衰减公式
|
||||
decay_rate = 0.693 / half_life_days # ln(2) / half_life
|
||||
weight = max(0.01, min(1.0, 2 ** (-decay_rate * time_diff_days)))
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def merge_expressions_from_multiple_chats(
|
||||
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
合并多个聊天室的表达方式
|
||||
|
||||
Args:
|
||||
expressions_dict: {chat_id: [expressions]}
|
||||
max_total: 最大合并数量
|
||||
|
||||
Returns:
|
||||
合并后的表达方式列表
|
||||
"""
|
||||
all_expressions = []
|
||||
|
||||
# 收集所有表达方式
|
||||
for chat_id, expressions in expressions_dict.items():
|
||||
for expr in expressions:
|
||||
# 添加source_id标识
|
||||
expr_with_source = expr.copy()
|
||||
expr_with_source["source_id"] = chat_id
|
||||
all_expressions.append(expr_with_source)
|
||||
|
||||
# 按count或last_active_time排序
|
||||
if all_expressions and "count" in all_expressions[0]:
|
||||
all_expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||||
elif all_expressions and "last_active_time" in all_expressions[0]:
|
||||
all_expressions.sort(key=lambda x: x.get("last_active_time", 0), reverse=True)
|
||||
|
||||
# 去重(基于situation和style)
|
||||
all_expressions = batch_filter_duplicates(all_expressions, ["situation", "style"])
|
||||
|
||||
# 限制数量
|
||||
return all_expressions[:max_total]
|
||||
@@ -16,6 +16,9 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 导入 StyleLearner 管理器
|
||||
from .style_learner import style_learner_manager
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
DECAY_DAYS = 30 # 30天衰减到0.01
|
||||
DECAY_MIN = 0.01 # 最小衰减值
|
||||
@@ -43,17 +46,29 @@ def init_prompt() -> None:
|
||||
3. 语言风格包含特殊内容和情感
|
||||
4. 思考有没有特殊的梗,一并总结成语言风格
|
||||
5. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
|
||||
**重要:必须严格按照以下格式输出,每行一条规律:**
|
||||
当"xxx"时,使用"xxx"
|
||||
|
||||
格式说明:
|
||||
- 必须以"当"开头
|
||||
- 场景描述用双引号包裹,不超过20个字
|
||||
- 必须包含"使用"或"可以"
|
||||
- 表达风格用双引号包裹,不超过20个字
|
||||
- 每条规律独占一行
|
||||
|
||||
例如:
|
||||
当"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
|
||||
当"表示讽刺的赞同,不想讲道理"时,使用"对对对"
|
||||
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
|
||||
当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契"时,使用"懂的都懂"
|
||||
当"涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
|
||||
请注意:不要总结你自己(SELF)的发言
|
||||
现在请你概括
|
||||
注意:
|
||||
1. 不要总结你自己(SELF)的发言
|
||||
2. 如果聊天内容中没有明显的特殊风格,请只输出1-2条最明显的特点
|
||||
3. 不要输出其他解释性文字,只输出符合格式的规律
|
||||
|
||||
现在请你概括:
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
@@ -65,16 +80,28 @@ def init_prompt() -> None:
|
||||
2.不要涉及具体的人名,只考虑语法和句法特点,
|
||||
3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。
|
||||
4. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
总结成如下格式的规律,总结的内容要简洁,不浮夸:
|
||||
当"xxx"时,可以"xxx"
|
||||
|
||||
**重要:必须严格按照以下格式输出,每行一条规律:**
|
||||
当"xxx"时,使用"xxx"
|
||||
|
||||
格式说明:
|
||||
- 必须以"当"开头
|
||||
- 场景描述用双引号包裹
|
||||
- 必须包含"使用"或"可以"
|
||||
- 句法特点用双引号包裹
|
||||
- 每条规律独占一行
|
||||
|
||||
例如:
|
||||
当"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法
|
||||
当"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法
|
||||
当"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法
|
||||
|
||||
注意不要总结你自己(SELF)的发言
|
||||
现在请你概括
|
||||
注意:
|
||||
1. 不要总结你自己(SELF)的发言
|
||||
2. 如果聊天内容中没有明显的句法特点,请只输出1-2条最明显的特点
|
||||
3. 不要输出其他解释性文字,只输出符合格式的规律
|
||||
|
||||
现在请你概括:
|
||||
"""
|
||||
Prompt(learn_grammar_prompt, "learn_grammar_prompt")
|
||||
|
||||
@@ -405,6 +432,44 @@ class ExpressionLearner:
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
await session.delete(expr)
|
||||
|
||||
# 🔥 训练 StyleLearner
|
||||
# 只对 style 类型的表达方式进行训练(grammar 不需要训练到模型)
|
||||
if type == "style":
|
||||
try:
|
||||
# 获取 StyleLearner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
|
||||
|
||||
# 为每个学习到的表达方式训练模型
|
||||
# 使用 situation 作为输入,style 作为目标
|
||||
# 这是最符合语义的方式:场景 -> 表达方式
|
||||
success_count = 0
|
||||
for expr in expr_list:
|
||||
situation = expr["situation"]
|
||||
style = expr["style"]
|
||||
|
||||
# 训练映射关系: situation -> style
|
||||
if learner.learn_mapping(situation, style):
|
||||
success_count += 1
|
||||
else:
|
||||
logger.warning(f"训练失败: {situation} -> {style}")
|
||||
|
||||
logger.info(
|
||||
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
|
||||
f"当前风格总数={len(learner.get_all_styles())}, "
|
||||
f"总样本数={learner.learning_stats['total_samples']}"
|
||||
)
|
||||
|
||||
# 保存模型
|
||||
if learner.save(style_learner_manager.model_save_path):
|
||||
logger.info(f"StyleLearner 模型保存成功: {chat_id}")
|
||||
else:
|
||||
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True)
|
||||
|
||||
return learnt_expressions
|
||||
return None
|
||||
|
||||
@@ -455,9 +520,17 @@ class ExpressionLearner:
|
||||
logger.error(f"学习{type_str}失败: {e}")
|
||||
return None
|
||||
|
||||
if not response or not response.strip():
|
||||
logger.warning(f"LLM返回空响应,无法学习{type_str}")
|
||||
return None
|
||||
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。")
|
||||
logger.info(f"LLM完整响应:\n{response}")
|
||||
|
||||
return expressions, chat_id
|
||||
|
||||
@@ -465,31 +538,100 @@ class ExpressionLearner:
|
||||
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
支持多种引号格式:"" 和 ""
|
||||
"""
|
||||
expressions: list[tuple[str, str, str]] = []
|
||||
for line in response.splitlines():
|
||||
failed_lines = []
|
||||
|
||||
for line_num, line in enumerate(response.splitlines(), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 替换中文引号为英文引号,便于统一处理
|
||||
line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"')
|
||||
|
||||
# 查找"当"和下一个引号
|
||||
idx_when = line.find('当"')
|
||||
idx_when = line_normalized.find('当"')
|
||||
if idx_when == -1:
|
||||
continue
|
||||
idx_quote1 = idx_when + 1
|
||||
idx_quote2 = line.find('"', idx_quote1 + 1)
|
||||
if idx_quote2 == -1:
|
||||
continue
|
||||
situation = line[idx_quote1 + 1 : idx_quote2]
|
||||
# 查找"使用"
|
||||
idx_use = line.find('使用"', idx_quote2)
|
||||
# 尝试不带引号的格式: 当xxx时
|
||||
idx_when = line_normalized.find('当')
|
||||
if idx_when == -1:
|
||||
failed_lines.append((line_num, line, "找不到'当'关键字"))
|
||||
continue
|
||||
|
||||
# 提取"当"和"时"之间的内容
|
||||
idx_shi = line_normalized.find('时', idx_when)
|
||||
if idx_shi == -1:
|
||||
failed_lines.append((line_num, line, "找不到'时'关键字"))
|
||||
continue
|
||||
situation = line_normalized[idx_when + 1:idx_shi].strip('"\'""')
|
||||
search_start = idx_shi
|
||||
else:
|
||||
idx_quote1 = idx_when + 1
|
||||
idx_quote2 = line_normalized.find('"', idx_quote1 + 1)
|
||||
if idx_quote2 == -1:
|
||||
failed_lines.append((line_num, line, "situation部分引号不匹配"))
|
||||
continue
|
||||
situation = line_normalized[idx_quote1 + 1 : idx_quote2]
|
||||
search_start = idx_quote2
|
||||
|
||||
# 查找"使用"或"可以"
|
||||
idx_use = line_normalized.find('使用"', search_start)
|
||||
if idx_use == -1:
|
||||
idx_use = line_normalized.find('可以"', search_start)
|
||||
if idx_use == -1:
|
||||
# 尝试不带引号的格式
|
||||
idx_use = line_normalized.find('使用', search_start)
|
||||
if idx_use == -1:
|
||||
idx_use = line_normalized.find('可以', search_start)
|
||||
if idx_use == -1:
|
||||
failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字"))
|
||||
continue
|
||||
|
||||
# 提取剩余部分作为style
|
||||
style = line_normalized[idx_use + 2:].strip('"\'"",。')
|
||||
if not style:
|
||||
failed_lines.append((line_num, line, "style部分为空"))
|
||||
continue
|
||||
else:
|
||||
idx_quote3 = idx_use + 2
|
||||
idx_quote4 = line_normalized.find('"', idx_quote3 + 1)
|
||||
if idx_quote4 == -1:
|
||||
# 如果没有结束引号,取到行尾
|
||||
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
|
||||
else:
|
||||
style = line_normalized[idx_quote3 + 1 : idx_quote4]
|
||||
else:
|
||||
idx_quote3 = idx_use + 2
|
||||
idx_quote4 = line_normalized.find('"', idx_quote3 + 1)
|
||||
if idx_quote4 == -1:
|
||||
# 如果没有结束引号,取到行尾
|
||||
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
|
||||
else:
|
||||
style = line_normalized[idx_quote3 + 1 : idx_quote4]
|
||||
|
||||
# 清理并验证
|
||||
situation = situation.strip()
|
||||
style = style.strip()
|
||||
|
||||
if not situation or not style:
|
||||
failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'"))
|
||||
continue
|
||||
idx_quote3 = idx_use + 2
|
||||
idx_quote4 = line.find('"', idx_quote3 + 1)
|
||||
if idx_quote4 == -1:
|
||||
continue
|
||||
style = line[idx_quote3 + 1 : idx_quote4]
|
||||
|
||||
expressions.append((chat_id, situation, style))
|
||||
|
||||
# 记录解析失败的行
|
||||
if failed_lines:
|
||||
logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:")
|
||||
for line_num, line, reason in failed_lines[:5]: # 只显示前5个
|
||||
logger.warning(f" 行{line_num}: {reason}")
|
||||
logger.debug(f" 原文: {line}")
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}")
|
||||
else:
|
||||
logger.debug(f"成功解析 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
|
||||
@@ -522,12 +664,12 @@ class ExpressionLearnerManager:
|
||||
os.path.join(base_dir, "learnt_grammar"),
|
||||
]
|
||||
|
||||
try:
|
||||
for directory in directories_to_create:
|
||||
for directory in directories_to_create:
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _auto_migrate_json_to_db():
|
||||
|
||||
@@ -15,6 +15,10 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 导入StyleLearner管理器和情境提取器
|
||||
from .situation_extractor import situation_extractor
|
||||
from .style_learner import style_learner_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
@@ -127,17 +131,18 @@ class ExpressionSelector:
|
||||
current_group = rule.group
|
||||
break
|
||||
|
||||
if not current_group:
|
||||
return [chat_id]
|
||||
# 🔥 始终包含当前 chat_id(确保至少能查到自己的数据)
|
||||
related_chat_ids = [chat_id]
|
||||
|
||||
# 找出同一组的所有chat_id
|
||||
related_chat_ids = []
|
||||
for rule in rules:
|
||||
if rule.group == current_group and rule.chat_stream_id:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
|
||||
related_chat_ids.append(chat_id_candidate)
|
||||
if current_group:
|
||||
# 找出同一组的所有chat_id
|
||||
for rule in rules:
|
||||
if rule.group == current_group and rule.chat_stream_id:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(rule.chat_stream_id):
|
||||
if chat_id_candidate not in related_chat_ids:
|
||||
related_chat_ids.append(chat_id_candidate)
|
||||
|
||||
return related_chat_ids if related_chat_ids else [chat_id]
|
||||
return related_chat_ids
|
||||
|
||||
async def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
@@ -236,6 +241,287 @@ class ExpressionSelector:
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_history: list | str,
|
||||
target_message: str | None = None,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
统一的表达方式选择入口,根据配置自动选择模式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
chat_history: 聊天历史(列表或字符串)
|
||||
target_message: 目标消息
|
||||
max_num: 最多返回数量
|
||||
min_num: 最少返回数量
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
# 转换chat_history为字符串
|
||||
if isinstance(chat_history, list):
|
||||
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
# 根据配置选择模式
|
||||
mode = global_config.expression.mode
|
||||
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
|
||||
|
||||
if mode == "exp_model":
|
||||
return await self._select_expressions_model_only(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
else: # classic mode
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
target_message: str | None = None,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""经典模式:随机抽样 + LLM评估"""
|
||||
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
|
||||
return await self.select_suitable_expressions_llm(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
max_num=max_num,
|
||||
min_num=min_num,
|
||||
target_message=target_message
|
||||
)
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
target_message: str | None = None,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""模型预测模式:先提取情境,再使用StyleLearner预测表达风格"""
|
||||
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return []
|
||||
|
||||
# 步骤1: 提取聊天情境
|
||||
situations = await situation_extractor.extract_situations(
|
||||
chat_history=chat_info,
|
||||
target_message=target_message,
|
||||
max_situations=3
|
||||
)
|
||||
|
||||
if not situations:
|
||||
logger.warning(f"无法提取聊天情境,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
|
||||
|
||||
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
all_predicted_styles = {}
|
||||
for i, situation in enumerate(situations, 1):
|
||||
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
|
||||
best_style, scores = learner.predict_style(situation, top_k=max_num)
|
||||
|
||||
if best_style and scores:
|
||||
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
|
||||
# 合并分数(取最高分)
|
||||
for style, score in scores.items():
|
||||
if style not in all_predicted_styles or score > all_predicted_styles[style]:
|
||||
all_predicted_styles[style] = score
|
||||
else:
|
||||
logger.debug(f" 该情境未返回预测结果")
|
||||
|
||||
if not all_predicted_styles:
|
||||
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
# 将分数字典转换为列表格式 [(style, score), ...]
|
||||
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
|
||||
|
||||
# 步骤3: 根据预测的风格从数据库获取表达方式
|
||||
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||
expressions = await self.get_model_predicted_expressions(
|
||||
chat_id=chat_id,
|
||||
predicted_styles=predicted_styles,
|
||||
max_num=max_num
|
||||
)
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
target_message=target_message,
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
async def get_model_predicted_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
predicted_styles: list[tuple[str, float]],
|
||||
max_num: int = 10
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
根据StyleLearner预测的风格获取表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
predicted_styles: 预测的风格列表,格式: [(style, score), ...]
|
||||
max_num: 最多返回数量
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
if not predicted_styles:
|
||||
return []
|
||||
|
||||
# 提取风格名称(前3个最佳匹配)
|
||||
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
|
||||
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
|
||||
|
||||
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式)
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
|
||||
db_chat_ids_result = await session.execute(
|
||||
select(Expression.chat_id)
|
||||
.where(Expression.type == "style")
|
||||
.distinct()
|
||||
)
|
||||
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
|
||||
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
|
||||
|
||||
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
|
||||
all_expressions_result = await session.execute(
|
||||
select(Expression)
|
||||
.where(Expression.chat_id.in_(related_chat_ids))
|
||||
.where(Expression.type == "style")
|
||||
)
|
||||
all_expressions = list(all_expressions_result.scalars())
|
||||
|
||||
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
|
||||
|
||||
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
|
||||
if not all_expressions:
|
||||
logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询")
|
||||
all_expressions_result = await session.execute(
|
||||
select(Expression)
|
||||
.where(Expression.type == "style")
|
||||
)
|
||||
all_expressions = list(all_expressions_result.scalars())
|
||||
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
|
||||
return []
|
||||
|
||||
# 🔥 使用模糊匹配而不是精确匹配
|
||||
# 计算每个预测style与数据库style的相似度
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
matched_expressions = []
|
||||
for expr in all_expressions:
|
||||
db_style = expr.style or ""
|
||||
max_similarity = 0.0
|
||||
best_predicted = ""
|
||||
|
||||
# 与每个预测的style计算相似度
|
||||
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
|
||||
# 计算字符串相似度
|
||||
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
|
||||
|
||||
# 也检查包含关系(如果一个是另一个的子串,给更高分)
|
||||
if len(predicted_style) >= 2 and len(db_style) >= 2:
|
||||
if predicted_style in db_style or db_style in predicted_style:
|
||||
similarity = max(similarity, 0.7)
|
||||
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style
|
||||
|
||||
# 🔥 降低阈值到30%,因为StyleLearner预测质量较差
|
||||
if max_similarity >= 0.3: # 30%相似度阈值
|
||||
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
|
||||
|
||||
if not matched_expressions:
|
||||
# 收集数据库中的style样例用于调试
|
||||
all_styles = [e.style for e in all_expressions[:10]]
|
||||
logger.warning(
|
||||
f"数据库中没有找到匹配的表达方式(相似度阈值30%):\n"
|
||||
f" 预测的style (前3个): {style_names}\n"
|
||||
f" 数据库中存在的style样例: {all_styles}\n"
|
||||
f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式"
|
||||
)
|
||||
return []
|
||||
|
||||
# 按照相似度*count排序,选择最佳匹配
|
||||
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
|
||||
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
|
||||
|
||||
# 显示最佳匹配的详细信息
|
||||
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
|
||||
logger.info(
|
||||
f"模糊匹配成功: 找到 {len(expressions_objs)} 个表达方式\n"
|
||||
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
|
||||
f" Top3匹配: {top_matches}"
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
expressions = []
|
||||
for expr in expressions_objs:
|
||||
expressions.append({
|
||||
"situation": expr.situation or "",
|
||||
"style": expr.style or "",
|
||||
"type": expr.type or "style",
|
||||
"count": float(expr.count) if expr.count else 0.0,
|
||||
"last_active_time": expr.last_active_time or 0.0
|
||||
})
|
||||
|
||||
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
9
src/chat/express/expressor_model/__init__.py
Normal file
9
src/chat/express/expressor_model/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
表达模型包
|
||||
包含基于Online Naive Bayes的机器学习模型
|
||||
"""
|
||||
from .model import ExpressorModel
|
||||
from .online_nb import OnlineNaiveBayes
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
__all__ = ["ExpressorModel", "OnlineNaiveBayes", "Tokenizer"]
|
||||
216
src/chat/express/expressor_model/model.py
Normal file
216
src/chat/express/expressor_model/model.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
基于Online Naive Bayes的表达模型
|
||||
支持候选表达的动态添加和在线学习
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .online_nb import OnlineNaiveBayes
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
logger = get_logger("expressor.model")
|
||||
|
||||
|
||||
class ExpressorModel:
|
||||
"""直接使用朴素贝叶斯精排(可在线学习)"""
|
||||
|
||||
def __init__(
|
||||
self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000, use_jieba: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
alpha: 词频平滑参数
|
||||
beta: 类别先验平滑参数
|
||||
gamma: 衰减因子
|
||||
vocab_size: 词汇表大小
|
||||
use_jieba: 是否使用jieba分词
|
||||
"""
|
||||
# 初始化分词器
|
||||
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
||||
|
||||
# 初始化在线朴素贝叶斯模型
|
||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||
|
||||
# 候选表达管理
|
||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
|
||||
logger.info(
|
||||
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
|
||||
)
|
||||
|
||||
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
|
||||
"""
|
||||
添加候选文本和对应的situation
|
||||
|
||||
Args:
|
||||
cid: 候选ID
|
||||
text: 表达文本 (style)
|
||||
situation: 情境文本
|
||||
"""
|
||||
self._candidates[cid] = text
|
||||
if situation is not None:
|
||||
self._situations[cid] = situation
|
||||
|
||||
# 确保在nb模型中初始化该候选的计数
|
||||
if cid not in self.nb.cls_counts:
|
||||
self.nb.cls_counts[cid] = 0.0
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
直接对所有候选进行朴素贝叶斯评分
|
||||
|
||||
Args:
|
||||
text: 查询文本
|
||||
k: 返回前k个候选,如果为None则返回所有
|
||||
|
||||
Returns:
|
||||
(最佳候选ID, 所有候选的分数字典)
|
||||
"""
|
||||
# 1. 分词
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks or not self._candidates:
|
||||
return None, {}
|
||||
|
||||
# 2. 计算词频
|
||||
tf = Counter(toks)
|
||||
all_cids = list(self._candidates.keys())
|
||||
|
||||
# 3. 批量评分
|
||||
scores = self.nb.score_batch(tf, all_cids)
|
||||
|
||||
if not scores:
|
||||
return None, {}
|
||||
|
||||
# 4. 根据k参数限制返回的候选数量
|
||||
if k is not None and k > 0:
|
||||
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||
limited_scores = dict(sorted_scores[:k])
|
||||
best = sorted_scores[0][0] if sorted_scores else None
|
||||
return best, limited_scores
|
||||
else:
|
||||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||||
return best, scores
|
||||
|
||||
def update_positive(self, text: str, cid: str):
|
||||
"""
|
||||
更新正反馈学习
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
cid: 目标类别ID
|
||||
"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return
|
||||
|
||||
tf = Counter(toks)
|
||||
self.nb.update_positive(tf, cid)
|
||||
|
||||
def decay(self, factor: Optional[float] = None):
|
||||
"""
|
||||
应用知识衰减
|
||||
|
||||
Args:
|
||||
factor: 衰减因子,如果为None则使用模型配置的gamma
|
||||
"""
|
||||
self.nb.decay(factor)
|
||||
|
||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
获取候选信息
|
||||
|
||||
Args:
|
||||
cid: 候选ID
|
||||
|
||||
Returns:
|
||||
(style文本, situation文本)
|
||||
"""
|
||||
style = self._candidates.get(cid)
|
||||
situation = self._situations.get(cid)
|
||||
return style, situation
|
||||
|
||||
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
|
||||
"""
|
||||
获取所有候选
|
||||
|
||||
Returns:
|
||||
{cid: (style, situation)}
|
||||
"""
|
||||
result = {}
|
||||
for cid in self._candidates.keys():
|
||||
style, situation = self.get_candidate_info(cid)
|
||||
result[cid] = (style, situation)
|
||||
return result
|
||||
|
||||
def save(self, path: str):
|
||||
"""
|
||||
保存模型到文件
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
data = {
|
||||
"candidates": self._candidates,
|
||||
"situations": self._situations,
|
||||
"nb_cls_counts": dict(self.nb.cls_counts),
|
||||
"nb_token_counts": {k: dict(v) for k, v in self.nb.token_counts.items()},
|
||||
"nb_alpha": self.nb.alpha,
|
||||
"nb_beta": self.nb.beta,
|
||||
"nb_gamma": self.nb.gamma,
|
||||
"nb_V": self.nb.V,
|
||||
}
|
||||
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
logger.info(f"模型已保存到 {path}")
|
||||
|
||||
def load(self, path: str):
|
||||
"""
|
||||
从文件加载模型
|
||||
|
||||
Args:
|
||||
path: 加载路径
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
logger.warning(f"模型文件不存在: {path}")
|
||||
return
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
self._candidates = data["candidates"]
|
||||
self._situations = data["situations"]
|
||||
|
||||
# 恢复nb模型的参数
|
||||
self.nb.alpha = data["nb_alpha"]
|
||||
self.nb.beta = data["nb_beta"]
|
||||
self.nb.gamma = data["nb_gamma"]
|
||||
self.nb.V = data["nb_V"]
|
||||
|
||||
# 恢复统计数据
|
||||
self.nb.cls_counts = defaultdict(float, data["nb_cls_counts"])
|
||||
self.nb.token_counts = defaultdict(lambda: defaultdict(float))
|
||||
for cid, tc in data["nb_token_counts"].items():
|
||||
self.nb.token_counts[cid] = defaultdict(float, tc)
|
||||
|
||||
logger.info(f"模型已从 {path} 加载")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取模型统计信息"""
|
||||
nb_stats = self.nb.get_stats()
|
||||
return {
|
||||
"n_candidates": len(self._candidates),
|
||||
"n_classes": nb_stats["n_classes"],
|
||||
"n_tokens": nb_stats["n_tokens"],
|
||||
"total_counts": nb_stats["total_counts"],
|
||||
}
|
||||
142
src/chat/express/expressor_model/online_nb.py
Normal file
142
src/chat/express/expressor_model/online_nb.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
在线朴素贝叶斯分类器
|
||||
支持增量学习和知识衰减
|
||||
"""
|
||||
import math
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expressor.online_nb")
|
||||
|
||||
|
||||
class OnlineNaiveBayes:
|
||||
"""在线朴素贝叶斯分类器"""
|
||||
|
||||
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
||||
"""
|
||||
Args:
|
||||
alpha: 词频平滑参数
|
||||
beta: 类别先验平滑参数
|
||||
gamma: 衰减因子 (0-1之间,1表示不衰减)
|
||||
vocab_size: 词汇表大小
|
||||
"""
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.V = vocab_size
|
||||
|
||||
# 类别统计
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
|
||||
lambda: defaultdict(float)
|
||||
) # cid -> term -> count
|
||||
|
||||
# 缓存
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
||||
"""
|
||||
批量计算候选的贝叶斯分数
|
||||
|
||||
Args:
|
||||
tf: 查询文本的词频Counter
|
||||
cids: 候选类别ID列表
|
||||
|
||||
Returns:
|
||||
每个候选的分数字典
|
||||
"""
|
||||
total_cls = sum(self.cls_counts.values())
|
||||
n_cls = max(1, len(self.cls_counts))
|
||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||
|
||||
out: Dict[str, float] = {}
|
||||
for cid in cids:
|
||||
# 计算先验概率 log P(c)
|
||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||
s = prior
|
||||
|
||||
# 计算似然概率 log P(w|c)
|
||||
logZ = self._logZ_c(cid)
|
||||
tc = self.token_counts[cid]
|
||||
|
||||
for term, qtf in tf.items():
|
||||
num = tc.get(term, 0.0) + self.alpha
|
||||
s += qtf * (math.log(num) - logZ)
|
||||
|
||||
out[cid] = s
|
||||
return out
|
||||
|
||||
def update_positive(self, tf: Counter, cid: str):
|
||||
"""
|
||||
正反馈更新
|
||||
|
||||
Args:
|
||||
tf: 词频Counter
|
||||
cid: 类别ID
|
||||
"""
|
||||
inc = 0.0
|
||||
tc = self.token_counts[cid]
|
||||
|
||||
# 更新词频统计
|
||||
for term, c in tf.items():
|
||||
tc[term] += float(c)
|
||||
inc += float(c)
|
||||
|
||||
# 更新类别统计
|
||||
self.cls_counts[cid] += inc
|
||||
self._invalidate(cid)
|
||||
|
||||
def decay(self, factor: Optional[float] = None):
|
||||
"""
|
||||
知识衰减(遗忘机制)
|
||||
|
||||
Args:
|
||||
factor: 衰减因子,如果为None则使用self.gamma
|
||||
"""
|
||||
g = self.gamma if factor is None else factor
|
||||
if g >= 1.0:
|
||||
return
|
||||
|
||||
# 对所有统计进行衰减
|
||||
for cid in list(self.cls_counts.keys()):
|
||||
self.cls_counts[cid] *= g
|
||||
for term in list(self.token_counts[cid].keys()):
|
||||
self.token_counts[cid][term] *= g
|
||||
self._invalidate(cid)
|
||||
|
||||
logger.debug(f"应用知识衰减,衰减因子: {g}")
|
||||
|
||||
def _logZ_c(self, cid: str) -> float:
|
||||
"""
|
||||
计算归一化因子logZ
|
||||
|
||||
Args:
|
||||
cid: 类别ID
|
||||
|
||||
Returns:
|
||||
log(Z_c)
|
||||
"""
|
||||
if cid not in self._logZ:
|
||||
Z = self.cls_counts[cid] + self.V * self.alpha
|
||||
self._logZ[cid] = math.log(max(Z, 1e-12))
|
||||
return self._logZ[cid]
|
||||
|
||||
def _invalidate(self, cid: str):
|
||||
"""
|
||||
使缓存失效
|
||||
|
||||
Args:
|
||||
cid: 类别ID
|
||||
"""
|
||||
if cid in self._logZ:
|
||||
del self._logZ[cid]
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"n_classes": len(self.cls_counts),
|
||||
"n_tokens": sum(len(tc) for tc in self.token_counts.values()),
|
||||
"total_counts": sum(self.cls_counts.values()),
|
||||
}
|
||||
62
src/chat/express/expressor_model/tokenizer.py
Normal file
62
src/chat/express/expressor_model/tokenizer.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
文本分词器,支持中文Jieba分词
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expressor.tokenizer")
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""文本分词器,支持中文Jieba分词"""
|
||||
|
||||
def __init__(self, stopwords: set = None, use_jieba: bool = True):
|
||||
"""
|
||||
Args:
|
||||
stopwords: 停用词集合
|
||||
use_jieba: 是否使用jieba分词
|
||||
"""
|
||||
self.stopwords = stopwords or set()
|
||||
self.use_jieba = use_jieba
|
||||
|
||||
if use_jieba:
|
||||
try:
|
||||
import jieba
|
||||
|
||||
jieba.initialize()
|
||||
logger.info("Jieba分词器初始化成功")
|
||||
except ImportError:
|
||||
logger.warning("Jieba未安装,将使用字符级分词")
|
||||
self.use_jieba = False
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
"""
|
||||
分词并返回token列表
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
token列表
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# 使用jieba分词
|
||||
if self.use_jieba:
|
||||
try:
|
||||
import jieba
|
||||
|
||||
tokens = list(jieba.cut(text))
|
||||
except Exception as e:
|
||||
logger.warning(f"Jieba分词失败,使用字符级分词: {e}")
|
||||
tokens = list(text)
|
||||
else:
|
||||
# 简单按字符分词
|
||||
tokens = list(text)
|
||||
|
||||
# 过滤停用词和空字符串
|
||||
tokens = [token.strip() for token in tokens if token.strip() and token not in self.stopwords]
|
||||
|
||||
return tokens
|
||||
162
src/chat/express/situation_extractor.py
Normal file
162
src/chat/express/situation_extractor.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
情境提取器
|
||||
从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("situation_extractor")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
situation_extraction_prompt = """
|
||||
以下是正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
你的名字是{bot_name}{target_message_info}
|
||||
|
||||
请分析当前聊天的情境特征,提取出最能描述当前情境的1-3个关键场景描述。
|
||||
|
||||
场景描述应该:
|
||||
1. 简洁明了(每个不超过20个字)
|
||||
2. 聚焦情绪、话题、氛围
|
||||
3. 不涉及具体人名
|
||||
4. 类似于"表示惊讶"、"讨论游戏"、"表达赞同"这样的格式
|
||||
|
||||
请以纯文本格式输出,每行一个场景描述,不要有序号、引号或其他格式:
|
||||
|
||||
例如:
|
||||
表示惊讶和意外
|
||||
讨论技术问题
|
||||
表达友好的赞同
|
||||
|
||||
现在请提取当前聊天的情境:
|
||||
"""
|
||||
Prompt(situation_extraction_prompt, "situation_extraction_prompt")
|
||||
|
||||
|
||||
class SituationExtractor:
|
||||
"""情境提取器,从聊天历史中提取当前情境"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="expression.situation_extractor"
|
||||
)
|
||||
|
||||
async def extract_situations(
|
||||
self,
|
||||
chat_history: list | str,
|
||||
target_message: Optional[str] = None,
|
||||
max_situations: int = 3
|
||||
) -> list[str]:
|
||||
"""
|
||||
从聊天历史中提取情境
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史(列表或字符串)
|
||||
target_message: 目标消息(可选)
|
||||
max_situations: 最多提取的情境数量
|
||||
|
||||
Returns:
|
||||
情境描述列表
|
||||
"""
|
||||
# 转换chat_history为字符串
|
||||
if isinstance(chat_history, list):
|
||||
chat_info = "\n".join([
|
||||
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
|
||||
for msg in chat_history
|
||||
])
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
# 构建目标消息信息
|
||||
if target_message:
|
||||
target_message_info = f",现在你想要回复消息:{target_message}"
|
||||
else:
|
||||
target_message_info = ""
|
||||
|
||||
# 构建 prompt
|
||||
try:
|
||||
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_history=chat_info,
|
||||
target_message_info=target_message_info
|
||||
)
|
||||
|
||||
# 调用 LLM
|
||||
response, _ = await self.llm_model.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.3
|
||||
)
|
||||
|
||||
if not response or not response.strip():
|
||||
logger.warning("LLM返回空响应,无法提取情境")
|
||||
return []
|
||||
|
||||
# 解析响应
|
||||
situations = self._parse_situations(response, max_situations)
|
||||
|
||||
if situations:
|
||||
logger.debug(f"提取到 {len(situations)} 个情境: {situations}")
|
||||
else:
|
||||
logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}")
|
||||
|
||||
return situations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取情境失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _parse_situations(response: str, max_situations: int) -> list[str]:
|
||||
"""
|
||||
解析 LLM 返回的情境描述
|
||||
|
||||
Args:
|
||||
response: LLM 响应
|
||||
max_situations: 最多返回的情境数量
|
||||
|
||||
Returns:
|
||||
情境描述列表
|
||||
"""
|
||||
situations = []
|
||||
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 移除可能的序号、引号等
|
||||
line = line.lstrip('0123456789.、-*>))】] \t"\'""''')
|
||||
line = line.rstrip('"\'""''')
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 过滤掉明显不是情境描述的内容
|
||||
if len(line) > 30: # 太长
|
||||
continue
|
||||
if len(line) < 2: # 太短
|
||||
continue
|
||||
if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']):
|
||||
continue
|
||||
|
||||
situations.append(line)
|
||||
|
||||
if len(situations) >= max_situations:
|
||||
break
|
||||
|
||||
return situations
|
||||
|
||||
|
||||
# 初始化 prompt
|
||||
init_prompt()
|
||||
|
||||
# 全局单例
|
||||
situation_extractor = SituationExtractor()
|
||||
425
src/chat/express/style_learner.py
Normal file
425
src/chat/express/style_learner.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
风格学习引擎
|
||||
基于ExpressorModel实现的表达风格学习和预测系统
|
||||
支持多聊天室独立建模和在线学习
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .expressor_model import ExpressorModel
|
||||
|
||||
logger = get_logger("expressor.style_learner")
|
||||
|
||||
|
||||
class StyleLearner:
|
||||
"""单个聊天室的表达风格学习器"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||||
"""
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
model_config: 模型配置
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.model_config = model_config or {
|
||||
"alpha": 0.5,
|
||||
"beta": 0.5,
|
||||
"gamma": 0.99, # 衰减因子,支持遗忘
|
||||
"vocab_size": 200000,
|
||||
"use_jieba": True,
|
||||
}
|
||||
|
||||
# 初始化表达模型
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||||
self.next_style_id = 0
|
||||
|
||||
# 学习统计
|
||||
self.learning_stats = {
|
||||
"total_samples": 0,
|
||||
"style_counts": {},
|
||||
"last_update": time.time(),
|
||||
}
|
||||
|
||||
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
|
||||
|
||||
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
|
||||
"""
|
||||
动态添加一个新的风格
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
situation: 情境文本
|
||||
|
||||
Returns:
|
||||
是否添加成功
|
||||
"""
|
||||
try:
|
||||
# 检查是否已存在
|
||||
if style in self.style_to_id:
|
||||
return True
|
||||
|
||||
# 检查是否超过最大限制
|
||||
if len(self.style_to_id) >= self.max_styles:
|
||||
logger.warning(f"已达到最大风格数量限制 ({self.max_styles})")
|
||||
return False
|
||||
|
||||
# 生成新的style_id
|
||||
style_id = f"style_{self.next_style_id}"
|
||||
self.next_style_id += 1
|
||||
|
||||
# 添加到映射
|
||||
self.style_to_id[style] = style_id
|
||||
self.id_to_style[style_id] = style
|
||||
if situation:
|
||||
self.id_to_situation[style_id] = situation
|
||||
|
||||
# 添加到expressor模型
|
||||
self.expressor.add_candidate(style_id, style, situation)
|
||||
|
||||
# 初始化统计
|
||||
self.learning_stats["style_counts"][style_id] = 0
|
||||
|
||||
logger.debug(f"添加风格成功: {style_id} -> {style}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加风格失败: {e}")
|
||||
return False
|
||||
|
||||
def learn_mapping(self, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个up_content到style的映射
|
||||
|
||||
Args:
|
||||
up_content: 前置内容
|
||||
style: 目标风格
|
||||
|
||||
Returns:
|
||||
是否学习成功
|
||||
"""
|
||||
try:
|
||||
# 如果style不存在,先添加它
|
||||
if style not in self.style_to_id:
|
||||
if not self.add_style(style):
|
||||
return False
|
||||
|
||||
# 获取style_id
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 使用正反馈学习
|
||||
self.expressor.update_positive(up_content, style_id)
|
||||
|
||||
# 更新统计
|
||||
self.learning_stats["total_samples"] += 1
|
||||
self.learning_stats["style_counts"][style_id] += 1
|
||||
self.learning_stats["last_update"] = time.time()
|
||||
|
||||
logger.debug(f"学习映射成功: {up_content[:20]}... -> {style}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"学习映射失败: {e}")
|
||||
return False
|
||||
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
根据up_content预测最合适的style
|
||||
|
||||
Args:
|
||||
up_content: 前置内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
(最佳style文本, 所有候选的分数字典)
|
||||
"""
|
||||
try:
|
||||
# 先检查是否有训练数据
|
||||
if not self.style_to_id:
|
||||
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
|
||||
return None, {}
|
||||
|
||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||
|
||||
if best_style_id is None:
|
||||
logger.debug(f"ExpressorModel未返回预测结果: chat_id={self.chat_id}, up_content={up_content[:50]}...")
|
||||
return None, {}
|
||||
|
||||
# 将style_id转换为style文本
|
||||
best_style = self.id_to_style.get(best_style_id)
|
||||
|
||||
if best_style is None:
|
||||
logger.warning(
|
||||
f"style_id无法转换为style文本: style_id={best_style_id}, "
|
||||
f"已知的id_to_style数量={len(self.id_to_style)}"
|
||||
)
|
||||
return None, {}
|
||||
|
||||
# 转换所有分数
|
||||
style_scores = {}
|
||||
for sid, score in scores.items():
|
||||
style_text = self.id_to_style.get(sid)
|
||||
if style_text:
|
||||
style_scores[style_text] = score
|
||||
else:
|
||||
logger.warning(f"跳过无法转换的style_id: {sid}")
|
||||
|
||||
logger.debug(
|
||||
f"预测成功: up_content={up_content[:30]}..., "
|
||||
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
|
||||
)
|
||||
|
||||
return best_style, style_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预测style失败: {e}", exc_info=True)
|
||||
return None, {}
|
||||
|
||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
获取style的完整信息
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
(style_id, situation)
|
||||
"""
|
||||
style_id = self.style_to_id.get(style)
|
||||
if not style_id:
|
||||
return None, None
|
||||
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
return style_id, situation
|
||||
|
||||
def get_all_styles(self) -> List[str]:
|
||||
"""
|
||||
获取所有风格列表
|
||||
|
||||
Returns:
|
||||
风格文本列表
|
||||
"""
|
||||
return list(self.style_to_id.keys())
|
||||
|
||||
def apply_decay(self, factor: Optional[float] = None):
|
||||
"""
|
||||
应用知识衰减
|
||||
|
||||
Args:
|
||||
factor: 衰减因子
|
||||
"""
|
||||
self.expressor.decay(factor)
|
||||
logger.debug(f"应用知识衰减: chat_id={self.chat_id}")
|
||||
|
||||
def save(self, base_path: str) -> bool:
|
||||
"""
|
||||
保存学习器到文件
|
||||
|
||||
Args:
|
||||
base_path: 基础保存路径
|
||||
|
||||
Returns:
|
||||
是否保存成功
|
||||
"""
|
||||
try:
|
||||
# 创建保存目录
|
||||
save_dir = os.path.join(base_path, self.chat_id)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 保存expressor模型
|
||||
model_path = os.path.join(save_dir, "expressor_model.pkl")
|
||||
self.expressor.save(model_path)
|
||||
|
||||
# 保存映射关系和统计信息
|
||||
import pickle
|
||||
|
||||
meta_path = os.path.join(save_dir, "meta.pkl")
|
||||
meta_data = {
|
||||
"style_to_id": self.style_to_id,
|
||||
"id_to_style": self.id_to_style,
|
||||
"id_to_situation": self.id_to_situation,
|
||||
"next_style_id": self.next_style_id,
|
||||
"learning_stats": self.learning_stats,
|
||||
}
|
||||
|
||||
with open(meta_path, "wb") as f:
|
||||
pickle.dump(meta_data, f)
|
||||
|
||||
logger.info(f"StyleLearner保存成功: {save_dir}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存StyleLearner失败: {e}")
|
||||
return False
|
||||
|
||||
def load(self, base_path: str) -> bool:
|
||||
"""
|
||||
从文件加载学习器
|
||||
|
||||
Args:
|
||||
base_path: 基础加载路径
|
||||
|
||||
Returns:
|
||||
是否加载成功
|
||||
"""
|
||||
try:
|
||||
save_dir = os.path.join(base_path, self.chat_id)
|
||||
|
||||
# 检查目录是否存在
|
||||
if not os.path.exists(save_dir):
|
||||
logger.debug(f"StyleLearner保存目录不存在: {save_dir}")
|
||||
return False
|
||||
|
||||
# 加载expressor模型
|
||||
model_path = os.path.join(save_dir, "expressor_model.pkl")
|
||||
if os.path.exists(model_path):
|
||||
self.expressor.load(model_path)
|
||||
|
||||
# 加载映射关系和统计信息
|
||||
import pickle
|
||||
|
||||
meta_path = os.path.join(save_dir, "meta.pkl")
|
||||
if os.path.exists(meta_path):
|
||||
with open(meta_path, "rb") as f:
|
||||
meta_data = pickle.load(f)
|
||||
|
||||
self.style_to_id = meta_data["style_to_id"]
|
||||
self.id_to_style = meta_data["id_to_style"]
|
||||
self.id_to_situation = meta_data["id_to_situation"]
|
||||
self.next_style_id = meta_data["next_style_id"]
|
||||
self.learning_stats = meta_data["learning_stats"]
|
||||
|
||||
logger.info(f"StyleLearner加载成功: {save_dir}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载StyleLearner失败: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取统计信息"""
|
||||
model_stats = self.expressor.get_stats()
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"n_styles": len(self.style_to_id),
|
||||
"total_samples": self.learning_stats["total_samples"],
|
||||
"last_update": self.learning_stats["last_update"],
|
||||
"model_stats": model_stats,
|
||||
}
|
||||
|
||||
|
||||
class StyleLearnerManager:
|
||||
"""多聊天室表达风格学习管理器"""
|
||||
|
||||
def __init__(self, model_save_path: str = "data/expression/style_models"):
|
||||
"""
|
||||
Args:
|
||||
model_save_path: 模型保存路径
|
||||
"""
|
||||
self.learners: Dict[str, StyleLearner] = {}
|
||||
self.model_save_path = model_save_path
|
||||
|
||||
# 确保保存目录存在
|
||||
os.makedirs(model_save_path, exist_ok=True)
|
||||
|
||||
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
|
||||
|
||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||||
"""
|
||||
获取或创建指定chat_id的学习器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
model_config: 模型配置
|
||||
|
||||
Returns:
|
||||
StyleLearner实例
|
||||
"""
|
||||
if chat_id not in self.learners:
|
||||
# 创建新的学习器
|
||||
learner = StyleLearner(chat_id, model_config)
|
||||
|
||||
# 尝试加载已保存的模型
|
||||
learner.load(self.model_save_path)
|
||||
|
||||
self.learners[chat_id] = learner
|
||||
|
||||
return self.learners[chat_id]
|
||||
|
||||
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个映射关系
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 前置内容
|
||||
style: 目标风格
|
||||
|
||||
Returns:
|
||||
是否学习成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.learn_mapping(up_content, style)
|
||||
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
预测最合适的风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 前置内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
(最佳style, 分数字典)
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.predict_style(up_content, top_k)
|
||||
|
||||
def save_all(self) -> bool:
|
||||
"""
|
||||
保存所有学习器
|
||||
|
||||
Returns:
|
||||
是否全部保存成功
|
||||
"""
|
||||
success = True
|
||||
for chat_id, learner in self.learners.items():
|
||||
if not learner.save(self.model_save_path):
|
||||
success = False
|
||||
|
||||
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
|
||||
return success
|
||||
|
||||
def apply_decay_all(self, factor: Optional[float] = None):
|
||||
"""
|
||||
对所有学习器应用知识衰减
|
||||
|
||||
Args:
|
||||
factor: 衰减因子
|
||||
"""
|
||||
for learner in self.learners.values():
|
||||
learner.apply_decay(factor)
|
||||
|
||||
logger.info(f"对所有StyleLearner应用知识衰减")
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict]:
|
||||
"""
|
||||
获取所有学习器的统计信息
|
||||
|
||||
Returns:
|
||||
{chat_id: stats}
|
||||
"""
|
||||
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
|
||||
|
||||
|
||||
# 全局单例
|
||||
style_learner_manager = StyleLearnerManager()
|
||||
@@ -46,6 +46,9 @@ class StreamLoopManager:
|
||||
# 状态控制
|
||||
self.is_running = False
|
||||
|
||||
# 每个流的上一次间隔值(用于日志去重)
|
||||
self._last_intervals: dict[str, float] = {}
|
||||
|
||||
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -285,7 +288,11 @@ class StreamLoopManager:
|
||||
interval = await self._calculate_interval(stream_id, has_messages)
|
||||
|
||||
# 6. sleep等待下次检查
|
||||
logger.info(f"流 {stream_id} 等待 {interval:.2f}s")
|
||||
# 只在间隔发生变化时输出日志,避免刷屏
|
||||
last_interval = self._last_intervals.get(stream_id)
|
||||
if last_interval is None or abs(interval - last_interval) > 0.01:
|
||||
logger.info(f"流 {stream_id} 等待周期变化: {interval:.2f}s")
|
||||
self._last_intervals[stream_id] = interval
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -316,6 +323,9 @@ class StreamLoopManager:
|
||||
except Exception as e:
|
||||
logger.debug(f"释放自适应流处理槽位失败: {e}")
|
||||
|
||||
# 清理间隔记录
|
||||
self._last_intervals.pop(stream_id, None)
|
||||
|
||||
logger.info(f"流循环结束: {stream_id}")
|
||||
|
||||
async def _get_stream_context(self, stream_id: str) -> Any | None:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
@@ -142,7 +143,7 @@ class ChatterActionManager:
|
||||
self,
|
||||
action_name: str,
|
||||
chat_id: str,
|
||||
target_message: dict | None = None,
|
||||
target_message: dict | DatabaseMessages | None = None,
|
||||
reasoning: str = "",
|
||||
action_data: dict | None = None,
|
||||
thinking_id: str | None = None,
|
||||
@@ -262,9 +263,15 @@ class ChatterActionManager:
|
||||
from_plugin=False,
|
||||
)
|
||||
if not success or not response_set:
|
||||
logger.info(
|
||||
f"对 {target_message.get('processed_plain_text') if target_message else '未知消息'} 的回复生成失败"
|
||||
)
|
||||
# 安全地获取 processed_plain_text
|
||||
if isinstance(target_message, DatabaseMessages):
|
||||
msg_text = target_message.processed_plain_text or "未知消息"
|
||||
elif target_message:
|
||||
msg_text = target_message.get("processed_plain_text", "未知消息")
|
||||
else:
|
||||
msg_text = "未知消息"
|
||||
|
||||
logger.info(f"对 {msg_text} 的回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消")
|
||||
@@ -322,8 +329,11 @@ class ChatterActionManager:
|
||||
|
||||
# 获取目标消息ID
|
||||
target_message_id = None
|
||||
if target_message and isinstance(target_message, dict):
|
||||
target_message_id = target_message.get("message_id")
|
||||
if target_message:
|
||||
if isinstance(target_message, DatabaseMessages):
|
||||
target_message_id = target_message.message_id
|
||||
elif isinstance(target_message, dict):
|
||||
target_message_id = target_message.get("message_id")
|
||||
elif action_data and isinstance(action_data, dict):
|
||||
target_message_id = action_data.get("target_message_id")
|
||||
|
||||
@@ -488,14 +498,19 @@ class ChatterActionManager:
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||
platform = action_message.get("chat_info_platform")
|
||||
if platform is None:
|
||||
platform = getattr(chat_stream, "platform", "unknown")
|
||||
if isinstance(action_message, DatabaseMessages):
|
||||
platform = action_message.chat_info.platform
|
||||
user_id = action_message.user_info.user_id
|
||||
else:
|
||||
platform = action_message.get("chat_info_platform")
|
||||
if platform is None:
|
||||
platform = getattr(chat_stream, "platform", "unknown")
|
||||
user_id = action_message.get("user_id", "")
|
||||
|
||||
# 获取用户信息并生成回复提示
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform,
|
||||
action_message.get("user_id", ""),
|
||||
user_id,
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
@@ -565,7 +580,14 @@ class ChatterActionManager:
|
||||
|
||||
# 根据新消息数量决定是否需要引用回复
|
||||
reply_text = ""
|
||||
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True
|
||||
# 检查是否为主动思考消息
|
||||
if isinstance(message_data, DatabaseMessages):
|
||||
# DatabaseMessages 对象没有 message_type 字段,默认为 False
|
||||
is_proactive_thinking = False
|
||||
elif message_data:
|
||||
is_proactive_thinking = message_data.get("message_type") == "proactive_thinking"
|
||||
else:
|
||||
is_proactive_thinking = True
|
||||
|
||||
logger.debug(f"[send_response] message_data: {message_data}")
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
@@ -474,10 +475,13 @@ class DefaultReplyer:
|
||||
style_habits = []
|
||||
grammar_habits = []
|
||||
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, min_num=2, target_message=target
|
||||
# 使用统一的表达方式选择入口(支持classic和exp_model模式)
|
||||
selected_expressions = await expression_selector.select_suitable_expressions(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_history=chat_history,
|
||||
target_message=target,
|
||||
max_num=8,
|
||||
min_num=2
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
@@ -1208,7 +1212,7 @@ class DefaultReplyer:
|
||||
extra_info: str = "",
|
||||
available_actions: dict[str, ActionInfo] | None = None,
|
||||
enable_tool: bool = True,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
reply_message: dict[str, Any] | DatabaseMessages | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
构建回复器上下文
|
||||
@@ -1250,10 +1254,24 @@ class DefaultReplyer:
|
||||
if reply_message is None:
|
||||
logger.warning("reply_message 为 None,无法构建prompt")
|
||||
return ""
|
||||
platform = reply_message.get("chat_info_platform")
|
||||
|
||||
# 统一处理 DatabaseMessages 对象和字典
|
||||
if isinstance(reply_message, DatabaseMessages):
|
||||
platform = reply_message.chat_info.platform
|
||||
user_id = reply_message.user_info.user_id
|
||||
user_nickname = reply_message.user_info.user_nickname
|
||||
user_cardname = reply_message.user_info.user_cardname
|
||||
processed_plain_text = reply_message.processed_plain_text
|
||||
else:
|
||||
platform = reply_message.get("chat_info_platform")
|
||||
user_id = reply_message.get("user_id")
|
||||
user_nickname = reply_message.get("user_nickname")
|
||||
user_cardname = reply_message.get("user_cardname")
|
||||
processed_plain_text = reply_message.get("processed_plain_text")
|
||||
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform, # type: ignore
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
user_id, # type: ignore
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
@@ -1262,22 +1280,22 @@ class DefaultReplyer:
|
||||
# 尝试从reply_message获取用户名
|
||||
await person_info_manager.first_knowing_some_one(
|
||||
platform, # type: ignore
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
reply_message.get("user_nickname") or "",
|
||||
reply_message.get("user_cardname") or "",
|
||||
user_id, # type: ignore
|
||||
user_nickname or "",
|
||||
user_cardname or "",
|
||||
)
|
||||
|
||||
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
current_user_id = await person_info_manager.get_value(person_id, "user_id")
|
||||
current_platform = reply_message.get("chat_info_platform")
|
||||
current_platform = platform
|
||||
|
||||
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
|
||||
sender = f"{person_name}(你)"
|
||||
else:
|
||||
# 如果不是bot自己,直接使用person_name
|
||||
sender = person_name
|
||||
target = reply_message.get("processed_plain_text")
|
||||
target = processed_plain_text
|
||||
|
||||
# 最终的空值检查,确保sender和target不为None
|
||||
if sender is None:
|
||||
@@ -1611,15 +1629,22 @@ class DefaultReplyer:
|
||||
raw_reply: str,
|
||||
reason: str,
|
||||
reply_to: str,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
reply_message: dict[str, Any] | DatabaseMessages | None = None,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
if reply_message:
|
||||
sender = reply_message.get("sender")
|
||||
target = reply_message.get("target")
|
||||
if isinstance(reply_message, DatabaseMessages):
|
||||
# 从 DatabaseMessages 对象获取 sender 和 target
|
||||
# 注意: DatabaseMessages 没有直接的 sender/target 字段
|
||||
# 需要根据实际情况构造
|
||||
sender = reply_message.user_info.user_nickname or reply_message.user_info.user_id
|
||||
target = reply_message.processed_plain_text or ""
|
||||
else:
|
||||
sender = reply_message.get("sender")
|
||||
target = reply_message.get("target")
|
||||
else:
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
|
||||
@@ -1891,42 +1916,64 @@ class DefaultReplyer:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
# 使用统一评分API获取关系信息
|
||||
# 使用 RelationshipFetcher 获取完整关系信息(包含新字段)
|
||||
try:
|
||||
from src.plugin_system.apis.scoring_api import scoring_api
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
# 获取用户信息以获取真实的user_id
|
||||
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
|
||||
user_id = user_info.get("user_id", "unknown")
|
||||
# 获取 chat_id
|
||||
chat_id = self.chat_stream.stream_id
|
||||
|
||||
# 从统一API获取关系数据
|
||||
relationship_data = await scoring_api.get_user_relationship_data(user_id)
|
||||
if relationship_data:
|
||||
relationship_text = relationship_data.get("relationship_text", "")
|
||||
relationship_score = relationship_data.get("relationship_score", 0.3)
|
||||
# 获取 RelationshipFetcher 实例
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_id)
|
||||
|
||||
# 构建丰富的关系信息描述
|
||||
if relationship_text:
|
||||
# 转换关系分数为描述性文本
|
||||
if relationship_score >= 0.8:
|
||||
relationship_level = "非常亲密的朋友"
|
||||
elif relationship_score >= 0.6:
|
||||
relationship_level = "好朋友"
|
||||
elif relationship_score >= 0.4:
|
||||
relationship_level = "普通朋友"
|
||||
elif relationship_score >= 0.2:
|
||||
relationship_level = "认识的人"
|
||||
else:
|
||||
relationship_level = "陌生人"
|
||||
# 构建用户关系信息(包含别名、偏好关键词等新字段)
|
||||
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
|
||||
else:
|
||||
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
|
||||
# 构建聊天流印象信息
|
||||
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
|
||||
|
||||
# 组合两部分信息
|
||||
if user_relation_info and stream_impression:
|
||||
return "\n\n".join([user_relation_info, stream_impression])
|
||||
elif user_relation_info:
|
||||
return user_relation_info
|
||||
elif stream_impression:
|
||||
return stream_impression
|
||||
else:
|
||||
return f"你完全不认识{sender},这是第一次互动。"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取关系信息失败: {e}")
|
||||
# 降级到基本信息
|
||||
try:
|
||||
from src.plugin_system.apis.scoring_api import scoring_api
|
||||
|
||||
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
|
||||
user_id = user_info.get("user_id", "unknown")
|
||||
|
||||
relationship_data = await scoring_api.get_user_relationship_data(user_id)
|
||||
if relationship_data:
|
||||
relationship_text = relationship_data.get("relationship_text", "")
|
||||
relationship_score = relationship_data.get("relationship_score", 0.3)
|
||||
|
||||
if relationship_text:
|
||||
if relationship_score >= 0.8:
|
||||
relationship_level = "非常亲密的朋友"
|
||||
elif relationship_score >= 0.6:
|
||||
relationship_level = "好朋友"
|
||||
elif relationship_score >= 0.4:
|
||||
relationship_level = "普通朋友"
|
||||
elif relationship_score >= 0.2:
|
||||
relationship_level = "认识的人"
|
||||
else:
|
||||
relationship_level = "陌生人"
|
||||
|
||||
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
|
||||
else:
|
||||
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return f"你与{sender}是普通朋友关系。"
|
||||
|
||||
async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None):
|
||||
|
||||
@@ -606,11 +606,11 @@ class Prompt:
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
# 使用LLM选择与当前情景匹配的表达习惯
|
||||
# 使用统一的表达方式选择入口(支持classic和exp_model模式)
|
||||
expression_selector = ExpressionSelector(self.parameters.chat_id)
|
||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
||||
selected_expressions = await expression_selector.select_suitable_expressions(
|
||||
chat_id=self.parameters.chat_id,
|
||||
chat_info=chat_history,
|
||||
chat_history=chat_history,
|
||||
target_message=self.parameters.target,
|
||||
)
|
||||
|
||||
@@ -1109,8 +1109,18 @@ class Prompt:
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
# 使用关系提取器构建关系信息
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
# 使用关系提取器构建用户关系信息和聊天流印象
|
||||
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
|
||||
|
||||
# 组合两部分信息
|
||||
info_parts = []
|
||||
if user_relation_info:
|
||||
info_parts.append(user_relation_info)
|
||||
if stream_impression:
|
||||
info_parts.append(stream_impression)
|
||||
|
||||
return "\n\n".join(info_parts) if info_parts else ""
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
|
||||
"""为超时或失败的异步构建任务提供一个安全的默认返回值.
|
||||
|
||||
Reference in New Issue
Block a user