refactor(core): 优化类型提示与代码风格
本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:
1. **类型提示现代化**:
- 将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
- 这提高了代码的可读性,并与较新 Python 版本的风格保持一致。
2. **代码风格统一**:
- 移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
- 统一了部分日志输出的格式,增强了日志的可读性。
3. **导入语句优化**:
- 调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。
这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
@@ -201,15 +201,16 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
|
||||
# 从数据库获取聊天流兴趣分数
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams
|
||||
from sqlalchemy import select
|
||||
|
||||
async with get_db_session() as session:
|
||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||
result = await session.execute(stmt)
|
||||
stream = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if stream and stream.stream_interest_score is not None:
|
||||
interest_score = float(stream.stream_interest_score)
|
||||
logger.debug(f"使用聊天流兴趣度计算关系能量: {interest_score:.3f}")
|
||||
|
||||
@@ -5,14 +5,14 @@
|
||||
import difflib
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("express_utils")
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
def filter_message_content(content: str | None) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
@@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
|
||||
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
|
||||
"""
|
||||
加权随机抽样函数
|
||||
|
||||
@@ -108,7 +108,7 @@ def normalize_text(text: str) -> str:
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
|
||||
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
|
||||
"""
|
||||
简单的关键词提取(基于词频)
|
||||
|
||||
@@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
|
||||
return words[:max_keywords]
|
||||
|
||||
|
||||
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
|
||||
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
|
||||
"""
|
||||
格式化表达方式对
|
||||
|
||||
@@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No
|
||||
return f'当"{situation}"时,使用"{style}"'
|
||||
|
||||
|
||||
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
|
||||
def parse_expression_pair(text: str) -> tuple[str, str] | None:
|
||||
"""
|
||||
解析表达方式对文本
|
||||
|
||||
@@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
|
||||
return None
|
||||
|
||||
|
||||
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
|
||||
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
批量去重表达方式
|
||||
|
||||
@@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif
|
||||
|
||||
|
||||
def merge_expressions_from_multiple_chats(
|
||||
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
合并多个聊天室的表达方式
|
||||
|
||||
|
||||
@@ -438,9 +438,9 @@ class ExpressionLearner:
|
||||
try:
|
||||
# 获取 StyleLearner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
|
||||
logger.info(f"开始训练 StyleLearner: chat_id={chat_id}, 样本数={len(expr_list)}")
|
||||
|
||||
|
||||
# 为每个学习到的表达方式训练模型
|
||||
# 使用 situation 作为输入,style 作为目标
|
||||
# 这是最符合语义的方式:场景 -> 表达方式
|
||||
@@ -448,25 +448,25 @@ class ExpressionLearner:
|
||||
for expr in expr_list:
|
||||
situation = expr["situation"]
|
||||
style = expr["style"]
|
||||
|
||||
|
||||
# 训练映射关系: situation -> style
|
||||
if learner.learn_mapping(situation, style):
|
||||
success_count += 1
|
||||
else:
|
||||
logger.warning(f"训练失败: {situation} -> {style}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"StyleLearner 训练完成: {success_count}/{len(expr_list)} 成功, "
|
||||
f"当前风格总数={len(learner.get_all_styles())}, "
|
||||
f"总样本数={learner.learning_stats['total_samples']}"
|
||||
)
|
||||
|
||||
|
||||
# 保存模型
|
||||
if learner.save(style_learner_manager.model_save_path):
|
||||
logger.info(f"StyleLearner 模型保存成功: {chat_id}")
|
||||
else:
|
||||
logger.error(f"StyleLearner 模型保存失败: {chat_id}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"训练 StyleLearner 失败: {e}", exc_info=True)
|
||||
|
||||
@@ -527,7 +527,7 @@ class ExpressionLearner:
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"从LLM响应中未能解析出任何{type_str}。请检查LLM输出格式是否正确。")
|
||||
logger.info(f"LLM完整响应:\n{response}")
|
||||
@@ -542,26 +542,26 @@ class ExpressionLearner:
|
||||
"""
|
||||
expressions: list[tuple[str, str, str]] = []
|
||||
failed_lines = []
|
||||
|
||||
|
||||
for line_num, line in enumerate(response.splitlines(), 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
# 替换中文引号为英文引号,便于统一处理
|
||||
line_normalized = line.replace('"', '"').replace('"', '"').replace("'", '"').replace("'", '"')
|
||||
|
||||
|
||||
# 查找"当"和下一个引号
|
||||
idx_when = line_normalized.find('当"')
|
||||
if idx_when == -1:
|
||||
# 尝试不带引号的格式: 当xxx时
|
||||
idx_when = line_normalized.find('当')
|
||||
idx_when = line_normalized.find("当")
|
||||
if idx_when == -1:
|
||||
failed_lines.append((line_num, line, "找不到'当'关键字"))
|
||||
continue
|
||||
|
||||
|
||||
# 提取"当"和"时"之间的内容
|
||||
idx_shi = line_normalized.find('时', idx_when)
|
||||
idx_shi = line_normalized.find("时", idx_when)
|
||||
if idx_shi == -1:
|
||||
failed_lines.append((line_num, line, "找不到'时'关键字"))
|
||||
continue
|
||||
@@ -575,20 +575,20 @@ class ExpressionLearner:
|
||||
continue
|
||||
situation = line_normalized[idx_quote1 + 1 : idx_quote2]
|
||||
search_start = idx_quote2
|
||||
|
||||
|
||||
# 查找"使用"或"可以"
|
||||
idx_use = line_normalized.find('使用"', search_start)
|
||||
if idx_use == -1:
|
||||
idx_use = line_normalized.find('可以"', search_start)
|
||||
if idx_use == -1:
|
||||
# 尝试不带引号的格式
|
||||
idx_use = line_normalized.find('使用', search_start)
|
||||
idx_use = line_normalized.find("使用", search_start)
|
||||
if idx_use == -1:
|
||||
idx_use = line_normalized.find('可以', search_start)
|
||||
idx_use = line_normalized.find("可以", search_start)
|
||||
if idx_use == -1:
|
||||
failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字"))
|
||||
continue
|
||||
|
||||
|
||||
# 提取剩余部分作为style
|
||||
style = line_normalized[idx_use + 2:].strip('"\'"",。')
|
||||
if not style:
|
||||
@@ -610,24 +610,24 @@ class ExpressionLearner:
|
||||
style = line_normalized[idx_quote3 + 1:].strip('"\'""')
|
||||
else:
|
||||
style = line_normalized[idx_quote3 + 1 : idx_quote4]
|
||||
|
||||
|
||||
# 清理并验证
|
||||
situation = situation.strip()
|
||||
style = style.strip()
|
||||
|
||||
|
||||
if not situation or not style:
|
||||
failed_lines.append((line_num, line, f"situation或style为空: situation='{situation}', style='{style}'"))
|
||||
continue
|
||||
|
||||
|
||||
expressions.append((chat_id, situation, style))
|
||||
|
||||
|
||||
# 记录解析失败的行
|
||||
if failed_lines:
|
||||
logger.warning(f"解析表达方式时有 {len(failed_lines)} 行失败:")
|
||||
for line_num, line, reason in failed_lines[:5]: # 只显示前5个
|
||||
logger.warning(f" 行{line_num}: {reason}")
|
||||
logger.debug(f" 原文: {line}")
|
||||
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"LLM返回了内容但无法解析任何表达方式。响应预览:\n{response[:500]}")
|
||||
else:
|
||||
|
||||
@@ -267,11 +267,11 @@ class ExpressionSelector:
|
||||
chat_info = "\n".join([f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}" for msg in chat_history])
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
|
||||
# 根据配置选择模式
|
||||
mode = global_config.expression.mode
|
||||
logger.debug(f"[ExpressionSelector] 使用模式: {mode}")
|
||||
|
||||
|
||||
if mode == "exp_model":
|
||||
return await self._select_expressions_model_only(
|
||||
chat_id=chat_id,
|
||||
@@ -288,7 +288,7 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -298,7 +298,7 @@ class ExpressionSelector:
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""经典模式:随机抽样 + LLM评估"""
|
||||
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
|
||||
logger.debug("[Classic模式] 使用LLM评估表达方式")
|
||||
return await self.select_suitable_expressions_llm(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -306,7 +306,7 @@ class ExpressionSelector:
|
||||
min_num=min_num,
|
||||
target_message=target_message
|
||||
)
|
||||
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -316,22 +316,22 @@ class ExpressionSelector:
|
||||
min_num: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""模型预测模式:先提取情境,再使用StyleLearner预测表达风格"""
|
||||
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||
|
||||
logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return []
|
||||
|
||||
|
||||
# 步骤1: 提取聊天情境
|
||||
situations = await situation_extractor.extract_situations(
|
||||
chat_history=chat_info,
|
||||
target_message=target_message,
|
||||
max_situations=3
|
||||
)
|
||||
|
||||
|
||||
if not situations:
|
||||
logger.warning(f"无法提取聊天情境,回退到经典模式")
|
||||
logger.warning("无法提取聊天情境,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -339,17 +339,17 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 步骤1完成 - 提取到 {len(situations)} 个情境: {situations}")
|
||||
|
||||
|
||||
# 步骤2: 使用 StyleLearner 为每个情境预测合适的表达方式
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
|
||||
all_predicted_styles = {}
|
||||
for i, situation in enumerate(situations, 1):
|
||||
logger.debug(f"[Exp_model模式] 步骤2.{i} - 为情境预测风格: {situation}")
|
||||
best_style, scores = learner.predict_style(situation, top_k=max_num)
|
||||
|
||||
|
||||
if best_style and scores:
|
||||
logger.debug(f" 预测结果: best={best_style}, scores数量={len(scores)}")
|
||||
# 合并分数(取最高分)
|
||||
@@ -357,10 +357,10 @@ class ExpressionSelector:
|
||||
if style not in all_predicted_styles or score > all_predicted_styles[style]:
|
||||
all_predicted_styles[style] = score
|
||||
else:
|
||||
logger.debug(f" 该情境未返回预测结果")
|
||||
|
||||
logger.debug(" 该情境未返回预测结果")
|
||||
|
||||
if not all_predicted_styles:
|
||||
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||
logger.warning("[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -368,22 +368,22 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
# 将分数字典转换为列表格式 [(style, score), ...]
|
||||
predicted_styles = sorted(all_predicted_styles.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
|
||||
|
||||
|
||||
# 步骤3: 根据预测的风格从数据库获取表达方式
|
||||
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||
logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||
expressions = await self.get_model_predicted_expressions(
|
||||
chat_id=chat_id,
|
||||
predicted_styles=predicted_styles,
|
||||
max_num=max_num
|
||||
)
|
||||
|
||||
|
||||
if not expressions:
|
||||
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||
logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id=chat_id,
|
||||
chat_info=chat_info,
|
||||
@@ -391,10 +391,10 @@ class ExpressionSelector:
|
||||
max_num=max_num,
|
||||
min_num=min_num
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Exp_model模式] 成功! 返回 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
|
||||
async def get_model_predicted_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -414,15 +414,15 @@ class ExpressionSelector:
|
||||
"""
|
||||
if not predicted_styles:
|
||||
return []
|
||||
|
||||
|
||||
# 提取风格名称(前3个最佳匹配)
|
||||
style_names = [style for style, _ in predicted_styles[:min(3, len(predicted_styles))]]
|
||||
logger.debug(f"预测最佳风格: {style_names[0] if style_names else 'None'}, Top3分数: {predicted_styles[:3]}")
|
||||
|
||||
|
||||
# 🔥 使用 get_related_chat_ids 获取所有相关的 chat_id(支持共享表达方式)
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
logger.info(f"查询相关的chat_ids ({len(related_chat_ids)}个): {related_chat_ids}")
|
||||
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 🔍 先检查数据库中实际有哪些 chat_id 的数据
|
||||
db_chat_ids_result = await session.execute(
|
||||
@@ -432,7 +432,7 @@ class ExpressionSelector:
|
||||
)
|
||||
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
|
||||
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
|
||||
|
||||
|
||||
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
|
||||
all_expressions_result = await session.execute(
|
||||
select(Expression)
|
||||
@@ -440,51 +440,51 @@ class ExpressionSelector:
|
||||
.where(Expression.type == "style")
|
||||
)
|
||||
all_expressions = list(all_expressions_result.scalars())
|
||||
|
||||
|
||||
logger.info(f"配置的相关chat_id的表达方式数量: {len(all_expressions)}")
|
||||
|
||||
|
||||
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
|
||||
if not all_expressions:
|
||||
logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询")
|
||||
logger.info("相关chat_id没有数据,尝试从所有chat_id查询")
|
||||
all_expressions_result = await session.execute(
|
||||
select(Expression)
|
||||
.where(Expression.type == "style")
|
||||
)
|
||||
all_expressions = list(all_expressions_result.scalars())
|
||||
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
|
||||
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
|
||||
logger.warning("数据库中完全没有任何表达方式,需要先学习")
|
||||
return []
|
||||
|
||||
|
||||
# 🔥 使用模糊匹配而不是精确匹配
|
||||
# 计算每个预测style与数据库style的相似度
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
matched_expressions = []
|
||||
for expr in all_expressions:
|
||||
db_style = expr.style or ""
|
||||
max_similarity = 0.0
|
||||
best_predicted = ""
|
||||
|
||||
|
||||
# 与每个预测的style计算相似度
|
||||
for predicted_style, pred_score in predicted_styles[:20]: # 考虑前20个预测
|
||||
# 计算字符串相似度
|
||||
similarity = SequenceMatcher(None, predicted_style, db_style).ratio()
|
||||
|
||||
|
||||
# 也检查包含关系(如果一个是另一个的子串,给更高分)
|
||||
if len(predicted_style) >= 2 and len(db_style) >= 2:
|
||||
if predicted_style in db_style or db_style in predicted_style:
|
||||
similarity = max(similarity, 0.7)
|
||||
|
||||
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style
|
||||
|
||||
|
||||
# 🔥 降低阈值到30%,因为StyleLearner预测质量较差
|
||||
if max_similarity >= 0.3: # 30%相似度阈值
|
||||
matched_expressions.append((expr, max_similarity, expr.count, best_predicted))
|
||||
|
||||
|
||||
if not matched_expressions:
|
||||
# 收集数据库中的style样例用于调试
|
||||
all_styles = [e.style for e in all_expressions[:10]]
|
||||
@@ -495,11 +495,11 @@ class ExpressionSelector:
|
||||
f" 提示: StyleLearner预测质量差,建议重新训练或使用classic模式"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
# 按照相似度*count排序,选择最佳匹配
|
||||
matched_expressions.sort(key=lambda x: x[1] * (x[2] ** 0.5), reverse=True)
|
||||
expressions_objs = [e[0] for e in matched_expressions[:max_num]]
|
||||
|
||||
|
||||
# 显示最佳匹配的详细信息
|
||||
top_matches = [f"{e[3]}->{e[0].style}({e[1]:.2f})" for e in matched_expressions[:3]]
|
||||
logger.info(
|
||||
@@ -507,7 +507,7 @@ class ExpressionSelector:
|
||||
f" 相似度范围: {matched_expressions[0][1]:.2f} ~ {matched_expressions[min(len(matched_expressions)-1, max_num-1)][1]:.2f}\n"
|
||||
f" Top3匹配: {top_matches}"
|
||||
)
|
||||
|
||||
|
||||
# 转换为字典格式
|
||||
expressions = []
|
||||
for expr in expressions_objs:
|
||||
@@ -518,7 +518,7 @@ class ExpressionSelector:
|
||||
"count": float(expr.count) if expr.count else 0.0,
|
||||
"last_active_time": expr.last_active_time or 0.0
|
||||
})
|
||||
|
||||
|
||||
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -36,14 +35,14 @@ class ExpressorModel:
|
||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||
|
||||
# 候选表达管理
|
||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
self._candidates: dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
|
||||
logger.info(
|
||||
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
|
||||
)
|
||||
|
||||
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
|
||||
def add_candidate(self, cid: str, text: str, situation: str | None = None):
|
||||
"""
|
||||
添加候选文本和对应的situation
|
||||
|
||||
@@ -62,7 +61,7 @@ class ExpressorModel:
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
直接对所有候选进行朴素贝叶斯评分
|
||||
|
||||
@@ -113,7 +112,7 @@ class ExpressorModel:
|
||||
tf = Counter(toks)
|
||||
self.nb.update_positive(tf, cid)
|
||||
|
||||
def decay(self, factor: Optional[float] = None):
|
||||
def decay(self, factor: float | None = None):
|
||||
"""
|
||||
应用知识衰减
|
||||
|
||||
@@ -122,7 +121,7 @@ class ExpressorModel:
|
||||
"""
|
||||
self.nb.decay(factor)
|
||||
|
||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
获取候选信息
|
||||
|
||||
@@ -136,7 +135,7 @@ class ExpressorModel:
|
||||
situation = self._situations.get(cid)
|
||||
return style, situation
|
||||
|
||||
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
|
||||
def get_all_candidates(self) -> dict[str, tuple[str, str]]:
|
||||
"""
|
||||
获取所有候选
|
||||
|
||||
@@ -205,7 +204,7 @@ class ExpressorModel:
|
||||
|
||||
logger.info(f"模型已从 {path} 加载")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取模型统计信息"""
|
||||
nb_stats = self.nb.get_stats()
|
||||
return {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
"""
|
||||
import math
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -28,15 +27,15 @@ class OnlineNaiveBayes:
|
||||
self.V = vocab_size
|
||||
|
||||
# 类别统计
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
|
||||
self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: dict[str, dict[str, float]] = defaultdict(
|
||||
lambda: defaultdict(float)
|
||||
) # cid -> term -> count
|
||||
|
||||
# 缓存
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
self._logZ: dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
||||
def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]:
|
||||
"""
|
||||
批量计算候选的贝叶斯分数
|
||||
|
||||
@@ -51,7 +50,7 @@ class OnlineNaiveBayes:
|
||||
n_cls = max(1, len(self.cls_counts))
|
||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||
|
||||
out: Dict[str, float] = {}
|
||||
out: dict[str, float] = {}
|
||||
for cid in cids:
|
||||
# 计算先验概率 log P(c)
|
||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||
@@ -88,7 +87,7 @@ class OnlineNaiveBayes:
|
||||
self.cls_counts[cid] += inc
|
||||
self._invalidate(cid)
|
||||
|
||||
def decay(self, factor: Optional[float] = None):
|
||||
def decay(self, factor: float | None = None):
|
||||
"""
|
||||
知识衰减(遗忘机制)
|
||||
|
||||
@@ -133,7 +132,7 @@ class OnlineNaiveBayes:
|
||||
if cid in self._logZ:
|
||||
del self._logZ[cid]
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"n_classes": len(self.cls_counts),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
文本分词器,支持中文Jieba分词
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -30,7 +29,7 @@ class Tokenizer:
|
||||
logger.warning("Jieba未安装,将使用字符级分词")
|
||||
self.use_jieba = False
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
"""
|
||||
分词并返回token列表
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
情境提取器
|
||||
从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
@@ -41,17 +40,17 @@ def init_prompt():
|
||||
|
||||
class SituationExtractor:
|
||||
"""情境提取器,从聊天历史中提取当前情境"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="expression.situation_extractor"
|
||||
)
|
||||
|
||||
|
||||
async def extract_situations(
|
||||
self,
|
||||
chat_history: list | str,
|
||||
target_message: Optional[str] = None,
|
||||
target_message: str | None = None,
|
||||
max_situations: int = 3
|
||||
) -> list[str]:
|
||||
"""
|
||||
@@ -68,18 +67,18 @@ class SituationExtractor:
|
||||
# 转换chat_history为字符串
|
||||
if isinstance(chat_history, list):
|
||||
chat_info = "\n".join([
|
||||
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
|
||||
f"{msg.get('sender', 'Unknown')}: {msg.get('content', '')}"
|
||||
for msg in chat_history
|
||||
])
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
|
||||
# 构建目标消息信息
|
||||
if target_message:
|
||||
target_message_info = f",现在你想要回复消息:{target_message}"
|
||||
else:
|
||||
target_message_info = ""
|
||||
|
||||
|
||||
# 构建 prompt
|
||||
try:
|
||||
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
|
||||
@@ -87,31 +86,31 @@ class SituationExtractor:
|
||||
chat_history=chat_info,
|
||||
target_message_info=target_message_info
|
||||
)
|
||||
|
||||
|
||||
# 调用 LLM
|
||||
response, _ = await self.llm_model.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.3
|
||||
)
|
||||
|
||||
|
||||
if not response or not response.strip():
|
||||
logger.warning("LLM返回空响应,无法提取情境")
|
||||
return []
|
||||
|
||||
|
||||
# 解析响应
|
||||
situations = self._parse_situations(response, max_situations)
|
||||
|
||||
|
||||
if situations:
|
||||
logger.debug(f"提取到 {len(situations)} 个情境: {situations}")
|
||||
else:
|
||||
logger.warning(f"无法从LLM响应中解析出情境。响应:\n{response}")
|
||||
|
||||
|
||||
return situations
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取情境失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _parse_situations(response: str, max_situations: int) -> list[str]:
|
||||
"""
|
||||
@@ -125,33 +124,33 @@ class SituationExtractor:
|
||||
情境描述列表
|
||||
"""
|
||||
situations = []
|
||||
|
||||
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
# 移除可能的序号、引号等
|
||||
line = line.lstrip('0123456789.、-*>))】] \t"\'""''')
|
||||
line = line.rstrip('"\'""''')
|
||||
line = line.strip()
|
||||
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
# 过滤掉明显不是情境描述的内容
|
||||
if len(line) > 30: # 太长
|
||||
continue
|
||||
if len(line) < 2: # 太短
|
||||
continue
|
||||
if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']):
|
||||
if any(keyword in line.lower() for keyword in ["例如", "注意", "请", "分析", "总结"]):
|
||||
continue
|
||||
|
||||
|
||||
situations.append(line)
|
||||
|
||||
|
||||
if len(situations) >= max_situations:
|
||||
break
|
||||
|
||||
|
||||
return situations
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
|
||||
class StyleLearner:
|
||||
"""单个聊天室的表达风格学习器"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||||
def __init__(self, chat_id: str, model_config: dict | None = None):
|
||||
"""
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
@@ -37,9 +36,9 @@ class StyleLearner:
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||||
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
|
||||
self.next_style_id = 0
|
||||
|
||||
# 学习统计
|
||||
@@ -51,7 +50,7 @@ class StyleLearner:
|
||||
|
||||
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
|
||||
|
||||
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
|
||||
def add_style(self, style: str, situation: str | None = None) -> bool:
|
||||
"""
|
||||
动态添加一个新的风格
|
||||
|
||||
@@ -130,7 +129,7 @@ class StyleLearner:
|
||||
logger.error(f"学习映射失败: {e}")
|
||||
return False
|
||||
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
根据up_content预测最合适的style
|
||||
|
||||
@@ -146,7 +145,7 @@ class StyleLearner:
|
||||
if not self.style_to_id:
|
||||
logger.debug(f"StyleLearner还没有任何训练数据: chat_id={self.chat_id}")
|
||||
return None, {}
|
||||
|
||||
|
||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||
|
||||
if best_style_id is None:
|
||||
@@ -155,7 +154,7 @@ class StyleLearner:
|
||||
|
||||
# 将style_id转换为style文本
|
||||
best_style = self.id_to_style.get(best_style_id)
|
||||
|
||||
|
||||
if best_style is None:
|
||||
logger.warning(
|
||||
f"style_id无法转换为style文本: style_id={best_style_id}, "
|
||||
@@ -171,7 +170,7 @@ class StyleLearner:
|
||||
style_scores[style_text] = score
|
||||
else:
|
||||
logger.warning(f"跳过无法转换的style_id: {sid}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"预测成功: up_content={up_content[:30]}..., "
|
||||
f"best_style={best_style}, top3_scores={list(style_scores.items())[:3]}"
|
||||
@@ -183,7 +182,7 @@ class StyleLearner:
|
||||
logger.error(f"预测style失败: {e}", exc_info=True)
|
||||
return None, {}
|
||||
|
||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
获取style的完整信息
|
||||
|
||||
@@ -200,7 +199,7 @@ class StyleLearner:
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
return style_id, situation
|
||||
|
||||
def get_all_styles(self) -> List[str]:
|
||||
def get_all_styles(self) -> list[str]:
|
||||
"""
|
||||
获取所有风格列表
|
||||
|
||||
@@ -209,7 +208,7 @@ class StyleLearner:
|
||||
"""
|
||||
return list(self.style_to_id.keys())
|
||||
|
||||
def apply_decay(self, factor: Optional[float] = None):
|
||||
def apply_decay(self, factor: float | None = None):
|
||||
"""
|
||||
应用知识衰减
|
||||
|
||||
@@ -304,7 +303,7 @@ class StyleLearner:
|
||||
logger.error(f"加载StyleLearner失败: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
model_stats = self.expressor.get_stats()
|
||||
return {
|
||||
@@ -324,7 +323,7 @@ class StyleLearnerManager:
|
||||
Args:
|
||||
model_save_path: 模型保存路径
|
||||
"""
|
||||
self.learners: Dict[str, StyleLearner] = {}
|
||||
self.learners: dict[str, StyleLearner] = {}
|
||||
self.model_save_path = model_save_path
|
||||
|
||||
# 确保保存目录存在
|
||||
@@ -332,7 +331,7 @@ class StyleLearnerManager:
|
||||
|
||||
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
|
||||
|
||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||||
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
|
||||
"""
|
||||
获取或创建指定chat_id的学习器
|
||||
|
||||
@@ -369,7 +368,7 @@ class StyleLearnerManager:
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.learn_mapping(up_content, style)
|
||||
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
预测最合适的风格
|
||||
|
||||
@@ -399,7 +398,7 @@ class StyleLearnerManager:
|
||||
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
|
||||
return success
|
||||
|
||||
def apply_decay_all(self, factor: Optional[float] = None):
|
||||
def apply_decay_all(self, factor: float | None = None):
|
||||
"""
|
||||
对所有学习器应用知识衰减
|
||||
|
||||
@@ -409,9 +408,9 @@ class StyleLearnerManager:
|
||||
for learner in self.learners.values():
|
||||
learner.apply_decay(factor)
|
||||
|
||||
logger.info(f"对所有StyleLearner应用知识衰减")
|
||||
logger.info("对所有StyleLearner应用知识衰减")
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict]:
|
||||
def get_all_stats(self) -> dict[str, dict]:
|
||||
"""
|
||||
获取所有学习器的统计信息
|
||||
|
||||
|
||||
@@ -503,7 +503,7 @@ class MemorySystem:
|
||||
existing_id = self._memory_fingerprints.get(fingerprint_key)
|
||||
if existing_id and existing_id not in new_memory_ids:
|
||||
candidate_ids.add(existing_id)
|
||||
except Exception as exc: # noqa: PERF203
|
||||
except Exception as exc:
|
||||
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
|
||||
|
||||
# 基于主体索引的候选(使用统一存储)
|
||||
|
||||
@@ -35,12 +35,12 @@ class SingleStreamContextManager:
|
||||
self.last_access_time = time.time()
|
||||
self.access_count = 0
|
||||
self.total_messages = 0
|
||||
|
||||
|
||||
# 标记是否已初始化历史消息
|
||||
self._history_initialized = False
|
||||
|
||||
logger.info(f"[新建] 单流上下文管理器初始化: {stream_id} (id={id(self)})")
|
||||
|
||||
|
||||
# 异步初始化历史消息(不阻塞构造函数)
|
||||
asyncio.create_task(self._initialize_history_from_db())
|
||||
|
||||
@@ -299,55 +299,55 @@ class SingleStreamContextManager:
|
||||
"""更新访问统计"""
|
||||
self.last_access_time = time.time()
|
||||
self.access_count += 1
|
||||
|
||||
|
||||
async def _initialize_history_from_db(self):
|
||||
"""从数据库初始化历史消息到context中"""
|
||||
if self._history_initialized:
|
||||
logger.info(f"历史消息已初始化,跳过: {self.stream_id}")
|
||||
return
|
||||
|
||||
|
||||
# 立即设置标志,防止并发重复加载
|
||||
logger.info(f"设置历史初始化标志: {self.stream_id}")
|
||||
self._history_initialized = True
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"开始从数据库加载历史消息: {self.stream_id}")
|
||||
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
|
||||
# 加载历史消息(限制数量为max_context_size的2倍,用于丰富上下文)
|
||||
db_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=self.max_context_size * 2,
|
||||
)
|
||||
|
||||
|
||||
if db_messages:
|
||||
# 将数据库消息转换为 DatabaseMessages 对象并添加到历史
|
||||
for msg_dict in db_messages:
|
||||
try:
|
||||
# 使用 ** 解包字典作为关键字参数
|
||||
db_msg = DatabaseMessages(**msg_dict)
|
||||
|
||||
|
||||
# 标记为已读
|
||||
db_msg.is_read = True
|
||||
|
||||
|
||||
# 添加到历史消息
|
||||
self.context.history_messages.append(db_msg)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}")
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"成功从数据库加载 {len(self.context.history_messages)} 条历史消息到内存: {self.stream_id}")
|
||||
else:
|
||||
logger.debug(f"没有历史消息需要加载: {self.stream_id}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库初始化历史消息失败: {self.stream_id}, {e}", exc_info=True)
|
||||
# 加载失败时重置标志,允许重试
|
||||
self._history_initialized = False
|
||||
|
||||
|
||||
async def ensure_history_initialized(self):
|
||||
"""确保历史消息已初始化(供外部调用)"""
|
||||
if not self._history_initialized:
|
||||
|
||||
@@ -69,10 +69,10 @@ class StreamLoopManager:
|
||||
try:
|
||||
# 获取所有活跃的流
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
all_streams = await chat_manager.get_all_streams()
|
||||
|
||||
|
||||
# 创建任务列表以便并发取消
|
||||
cancel_tasks = []
|
||||
for chat_stream in all_streams:
|
||||
@@ -119,10 +119,10 @@ class StreamLoopManager:
|
||||
# 创建流循环任务
|
||||
try:
|
||||
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
|
||||
|
||||
|
||||
# 将任务记录到 StreamContext 中
|
||||
context.stream_loop_task = loop_task
|
||||
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["total_loops"] += 1
|
||||
@@ -169,7 +169,7 @@ class StreamLoopManager:
|
||||
|
||||
# 清空 StreamContext 中的任务记录
|
||||
context.stream_loop_task = None
|
||||
|
||||
|
||||
logger.info(f"停止流循环: {stream_id}")
|
||||
return True
|
||||
|
||||
@@ -200,13 +200,13 @@ class StreamLoopManager:
|
||||
if has_messages:
|
||||
if force_dispatch:
|
||||
logger.info("流 %s 未读消息 %d 条,触发强制分发", stream_id, unread_count)
|
||||
|
||||
|
||||
# 3. 在处理前更新能量值(用于下次间隔计算)
|
||||
try:
|
||||
await self._update_stream_energy(stream_id, context)
|
||||
except Exception as e:
|
||||
logger.debug(f"更新流能量失败 {stream_id}: {e}")
|
||||
|
||||
|
||||
# 4. 激活chatter处理
|
||||
success = await self._process_stream_messages(stream_id, context)
|
||||
|
||||
@@ -371,7 +371,7 @@ class StreamLoopManager:
|
||||
# 清除 Chatter 处理标志
|
||||
context.is_chatter_processing = False
|
||||
logger.debug(f"清除 Chatter 处理标志: {stream_id}")
|
||||
|
||||
|
||||
# 无论成功或失败,都要设置处理状态为未处理
|
||||
self._set_stream_processing_status(stream_id, False)
|
||||
|
||||
@@ -432,48 +432,48 @@ class StreamLoopManager:
|
||||
"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
# 获取聊天流
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(stream_id)
|
||||
|
||||
|
||||
if not chat_stream:
|
||||
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
|
||||
return
|
||||
|
||||
|
||||
# 从 context_manager 获取消息(包括未读和历史消息)
|
||||
# 合并未读消息和历史消息
|
||||
all_messages = []
|
||||
|
||||
|
||||
# 添加历史消息
|
||||
history_messages = context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||
all_messages.extend(history_messages)
|
||||
|
||||
|
||||
# 添加未读消息
|
||||
unread_messages = context.get_unread_messages()
|
||||
all_messages.extend(unread_messages)
|
||||
|
||||
|
||||
# 按时间排序并限制数量
|
||||
all_messages.sort(key=lambda m: m.time)
|
||||
messages = all_messages[-global_config.chat.max_context_size:]
|
||||
|
||||
|
||||
# 获取用户ID
|
||||
user_id = None
|
||||
if context.triggering_user_id:
|
||||
user_id = context.triggering_user_id
|
||||
|
||||
|
||||
# 使用能量管理器计算并缓存能量值
|
||||
energy = await energy_manager.calculate_focus_energy(
|
||||
stream_id=stream_id,
|
||||
messages=messages,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
# 同步更新到 ChatStream
|
||||
chat_stream._focus_energy = energy
|
||||
|
||||
|
||||
logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False)
|
||||
|
||||
@@ -670,7 +670,7 @@ class StreamLoopManager:
|
||||
|
||||
# 使用 start_stream_loop 重新创建流循环任务
|
||||
success = await self.start_stream_loop(stream_id, force=True)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"已创建强制分发流循环: {stream_id}")
|
||||
else:
|
||||
|
||||
@@ -307,7 +307,7 @@ class MessageManager:
|
||||
|
||||
# 检查上下文
|
||||
context = chat_stream.context_manager.context
|
||||
|
||||
|
||||
# 只有当 Chatter 真正在处理时才检查打断
|
||||
if not context.is_chatter_processing:
|
||||
logger.debug(f"聊天流 {chat_stream.stream_id} Chatter 未在处理,跳过打断检查")
|
||||
@@ -315,7 +315,7 @@ class MessageManager:
|
||||
|
||||
# 检查是否有 stream_loop_task 在运行
|
||||
stream_loop_task = context.stream_loop_task
|
||||
|
||||
|
||||
if stream_loop_task and not stream_loop_task.done():
|
||||
# 检查触发用户ID
|
||||
triggering_user_id = context.triggering_user_id
|
||||
@@ -387,7 +387,7 @@ class MessageManager:
|
||||
|
||||
# 重新创建 stream_loop 任务
|
||||
success = await stream_loop_manager.start_stream_loop(stream_id, force=True)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"✅ 成功重新创建流循环任务: {stream_id}")
|
||||
else:
|
||||
|
||||
@@ -10,7 +10,7 @@ from src.chat.antipromptinjector import initialize_anti_injector
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
@@ -181,7 +181,7 @@ class ChatBot:
|
||||
|
||||
# 创建PlusCommand实例
|
||||
plus_command_instance = plus_command_class(message, plugin_config)
|
||||
|
||||
|
||||
# 为插件实例设置 chat_stream 运行时属性
|
||||
setattr(plus_command_instance, "chat_stream", chat)
|
||||
|
||||
@@ -257,7 +257,7 @@ class ChatBot:
|
||||
# 创建命令实例
|
||||
command_instance: BaseCommand = command_class(message, plugin_config)
|
||||
command_instance.set_matched_groups(matched_groups)
|
||||
|
||||
|
||||
# 为插件实例设置 chat_stream 运行时属性
|
||||
setattr(command_instance, "chat_stream", chat)
|
||||
|
||||
@@ -340,7 +340,7 @@ class ChatBot:
|
||||
)
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
|
||||
|
||||
# 先提取基础信息检查是否是自身消息上报
|
||||
from maim_message import BaseMessageInfo
|
||||
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
|
||||
@@ -350,7 +350,7 @@ class ChatBot:
|
||||
# 直接使用消息字典更新,不再需要创建 MessageRecv
|
||||
await MessageStorage.update_message(message_data)
|
||||
return
|
||||
|
||||
|
||||
group_info = temp_message_info.group_info
|
||||
user_info = temp_message_info.user_info
|
||||
|
||||
@@ -368,14 +368,14 @@ class ChatBot:
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
|
||||
|
||||
# 注册消息到聊天管理器
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
|
||||
# 检测是否提及机器人
|
||||
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
@@ -10,13 +8,12 @@ from sqlalchemy import select
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
@@ -134,7 +131,7 @@ class ChatStream:
|
||||
"""
|
||||
# 直接使用传入的 DatabaseMessages,设置到上下文中
|
||||
self.context_manager.context.set_current_message(message)
|
||||
|
||||
|
||||
# 设置优先级信息(如果存在)
|
||||
priority_mode = getattr(message, "priority_mode", None)
|
||||
priority_info = getattr(message, "priority_info", None)
|
||||
@@ -156,7 +153,7 @@ class ChatStream:
|
||||
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
import json
|
||||
|
||||
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
if actions is None:
|
||||
@@ -321,7 +318,7 @@ class ChatManager:
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
|
||||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
|
||||
# try:
|
||||
@@ -360,15 +357,15 @@ class ChatManager:
|
||||
def register_message(self, message: DatabaseMessages):
|
||||
"""注册消息到聊天流"""
|
||||
# 从 DatabaseMessages 提取平台和用户/群组信息
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
user_info = UserInfo(
|
||||
platform=message.user_info.platform,
|
||||
user_id=message.user_info.user_id,
|
||||
user_nickname=message.user_info.user_nickname,
|
||||
user_cardname=message.user_info.user_cardname or ""
|
||||
)
|
||||
|
||||
|
||||
group_info = None
|
||||
if message.group_info:
|
||||
group_info = GroupInfo(
|
||||
@@ -376,7 +373,7 @@ class ChatManager:
|
||||
group_id=message.group_info.group_id,
|
||||
group_name=message.group_info.group_name
|
||||
)
|
||||
|
||||
|
||||
stream_id = self._generate_stream_id(
|
||||
message.chat_info.platform,
|
||||
user_info,
|
||||
@@ -435,7 +432,7 @@ class ChatManager:
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
|
||||
|
||||
# 检查是否有最后一条消息(现在使用 DatabaseMessages)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||
@@ -532,7 +529,7 @@ class ChatManager:
|
||||
async def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
return None
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import base64
|
||||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import urllib3
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
@@ -11,7 +10,6 @@ from rich.traceback import install
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
@@ -269,7 +267,7 @@ class MessageSending(MessageProcessBase):
|
||||
if self.reply:
|
||||
# 从 DatabaseMessages 获取 message_id
|
||||
message_id = self.reply.message_id
|
||||
|
||||
|
||||
if message_id:
|
||||
self.reply_to_message_id = message_id
|
||||
self.message_segment = Seg(
|
||||
|
||||
@@ -39,7 +39,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
# 解析基础信息
|
||||
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
|
||||
|
||||
# 初始化处理状态
|
||||
processing_state = {
|
||||
"is_emoji": False,
|
||||
@@ -53,10 +53,10 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
"priority_mode": "interest",
|
||||
"priority_info": None,
|
||||
}
|
||||
|
||||
|
||||
# 异步处理消息段,生成纯文本
|
||||
processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
|
||||
|
||||
|
||||
# 解析 notice 信息
|
||||
is_notify = False
|
||||
is_public_notice = False
|
||||
@@ -65,34 +65,34 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
is_notify = message_info.additional_config.get("is_notice", False)
|
||||
is_public_notice = message_info.additional_config.get("is_public_notice", False)
|
||||
notice_type = message_info.additional_config.get("notice_type")
|
||||
|
||||
|
||||
# 提取用户信息
|
||||
user_info = message_info.user_info
|
||||
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
|
||||
user_nickname = (user_info.user_nickname or "") if user_info else ""
|
||||
user_cardname = user_info.user_cardname if user_info else None
|
||||
user_platform = (user_info.platform or "") if user_info else ""
|
||||
|
||||
|
||||
# 提取群组信息
|
||||
group_info = message_info.group_info
|
||||
group_id = group_info.group_id if group_info else None
|
||||
group_name = group_info.group_name if group_info else None
|
||||
group_platform = group_info.platform if group_info else None
|
||||
|
||||
|
||||
# chat_id 应该直接使用 stream_id(与数据库存储格式一致)
|
||||
# stream_id 是通过 platform + user_id/group_id 的 SHA-256 哈希生成的
|
||||
chat_id = stream_id
|
||||
|
||||
|
||||
# 准备 additional_config
|
||||
additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type)
|
||||
|
||||
|
||||
# 提取 reply_to
|
||||
reply_to = _extract_reply_from_segment(message_segment)
|
||||
|
||||
|
||||
# 构造 DatabaseMessages
|
||||
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
||||
message_id = message_info.message_id or ""
|
||||
|
||||
|
||||
# 处理 is_mentioned
|
||||
is_mentioned = None
|
||||
mentioned_value = processing_state.get("is_mentioned")
|
||||
@@ -100,7 +100,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
is_mentioned = mentioned_value
|
||||
elif isinstance(mentioned_value, (int, float)):
|
||||
is_mentioned = mentioned_value != 0
|
||||
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message_id,
|
||||
time=float(message_time),
|
||||
@@ -133,19 +133,19 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
chat_info_group_name=group_name,
|
||||
chat_info_group_platform=group_platform,
|
||||
)
|
||||
|
||||
|
||||
# 设置优先级信息
|
||||
if processing_state.get("priority_mode"):
|
||||
setattr(db_message, "priority_mode", processing_state["priority_mode"])
|
||||
if processing_state.get("priority_info"):
|
||||
setattr(db_message, "priority_info", processing_state["priority_info"])
|
||||
|
||||
|
||||
# 设置其他运行时属性
|
||||
setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False)))
|
||||
setattr(db_message, "is_video", bool(processing_state.get("is_video", False)))
|
||||
setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False)))
|
||||
setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False)))
|
||||
|
||||
|
||||
return db_message
|
||||
|
||||
|
||||
@@ -190,7 +190,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
state["is_emoji"] = False
|
||||
state["is_video"] = False
|
||||
return segment.data
|
||||
|
||||
|
||||
elif segment.type == "at":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
@@ -201,7 +201,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
nickname, qq_id = segment.data.split(":", 1)
|
||||
return f"@{nickname}"
|
||||
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
|
||||
|
||||
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
@@ -213,7 +213,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
|
||||
|
||||
elif segment.type == "emoji":
|
||||
state["has_emoji"] = True
|
||||
state["is_emoji"] = True
|
||||
@@ -223,13 +223,13 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
|
||||
|
||||
elif segment.type == "voice":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = True
|
||||
state["is_video"] = False
|
||||
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||
@@ -240,12 +240,12 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
return f"[语音:{cached_text}]"
|
||||
else:
|
||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||
|
||||
|
||||
# 标准语音识别流程
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
|
||||
|
||||
elif segment.type == "mention_bot":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
@@ -253,7 +253,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
state["is_video"] = False
|
||||
state["is_mentioned"] = float(segment.data)
|
||||
return ""
|
||||
|
||||
|
||||
elif segment.type == "priority_info":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
@@ -263,26 +263,26 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
state["priority_mode"] = "priority"
|
||||
state["priority_info"] = segment.data
|
||||
return ""
|
||||
|
||||
|
||||
elif segment.type == "file":
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get('name', '未知文件')
|
||||
file_size = segment.data.get('size', '未知大小')
|
||||
file_name = segment.data.get("name", "未知文件")
|
||||
file_size = segment.data.get("size", "未知大小")
|
||||
return f"[文件:{file_name} ({file_size}字节)]"
|
||||
return "[收到一个文件]"
|
||||
|
||||
|
||||
elif segment.type == "video":
|
||||
state["is_picid"] = False
|
||||
state["is_emoji"] = False
|
||||
state["is_voice"] = False
|
||||
state["is_video"] = True
|
||||
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
||||
|
||||
|
||||
# 检查视频分析功能是否可用
|
||||
if not is_video_analysis_available():
|
||||
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||
return "[视频]"
|
||||
|
||||
|
||||
if global_config.video_analysis.enable:
|
||||
logger.info("已启用视频识别,开始识别")
|
||||
if isinstance(segment.data, dict):
|
||||
@@ -290,23 +290,23 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
# 从Adapter接收的视频数据
|
||||
video_base64 = segment.data.get("base64")
|
||||
filename = segment.data.get("filename", "video.mp4")
|
||||
|
||||
|
||||
logger.info(f"视频文件名: {filename}")
|
||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||
|
||||
|
||||
if video_base64:
|
||||
# 解码base64视频数据
|
||||
video_bytes = base64.b64decode(video_base64)
|
||||
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
|
||||
|
||||
|
||||
# 使用video analyzer分析视频
|
||||
video_analyzer = get_video_analyzer()
|
||||
result = await video_analyzer.analyze_video_from_bytes(
|
||||
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"视频分析结果: {result}")
|
||||
|
||||
|
||||
# 返回视频分析结果
|
||||
summary = result.get("summary", "")
|
||||
if summary:
|
||||
@@ -329,7 +329,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
||||
else:
|
||||
logger.warning(f"未知的消息段类型: {segment.type}")
|
||||
return f"[{segment.type} 消息]"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
@@ -349,9 +349,9 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
||||
"""
|
||||
try:
|
||||
additional_config_data = {}
|
||||
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
||||
if hasattr(message_info, "additional_config") and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_data = message_info.additional_config.copy()
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
@@ -360,28 +360,28 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
|
||||
|
||||
# 添加notice相关标志
|
||||
if is_notify:
|
||||
additional_config_data["is_notice"] = True
|
||||
additional_config_data["notice_type"] = notice_type or "unknown"
|
||||
additional_config_data["is_public_notice"] = bool(is_public_notice)
|
||||
|
||||
|
||||
# 添加format_info到additional_config中
|
||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
||||
if hasattr(message_info, "format_info") and message_info.format_info:
|
||||
try:
|
||||
format_info_dict = message_info.format_info.to_dict()
|
||||
additional_config_data["format_info"] = format_info_dict
|
||||
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||
except Exception as e:
|
||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||
|
||||
|
||||
# 序列化为JSON字符串
|
||||
if additional_config_data:
|
||||
return orjson.dumps(additional_config_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"准备 additional_config 失败: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -423,8 +423,8 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
||||
Returns:
|
||||
BaseMessageInfo: 重建的消息信息对象
|
||||
"""
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
|
||||
user_info = UserInfo(
|
||||
platform=db_message.user_info.platform,
|
||||
@@ -432,7 +432,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
||||
user_nickname=db_message.user_info.user_nickname,
|
||||
user_cardname=db_message.user_info.user_cardname or ""
|
||||
)
|
||||
|
||||
|
||||
# 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在)
|
||||
group_info = None
|
||||
if db_message.group_info:
|
||||
@@ -441,7 +441,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
||||
group_id=db_message.group_info.group_id,
|
||||
group_name=db_message.group_info.group_name
|
||||
)
|
||||
|
||||
|
||||
# 解析 additional_config(从 JSON 字符串到字典)
|
||||
additional_config = None
|
||||
if db_message.additional_config:
|
||||
@@ -450,7 +450,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
||||
except Exception:
|
||||
# 如果解析失败,保持为字符串
|
||||
pass
|
||||
|
||||
|
||||
# 创建 BaseMessageInfo
|
||||
message_info = BaseMessageInfo(
|
||||
platform=db_message.chat_info.platform,
|
||||
@@ -460,7 +460,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
||||
group_info=group_info,
|
||||
additional_config=additional_config # type: ignore
|
||||
)
|
||||
|
||||
|
||||
return message_info
|
||||
|
||||
|
||||
|
||||
@@ -5,12 +5,11 @@ import traceback
|
||||
import orjson
|
||||
from sqlalchemy import desc, select, update
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Images, Messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending
|
||||
|
||||
@@ -51,10 +50,10 @@ class MessageStorage:
|
||||
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
|
||||
display_message = message.display_message or message.processed_plain_text or ""
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
|
||||
|
||||
# 直接从 DatabaseMessages 获取所有字段
|
||||
msg_id = message.message_id
|
||||
msg_time = message.time
|
||||
@@ -71,13 +70,13 @@ class MessageStorage:
|
||||
key_words = "" # DatabaseMessages 没有 key_words
|
||||
key_words_lite = ""
|
||||
memorized_times = 0 # DatabaseMessages 没有 memorized_times
|
||||
|
||||
|
||||
# 使用 DatabaseMessages 中的嵌套对象信息
|
||||
user_platform = message.user_info.platform if message.user_info else ""
|
||||
user_id = message.user_info.user_id if message.user_info else ""
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else ""
|
||||
user_cardname = message.user_info.user_cardname if message.user_info else None
|
||||
|
||||
|
||||
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
|
||||
chat_info_platform = message.chat_info.platform if message.chat_info else ""
|
||||
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
|
||||
@@ -89,7 +88,7 @@ class MessageStorage:
|
||||
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
|
||||
chat_info_group_id = message.group_info.group_id if message.group_info else None
|
||||
chat_info_group_name = message.group_info.group_name if message.group_info else None
|
||||
|
||||
|
||||
else:
|
||||
# MessageSending 处理逻辑
|
||||
processed_plain_text = message.processed_plain_text
|
||||
@@ -145,7 +144,7 @@ class MessageStorage:
|
||||
msg_time = float(message.message_info.time or time.time())
|
||||
chat_id = chat_stream.stream_id
|
||||
memorized_times = message.memorized_times
|
||||
|
||||
|
||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||
@@ -153,12 +152,12 @@ class MessageStorage:
|
||||
|
||||
# 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
|
||||
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
|
||||
|
||||
|
||||
user_platform = user_info_dict.get("platform")
|
||||
user_id = user_info_dict.get("user_id")
|
||||
user_nickname = user_info_dict.get("user_nickname")
|
||||
user_cardname = user_info_dict.get("user_cardname")
|
||||
|
||||
|
||||
chat_info_stream_id = chat_info_dict.get("stream_id")
|
||||
chat_info_platform = chat_info_dict.get("platform")
|
||||
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
|
||||
@@ -222,11 +221,11 @@ class MessageStorage:
|
||||
# 从字典中提取信息
|
||||
message_info = message_data.get("message_info", {})
|
||||
mmc_message_id = message_info.get("message_id")
|
||||
|
||||
|
||||
message_segment = message_data.get("message_segment", {})
|
||||
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
|
||||
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
|
||||
|
||||
|
||||
qq_message_id = None
|
||||
|
||||
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
|
||||
|
||||
@@ -23,35 +23,35 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
await get_global_api().send_message(message)
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
|
||||
|
||||
# 触发 AFTER_SEND 事件
|
||||
try:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
if message.chat_stream:
|
||||
logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件,stream_id={message.chat_stream.stream_id}")
|
||||
|
||||
|
||||
# 使用 asyncio.create_task 来异步触发事件,避免阻塞
|
||||
async def trigger_event_async():
|
||||
try:
|
||||
logger.info(f"[事件触发] 开始异步触发 AFTER_SEND 事件")
|
||||
logger.info("[事件触发] 开始异步触发 AFTER_SEND 事件")
|
||||
await event_manager.trigger_event(
|
||||
EventType.AFTER_SEND,
|
||||
permission_group="SYSTEM",
|
||||
stream_id=message.chat_stream.stream_id,
|
||||
message=message,
|
||||
)
|
||||
logger.info(f"[事件触发] AFTER_SEND 事件触发完成")
|
||||
logger.info("[事件触发] AFTER_SEND 事件触发完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 创建异步任务,不等待完成
|
||||
asyncio.create_task(trigger_event_async())
|
||||
logger.info(f"[发送完成] AFTER_SEND 事件已提交到异步任务")
|
||||
logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务")
|
||||
except Exception as event_error:
|
||||
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -270,7 +270,7 @@ class ChatterActionManager:
|
||||
msg_text = target_message.get("processed_plain_text", "未知消息")
|
||||
else:
|
||||
msg_text = "未知消息"
|
||||
|
||||
|
||||
logger.info(f"对 {msg_text} 的回复生成失败")
|
||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -32,8 +32,6 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import llm_api
|
||||
@@ -943,10 +941,10 @@ class DefaultReplyer:
|
||||
chat_stream = await chat_manager.get_stream(chat_id)
|
||||
if chat_stream:
|
||||
stream_context = chat_stream.context_manager
|
||||
|
||||
|
||||
# 确保历史消息已从数据库加载
|
||||
await stream_context.ensure_history_initialized()
|
||||
|
||||
|
||||
# 直接使用内存中的已读和未读消息,无需再查询数据库
|
||||
read_messages = stream_context.context.history_messages # 已读消息(已从数据库加载)
|
||||
unread_messages = stream_context.get_unread_messages() # 未读消息
|
||||
@@ -956,11 +954,11 @@ class DefaultReplyer:
|
||||
if read_messages:
|
||||
# 将 DatabaseMessages 对象转换为字典格式,以便使用 build_readable_messages
|
||||
read_messages_dicts = [msg.flatten() for msg in read_messages]
|
||||
|
||||
|
||||
# 按时间排序并限制数量
|
||||
sorted_messages = sorted(read_messages_dicts, key=lambda x: x.get("time", 0))
|
||||
final_history = sorted_messages[-50:] # 限制最多50条
|
||||
|
||||
|
||||
read_content = await build_readable_messages(
|
||||
final_history,
|
||||
replace_bot_name=True,
|
||||
@@ -1194,7 +1192,7 @@ class DefaultReplyer:
|
||||
if reply_message is None:
|
||||
logger.warning("reply_message 为 None,无法构建prompt")
|
||||
return ""
|
||||
|
||||
|
||||
# 统一处理 DatabaseMessages 对象和字典
|
||||
if isinstance(reply_message, DatabaseMessages):
|
||||
platform = reply_message.chat_info.platform
|
||||
@@ -1208,7 +1206,7 @@ class DefaultReplyer:
|
||||
user_nickname = reply_message.get("user_nickname")
|
||||
user_cardname = reply_message.get("user_cardname")
|
||||
processed_plain_text = reply_message.get("processed_plain_text")
|
||||
|
||||
|
||||
person_id = person_info_manager.get_person_id(
|
||||
platform, # type: ignore
|
||||
user_id, # type: ignore
|
||||
@@ -1262,24 +1260,24 @@ class DefaultReplyer:
|
||||
|
||||
# 从内存获取历史消息,避免重复查询数据库
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(chat_id)
|
||||
|
||||
|
||||
if chat_stream_obj:
|
||||
# 确保历史消息已初始化
|
||||
await chat_stream_obj.context_manager.ensure_history_initialized()
|
||||
|
||||
|
||||
# 获取所有消息(历史+未读)
|
||||
all_messages = (
|
||||
chat_stream_obj.context_manager.context.history_messages +
|
||||
chat_stream_obj.context_manager.get_unread_messages()
|
||||
)
|
||||
|
||||
|
||||
# 转换为字典格式
|
||||
message_list_before_now_long = [msg.flatten() for msg in all_messages[-(global_config.chat.max_context_size * 2):]]
|
||||
message_list_before_short = [msg.flatten() for msg in all_messages[-int(global_config.chat.max_context_size * 0.33):]]
|
||||
|
||||
|
||||
logger.debug(f"使用内存中的消息: long={len(message_list_before_now_long)}, short={len(message_list_before_short)}")
|
||||
else:
|
||||
# 回退到数据库查询
|
||||
@@ -1294,7 +1292,7 @@ class DefaultReplyer:
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
|
||||
|
||||
chat_talking_prompt_short = await build_readable_messages(
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
@@ -1634,24 +1632,24 @@ class DefaultReplyer:
|
||||
|
||||
# 从内存获取历史消息,避免重复查询数据库
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(chat_id)
|
||||
|
||||
|
||||
if chat_stream_obj:
|
||||
# 确保历史消息已初始化
|
||||
await chat_stream_obj.context_manager.ensure_history_initialized()
|
||||
|
||||
|
||||
# 获取所有消息(历史+未读)
|
||||
all_messages = (
|
||||
chat_stream_obj.context_manager.context.history_messages +
|
||||
chat_stream_obj.context_manager.get_unread_messages()
|
||||
)
|
||||
|
||||
|
||||
# 转换为字典格式,限制数量
|
||||
limit = min(int(global_config.chat.max_context_size * 0.33), 15)
|
||||
message_list_before_now_half = [msg.flatten() for msg in all_messages[-limit:]]
|
||||
|
||||
|
||||
logger.debug(f"Rewrite使用内存中的 {len(message_list_before_now_half)} 条消息")
|
||||
else:
|
||||
# 回退到数据库查询
|
||||
@@ -1661,7 +1659,7 @@ class DefaultReplyer:
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
|
||||
|
||||
chat_talking_prompt_half = await build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
@@ -1818,7 +1816,7 @@ class DefaultReplyer:
|
||||
# 循环移除,以处理模型可能生成的嵌套回复头/尾
|
||||
# 使用更健壮的正则表达式,通过非贪婪匹配和向后查找来定位真正的消息内容
|
||||
pattern = re.compile(r"^\s*\[回复<.+?>\s*(?:的消息)?:(?P<content>.*)\](?:,?说:)?\s*$", re.DOTALL)
|
||||
|
||||
|
||||
temp_content = cleaned_content
|
||||
while True:
|
||||
match = pattern.match(temp_content)
|
||||
@@ -1830,7 +1828,7 @@ class DefaultReplyer:
|
||||
temp_content = new_content
|
||||
else:
|
||||
break # 没有匹配到,退出循环
|
||||
|
||||
|
||||
# 在循环处理后,再使用 rsplit 来处理日志中观察到的特殊情况
|
||||
# 这可以作为处理复杂嵌套的最后一道防线
|
||||
final_split = temp_content.rsplit("],说:", 1)
|
||||
@@ -1838,7 +1836,7 @@ class DefaultReplyer:
|
||||
final_content = final_split[1].strip()
|
||||
else:
|
||||
final_content = temp_content
|
||||
|
||||
|
||||
if final_content != content:
|
||||
logger.debug(f"清理了模型生成的多余内容,原始内容: '{content}', 清理后: '{final_content}'")
|
||||
content = final_content
|
||||
@@ -2077,24 +2075,24 @@ class DefaultReplyer:
|
||||
|
||||
# 从内存获取聊天历史用于存储,避免重复查询数据库
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream_obj = await chat_manager.get_stream(stream.stream_id)
|
||||
|
||||
|
||||
if chat_stream_obj:
|
||||
# 确保历史消息已初始化
|
||||
await chat_stream_obj.context_manager.ensure_history_initialized()
|
||||
|
||||
|
||||
# 获取所有消息(历史+未读)
|
||||
all_messages = (
|
||||
chat_stream_obj.context_manager.context.history_messages +
|
||||
chat_stream_obj.context_manager.get_unread_messages()
|
||||
)
|
||||
|
||||
|
||||
# 转换为字典格式,限制数量
|
||||
limit = int(global_config.chat.max_context_size * 0.33)
|
||||
message_list_before_short = [msg.flatten() for msg in all_messages[-limit:]]
|
||||
|
||||
|
||||
logger.debug(f"记忆存储使用内存中的 {len(message_list_before_short)} 条消息")
|
||||
else:
|
||||
# 回退到数据库查询
|
||||
|
||||
@@ -1112,14 +1112,14 @@ class Prompt:
|
||||
# 使用关系提取器构建用户关系信息和聊天流印象
|
||||
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
|
||||
|
||||
|
||||
# 组合两部分信息
|
||||
info_parts = []
|
||||
if user_relation_info:
|
||||
info_parts.append(user_relation_info)
|
||||
if stream_impression:
|
||||
info_parts.append(stream_impression)
|
||||
|
||||
|
||||
return "\n\n".join(info_parts) if info_parts else ""
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
|
||||
|
||||
@@ -11,6 +11,7 @@ import rjieba
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
@@ -49,13 +50,13 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
|
||||
|
||||
Returns:
|
||||
tuple[bool, float]: (是否提及, 提及概率)
|
||||
"""
|
||||
"""
|
||||
keywords = [global_config.bot.nickname]
|
||||
nicknames = global_config.bot.alias_names
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
|
||||
|
||||
# 检查 is_mentioned 属性
|
||||
mentioned_attr = getattr(message, "is_mentioned", None)
|
||||
if mentioned_attr is not None:
|
||||
@@ -63,7 +64,7 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
|
||||
return bool(mentioned_attr), float(mentioned_attr)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
# 检查 additional_config
|
||||
additional_config = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user