refactor(core): 优化类型提示与代码风格
本次提交对项目代码进行了广泛的重构,主要集中在以下几个方面:
1. **类型提示现代化**:
- 将 `typing` 模块中的 `Optional[T]`、`List[T]`、`Dict[K, V]` 等旧式类型提示更新为现代的 `T | None`、`list[T]`、`dict[K, V]` 语法。
- 这提高了代码的可读性,并与较新 Python 版本的风格保持一致。
2. **代码风格统一**:
- 移除了多余的空行和不必要的空格,使代码更加紧凑和规范。
- 统一了部分日志输出的格式,增强了日志的可读性。
3. **导入语句优化**:
- 调整了部分模块的 `import` 语句顺序,使其符合 PEP 8 规范。
这些更改不涉及任何功能性变动,旨在提升代码库的整体质量、可维护性和开发体验。
This commit is contained in:
committed by
Windpicker-owo
parent
6026682a03
commit
2ee6aa3951
@@ -9,7 +9,8 @@ from pathlib import Path
|
|||||||
project_root = Path(__file__).parent.parent
|
project_root = Path(__file__).parent.parent
|
||||||
sys.path.insert(0, str(project_root))
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.database.sqlalchemy_models import Expression
|
from src.common.database.sqlalchemy_models import Expression
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ project_root = Path(__file__).parent.parent
|
|||||||
sys.path.insert(0, str(project_root))
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.database.sqlalchemy_models import Expression
|
from src.common.database.sqlalchemy_models import Expression
|
||||||
|
|
||||||
@@ -47,9 +48,9 @@ async def analyze_style_fields():
|
|||||||
print(f" 长度: {ex['length']} 字符")
|
print(f" 长度: {ex['length']} 字符")
|
||||||
|
|
||||||
# 判断是具体表达还是风格描述
|
# 判断是具体表达还是风格描述
|
||||||
if ex['length'] <= 20 and any(word in ex['style'] for word in ['简洁', '短句', '陈述', '疑问', '感叹', '省略', '完整']):
|
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
|
||||||
style_type = "✓ 风格描述"
|
style_type = "✓ 风格描述"
|
||||||
elif ex['length'] <= 10:
|
elif ex["length"] <= 10:
|
||||||
style_type = "? 可能是具体表达(较短)"
|
style_type = "? 可能是具体表达(较短)"
|
||||||
else:
|
else:
|
||||||
style_type = "✗ 具体表达内容"
|
style_type = "✗ 具体表达内容"
|
||||||
|
|||||||
@@ -25,19 +25,19 @@ def check_style_learner_status(chat_id: str):
|
|||||||
learner = style_learner_manager.get_learner(chat_id)
|
learner = style_learner_manager.get_learner(chat_id)
|
||||||
|
|
||||||
# 1. 基本信息
|
# 1. 基本信息
|
||||||
print(f"\n📊 基本信息:")
|
print("\n📊 基本信息:")
|
||||||
print(f" Chat ID: {learner.chat_id}")
|
print(f" Chat ID: {learner.chat_id}")
|
||||||
print(f" 风格数量: {len(learner.style_to_id)}")
|
print(f" 风格数量: {len(learner.style_to_id)}")
|
||||||
print(f" 下一个ID: {learner.next_style_id}")
|
print(f" 下一个ID: {learner.next_style_id}")
|
||||||
print(f" 最大风格数: {learner.max_styles}")
|
print(f" 最大风格数: {learner.max_styles}")
|
||||||
|
|
||||||
# 2. 学习统计
|
# 2. 学习统计
|
||||||
print(f"\n📈 学习统计:")
|
print("\n📈 学习统计:")
|
||||||
print(f" 总样本数: {learner.learning_stats['total_samples']}")
|
print(f" 总样本数: {learner.learning_stats['total_samples']}")
|
||||||
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
|
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
|
||||||
|
|
||||||
# 3. 风格列表(前20个)
|
# 3. 风格列表(前20个)
|
||||||
print(f"\n📋 已学习的风格 (前20个):")
|
print("\n📋 已学习的风格 (前20个):")
|
||||||
all_styles = learner.get_all_styles()
|
all_styles = learner.get_all_styles()
|
||||||
if not all_styles:
|
if not all_styles:
|
||||||
print(" ⚠️ 没有任何风格!模型尚未训练")
|
print(" ⚠️ 没有任何风格!模型尚未训练")
|
||||||
@@ -49,7 +49,7 @@ def check_style_learner_status(chat_id: str):
|
|||||||
print(f" (ID: {style_id}, Situation: {situation})")
|
print(f" (ID: {style_id}, Situation: {situation})")
|
||||||
|
|
||||||
# 4. 测试预测
|
# 4. 测试预测
|
||||||
print(f"\n🔮 测试预测功能:")
|
print("\n🔮 测试预测功能:")
|
||||||
if not all_styles:
|
if not all_styles:
|
||||||
print(" ⚠️ 无法测试,模型没有训练数据")
|
print(" ⚠️ 无法测试,模型没有训练数据")
|
||||||
else:
|
else:
|
||||||
@@ -65,11 +65,11 @@ def check_style_learner_status(chat_id: str):
|
|||||||
|
|
||||||
if best_style:
|
if best_style:
|
||||||
print(f" ✓ 最佳匹配: {best_style}")
|
print(f" ✓ 最佳匹配: {best_style}")
|
||||||
print(f" Top 3:")
|
print(" Top 3:")
|
||||||
for style, score in list(scores.items())[:3]:
|
for style, score in list(scores.items())[:3]:
|
||||||
print(f" - {style}: {score:.4f}")
|
print(f" - {style}: {score:.4f}")
|
||||||
else:
|
else:
|
||||||
print(f" ✗ 预测失败")
|
print(" ✗ 预测失败")
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("诊断完成")
|
print("诊断完成")
|
||||||
|
|||||||
@@ -201,9 +201,10 @@ class RelationshipEnergyCalculator(EnergyCalculator):
|
|||||||
|
|
||||||
# 从数据库获取聊天流兴趣分数
|
# 从数据库获取聊天流兴趣分数
|
||||||
try:
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.database.sqlalchemy_models import ChatStreams
|
from src.common.database.sqlalchemy_models import ChatStreams
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
stmt = select(ChatStreams).where(ChatStreams.stream_id == stream_id)
|
||||||
|
|||||||
@@ -5,14 +5,14 @@
|
|||||||
import difflib
|
import difflib
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("express_utils")
|
logger = get_logger("express_utils")
|
||||||
|
|
||||||
|
|
||||||
def filter_message_content(content: Optional[str]) -> str:
|
def filter_message_content(content: str | None) -> str:
|
||||||
"""
|
"""
|
||||||
过滤消息内容,移除回复、@、图片等格式
|
过滤消息内容,移除回复、@、图片等格式
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
|||||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample(population: List[Dict], k: int, weight_key: Optional[str] = None) -> List[Dict]:
|
def weighted_sample(population: list[dict], k: int, weight_key: str | None = None) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
加权随机抽样函数
|
加权随机抽样函数
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ def normalize_text(text: str) -> str:
|
|||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
|
def extract_keywords(text: str, max_keywords: int = 10) -> list[str]:
|
||||||
"""
|
"""
|
||||||
简单的关键词提取(基于词频)
|
简单的关键词提取(基于词频)
|
||||||
|
|
||||||
@@ -135,7 +135,7 @@ def extract_keywords(text: str, max_keywords: int = 10) -> List[str]:
|
|||||||
return words[:max_keywords]
|
return words[:max_keywords]
|
||||||
|
|
||||||
|
|
||||||
def format_expression_pair(situation: str, style: str, index: Optional[int] = None) -> str:
|
def format_expression_pair(situation: str, style: str, index: int | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
格式化表达方式对
|
格式化表达方式对
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ def format_expression_pair(situation: str, style: str, index: Optional[int] = No
|
|||||||
return f'当"{situation}"时,使用"{style}"'
|
return f'当"{situation}"时,使用"{style}"'
|
||||||
|
|
||||||
|
|
||||||
def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
|
def parse_expression_pair(text: str) -> tuple[str, str] | None:
|
||||||
"""
|
"""
|
||||||
解析表达方式对文本
|
解析表达方式对文本
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ def parse_expression_pair(text: str) -> Optional[tuple[str, str]]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def batch_filter_duplicates(expressions: List[Dict[str, Any]], key_fields: List[str]) -> List[Dict[str, Any]]:
|
def batch_filter_duplicates(expressions: list[dict[str, Any]], key_fields: list[str]) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
批量去重表达方式
|
批量去重表达方式
|
||||||
|
|
||||||
@@ -219,8 +219,8 @@ def calculate_time_weight(last_active_time: float, current_time: float, half_lif
|
|||||||
|
|
||||||
|
|
||||||
def merge_expressions_from_multiple_chats(
|
def merge_expressions_from_multiple_chats(
|
||||||
expressions_dict: Dict[str, List[Dict[str, Any]]], max_total: int = 100
|
expressions_dict: dict[str, list[dict[str, Any]]], max_total: int = 100
|
||||||
) -> List[Dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
合并多个聊天室的表达方式
|
合并多个聊天室的表达方式
|
||||||
|
|
||||||
|
|||||||
@@ -541,13 +541,13 @@ class ExpressionLearner:
|
|||||||
idx_when = line_normalized.find('当"')
|
idx_when = line_normalized.find('当"')
|
||||||
if idx_when == -1:
|
if idx_when == -1:
|
||||||
# 尝试不带引号的格式: 当xxx时
|
# 尝试不带引号的格式: 当xxx时
|
||||||
idx_when = line_normalized.find('当')
|
idx_when = line_normalized.find("当")
|
||||||
if idx_when == -1:
|
if idx_when == -1:
|
||||||
failed_lines.append((line_num, line, "找不到'当'关键字"))
|
failed_lines.append((line_num, line, "找不到'当'关键字"))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 提取"当"和"时"之间的内容
|
# 提取"当"和"时"之间的内容
|
||||||
idx_shi = line_normalized.find('时', idx_when)
|
idx_shi = line_normalized.find("时", idx_when)
|
||||||
if idx_shi == -1:
|
if idx_shi == -1:
|
||||||
failed_lines.append((line_num, line, "找不到'时'关键字"))
|
failed_lines.append((line_num, line, "找不到'时'关键字"))
|
||||||
continue
|
continue
|
||||||
@@ -568,9 +568,9 @@ class ExpressionLearner:
|
|||||||
idx_use = line_normalized.find('可以"', search_start)
|
idx_use = line_normalized.find('可以"', search_start)
|
||||||
if idx_use == -1:
|
if idx_use == -1:
|
||||||
# 尝试不带引号的格式
|
# 尝试不带引号的格式
|
||||||
idx_use = line_normalized.find('使用', search_start)
|
idx_use = line_normalized.find("使用", search_start)
|
||||||
if idx_use == -1:
|
if idx_use == -1:
|
||||||
idx_use = line_normalized.find('可以', search_start)
|
idx_use = line_normalized.find("可以", search_start)
|
||||||
if idx_use == -1:
|
if idx_use == -1:
|
||||||
failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字"))
|
failed_lines.append((line_num, line, "找不到'使用'或'可以'关键字"))
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ class ExpressionSelector:
|
|||||||
min_num: int = 5,
|
min_num: int = 5,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""经典模式:随机抽样 + LLM评估"""
|
"""经典模式:随机抽样 + LLM评估"""
|
||||||
logger.debug(f"[Classic模式] 使用LLM评估表达方式")
|
logger.debug("[Classic模式] 使用LLM评估表达方式")
|
||||||
return await self.select_suitable_expressions_llm(
|
return await self.select_suitable_expressions_llm(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
chat_info=chat_info,
|
chat_info=chat_info,
|
||||||
@@ -316,7 +316,7 @@ class ExpressionSelector:
|
|||||||
min_num: int = 5,
|
min_num: int = 5,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""模型预测模式:先提取情境,再使用StyleLearner预测表达风格"""
|
"""模型预测模式:先提取情境,再使用StyleLearner预测表达风格"""
|
||||||
logger.debug(f"[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
logger.debug("[Exp_model模式] 使用情境提取 + StyleLearner预测表达方式")
|
||||||
|
|
||||||
# 检查是否允许在此聊天流中使用表达
|
# 检查是否允许在此聊天流中使用表达
|
||||||
if not self.can_use_expression_for_chat(chat_id):
|
if not self.can_use_expression_for_chat(chat_id):
|
||||||
@@ -331,7 +331,7 @@ class ExpressionSelector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not situations:
|
if not situations:
|
||||||
logger.warning(f"无法提取聊天情境,回退到经典模式")
|
logger.warning("无法提取聊天情境,回退到经典模式")
|
||||||
return await self._select_expressions_classic(
|
return await self._select_expressions_classic(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
chat_info=chat_info,
|
chat_info=chat_info,
|
||||||
@@ -357,10 +357,10 @@ class ExpressionSelector:
|
|||||||
if style not in all_predicted_styles or score > all_predicted_styles[style]:
|
if style not in all_predicted_styles or score > all_predicted_styles[style]:
|
||||||
all_predicted_styles[style] = score
|
all_predicted_styles[style] = score
|
||||||
else:
|
else:
|
||||||
logger.debug(f" 该情境未返回预测结果")
|
logger.debug(" 该情境未返回预测结果")
|
||||||
|
|
||||||
if not all_predicted_styles:
|
if not all_predicted_styles:
|
||||||
logger.warning(f"[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
logger.warning("[Exp_model模式] StyleLearner未返回预测结果(可能模型未训练),回退到经典模式")
|
||||||
return await self._select_expressions_classic(
|
return await self._select_expressions_classic(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
chat_info=chat_info,
|
chat_info=chat_info,
|
||||||
@@ -375,7 +375,7 @@ class ExpressionSelector:
|
|||||||
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
|
logger.info(f"[Exp_model模式] 步骤2完成 - 预测到 {len(predicted_styles)} 个风格, Top3: {predicted_styles[:3]}")
|
||||||
|
|
||||||
# 步骤3: 根据预测的风格从数据库获取表达方式
|
# 步骤3: 根据预测的风格从数据库获取表达方式
|
||||||
logger.debug(f"[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
logger.debug("[Exp_model模式] 步骤3 - 从数据库查询表达方式")
|
||||||
expressions = await self.get_model_predicted_expressions(
|
expressions = await self.get_model_predicted_expressions(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
predicted_styles=predicted_styles,
|
predicted_styles=predicted_styles,
|
||||||
@@ -383,7 +383,7 @@ class ExpressionSelector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not expressions:
|
if not expressions:
|
||||||
logger.warning(f"[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
logger.warning("[Exp_model模式] 未找到匹配预测风格的表达方式,回退到经典模式")
|
||||||
return await self._select_expressions_classic(
|
return await self._select_expressions_classic(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
chat_info=chat_info,
|
chat_info=chat_info,
|
||||||
@@ -445,7 +445,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
|
# 🔥 智能回退:如果相关 chat_id 没有数据,尝试查询所有 chat_id
|
||||||
if not all_expressions:
|
if not all_expressions:
|
||||||
logger.info(f"相关chat_id没有数据,尝试从所有chat_id查询")
|
logger.info("相关chat_id没有数据,尝试从所有chat_id查询")
|
||||||
all_expressions_result = await session.execute(
|
all_expressions_result = await session.execute(
|
||||||
select(Expression)
|
select(Expression)
|
||||||
.where(Expression.type == "style")
|
.where(Expression.type == "style")
|
||||||
@@ -454,7 +454,7 @@ class ExpressionSelector:
|
|||||||
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
|
logger.debug(f"数据库中所有表达方式数量: {len(all_expressions)}")
|
||||||
|
|
||||||
if not all_expressions:
|
if not all_expressions:
|
||||||
logger.warning(f"数据库中完全没有任何表达方式,需要先学习")
|
logger.warning("数据库中完全没有任何表达方式,需要先学习")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 🔥 使用模糊匹配而不是精确匹配
|
# 🔥 使用模糊匹配而不是精确匹配
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -36,14 +35,14 @@ class ExpressorModel:
|
|||||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||||
|
|
||||||
# 候选表达管理
|
# 候选表达管理
|
||||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
self._candidates: dict[str, str] = {} # cid -> text (style)
|
||||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
self._situations: dict[str, str] = {} # cid -> situation (不参与计算)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
|
f"ExpressorModel初始化完成 (alpha={alpha}, beta={beta}, gamma={gamma}, vocab_size={vocab_size}, use_jieba={use_jieba})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_candidate(self, cid: str, text: str, situation: Optional[str] = None):
|
def add_candidate(self, cid: str, text: str, situation: str | None = None):
|
||||||
"""
|
"""
|
||||||
添加候选文本和对应的situation
|
添加候选文本和对应的situation
|
||||||
|
|
||||||
@@ -62,7 +61,7 @@ class ExpressorModel:
|
|||||||
if cid not in self.nb.token_counts:
|
if cid not in self.nb.token_counts:
|
||||||
self.nb.token_counts[cid] = defaultdict(float)
|
self.nb.token_counts[cid] = defaultdict(float)
|
||||||
|
|
||||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
直接对所有候选进行朴素贝叶斯评分
|
直接对所有候选进行朴素贝叶斯评分
|
||||||
|
|
||||||
@@ -113,7 +112,7 @@ class ExpressorModel:
|
|||||||
tf = Counter(toks)
|
tf = Counter(toks)
|
||||||
self.nb.update_positive(tf, cid)
|
self.nb.update_positive(tf, cid)
|
||||||
|
|
||||||
def decay(self, factor: Optional[float] = None):
|
def decay(self, factor: float | None = None):
|
||||||
"""
|
"""
|
||||||
应用知识衰减
|
应用知识衰减
|
||||||
|
|
||||||
@@ -122,7 +121,7 @@ class ExpressorModel:
|
|||||||
"""
|
"""
|
||||||
self.nb.decay(factor)
|
self.nb.decay(factor)
|
||||||
|
|
||||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
def get_candidate_info(self, cid: str) -> tuple[str | None, str | None]:
|
||||||
"""
|
"""
|
||||||
获取候选信息
|
获取候选信息
|
||||||
|
|
||||||
@@ -136,7 +135,7 @@ class ExpressorModel:
|
|||||||
situation = self._situations.get(cid)
|
situation = self._situations.get(cid)
|
||||||
return style, situation
|
return style, situation
|
||||||
|
|
||||||
def get_all_candidates(self) -> Dict[str, Tuple[str, str]]:
|
def get_all_candidates(self) -> dict[str, tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
获取所有候选
|
获取所有候选
|
||||||
|
|
||||||
@@ -205,7 +204,7 @@ class ExpressorModel:
|
|||||||
|
|
||||||
logger.info(f"模型已从 {path} 加载")
|
logger.info(f"模型已从 {path} 加载")
|
||||||
|
|
||||||
def get_stats(self) -> Dict:
|
def get_stats(self) -> dict:
|
||||||
"""获取模型统计信息"""
|
"""获取模型统计信息"""
|
||||||
nb_stats = self.nb.get_stats()
|
nb_stats = self.nb.get_stats()
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -28,15 +27,15 @@ class OnlineNaiveBayes:
|
|||||||
self.V = vocab_size
|
self.V = vocab_size
|
||||||
|
|
||||||
# 类别统计
|
# 类别统计
|
||||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
self.cls_counts: dict[str, float] = defaultdict(float) # cid -> total token count
|
||||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(
|
self.token_counts: dict[str, dict[str, float]] = defaultdict(
|
||||||
lambda: defaultdict(float)
|
lambda: defaultdict(float)
|
||||||
) # cid -> term -> count
|
) # cid -> term -> count
|
||||||
|
|
||||||
# 缓存
|
# 缓存
|
||||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
self._logZ: dict[str, float] = {} # cache log(∑counts + Vα)
|
||||||
|
|
||||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
def score_batch(self, tf: Counter, cids: list[str]) -> dict[str, float]:
|
||||||
"""
|
"""
|
||||||
批量计算候选的贝叶斯分数
|
批量计算候选的贝叶斯分数
|
||||||
|
|
||||||
@@ -51,7 +50,7 @@ class OnlineNaiveBayes:
|
|||||||
n_cls = max(1, len(self.cls_counts))
|
n_cls = max(1, len(self.cls_counts))
|
||||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||||
|
|
||||||
out: Dict[str, float] = {}
|
out: dict[str, float] = {}
|
||||||
for cid in cids:
|
for cid in cids:
|
||||||
# 计算先验概率 log P(c)
|
# 计算先验概率 log P(c)
|
||||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||||
@@ -88,7 +87,7 @@ class OnlineNaiveBayes:
|
|||||||
self.cls_counts[cid] += inc
|
self.cls_counts[cid] += inc
|
||||||
self._invalidate(cid)
|
self._invalidate(cid)
|
||||||
|
|
||||||
def decay(self, factor: Optional[float] = None):
|
def decay(self, factor: float | None = None):
|
||||||
"""
|
"""
|
||||||
知识衰减(遗忘机制)
|
知识衰减(遗忘机制)
|
||||||
|
|
||||||
@@ -133,7 +132,7 @@ class OnlineNaiveBayes:
|
|||||||
if cid in self._logZ:
|
if cid in self._logZ:
|
||||||
del self._logZ[cid]
|
del self._logZ[cid]
|
||||||
|
|
||||||
def get_stats(self) -> Dict:
|
def get_stats(self) -> dict:
|
||||||
"""获取统计信息"""
|
"""获取统计信息"""
|
||||||
return {
|
return {
|
||||||
"n_classes": len(self.cls_counts),
|
"n_classes": len(self.cls_counts),
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
文本分词器,支持中文Jieba分词
|
文本分词器,支持中文Jieba分词
|
||||||
"""
|
"""
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -30,7 +29,7 @@ class Tokenizer:
|
|||||||
logger.warning("Jieba未安装,将使用字符级分词")
|
logger.warning("Jieba未安装,将使用字符级分词")
|
||||||
self.use_jieba = False
|
self.use_jieba = False
|
||||||
|
|
||||||
def tokenize(self, text: str) -> List[str]:
|
def tokenize(self, text: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
分词并返回token列表
|
分词并返回token列表
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
情境提取器
|
情境提取器
|
||||||
从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测
|
从聊天历史中提取当前的情境(situation),用于 StyleLearner 预测
|
||||||
"""
|
"""
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -51,7 +50,7 @@ class SituationExtractor:
|
|||||||
async def extract_situations(
|
async def extract_situations(
|
||||||
self,
|
self,
|
||||||
chat_history: list | str,
|
chat_history: list | str,
|
||||||
target_message: Optional[str] = None,
|
target_message: str | None = None,
|
||||||
max_situations: int = 3
|
max_situations: int = 3
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -144,7 +143,7 @@ class SituationExtractor:
|
|||||||
continue
|
continue
|
||||||
if len(line) < 2: # 太短
|
if len(line) < 2: # 太短
|
||||||
continue
|
continue
|
||||||
if any(keyword in line.lower() for keyword in ['例如', '注意', '请', '分析', '总结']):
|
if any(keyword in line.lower() for keyword in ["例如", "注意", "请", "分析", "总结"]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
situations.append(line)
|
situations.append(line)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -17,7 +16,7 @@ logger = get_logger("expressor.style_learner")
|
|||||||
class StyleLearner:
|
class StyleLearner:
|
||||||
"""单个聊天室的表达风格学习器"""
|
"""单个聊天室的表达风格学习器"""
|
||||||
|
|
||||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
def __init__(self, chat_id: str, model_config: dict | None = None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
chat_id: 聊天室ID
|
chat_id: 聊天室ID
|
||||||
@@ -37,9 +36,9 @@ class StyleLearner:
|
|||||||
|
|
||||||
# 动态风格管理
|
# 动态风格管理
|
||||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
self.style_to_id: dict[str, str] = {} # style文本 -> style_id
|
||||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
self.id_to_style: dict[str, str] = {} # style_id -> style文本
|
||||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
self.id_to_situation: dict[str, str] = {} # style_id -> situation文本
|
||||||
self.next_style_id = 0
|
self.next_style_id = 0
|
||||||
|
|
||||||
# 学习统计
|
# 学习统计
|
||||||
@@ -51,7 +50,7 @@ class StyleLearner:
|
|||||||
|
|
||||||
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
|
logger.info(f"StyleLearner初始化成功: chat_id={chat_id}")
|
||||||
|
|
||||||
def add_style(self, style: str, situation: Optional[str] = None) -> bool:
|
def add_style(self, style: str, situation: str | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
动态添加一个新的风格
|
动态添加一个新的风格
|
||||||
|
|
||||||
@@ -130,7 +129,7 @@ class StyleLearner:
|
|||||||
logger.error(f"学习映射失败: {e}")
|
logger.error(f"学习映射失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
def predict_style(self, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
根据up_content预测最合适的style
|
根据up_content预测最合适的style
|
||||||
|
|
||||||
@@ -183,7 +182,7 @@ class StyleLearner:
|
|||||||
logger.error(f"预测style失败: {e}", exc_info=True)
|
logger.error(f"预测style失败: {e}", exc_info=True)
|
||||||
return None, {}
|
return None, {}
|
||||||
|
|
||||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
def get_style_info(self, style: str) -> tuple[str | None, str | None]:
|
||||||
"""
|
"""
|
||||||
获取style的完整信息
|
获取style的完整信息
|
||||||
|
|
||||||
@@ -200,7 +199,7 @@ class StyleLearner:
|
|||||||
situation = self.id_to_situation.get(style_id)
|
situation = self.id_to_situation.get(style_id)
|
||||||
return style_id, situation
|
return style_id, situation
|
||||||
|
|
||||||
def get_all_styles(self) -> List[str]:
|
def get_all_styles(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
获取所有风格列表
|
获取所有风格列表
|
||||||
|
|
||||||
@@ -209,7 +208,7 @@ class StyleLearner:
|
|||||||
"""
|
"""
|
||||||
return list(self.style_to_id.keys())
|
return list(self.style_to_id.keys())
|
||||||
|
|
||||||
def apply_decay(self, factor: Optional[float] = None):
|
def apply_decay(self, factor: float | None = None):
|
||||||
"""
|
"""
|
||||||
应用知识衰减
|
应用知识衰减
|
||||||
|
|
||||||
@@ -304,7 +303,7 @@ class StyleLearner:
|
|||||||
logger.error(f"加载StyleLearner失败: {e}")
|
logger.error(f"加载StyleLearner失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_stats(self) -> Dict:
|
def get_stats(self) -> dict:
|
||||||
"""获取统计信息"""
|
"""获取统计信息"""
|
||||||
model_stats = self.expressor.get_stats()
|
model_stats = self.expressor.get_stats()
|
||||||
return {
|
return {
|
||||||
@@ -324,7 +323,7 @@ class StyleLearnerManager:
|
|||||||
Args:
|
Args:
|
||||||
model_save_path: 模型保存路径
|
model_save_path: 模型保存路径
|
||||||
"""
|
"""
|
||||||
self.learners: Dict[str, StyleLearner] = {}
|
self.learners: dict[str, StyleLearner] = {}
|
||||||
self.model_save_path = model_save_path
|
self.model_save_path = model_save_path
|
||||||
|
|
||||||
# 确保保存目录存在
|
# 确保保存目录存在
|
||||||
@@ -332,7 +331,7 @@ class StyleLearnerManager:
|
|||||||
|
|
||||||
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
|
logger.info(f"StyleLearnerManager初始化成功, 模型保存路径: {model_save_path}")
|
||||||
|
|
||||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
def get_learner(self, chat_id: str, model_config: dict | None = None) -> StyleLearner:
|
||||||
"""
|
"""
|
||||||
获取或创建指定chat_id的学习器
|
获取或创建指定chat_id的学习器
|
||||||
|
|
||||||
@@ -369,7 +368,7 @@ class StyleLearnerManager:
|
|||||||
learner = self.get_learner(chat_id)
|
learner = self.get_learner(chat_id)
|
||||||
return learner.learn_mapping(up_content, style)
|
return learner.learn_mapping(up_content, style)
|
||||||
|
|
||||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> tuple[str | None, dict[str, float]]:
|
||||||
"""
|
"""
|
||||||
预测最合适的风格
|
预测最合适的风格
|
||||||
|
|
||||||
@@ -399,7 +398,7 @@ class StyleLearnerManager:
|
|||||||
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
|
logger.info(f"保存所有StyleLearner {'成功' if success else '部分失败'}")
|
||||||
return success
|
return success
|
||||||
|
|
||||||
def apply_decay_all(self, factor: Optional[float] = None):
|
def apply_decay_all(self, factor: float | None = None):
|
||||||
"""
|
"""
|
||||||
对所有学习器应用知识衰减
|
对所有学习器应用知识衰减
|
||||||
|
|
||||||
@@ -409,9 +408,9 @@ class StyleLearnerManager:
|
|||||||
for learner in self.learners.values():
|
for learner in self.learners.values():
|
||||||
learner.apply_decay(factor)
|
learner.apply_decay(factor)
|
||||||
|
|
||||||
logger.info(f"对所有StyleLearner应用知识衰减")
|
logger.info("对所有StyleLearner应用知识衰减")
|
||||||
|
|
||||||
def get_all_stats(self) -> Dict[str, Dict]:
|
def get_all_stats(self) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
获取所有学习器的统计信息
|
获取所有学习器的统计信息
|
||||||
|
|
||||||
|
|||||||
@@ -503,7 +503,7 @@ class MemorySystem:
|
|||||||
existing_id = self._memory_fingerprints.get(fingerprint_key)
|
existing_id = self._memory_fingerprints.get(fingerprint_key)
|
||||||
if existing_id and existing_id not in new_memory_ids:
|
if existing_id and existing_id not in new_memory_ids:
|
||||||
candidate_ids.add(existing_id)
|
candidate_ids.add(existing_id)
|
||||||
except Exception as exc: # noqa: PERF203
|
except Exception as exc:
|
||||||
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
|
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
|
||||||
|
|
||||||
# 基于主体索引的候选(使用统一存储)
|
# 基于主体索引的候选(使用统一存储)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from src.chat.antipromptinjector import initialize_anti_injector
|
|||||||
from src.chat.message_manager import message_manager
|
from src.chat.message_manager import message_manager
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
|
from src.chat.utils.prompt import global_prompt_manager
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -10,13 +8,12 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
|
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config # 新增导入
|
from src.config.config import global_config # 新增导入
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
@@ -360,7 +357,7 @@ class ChatManager:
|
|||||||
def register_message(self, message: DatabaseMessages):
|
def register_message(self, message: DatabaseMessages):
|
||||||
"""注册消息到聊天流"""
|
"""注册消息到聊天流"""
|
||||||
# 从 DatabaseMessages 提取平台和用户/群组信息
|
# 从 DatabaseMessages 提取平台和用户/群组信息
|
||||||
from maim_message import UserInfo, GroupInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
user_info = UserInfo(
|
user_info = UserInfo(
|
||||||
platform=message.user_info.platform,
|
platform=message.user_info.platform,
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import base64
|
|
||||||
import time
|
import time
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||||
@@ -11,7 +10,6 @@ from rich.traceback import install
|
|||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||||
from src.chat.utils.utils_image import get_image_manager
|
from src.chat.utils.utils_image import get_image_manager
|
||||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
|
||||||
from src.chat.utils.utils_voice import get_voice_text
|
from src.chat.utils.utils_voice import get_voice_text
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -266,8 +266,8 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
|
|||||||
|
|
||||||
elif segment.type == "file":
|
elif segment.type == "file":
|
||||||
if isinstance(segment.data, dict):
|
if isinstance(segment.data, dict):
|
||||||
file_name = segment.data.get('name', '未知文件')
|
file_name = segment.data.get("name", "未知文件")
|
||||||
file_size = segment.data.get('size', '未知大小')
|
file_size = segment.data.get("size", "未知大小")
|
||||||
return f"[文件:{file_name} ({file_size}字节)]"
|
return f"[文件:{file_name} ({file_size}字节)]"
|
||||||
return "[收到一个文件]"
|
return "[收到一个文件]"
|
||||||
|
|
||||||
@@ -351,7 +351,7 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
|||||||
additional_config_data = {}
|
additional_config_data = {}
|
||||||
|
|
||||||
# 首先获取adapter传递的additional_config
|
# 首先获取adapter传递的additional_config
|
||||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
if hasattr(message_info, "additional_config") and message_info.additional_config:
|
||||||
if isinstance(message_info.additional_config, dict):
|
if isinstance(message_info.additional_config, dict):
|
||||||
additional_config_data = message_info.additional_config.copy()
|
additional_config_data = message_info.additional_config.copy()
|
||||||
elif isinstance(message_info.additional_config, str):
|
elif isinstance(message_info.additional_config, str):
|
||||||
@@ -368,7 +368,7 @@ def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, i
|
|||||||
additional_config_data["is_public_notice"] = bool(is_public_notice)
|
additional_config_data["is_public_notice"] = bool(is_public_notice)
|
||||||
|
|
||||||
# 添加format_info到additional_config中
|
# 添加format_info到additional_config中
|
||||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
if hasattr(message_info, "format_info") and message_info.format_info:
|
||||||
try:
|
try:
|
||||||
format_info_dict = message_info.format_info.to_dict()
|
format_info_dict = message_info.format_info.to_dict()
|
||||||
additional_config_data["format_info"] = format_info_dict
|
additional_config_data["format_info"] = format_info_dict
|
||||||
@@ -423,7 +423,7 @@ def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessag
|
|||||||
Returns:
|
Returns:
|
||||||
BaseMessageInfo: 重建的消息信息对象
|
BaseMessageInfo: 重建的消息信息对象
|
||||||
"""
|
"""
|
||||||
from maim_message import UserInfo, GroupInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
|
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
|
||||||
user_info = UserInfo(
|
user_info = UserInfo(
|
||||||
|
|||||||
@@ -5,12 +5,11 @@ import traceback
|
|||||||
import orjson
|
import orjson
|
||||||
from sqlalchemy import desc, select, update
|
from sqlalchemy import desc, select, update
|
||||||
|
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.database.sqlalchemy_models import Images, Messages
|
from src.common.database.sqlalchemy_models import Images, Messages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
|
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from .message import MessageSending
|
from .message import MessageSending
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
|||||||
|
|
||||||
# 触发 AFTER_SEND 事件
|
# 触发 AFTER_SEND 事件
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
|
||||||
from src.plugin_system.base.component_types import EventType
|
from src.plugin_system.base.component_types import EventType
|
||||||
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
|
|
||||||
if message.chat_stream:
|
if message.chat_stream:
|
||||||
logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件,stream_id={message.chat_stream.stream_id}")
|
logger.info(f"[发送完成] 准备触发 AFTER_SEND 事件,stream_id={message.chat_stream.stream_id}")
|
||||||
@@ -35,20 +35,20 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
|||||||
# 使用 asyncio.create_task 来异步触发事件,避免阻塞
|
# 使用 asyncio.create_task 来异步触发事件,避免阻塞
|
||||||
async def trigger_event_async():
|
async def trigger_event_async():
|
||||||
try:
|
try:
|
||||||
logger.info(f"[事件触发] 开始异步触发 AFTER_SEND 事件")
|
logger.info("[事件触发] 开始异步触发 AFTER_SEND 事件")
|
||||||
await event_manager.trigger_event(
|
await event_manager.trigger_event(
|
||||||
EventType.AFTER_SEND,
|
EventType.AFTER_SEND,
|
||||||
permission_group="SYSTEM",
|
permission_group="SYSTEM",
|
||||||
stream_id=message.chat_stream.stream_id,
|
stream_id=message.chat_stream.stream_id,
|
||||||
message=message,
|
message=message,
|
||||||
)
|
)
|
||||||
logger.info(f"[事件触发] AFTER_SEND 事件触发完成")
|
logger.info("[事件触发] AFTER_SEND 事件触发完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
|
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
|
||||||
|
|
||||||
# 创建异步任务,不等待完成
|
# 创建异步任务,不等待完成
|
||||||
asyncio.create_task(trigger_event_async())
|
asyncio.create_task(trigger_event_async())
|
||||||
logger.info(f"[发送完成] AFTER_SEND 事件已提交到异步任务")
|
logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务")
|
||||||
except Exception as event_error:
|
except Exception as event_error:
|
||||||
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)
|
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)
|
||||||
|
|
||||||
|
|||||||
@@ -32,8 +32,6 @@ from src.common.logger import get_logger
|
|||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.individuality.individuality import get_individuality
|
from src.individuality.individuality import get_individuality
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from src.plugin_system.apis import llm_api
|
from src.plugin_system.apis import llm_api
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import rjieba
|
|||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import count_messages, find_messages
|
from src.common.message_repository import count_messages, find_messages
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
@@ -64,7 +64,7 @@ class StreamContext(BaseDataModel):
|
|||||||
triggering_user_id: str | None = None # 触发当前聊天流的用户ID
|
triggering_user_id: str | None = None # 触发当前聊天流的用户ID
|
||||||
is_replying: bool = False # 是否正在生成回复
|
is_replying: bool = False # 是否正在生成回复
|
||||||
processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID,用于防止重复回复
|
processing_message_id: str | None = None # 当前正在规划/处理的目标消息ID,用于防止重复回复
|
||||||
decision_history: List["DecisionRecord"] = field(default_factory=list) # 决策历史
|
decision_history: list["DecisionRecord"] = field(default_factory=list) # 决策历史
|
||||||
|
|
||||||
def add_action_to_message(self, message_id: str, action: str):
|
def add_action_to_message(self, message_id: str, action: str):
|
||||||
"""
|
"""
|
||||||
@@ -260,7 +260,7 @@ class StreamContext(BaseDataModel):
|
|||||||
if requested_type not in accept_format:
|
if requested_type not in accept_format:
|
||||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||||
return False
|
return False
|
||||||
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
|
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 方法2: 检查content_format字段(向后兼容)
|
# 方法2: 检查content_format字段(向后兼容)
|
||||||
@@ -279,7 +279,7 @@ class StreamContext(BaseDataModel):
|
|||||||
if requested_type not in content_format:
|
if requested_type not in content_format:
|
||||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||||
return False
|
return False
|
||||||
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
|
logger.debug("[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")
|
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from src.config.official_configs import (
|
|||||||
EmojiConfig,
|
EmojiConfig,
|
||||||
ExperimentalConfig,
|
ExperimentalConfig,
|
||||||
ExpressionConfig,
|
ExpressionConfig,
|
||||||
ReactionConfig,
|
|
||||||
LPMMKnowledgeConfig,
|
LPMMKnowledgeConfig,
|
||||||
MaimMessageConfig,
|
MaimMessageConfig,
|
||||||
MemoryConfig,
|
MemoryConfig,
|
||||||
@@ -38,6 +37,7 @@ from src.config.official_configs import (
|
|||||||
PersonalityConfig,
|
PersonalityConfig,
|
||||||
PlanningSystemConfig,
|
PlanningSystemConfig,
|
||||||
ProactiveThinkingConfig,
|
ProactiveThinkingConfig,
|
||||||
|
ReactionConfig,
|
||||||
ResponsePostProcessConfig,
|
ResponsePostProcessConfig,
|
||||||
ResponseSplitterConfig,
|
ResponseSplitterConfig,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ async def file_to_stream(
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from maim_message import Seg, UserInfo
|
from maim_message import Seg, UserInfo
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class EventManager:
|
|||||||
self._events: dict[str, BaseEvent] = {}
|
self._events: dict[str, BaseEvent] = {}
|
||||||
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
|
self._event_handlers: dict[str, type[BaseEventHandler]] = {}
|
||||||
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
|
self._pending_subscriptions: dict[str, list[str]] = {} # 缓存失败的订阅
|
||||||
self._scheduler_callback: Optional[Any] = None # scheduler 回调函数
|
self._scheduler_callback: Any | None = None # scheduler 回调函数
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("EventManager 单例初始化完成")
|
logger.info("EventManager 单例初始化完成")
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -31,10 +30,34 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
name = "update_chat_stream_impression"
|
name = "update_chat_stream_impression"
|
||||||
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
|
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
|
||||||
parameters = [
|
parameters = [
|
||||||
("impression_description", ToolParamType.STRING, "你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)", False, None),
|
(
|
||||||
("chat_style", ToolParamType.STRING, "这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)", False, None),
|
"impression_description",
|
||||||
("topic_keywords", ToolParamType.STRING, "这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)", False, None),
|
ToolParamType.STRING,
|
||||||
("interest_score", ToolParamType.FLOAT, "你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)", False, None),
|
"你对这个聊天环境的整体感受和印象,例如'这是个技术氛围浓厚的群'、'大家都很友好热情'。当你通过聊天记录感受到环境特点时填写(可选)",
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"chat_style",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"这个聊天环境的风格特征,如'活跃热闹,互帮互助'、'严肃专业,深度讨论'、'轻松闲聊,段子频出'等。当你发现聊天方式有明显特点时填写(可选)",
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"topic_keywords",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"这个聊天环境中经常出现的话题,如'编程,AI,技术分享'或'游戏,动漫,娱乐'。当你观察到群里反复讨论某些主题时填写,多个关键词用逗号分隔(可选)",
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"interest_score",
|
||||||
|
ToolParamType.FLOAT,
|
||||||
|
"你对这个聊天环境的兴趣和喜欢程度,0.0(无聊/不喜欢)到1.0(很有趣/很喜欢)。当你对这个环境的感觉发生变化时更新(可选)",
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
available_for_llm = True
|
available_for_llm = True
|
||||||
history_ttl = 5
|
history_ttl = 5
|
||||||
@@ -46,12 +69,13 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
self.impression_llm = LLMRequest(
|
self.impression_llm = LLMRequest(
|
||||||
model_set=model_config.model_task_config.relationship_tracker,
|
model_set=model_config.model_task_config.relationship_tracker,
|
||||||
request_type="chat_stream_impression_update"
|
request_type="chat_stream_impression_update",
|
||||||
)
|
)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# 降级处理
|
# 降级处理
|
||||||
available_models = [
|
available_models = [
|
||||||
attr for attr in dir(model_config.model_task_config)
|
attr
|
||||||
|
for attr in dir(model_config.model_task_config)
|
||||||
if not attr.startswith("_") and attr != "model_dump"
|
if not attr.startswith("_") and attr != "model_dump"
|
||||||
]
|
]
|
||||||
if available_models:
|
if available_models:
|
||||||
@@ -59,7 +83,7 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}")
|
logger.warning(f"relationship_tracker配置不存在,使用降级模型: {fallback_model}")
|
||||||
self.impression_llm = LLMRequest(
|
self.impression_llm = LLMRequest(
|
||||||
model_set=getattr(model_config.model_task_config, fallback_model),
|
model_set=getattr(model_config.model_task_config, fallback_model),
|
||||||
request_type="chat_stream_impression_update"
|
request_type="chat_stream_impression_update",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error("无可用的模型配置")
|
logger.error("无可用的模型配置")
|
||||||
@@ -89,11 +113,7 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
# 如果还是没有,返回错误
|
# 如果还是没有,返回错误
|
||||||
if not stream_id:
|
if not stream_id:
|
||||||
logger.error("无法获取 stream_id:function_args 和 chat_stream 都没有提供")
|
logger.error("无法获取 stream_id:function_args 和 chat_stream 都没有提供")
|
||||||
return {
|
return {"type": "error", "id": "chat_stream_impression", "content": "错误:无法获取当前聊天流ID"}
|
||||||
"type": "error",
|
|
||||||
"id": "chat_stream_impression",
|
|
||||||
"content": "错误:无法获取当前聊天流ID"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 从LLM传入的参数
|
# 从LLM传入的参数
|
||||||
new_impression = function_args.get("impression_description", "")
|
new_impression = function_args.get("impression_description", "")
|
||||||
@@ -109,17 +129,13 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
return {
|
return {
|
||||||
"type": "info",
|
"type": "info",
|
||||||
"id": stream_id,
|
"id": stream_id,
|
||||||
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)"
|
"content": "提示:需要提供至少一项更新内容(印象描述、聊天风格、话题关键词或兴趣分数)",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 调用LLM进行二步决策
|
# 调用LLM进行二步决策
|
||||||
if self.impression_llm is None:
|
if self.impression_llm is None:
|
||||||
logger.error("LLM未正确初始化,无法执行二步调用")
|
logger.error("LLM未正确初始化,无法执行二步调用")
|
||||||
return {
|
return {"type": "error", "id": stream_id, "content": "系统错误:LLM未正确初始化"}
|
||||||
"type": "error",
|
|
||||||
"id": stream_id,
|
|
||||||
"content": "系统错误:LLM未正确初始化"
|
|
||||||
}
|
|
||||||
|
|
||||||
final_impression = await self._llm_decide_final_impression(
|
final_impression = await self._llm_decide_final_impression(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
@@ -127,15 +143,11 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
new_impression=new_impression,
|
new_impression=new_impression,
|
||||||
new_style=new_style,
|
new_style=new_style,
|
||||||
new_topics=new_topics,
|
new_topics=new_topics,
|
||||||
new_score=new_score
|
new_score=new_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not final_impression:
|
if not final_impression:
|
||||||
return {
|
return {"type": "error", "id": stream_id, "content": "LLM决策失败,无法更新聊天流印象"}
|
||||||
"type": "error",
|
|
||||||
"id": stream_id,
|
|
||||||
"content": "LLM决策失败,无法更新聊天流印象"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 更新数据库
|
# 更新数据库
|
||||||
await self._update_stream_impression_in_db(stream_id, final_impression)
|
await self._update_stream_impression_in_db(stream_id, final_impression)
|
||||||
@@ -154,18 +166,14 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
|
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
|
||||||
logger.info(f"聊天流印象更新成功: {stream_id}")
|
logger.info(f"聊天流印象更新成功: {stream_id}")
|
||||||
|
|
||||||
return {
|
return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text}
|
||||||
"type": "chat_stream_impression_update",
|
|
||||||
"id": stream_id,
|
|
||||||
"content": result_text
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"聊天流印象更新失败: {e}", exc_info=True)
|
logger.error(f"聊天流印象更新失败: {e}", exc_info=True)
|
||||||
return {
|
return {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"id": function_args.get("stream_id", "unknown"),
|
"id": function_args.get("stream_id", "unknown"),
|
||||||
"content": f"聊天流印象更新失败: {str(e)}"
|
"content": f"聊天流印象更新失败: {e!s}",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]:
|
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any]:
|
||||||
@@ -188,7 +196,9 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
"stream_impression_text": stream.stream_impression_text or "",
|
"stream_impression_text": stream.stream_impression_text or "",
|
||||||
"stream_chat_style": stream.stream_chat_style or "",
|
"stream_chat_style": stream.stream_chat_style or "",
|
||||||
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
||||||
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score is not None else 0.5,
|
"stream_interest_score": float(stream.stream_interest_score)
|
||||||
|
if stream.stream_interest_score is not None
|
||||||
|
else 0.5,
|
||||||
"group_name": stream.group_name or "私聊",
|
"group_name": stream.group_name or "私聊",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@@ -217,7 +227,7 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
new_impression: str,
|
new_impression: str,
|
||||||
new_style: str,
|
new_style: str,
|
||||||
new_topics: str,
|
new_topics: str,
|
||||||
new_score: float | None
|
new_score: float | None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""使用LLM决策最终的聊天流印象内容
|
"""使用LLM决策最终的聊天流印象内容
|
||||||
|
|
||||||
@@ -235,6 +245,7 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
# 获取bot人设
|
# 获取bot人设
|
||||||
from src.individuality.individuality import Individuality
|
from src.individuality.individuality import Individuality
|
||||||
|
|
||||||
individuality = Individuality()
|
individuality = Individuality()
|
||||||
bot_personality = await individuality.get_personality_block()
|
bot_personality = await individuality.get_personality_block()
|
||||||
|
|
||||||
@@ -244,17 +255,17 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
你正在更新对聊天流 {stream_id} 的整体印象。
|
你正在更新对聊天流 {stream_id} 的整体印象。
|
||||||
|
|
||||||
【当前聊天流信息】
|
【当前聊天流信息】
|
||||||
- 聊天环境: {existing_impression.get('group_name', '未知')}
|
- 聊天环境: {existing_impression.get("group_name", "未知")}
|
||||||
- 当前印象: {existing_impression.get('stream_impression_text', '暂无印象')}
|
- 当前印象: {existing_impression.get("stream_impression_text", "暂无印象")}
|
||||||
- 聊天风格: {existing_impression.get('stream_chat_style', '未知')}
|
- 聊天风格: {existing_impression.get("stream_chat_style", "未知")}
|
||||||
- 常见话题: {existing_impression.get('stream_topic_keywords', '未知')}
|
- 常见话题: {existing_impression.get("stream_topic_keywords", "未知")}
|
||||||
- 当前兴趣分: {existing_impression.get('stream_interest_score', 0.5):.2f}
|
- 当前兴趣分: {existing_impression.get("stream_interest_score", 0.5):.2f}
|
||||||
|
|
||||||
【本次想要更新的内容】
|
【本次想要更新的内容】
|
||||||
- 新的印象描述: {new_impression if new_impression else '不更新'}
|
- 新的印象描述: {new_impression if new_impression else "不更新"}
|
||||||
- 新的聊天风格: {new_style if new_style else '不更新'}
|
- 新的聊天风格: {new_style if new_style else "不更新"}
|
||||||
- 新的话题关键词: {new_topics if new_topics else '不更新'}
|
- 新的话题关键词: {new_topics if new_topics else "不更新"}
|
||||||
- 新的兴趣分数: {new_score if new_score is not None else '不更新'}
|
- 新的兴趣分数: {new_score if new_score is not None else "不更新"}
|
||||||
|
|
||||||
请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意:
|
请综合考虑现有信息和新信息,决定最终的聊天流印象内容。注意:
|
||||||
1. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成对这个聊天环境的整体认知(100-200字)
|
1. 印象描述:如果提供了新印象,应该综合现有印象和新印象,形成对这个聊天环境的整体认知(100-200字)
|
||||||
@@ -285,10 +296,26 @@ class ChatStreamImpressionTool(BaseTool):
|
|||||||
|
|
||||||
# 提取最终决定的数据
|
# 提取最终决定的数据
|
||||||
final_impression = {
|
final_impression = {
|
||||||
"stream_impression_text": response_data.get("stream_impression_text", existing_impression.get("stream_impression_text", "")),
|
"stream_impression_text": response_data.get(
|
||||||
"stream_chat_style": response_data.get("stream_chat_style", existing_impression.get("stream_chat_style", "")),
|
"stream_impression_text", existing_impression.get("stream_impression_text", "")
|
||||||
"stream_topic_keywords": response_data.get("stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")),
|
),
|
||||||
"stream_interest_score": max(0.0, min(1.0, float(response_data.get("stream_interest_score", existing_impression.get("stream_interest_score", 0.5))))),
|
"stream_chat_style": response_data.get(
|
||||||
|
"stream_chat_style", existing_impression.get("stream_chat_style", "")
|
||||||
|
),
|
||||||
|
"stream_topic_keywords": response_data.get(
|
||||||
|
"stream_topic_keywords", existing_impression.get("stream_topic_keywords", "")
|
||||||
|
),
|
||||||
|
"stream_interest_score": max(
|
||||||
|
0.0,
|
||||||
|
min(
|
||||||
|
1.0,
|
||||||
|
float(
|
||||||
|
response_data.get(
|
||||||
|
"stream_interest_score", existing_impression.get("stream_interest_score", 0.5)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"LLM决策完成: {stream_id}")
|
logger.info(f"LLM决策完成: {stream_id}")
|
||||||
|
|||||||
@@ -362,13 +362,11 @@ class ChatterPlanExecutor:
|
|||||||
is_picid=False,
|
is_picid=False,
|
||||||
is_command=False,
|
is_command=False,
|
||||||
is_notify=False,
|
is_notify=False,
|
||||||
|
|
||||||
# 用户信息
|
# 用户信息
|
||||||
user_id=bot_user_id,
|
user_id=bot_user_id,
|
||||||
user_nickname=bot_nickname,
|
user_nickname=bot_nickname,
|
||||||
user_cardname=bot_nickname,
|
user_cardname=bot_nickname,
|
||||||
user_platform="qq",
|
user_platform="qq",
|
||||||
|
|
||||||
# 聊天上下文信息
|
# 聊天上下文信息
|
||||||
chat_info_user_id=chat_stream.user_info.user_id if chat_stream.user_info else bot_user_id,
|
chat_info_user_id=chat_stream.user_info.user_id if chat_stream.user_info else bot_user_id,
|
||||||
chat_info_user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else bot_nickname,
|
chat_info_user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else bot_nickname,
|
||||||
@@ -378,23 +376,22 @@ class ChatterPlanExecutor:
|
|||||||
chat_info_platform=chat_stream.platform,
|
chat_info_platform=chat_stream.platform,
|
||||||
chat_info_create_time=chat_stream.create_time,
|
chat_info_create_time=chat_stream.create_time,
|
||||||
chat_info_last_active_time=chat_stream.last_active_time,
|
chat_info_last_active_time=chat_stream.last_active_time,
|
||||||
|
|
||||||
# 群组信息(如果是群聊)
|
# 群组信息(如果是群聊)
|
||||||
chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None,
|
chat_info_group_id=chat_stream.group_info.group_id if chat_stream.group_info else None,
|
||||||
chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None,
|
chat_info_group_name=chat_stream.group_info.group_name if chat_stream.group_info else None,
|
||||||
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None) if chat_stream.group_info else None,
|
chat_info_group_platform=getattr(chat_stream.group_info, "platform", None)
|
||||||
|
if chat_stream.group_info
|
||||||
|
else None,
|
||||||
# 动作信息
|
# 动作信息
|
||||||
actions=["bot_reply"],
|
actions=["bot_reply"],
|
||||||
should_reply=False,
|
should_reply=False,
|
||||||
should_act=False
|
should_act=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加到chat_stream的已读消息中
|
# 添加到chat_stream的已读消息中
|
||||||
chat_stream.context_manager.context.history_messages.append(bot_message)
|
chat_stream.context_manager.context.history_messages.append(bot_message)
|
||||||
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"添加机器人回复到已读消息时出错: {e}")
|
logger.error(f"添加机器人回复到已读消息时出错: {e}")
|
||||||
logger.debug(f"plan.chat_id: {plan.chat_id}")
|
logger.debug(f"plan.chat_id: {plan.chat_id}")
|
||||||
|
|||||||
@@ -106,7 +106,10 @@ class ChatterPlanFilter:
|
|||||||
actions_obj = item.get("actions", {})
|
actions_obj = item.get("actions", {})
|
||||||
|
|
||||||
# 记录决策历史
|
# 记录决策历史
|
||||||
if hasattr(global_config.chat, "enable_decision_history") and global_config.chat.enable_decision_history:
|
if (
|
||||||
|
hasattr(global_config.chat, "enable_decision_history")
|
||||||
|
and global_config.chat.enable_decision_history
|
||||||
|
):
|
||||||
action_types_to_log = []
|
action_types_to_log = []
|
||||||
actions_to_process_for_log = []
|
actions_to_process_for_log = []
|
||||||
if isinstance(actions_obj, dict):
|
if isinstance(actions_obj, dict):
|
||||||
@@ -121,7 +124,6 @@ class ChatterPlanFilter:
|
|||||||
if thinking != "未提供思考过程" and action_types_to_log:
|
if thinking != "未提供思考过程" and action_types_to_log:
|
||||||
await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log))
|
await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log))
|
||||||
|
|
||||||
|
|
||||||
# 处理actions字段可能是字典或列表的情况
|
# 处理actions字段可能是字典或列表的情况
|
||||||
if isinstance(actions_obj, dict):
|
if isinstance(actions_obj, dict):
|
||||||
action_type = actions_obj.get("action_type", "no_action")
|
action_type = actions_obj.get("action_type", "no_action")
|
||||||
@@ -593,6 +595,15 @@ class ChatterPlanFilter:
|
|||||||
):
|
):
|
||||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
|
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
|
||||||
action = "no_action"
|
action = "no_action"
|
||||||
|
# TODO:把逻辑迁移到DatabaseMessages(如果没人做下个星期我自己来)
|
||||||
|
# from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
# action_message_obj = None
|
||||||
|
# if target_message_obj:
|
||||||
|
# try:
|
||||||
|
# action_message_obj = DatabaseMessages(**target_message_obj)
|
||||||
|
# except Exception:
|
||||||
|
# logger.warning("无法将目标消息转换为DatabaseMessages对象")
|
||||||
|
|
||||||
parsed_actions.append(
|
parsed_actions.append(
|
||||||
ActionPlannerInfo(
|
ActionPlannerInfo(
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPla
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from src.common.data_models.info_data_model import Plan
|
from src.common.data_models.info_data_model import Plan
|
||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
|
||||||
@@ -184,10 +183,10 @@ class ChatterActionPlanner:
|
|||||||
for action in filtered_plan.decided_actions:
|
for action in filtered_plan.decided_actions:
|
||||||
if action.action_type in ["reply", "proactive_reply"] and action.action_message:
|
if action.action_type in ["reply", "proactive_reply"] and action.action_message:
|
||||||
# 提取目标消息ID
|
# 提取目标消息ID
|
||||||
if hasattr(action.action_message, 'message_id'):
|
if hasattr(action.action_message, "message_id"):
|
||||||
target_message_id = action.action_message.message_id
|
target_message_id = action.action_message.message_id
|
||||||
elif isinstance(action.action_message, dict):
|
elif isinstance(action.action_message, dict):
|
||||||
target_message_id = action.action_message.get('message_id')
|
target_message_id = action.action_message.get("message_id")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 如果找到目标消息ID,检查是否已经在处理中
|
# 如果找到目标消息ID,检查是否已经在处理中
|
||||||
@@ -233,7 +232,7 @@ class ChatterActionPlanner:
|
|||||||
# 8. 清理处理标记
|
# 8. 清理处理标记
|
||||||
if context:
|
if context:
|
||||||
context.processing_message_id = None
|
context.processing_message_id = None
|
||||||
logger.debug(f"已清理处理标记,完成规划流程")
|
logger.debug("已清理处理标记,完成规划流程")
|
||||||
|
|
||||||
# 9. 返回结果
|
# 9. 返回结果
|
||||||
return self._build_return_result(filtered_plan)
|
return self._build_return_result(filtered_plan)
|
||||||
@@ -340,7 +339,7 @@ class ChatterActionPlanner:
|
|||||||
# 清理处理标记
|
# 清理处理标记
|
||||||
if context:
|
if context:
|
||||||
context.processing_message_id = None
|
context.processing_message_id = None
|
||||||
logger.debug(f"Normal模式: 已清理处理标记")
|
logger.debug("Normal模式: 已清理处理标记")
|
||||||
|
|
||||||
# 无论是否回复,都进行退出normal模式的判定
|
# 无论是否回复,都进行退出normal模式的判定
|
||||||
await self._check_exit_normal_mode(context)
|
await self._check_exit_normal_mode(context)
|
||||||
@@ -348,7 +347,7 @@ class ChatterActionPlanner:
|
|||||||
return [asdict(reply_action)], target_message_dict
|
return [asdict(reply_action)], target_message_dict
|
||||||
else:
|
else:
|
||||||
# 未达到reply阈值
|
# 未达到reply阈值
|
||||||
logger.debug(f"Normal模式: 未达到reply阈值")
|
logger.debug("Normal模式: 未达到reply阈值")
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
no_action = ActionPlannerInfo(
|
no_action = ActionPlannerInfo(
|
||||||
action_type="no_action",
|
action_type="no_action",
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
|
|||||||
|
|
||||||
stream_id = kwargs.get("stream_id")
|
stream_id = kwargs.get("stream_id")
|
||||||
if not stream_id:
|
if not stream_id:
|
||||||
logger.debug(f"[主动思考事件] Reply事件缺少stream_id参数")
|
logger.debug("[主动思考事件] Reply事件缺少stream_id参数")
|
||||||
return HandlerResult(success=True, continue_process=True, message=None)
|
return HandlerResult(success=True, continue_process=True, message=None)
|
||||||
|
|
||||||
logger.debug(f"[主动思考事件] 收到 AFTER_SEND 事件,stream_id={stream_id}")
|
logger.debug(f"[主动思考事件] 收到 AFTER_SEND 事件,stream_id={stream_id}")
|
||||||
@@ -53,7 +53,7 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
|
|||||||
|
|
||||||
# 检查是否启用reply重置
|
# 检查是否启用reply重置
|
||||||
if not global_config.proactive_thinking.reply_reset_enabled:
|
if not global_config.proactive_thinking.reply_reset_enabled:
|
||||||
logger.debug(f"[主动思考事件] reply_reset_enabled 为 False,跳过重置")
|
logger.debug("[主动思考事件] reply_reset_enabled 为 False,跳过重置")
|
||||||
return HandlerResult(success=True, continue_process=True, message=None)
|
return HandlerResult(success=True, continue_process=True, message=None)
|
||||||
|
|
||||||
# 检查是否被暂停
|
# 检查是否被暂停
|
||||||
|
|||||||
@@ -5,11 +5,10 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.chat.express.expression_learner import expression_learner_manager
|
|
||||||
from src.chat.express.expression_selector import expression_selector
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.database.sqlalchemy_models import ChatStreams
|
from src.common.database.sqlalchemy_models import ChatStreams
|
||||||
@@ -17,7 +16,7 @@ from src.common.logger import get_logger
|
|||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.individuality.individuality import Individuality
|
from src.individuality.individuality import Individuality
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.plugin_system.apis import chat_api, message_api, send_api
|
from src.plugin_system.apis import message_api, send_api
|
||||||
|
|
||||||
logger = get_logger("proactive_thinking_executor")
|
logger = get_logger("proactive_thinking_executor")
|
||||||
|
|
||||||
@@ -35,19 +34,17 @@ class ProactiveThinkingPlanner:
|
|||||||
"""初始化规划器"""
|
"""初始化规划器"""
|
||||||
try:
|
try:
|
||||||
self.decision_llm = LLMRequest(
|
self.decision_llm = LLMRequest(
|
||||||
model_set=model_config.model_task_config.utils,
|
model_set=model_config.model_task_config.utils, request_type="proactive_thinking_decision"
|
||||||
request_type="proactive_thinking_decision"
|
|
||||||
)
|
)
|
||||||
self.reply_llm = LLMRequest(
|
self.reply_llm = LLMRequest(
|
||||||
model_set=model_config.model_task_config.replyer,
|
model_set=model_config.model_task_config.replyer, request_type="proactive_thinking_reply"
|
||||||
request_type="proactive_thinking_reply"
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"初始化LLM失败: {e}")
|
logger.error(f"初始化LLM失败: {e}")
|
||||||
self.decision_llm = None
|
self.decision_llm = None
|
||||||
self.reply_llm = None
|
self.reply_llm = None
|
||||||
|
|
||||||
async def gather_context(self, stream_id: str) -> Optional[dict[str, Any]]:
|
async def gather_context(self, stream_id: str) -> dict[str, Any] | None:
|
||||||
"""搜集聊天流的上下文信息
|
"""搜集聊天流的上下文信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -65,10 +62,7 @@ class ProactiveThinkingPlanner:
|
|||||||
|
|
||||||
# 2. 获取最近的聊天记录
|
# 2. 获取最近的聊天记录
|
||||||
recent_messages = await message_api.get_recent_messages(
|
recent_messages = await message_api.get_recent_messages(
|
||||||
chat_id=stream_id,
|
chat_id=stream_id, limit=20, limit_mode="latest", hours=24
|
||||||
limit=20,
|
|
||||||
limit_mode="latest",
|
|
||||||
hours=24
|
|
||||||
)
|
)
|
||||||
|
|
||||||
recent_chat_history = ""
|
recent_chat_history = ""
|
||||||
@@ -83,6 +77,7 @@ class ProactiveThinkingPlanner:
|
|||||||
current_mood = "感觉很平静" # 默认心情
|
current_mood = "感觉很平静" # 默认心情
|
||||||
try:
|
try:
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
|
|
||||||
mood_obj = mood_manager.get_mood_by_chat_id(stream_id)
|
mood_obj = mood_manager.get_mood_by_chat_id(stream_id)
|
||||||
if mood_obj:
|
if mood_obj:
|
||||||
await mood_obj._initialize() # 确保已初始化
|
await mood_obj._initialize() # 确保已初始化
|
||||||
@@ -97,6 +92,7 @@ class ProactiveThinkingPlanner:
|
|||||||
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
|
from src.plugins.built_in.affinity_flow_chatter.proactive_thinking_scheduler import (
|
||||||
proactive_thinking_scheduler,
|
proactive_thinking_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
last_decision = proactive_thinking_scheduler.get_last_decision(stream_id)
|
last_decision = proactive_thinking_scheduler.get_last_decision(stream_id)
|
||||||
if last_decision:
|
if last_decision:
|
||||||
logger.debug(f"获取到聊天流 {stream_id} 的上次决策: {last_decision.get('action')}")
|
logger.debug(f"获取到聊天流 {stream_id} 的上次决策: {last_decision.get('action')}")
|
||||||
@@ -125,7 +121,7 @@ class ProactiveThinkingPlanner:
|
|||||||
logger.error(f"搜集上下文信息失败: {e}", exc_info=True)
|
logger.error(f"搜集上下文信息失败: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _get_stream_impression(self, stream_id: str) -> Optional[dict[str, Any]]:
|
async def _get_stream_impression(self, stream_id: str) -> dict[str, Any] | None:
|
||||||
"""从数据库获取聊天流印象数据"""
|
"""从数据库获取聊天流印象数据"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
@@ -141,16 +137,16 @@ class ProactiveThinkingPlanner:
|
|||||||
"stream_impression_text": stream.stream_impression_text or "",
|
"stream_impression_text": stream.stream_impression_text or "",
|
||||||
"stream_chat_style": stream.stream_chat_style or "",
|
"stream_chat_style": stream.stream_chat_style or "",
|
||||||
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
"stream_topic_keywords": stream.stream_topic_keywords or "",
|
||||||
"stream_interest_score": float(stream.stream_interest_score) if stream.stream_interest_score else 0.5,
|
"stream_interest_score": float(stream.stream_interest_score)
|
||||||
|
if stream.stream_interest_score
|
||||||
|
else 0.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天流印象失败: {e}")
|
logger.error(f"获取聊天流印象失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def make_decision(
|
async def make_decision(self, context: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
self, context: dict[str, Any]
|
|
||||||
) -> Optional[dict[str, Any]]:
|
|
||||||
"""使用LLM进行决策
|
"""使用LLM进行决策
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -183,9 +179,7 @@ class ProactiveThinkingPlanner:
|
|||||||
cleaned_response = self._clean_json_response(response)
|
cleaned_response = self._clean_json_response(response)
|
||||||
decision = json.loads(cleaned_response)
|
decision = json.loads(cleaned_response)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}")
|
||||||
f"决策结果: {decision.get('action', 'unknown')} - {decision.get('reasoning', '无理由')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return decision
|
return decision
|
||||||
|
|
||||||
@@ -202,12 +196,12 @@ class ProactiveThinkingPlanner:
|
|||||||
"""构建决策提示词"""
|
"""构建决策提示词"""
|
||||||
# 构建上次决策信息
|
# 构建上次决策信息
|
||||||
last_decision_text = ""
|
last_decision_text = ""
|
||||||
if context.get('last_decision'):
|
if context.get("last_decision"):
|
||||||
last_dec = context['last_decision']
|
last_dec = context["last_decision"]
|
||||||
last_action = last_dec.get('action', '未知')
|
last_action = last_dec.get("action", "未知")
|
||||||
last_reasoning = last_dec.get('reasoning', '无')
|
last_reasoning = last_dec.get("reasoning", "无")
|
||||||
last_topic = last_dec.get('topic')
|
last_topic = last_dec.get("topic")
|
||||||
last_time = last_dec.get('timestamp', '未知')
|
last_time = last_dec.get("timestamp", "未知")
|
||||||
|
|
||||||
last_decision_text = f"""
|
last_decision_text = f"""
|
||||||
【上次主动思考的决策】
|
【上次主动思考的决策】
|
||||||
@@ -218,22 +212,22 @@ class ProactiveThinkingPlanner:
|
|||||||
last_decision_text += f"\n- 话题: {last_topic}"
|
last_decision_text += f"\n- 话题: {last_topic}"
|
||||||
|
|
||||||
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
||||||
{context['bot_personality']}
|
{context["bot_personality"]}
|
||||||
|
|
||||||
现在是 {context['current_time']},你正在考虑是否要主动在 "{context['stream_name']}" 中说些什么。
|
现在是 {context["current_time"]},你正在考虑是否要主动在 "{context["stream_name"]}" 中说些什么。
|
||||||
|
|
||||||
【你当前的心情】
|
【你当前的心情】
|
||||||
{context.get('current_mood', '感觉很平静')}
|
{context.get("current_mood", "感觉很平静")}
|
||||||
|
|
||||||
【聊天环境信息】
|
【聊天环境信息】
|
||||||
- 整体印象: {context['stream_impression']}
|
- 整体印象: {context["stream_impression"]}
|
||||||
- 聊天风格: {context['chat_style']}
|
- 聊天风格: {context["chat_style"]}
|
||||||
- 常见话题: {context['topic_keywords'] or '暂无'}
|
- 常见话题: {context["topic_keywords"] or "暂无"}
|
||||||
- 你的兴趣程度: {context['interest_score']:.2f}/1.0
|
- 你的兴趣程度: {context["interest_score"]:.2f}/1.0
|
||||||
{last_decision_text}
|
{last_decision_text}
|
||||||
|
|
||||||
【最近的聊天记录】
|
【最近的聊天记录】
|
||||||
{context['recent_chat_history']}
|
{context["recent_chat_history"]}
|
||||||
|
|
||||||
请根据以上信息(包括你的心情和上次决策),决定你现在应该做什么:
|
请根据以上信息(包括你的心情和上次决策),决定你现在应该做什么:
|
||||||
|
|
||||||
@@ -269,11 +263,8 @@ class ProactiveThinkingPlanner:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def generate_reply(
|
async def generate_reply(
|
||||||
self,
|
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None = None
|
||||||
context: dict[str, Any],
|
) -> str | None:
|
||||||
action: Literal["simple_bubble", "throw_topic"],
|
|
||||||
topic: Optional[str] = None
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""生成回复内容
|
"""生成回复内容
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -324,7 +315,7 @@ class ProactiveThinkingPlanner:
|
|||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
target_message=None, # 主动思考没有target message
|
target_message=None, # 主动思考没有target message
|
||||||
max_num=6, # 主动思考时使用较少的表达方式
|
max_num=6, # 主动思考时使用较少的表达方式
|
||||||
min_num=2
|
min_num=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not selected_expressions:
|
if not selected_expressions:
|
||||||
@@ -357,33 +348,29 @@ class ProactiveThinkingPlanner:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _build_reply_prompt(
|
async def _build_reply_prompt(
|
||||||
self,
|
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None
|
||||||
context: dict[str, Any],
|
|
||||||
action: Literal["simple_bubble", "throw_topic"],
|
|
||||||
topic: Optional[str]
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""构建回复提示词"""
|
"""构建回复提示词"""
|
||||||
# 获取表达方式参考
|
# 获取表达方式参考
|
||||||
expression_habits = await self._get_expression_habits(
|
expression_habits = await self._get_expression_habits(
|
||||||
stream_id=context.get('stream_id', ''),
|
stream_id=context.get("stream_id", ""), chat_history=context.get("recent_chat_history", "")
|
||||||
chat_history=context.get('recent_chat_history', '')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if action == "simple_bubble":
|
if action == "simple_bubble":
|
||||||
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
||||||
{context['bot_personality']}
|
{context["bot_personality"]}
|
||||||
|
|
||||||
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中简单冒个泡。
|
现在是 {context["current_time"]},你决定在 "{context["stream_name"]}" 中简单冒个泡。
|
||||||
|
|
||||||
【你当前的心情】
|
【你当前的心情】
|
||||||
{context.get('current_mood', '感觉很平静')}
|
{context.get("current_mood", "感觉很平静")}
|
||||||
|
|
||||||
【聊天环境】
|
【聊天环境】
|
||||||
- 整体印象: {context['stream_impression']}
|
- 整体印象: {context["stream_impression"]}
|
||||||
- 聊天风格: {context['chat_style']}
|
- 聊天风格: {context["chat_style"]}
|
||||||
|
|
||||||
【最近的聊天记录】
|
【最近的聊天记录】
|
||||||
{context['recent_chat_history']}
|
{context["recent_chat_history"]}
|
||||||
{expression_habits}
|
{expression_habits}
|
||||||
请生成一条简短的消息,用于水群。要求:
|
请生成一条简短的消息,用于水群。要求:
|
||||||
1. 非常简短(5-15字)
|
1. 非常简短(5-15字)
|
||||||
@@ -397,20 +384,20 @@ class ProactiveThinkingPlanner:
|
|||||||
|
|
||||||
else: # throw_topic
|
else: # throw_topic
|
||||||
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
return f"""你是一个有着独特个性的AI助手。你的人设是:
|
||||||
{context['bot_personality']}
|
{context["bot_personality"]}
|
||||||
|
|
||||||
现在是 {context['current_time']},你决定在 "{context['stream_name']}" 中抛出一个话题。
|
现在是 {context["current_time"]},你决定在 "{context["stream_name"]}" 中抛出一个话题。
|
||||||
|
|
||||||
【你当前的心情】
|
【你当前的心情】
|
||||||
{context.get('current_mood', '感觉很平静')}
|
{context.get("current_mood", "感觉很平静")}
|
||||||
|
|
||||||
【聊天环境】
|
【聊天环境】
|
||||||
- 整体印象: {context['stream_impression']}
|
- 整体印象: {context["stream_impression"]}
|
||||||
- 聊天风格: {context['chat_style']}
|
- 聊天风格: {context["chat_style"]}
|
||||||
- 常见话题: {context['topic_keywords'] or '暂无'}
|
- 常见话题: {context["topic_keywords"] or "暂无"}
|
||||||
|
|
||||||
【最近的聊天记录】
|
【最近的聊天记录】
|
||||||
{context['recent_chat_history']}
|
{context["recent_chat_history"]}
|
||||||
|
|
||||||
【你想抛出的话题】
|
【你想抛出的话题】
|
||||||
{topic}
|
{topic}
|
||||||
@@ -471,7 +458,7 @@ def _update_statistics(stream_id: str, action: str):
|
|||||||
_statistics[stream_id]["last_execution_time"] = datetime.now().isoformat()
|
_statistics[stream_id]["last_execution_time"] = datetime.now().isoformat()
|
||||||
|
|
||||||
|
|
||||||
def get_statistics(stream_id: Optional[str] = None) -> dict[str, Any]:
|
def get_statistics(stream_id: str | None = None) -> dict[str, Any]:
|
||||||
"""获取统计数据
|
"""获取统计数据
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -503,31 +490,31 @@ async def execute_proactive_thinking(stream_id: str):
|
|||||||
try:
|
try:
|
||||||
# 0. 前置检查
|
# 0. 前置检查
|
||||||
if proactive_thinking_scheduler._is_in_quiet_hours():
|
if proactive_thinking_scheduler._is_in_quiet_hours():
|
||||||
logger.debug(f"安静时段,跳过")
|
logger.debug("安静时段,跳过")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not proactive_thinking_scheduler._check_daily_limit(stream_id):
|
if not proactive_thinking_scheduler._check_daily_limit(stream_id):
|
||||||
logger.debug(f"今日发言达上限")
|
logger.debug("今日发言达上限")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 1. 搜集信息
|
# 1. 搜集信息
|
||||||
logger.debug(f"步骤1: 搜集上下文")
|
logger.debug("步骤1: 搜集上下文")
|
||||||
context = await _planner.gather_context(stream_id)
|
context = await _planner.gather_context(stream_id)
|
||||||
if not context:
|
if not context:
|
||||||
logger.warning(f"无法搜集上下文,跳过")
|
logger.warning("无法搜集上下文,跳过")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 检查兴趣分数阈值
|
# 检查兴趣分数阈值
|
||||||
interest_score = context.get('interest_score', 0.5)
|
interest_score = context.get("interest_score", 0.5)
|
||||||
if not proactive_thinking_scheduler._check_interest_score_threshold(interest_score):
|
if not proactive_thinking_scheduler._check_interest_score_threshold(interest_score):
|
||||||
logger.debug(f"兴趣分数不在阈值范围内")
|
logger.debug("兴趣分数不在阈值范围内")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 2. 进行决策
|
# 2. 进行决策
|
||||||
logger.debug(f"步骤2: LLM决策")
|
logger.debug("步骤2: LLM决策")
|
||||||
decision = await _planner.make_decision(context)
|
decision = await _planner.make_decision(context)
|
||||||
if not decision:
|
if not decision:
|
||||||
logger.warning(f"决策失败,跳过")
|
logger.warning("决策失败,跳过")
|
||||||
return
|
return
|
||||||
|
|
||||||
action = decision.get("action", "do_nothing")
|
action = decision.get("action", "do_nothing")
|
||||||
@@ -549,14 +536,14 @@ async def execute_proactive_thinking(stream_id: str):
|
|||||||
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
|
proactive_thinking_scheduler.record_decision(stream_id, action, reasoning, None)
|
||||||
|
|
||||||
# 生成简单的消息
|
# 生成简单的消息
|
||||||
logger.debug(f"步骤3: 生成冒泡回复")
|
logger.debug("步骤3: 生成冒泡回复")
|
||||||
reply = await _planner.generate_reply(context, "simple_bubble")
|
reply = await _planner.generate_reply(context, "simple_bubble")
|
||||||
if reply:
|
if reply:
|
||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
text=reply,
|
text=reply,
|
||||||
)
|
)
|
||||||
logger.info(f"✅ 已发送冒泡消息")
|
logger.info("✅ 已发送冒泡消息")
|
||||||
|
|
||||||
# 增加每日计数
|
# 增加每日计数
|
||||||
proactive_thinking_scheduler._increment_daily_count(stream_id)
|
proactive_thinking_scheduler._increment_daily_count(stream_id)
|
||||||
@@ -568,11 +555,11 @@ async def execute_proactive_thinking(stream_id: str):
|
|||||||
# 冒泡后暂停主动思考,等待用户回复
|
# 冒泡后暂停主动思考,等待用户回复
|
||||||
# 使用与 topic_throw 相同的冷却时间配置
|
# 使用与 topic_throw 相同的冷却时间配置
|
||||||
if config.topic_throw_cooldown > 0:
|
if config.topic_throw_cooldown > 0:
|
||||||
logger.info(f"[主动思考] 步骤5:暂停任务")
|
logger.info("[主动思考] 步骤5:暂停任务")
|
||||||
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已冒泡")
|
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已冒泡")
|
||||||
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
|
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
|
||||||
|
|
||||||
logger.info(f"[主动思考] simple_bubble 执行完成")
|
logger.info("[主动思考] simple_bubble 执行完成")
|
||||||
|
|
||||||
elif action == "throw_topic":
|
elif action == "throw_topic":
|
||||||
topic = decision.get("topic", "")
|
topic = decision.get("topic", "")
|
||||||
@@ -583,15 +570,15 @@ async def execute_proactive_thinking(stream_id: str):
|
|||||||
|
|
||||||
if not topic:
|
if not topic:
|
||||||
logger.warning("[主动思考] 选择了抛出话题但未提供话题内容,降级为冒泡")
|
logger.warning("[主动思考] 选择了抛出话题但未提供话题内容,降级为冒泡")
|
||||||
logger.info(f"[主动思考] 步骤3:生成降级冒泡回复")
|
logger.info("[主动思考] 步骤3:生成降级冒泡回复")
|
||||||
reply = await _planner.generate_reply(context, "simple_bubble")
|
reply = await _planner.generate_reply(context, "simple_bubble")
|
||||||
else:
|
else:
|
||||||
# 生成基于话题的消息
|
# 生成基于话题的消息
|
||||||
logger.info(f"[主动思考] 步骤3:生成话题回复")
|
logger.info("[主动思考] 步骤3:生成话题回复")
|
||||||
reply = await _planner.generate_reply(context, "throw_topic", topic)
|
reply = await _planner.generate_reply(context, "throw_topic", topic)
|
||||||
|
|
||||||
if reply:
|
if reply:
|
||||||
logger.info(f"[主动思考] 步骤4:发送消息")
|
logger.info("[主动思考] 步骤4:发送消息")
|
||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
text=reply,
|
text=reply,
|
||||||
@@ -607,11 +594,11 @@ async def execute_proactive_thinking(stream_id: str):
|
|||||||
|
|
||||||
# 抛出话题后暂停主动思考(如果配置了冷却时间)
|
# 抛出话题后暂停主动思考(如果配置了冷却时间)
|
||||||
if config.topic_throw_cooldown > 0:
|
if config.topic_throw_cooldown > 0:
|
||||||
logger.info(f"[主动思考] 步骤5:暂停任务")
|
logger.info("[主动思考] 步骤5:暂停任务")
|
||||||
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已抛出话题")
|
await proactive_thinking_scheduler.pause_proactive_thinking(stream_id, reason="已抛出话题")
|
||||||
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
|
logger.info(f"[主动思考] 已暂停聊天流 {stream_id} 的主动思考,等待用户回复")
|
||||||
|
|
||||||
logger.info(f"[主动思考] throw_topic 执行完成")
|
logger.info("[主动思考] throw_topic 执行完成")
|
||||||
|
|
||||||
logger.info(f"[主动思考] 聊天流 {stream_id} 的主动思考执行完成")
|
logger.info(f"[主动思考] 聊天流 {stream_id} 的主动思考执行完成")
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,10 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
|
||||||
from src.common.database.sqlalchemy_models import ChatStreams
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.schedule.unified_scheduler import TriggerType, unified_scheduler
|
from src.schedule.unified_scheduler import TriggerType, unified_scheduler
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
logger = get_logger("proactive_thinking_scheduler")
|
logger = get_logger("proactive_thinking_scheduler")
|
||||||
|
|
||||||
@@ -42,6 +39,7 @@ class ProactiveThinkingScheduler:
|
|||||||
|
|
||||||
# 从全局配置加载(延迟导入避免循环依赖)
|
# 从全局配置加载(延迟导入避免循环依赖)
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
self.config = global_config.proactive_thinking
|
self.config = global_config.proactive_thinking
|
||||||
|
|
||||||
def _calculate_interval(self, focus_energy: float) -> int:
|
def _calculate_interval(self, focus_energy: float) -> int:
|
||||||
@@ -208,14 +206,14 @@ class ProactiveThinkingScheduler:
|
|||||||
# 从聊天管理器获取聊天流
|
# 从聊天管理器获取聊天流
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
logger.debug(f"[调度器] 获取聊天管理器")
|
logger.debug("[调度器] 获取聊天管理器")
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
logger.debug(f"[调度器] 从聊天管理器获取聊天流 {stream_id}")
|
logger.debug(f"[调度器] 从聊天管理器获取聊天流 {stream_id}")
|
||||||
chat_stream = await chat_manager.get_stream(stream_id)
|
chat_stream = await chat_manager.get_stream(stream_id)
|
||||||
|
|
||||||
if chat_stream:
|
if chat_stream:
|
||||||
# 计算并获取最新的 focus_energy
|
# 计算并获取最新的 focus_energy
|
||||||
logger.debug(f"[调度器] 找到聊天流,开始计算 focus_energy")
|
logger.debug("[调度器] 找到聊天流,开始计算 focus_energy")
|
||||||
focus_energy = await chat_stream.calculate_focus_energy()
|
focus_energy = await chat_stream.calculate_focus_energy()
|
||||||
logger.info(f"[调度器] 聊天流 {stream_id} 的 focus_energy: {focus_energy:.3f}")
|
logger.info(f"[调度器] 聊天流 {stream_id} 的 focus_energy: {focus_energy:.3f}")
|
||||||
return focus_energy
|
return focus_energy
|
||||||
@@ -319,7 +317,7 @@ class ProactiveThinkingScheduler:
|
|||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# 错误日志已在上面记录
|
# 错误日志已在上面记录
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -389,7 +387,7 @@ class ProactiveThinkingScheduler:
|
|||||||
async with self._lock:
|
async with self._lock:
|
||||||
return stream_id in self._paused_streams
|
return stream_id in self._paused_streams
|
||||||
|
|
||||||
async def get_task_info(self, stream_id: str) -> Optional[dict[str, Any]]:
|
async def get_task_info(self, stream_id: str) -> dict[str, Any] | None:
|
||||||
"""获取聊天流的主动思考任务信息
|
"""获取聊天流的主动思考任务信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -456,10 +454,7 @@ class ProactiveThinkingScheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 按下次触发时间排序
|
# 按下次触发时间排序
|
||||||
tasks_sorted = sorted(
|
tasks_sorted = sorted(tasks, key=lambda x: x.get("next_run_time", datetime.max) or datetime.max)
|
||||||
tasks,
|
|
||||||
key=lambda x: x.get("next_run_time", datetime.max) or datetime.max
|
|
||||||
)
|
|
||||||
|
|
||||||
# 限制显示数量
|
# 限制显示数量
|
||||||
if max_streams > 0:
|
if max_streams > 0:
|
||||||
@@ -500,15 +495,12 @@ class ProactiveThinkingScheduler:
|
|||||||
f" 下次触发: {next_run.strftime('%Y-%m-%d %H:%M:%S')} ({time_str})"
|
f" 下次触发: {next_run.strftime('%Y-%m-%d %H:%M:%S')} ({time_str})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(f"[{i:2d}] ⚠️ 未知 | {stream_name}\n 下次触发: 未设置")
|
||||||
f"[{i:2d}] ⚠️ 未知 | {stream_name}\n"
|
|
||||||
f" 下次触发: 未设置"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
|
|
||||||
def get_last_decision(self, stream_id: str) -> Optional[dict[str, Any]]:
|
def get_last_decision(self, stream_id: str) -> dict[str, Any] | None:
|
||||||
"""获取聊天流的上次主动思考决策
|
"""获取聊天流的上次主动思考决策
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -524,13 +516,7 @@ class ProactiveThinkingScheduler:
|
|||||||
"""
|
"""
|
||||||
return self._last_decisions.get(stream_id)
|
return self._last_decisions.get(stream_id)
|
||||||
|
|
||||||
def record_decision(
|
def record_decision(self, stream_id: str, action: str, reasoning: str, topic: str | None = None) -> None:
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
action: str,
|
|
||||||
reasoning: str,
|
|
||||||
topic: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""记录聊天流的主动思考决策
|
"""记录聊天流的主动思考决策
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -4,10 +4,10 @@
|
|||||||
通过LLM二步调用机制更新用户画像信息,包括别名、主观印象、偏好关键词和好感分数
|
通过LLM二步调用机制更新用户画像信息,包括别名、主观印象、偏好关键词和好感分数
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import orjson
|
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
@@ -99,7 +99,7 @@ class UserProfileTool(BaseTool):
|
|||||||
return {
|
return {
|
||||||
"type": "info",
|
"type": "info",
|
||||||
"id": target_user_id,
|
"id": target_user_id,
|
||||||
"content": f"提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
|
"content": "提示:需要提供至少一项更新内容(别名、印象描述、偏好关键词或好感分数)"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 调用LLM进行二步决策
|
# 调用LLM进行二步决策
|
||||||
@@ -155,7 +155,7 @@ class UserProfileTool(BaseTool):
|
|||||||
return {
|
return {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"id": function_args.get("target_user_id", "unknown"),
|
"id": function_args.get("target_user_id", "unknown"),
|
||||||
"content": f"用户画像更新失败: {str(e)}"
|
"content": f"用户画像更新失败: {e!s}"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _get_user_profile(self, user_id: str) -> dict[str, Any]:
|
async def _get_user_profile(self, user_id: str) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -5,9 +5,10 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta
|
from collections.abc import Awaitable, Callable
|
||||||
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Awaitable, Callable, Optional
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import EventType
|
from src.plugin_system.base.component_types import EventType
|
||||||
@@ -33,9 +34,9 @@ class ScheduleTask:
|
|||||||
trigger_type: TriggerType,
|
trigger_type: TriggerType,
|
||||||
trigger_config: dict[str, Any],
|
trigger_config: dict[str, Any],
|
||||||
is_recurring: bool = False,
|
is_recurring: bool = False,
|
||||||
task_name: Optional[str] = None,
|
task_name: str | None = None,
|
||||||
callback_args: Optional[tuple] = None,
|
callback_args: tuple | None = None,
|
||||||
callback_kwargs: Optional[dict] = None,
|
callback_kwargs: dict | None = None,
|
||||||
):
|
):
|
||||||
self.schedule_id = schedule_id
|
self.schedule_id = schedule_id
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
@@ -46,7 +47,7 @@ class ScheduleTask:
|
|||||||
self.callback_args = callback_args or ()
|
self.callback_args = callback_args or ()
|
||||||
self.callback_kwargs = callback_kwargs or {}
|
self.callback_kwargs = callback_kwargs or {}
|
||||||
self.created_at = datetime.now()
|
self.created_at = datetime.now()
|
||||||
self.last_triggered_at: Optional[datetime] = None
|
self.last_triggered_at: datetime | None = None
|
||||||
self.trigger_count = 0
|
self.trigger_count = 0
|
||||||
self.is_active = True
|
self.is_active = True
|
||||||
|
|
||||||
@@ -77,7 +78,7 @@ class UnifiedScheduler:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._tasks: dict[str, ScheduleTask] = {}
|
self._tasks: dict[str, ScheduleTask] = {}
|
||||||
self._running = False
|
self._running = False
|
||||||
self._check_task: Optional[asyncio.Task] = None
|
self._check_task: asyncio.Task | None = None
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._event_subscriptions: set[str] = set() # 追踪已订阅的事件
|
self._event_subscriptions: set[str] = set() # 追踪已订阅的事件
|
||||||
|
|
||||||
@@ -339,9 +340,9 @@ class UnifiedScheduler:
|
|||||||
trigger_type: TriggerType,
|
trigger_type: TriggerType,
|
||||||
trigger_config: dict[str, Any],
|
trigger_config: dict[str, Any],
|
||||||
is_recurring: bool = False,
|
is_recurring: bool = False,
|
||||||
task_name: Optional[str] = None,
|
task_name: str | None = None,
|
||||||
callback_args: Optional[tuple] = None,
|
callback_args: tuple | None = None,
|
||||||
callback_kwargs: Optional[dict] = None,
|
callback_kwargs: dict | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""创建调度任务(详细注释见文档)"""
|
"""创建调度任务(详细注释见文档)"""
|
||||||
schedule_id = str(uuid.uuid4())
|
schedule_id = str(uuid.uuid4())
|
||||||
@@ -430,7 +431,7 @@ class UnifiedScheduler:
|
|||||||
logger.info(f"恢复任务: {task.task_name} (ID: {schedule_id[:8]}...)")
|
logger.info(f"恢复任务: {task.task_name} (ID: {schedule_id[:8]}...)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_task_info(self, schedule_id: str) -> Optional[dict[str, Any]]:
|
async def get_task_info(self, schedule_id: str) -> dict[str, Any] | None:
|
||||||
"""获取任务信息"""
|
"""获取任务信息"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
task = self._tasks.get(schedule_id)
|
task = self._tasks.get(schedule_id)
|
||||||
@@ -449,7 +450,7 @@ class UnifiedScheduler:
|
|||||||
"trigger_config": task.trigger_config.copy(),
|
"trigger_config": task.trigger_config.copy(),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def list_tasks(self, trigger_type: Optional[TriggerType] = None) -> list[dict[str, Any]]:
|
async def list_tasks(self, trigger_type: TriggerType | None = None) -> list[dict[str, Any]]:
|
||||||
"""列出所有任务或指定类型的任务"""
|
"""列出所有任务或指定类型的任务"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|||||||
Reference in New Issue
Block a user