fix:修复表达提取无法提高count的问题
This commit is contained in:
@@ -88,7 +88,7 @@ def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
|
||||
last_active = expr.get("last_active_time", time.time())
|
||||
f.write(f"场景: {expr['situation']}\n")
|
||||
f.write(f"表达: {expr['style']}\n")
|
||||
f.write(f"计数: {count:.2f}\n")
|
||||
f.write(f"计数: {count:.4f}\n")
|
||||
f.write(f"最后活跃: {format_time(last_active)}\n")
|
||||
f.write("-" * 40 + "\n")
|
||||
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
HFC性能记录功能测试脚本
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.chat.focus_chat.hfc_performance_logger import HFCPerformanceLogger
|
||||
from src.chat.focus_chat.hfc_version_manager import set_hfc_version, get_hfc_version, auto_generate_hfc_version
|
||||
|
||||
|
||||
def test_performance_logger():
|
||||
"""测试性能记录器功能"""
|
||||
|
||||
# 设置测试版本号
|
||||
test_version = "v1.2.3_test"
|
||||
set_hfc_version(test_version)
|
||||
print(f"设置测试版本号: {test_version}")
|
||||
print(f"当前版本号: {get_hfc_version()}")
|
||||
|
||||
# 创建测试用的性能记录器
|
||||
test_chat_id = "test_chat_123"
|
||||
logger = HFCPerformanceLogger(test_chat_id, test_version)
|
||||
|
||||
print(f"测试 HFC 性能记录器 - Chat ID: {test_chat_id}, Version: {logger.version}")
|
||||
|
||||
# 模拟记录几个循环的数据
|
||||
test_cycles = [
|
||||
{
|
||||
"cycle_id": 1,
|
||||
"action_type": "reply",
|
||||
"total_time": 2.5,
|
||||
"step_times": {"观察": 0.1, "并行调整动作、处理": 1.2, "规划器": 0.8, "执行动作": 0.4},
|
||||
"reasoning": "用户询问天气,需要回复",
|
||||
"success": True,
|
||||
},
|
||||
{
|
||||
"cycle_id": 2,
|
||||
"action_type": "no_reply",
|
||||
"total_time": 1.8,
|
||||
"step_times": {"观察": 0.08, "并行调整动作、处理": 0.9, "规划器": 0.6, "执行动作": 0.22},
|
||||
"reasoning": "无需回复的日常对话",
|
||||
"success": True,
|
||||
},
|
||||
{
|
||||
"cycle_id": 3,
|
||||
"action_type": "reply",
|
||||
"total_time": 3.2,
|
||||
"step_times": {"观察": 0.12, "并行调整动作、处理": 1.5, "规划器": 1.1, "执行动作": 0.48},
|
||||
"reasoning": "用户提出复杂问题,需要详细回复",
|
||||
"success": True,
|
||||
},
|
||||
{
|
||||
"cycle_id": 4,
|
||||
"action_type": "no_reply",
|
||||
"total_time": 1.5,
|
||||
"step_times": {"观察": 0.07, "并行调整动作、处理": 0.8, "规划器": 0.5, "执行动作": 0.13},
|
||||
"reasoning": "群聊中的无关对话",
|
||||
"success": True,
|
||||
},
|
||||
{
|
||||
"cycle_id": 5,
|
||||
"action_type": "error",
|
||||
"total_time": 0.5,
|
||||
"step_times": {"观察": 0.05, "并行调整动作、处理": 0.2, "规划器": 0.15, "执行动作": 0.1},
|
||||
"reasoning": "处理过程中出现错误",
|
||||
"success": False,
|
||||
},
|
||||
]
|
||||
|
||||
# 记录测试数据
|
||||
for cycle_data in test_cycles:
|
||||
logger.record_cycle(cycle_data)
|
||||
print(f"已记录循环 {cycle_data['cycle_id']}: {cycle_data['action_type']} ({cycle_data['total_time']:.1f}s)")
|
||||
|
||||
# 获取当前会话统计
|
||||
current_stats = logger.get_current_session_stats()
|
||||
print("\n=== 当前会话统计 ===")
|
||||
print(json.dumps(current_stats, ensure_ascii=False, indent=2))
|
||||
|
||||
# 完成会话
|
||||
logger.finalize_session()
|
||||
print("\n=== 会话已完成 ===")
|
||||
print(f"日志文件: {logger.session_file}")
|
||||
print(f"统计文件: {logger.stats_file}")
|
||||
|
||||
# 检查生成的文件
|
||||
if logger.session_file.exists():
|
||||
print(f"\n会话文件大小: {logger.session_file.stat().st_size} 字节")
|
||||
|
||||
if logger.stats_file.exists():
|
||||
print(f"统计文件大小: {logger.stats_file.stat().st_size} 字节")
|
||||
|
||||
# 读取并显示统计数据
|
||||
with open(logger.stats_file, "r", encoding="utf-8") as f:
|
||||
stats_data = json.load(f)
|
||||
|
||||
print("\n=== 最终统计数据 ===")
|
||||
if test_chat_id in stats_data:
|
||||
chat_stats = stats_data[test_chat_id]
|
||||
print(f"Chat ID: {test_chat_id}")
|
||||
print(f"最后更新: {chat_stats['last_updated']}")
|
||||
print(f"总记录数: {chat_stats['overall']['total_records']}")
|
||||
print(f"平均总时间: {chat_stats['overall']['avg_total_time']:.2f}秒")
|
||||
|
||||
print("\n各步骤平均时间:")
|
||||
for step, avg_time in chat_stats["overall"]["avg_step_times"].items():
|
||||
print(f" {step}: {avg_time:.3f}秒")
|
||||
|
||||
print("\n按动作类型统计:")
|
||||
for action, action_stats in chat_stats["by_action"].items():
|
||||
print(
|
||||
f" {action}: {action_stats['count']}次 ({action_stats['percentage']:.1f}%), 平均{action_stats['avg_total_time']:.2f}秒"
|
||||
)
|
||||
|
||||
|
||||
def test_version_manager():
|
||||
"""测试版本号管理功能"""
|
||||
print("\n=== 测试版本号管理器 ===")
|
||||
|
||||
# 测试默认版本
|
||||
print(f"默认版本: {get_hfc_version()}")
|
||||
|
||||
# 测试设置版本
|
||||
test_versions = ["v2.0.0", "1.5.0", "v1.0.0.beta", "v1.0.build123"]
|
||||
for version in test_versions:
|
||||
success = set_hfc_version(version)
|
||||
print(f"设置版本 '{version}': {'成功' if success else '失败'} -> {get_hfc_version()}")
|
||||
|
||||
# 测试自动生成版本
|
||||
auto_version = auto_generate_hfc_version()
|
||||
print(f"自动生成版本: {auto_version}")
|
||||
|
||||
# 测试基于现有版本的自动生成
|
||||
auto_version2 = auto_generate_hfc_version("v2.1.0")
|
||||
print(f"基于v2.1.0自动生成: {auto_version2}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_version_manager()
|
||||
test_performance_logger()
|
||||
@@ -1,214 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
版本兼容性检查测试脚本
|
||||
|
||||
测试版本号标准化、比较和兼容性检查功能
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
def test_version_normalization():
|
||||
"""测试版本号标准化功能"""
|
||||
print("🧪 测试版本号标准化...")
|
||||
|
||||
test_cases = [
|
||||
("0.8.0-snapshot.1", "0.8.0"),
|
||||
("0.8.0-snapshot.2", "0.8.0"),
|
||||
("0.8.0", "0.8.0"),
|
||||
("0.9.0-snapshot.1", "0.9.0"),
|
||||
("1.0.0", "1.0.0"),
|
||||
("2.1", "2.1.0"),
|
||||
("3", "3.0.0"),
|
||||
("", "0.0.0"),
|
||||
("invalid", "0.0.0"),
|
||||
]
|
||||
|
||||
for input_version, expected in test_cases:
|
||||
result = VersionComparator.normalize_version(input_version)
|
||||
status = "✅" if result == expected else "❌"
|
||||
print(f" {status} {input_version} -> {result} (期望: {expected})")
|
||||
|
||||
|
||||
def test_version_comparison():
|
||||
"""测试版本号比较功能"""
|
||||
print("\n🧪 测试版本号比较...")
|
||||
|
||||
test_cases = [
|
||||
("0.8.0", "0.9.0", -1), # 0.8.0 < 0.9.0
|
||||
("0.9.0", "0.8.0", 1), # 0.9.0 > 0.8.0
|
||||
("1.0.0", "1.0.0", 0), # 1.0.0 == 1.0.0
|
||||
("0.8.0-snapshot.1", "0.8.0", 0), # 标准化后相等
|
||||
("1.2.3", "1.2.4", -1), # 1.2.3 < 1.2.4
|
||||
("2.0.0", "1.9.9", 1), # 2.0.0 > 1.9.9
|
||||
]
|
||||
|
||||
for v1, v2, expected in test_cases:
|
||||
result = VersionComparator.compare_versions(v1, v2)
|
||||
status = "✅" if result == expected else "❌"
|
||||
comparison = "<" if expected == -1 else ">" if expected == 1 else "=="
|
||||
print(f" {status} {v1} {comparison} {v2} (结果: {result})")
|
||||
|
||||
|
||||
def test_version_range_check():
|
||||
"""测试版本范围检查功能"""
|
||||
print("\n🧪 测试版本范围检查...")
|
||||
|
||||
test_cases = [
|
||||
("0.8.0", "0.7.0", "0.9.0", True), # 在范围内
|
||||
("0.6.0", "0.7.0", "0.9.0", False), # 低于最小版本
|
||||
("1.0.0", "0.7.0", "0.9.0", False), # 高于最大版本
|
||||
("0.8.0", "0.8.0", "0.8.0", True), # 等于边界
|
||||
("0.8.0", "", "0.9.0", True), # 只有最大版本限制
|
||||
("0.8.0", "0.7.0", "", True), # 只有最小版本限制
|
||||
("0.8.0", "", "", True), # 无版本限制
|
||||
]
|
||||
|
||||
for version, min_ver, max_ver, expected in test_cases:
|
||||
is_compatible, error_msg = VersionComparator.is_version_in_range(version, min_ver, max_ver)
|
||||
status = "✅" if is_compatible == expected else "❌"
|
||||
range_str = f"[{min_ver or '无限制'}, {max_ver or '无限制'}]"
|
||||
print(f" {status} {version} 在范围 {range_str}: {is_compatible}")
|
||||
if error_msg:
|
||||
print(f" 错误信息: {error_msg}")
|
||||
|
||||
|
||||
def test_current_version():
|
||||
"""测试获取当前版本功能"""
|
||||
print("\n🧪 测试获取当前主机版本...")
|
||||
|
||||
try:
|
||||
current_version = VersionComparator.get_current_host_version()
|
||||
print(f" ✅ 当前主机版本: {current_version}")
|
||||
|
||||
# 验证版本号格式
|
||||
parts = current_version.split(".")
|
||||
if len(parts) == 3 and all(part.isdigit() for part in parts):
|
||||
print(" ✅ 版本号格式正确")
|
||||
else:
|
||||
print(" ❌ 版本号格式错误")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 获取当前版本失败: {e}")
|
||||
|
||||
|
||||
def test_manifest_compatibility():
|
||||
"""测试manifest兼容性检查"""
|
||||
print("\n🧪 测试manifest兼容性检查...")
|
||||
|
||||
# 模拟manifest数据
|
||||
test_manifests = [
|
||||
{"name": "兼容插件", "host_application": {"min_version": "0.1.0", "max_version": "2.0.0"}},
|
||||
{"name": "版本过高插件", "host_application": {"min_version": "10.0.0", "max_version": "20.0.0"}},
|
||||
{"name": "版本过低插件", "host_application": {"min_version": "0.1.0", "max_version": "0.2.0"}},
|
||||
{
|
||||
"name": "无版本要求插件",
|
||||
# 没有host_application字段
|
||||
},
|
||||
]
|
||||
|
||||
# 这里需要导入PluginManager来测试,但可能会有依赖问题
|
||||
# 所以我们直接使用VersionComparator进行测试
|
||||
current_version = VersionComparator.get_current_host_version()
|
||||
|
||||
for manifest in test_manifests:
|
||||
plugin_name = manifest["name"]
|
||||
|
||||
if "host_application" in manifest:
|
||||
host_app = manifest["host_application"]
|
||||
min_version = host_app.get("min_version", "")
|
||||
max_version = host_app.get("max_version", "")
|
||||
|
||||
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
|
||||
|
||||
status = "✅" if is_compatible else "❌"
|
||||
print(f" {status} {plugin_name}: {is_compatible}")
|
||||
if error_msg:
|
||||
print(f" {error_msg}")
|
||||
else:
|
||||
print(f" ✅ {plugin_name}: True (无版本要求)")
|
||||
|
||||
|
||||
def test_additional_snapshot_formats():
|
||||
"""测试额外的snapshot版本格式"""
|
||||
print("\n🧪 测试额外的snapshot版本格式...")
|
||||
|
||||
test_cases = [
|
||||
# 用户提到的版本格式
|
||||
("0.8.0-snapshot.1", "0.8.0"),
|
||||
("0.8.0-snapshot.2", "0.8.0"),
|
||||
("0.8.0", "0.8.0"),
|
||||
("0.9.0-snapshot.1", "0.9.0"),
|
||||
# 边界情况
|
||||
("1.0.0-snapshot.999", "1.0.0"),
|
||||
("2.15.3-snapshot.42", "2.15.3"),
|
||||
("10.5.0-snapshot.1", "10.5.0"),
|
||||
# 不正确的snapshot格式(应该被忽略或正确处理)
|
||||
("0.8.0-snapshot", "0.0.0"), # 无数字后缀,应该标准化为0.0.0
|
||||
("0.8.0-snapshot.abc", "0.0.0"), # 非数字后缀,应该标准化为0.0.0
|
||||
("0.8.0-beta.1", "0.0.0"), # 其他预发布版本,应该标准化为0.0.0
|
||||
]
|
||||
|
||||
for input_version, expected in test_cases:
|
||||
result = VersionComparator.normalize_version(input_version)
|
||||
status = "✅" if result == expected else "❌"
|
||||
print(f" {status} {input_version} -> {result} (期望: {expected})")
|
||||
|
||||
|
||||
def test_snapshot_version_comparison():
|
||||
"""测试snapshot版本的比较功能"""
|
||||
print("\n🧪 测试snapshot版本比较...")
|
||||
|
||||
test_cases = [
|
||||
# snapshot版本与正式版本比较
|
||||
("0.8.0-snapshot.1", "0.8.0", 0), # 应该相等
|
||||
("0.8.0-snapshot.2", "0.8.0", 0), # 应该相等
|
||||
("0.9.0-snapshot.1", "0.8.0", 1), # 应该大于
|
||||
("0.7.0-snapshot.1", "0.8.0", -1), # 应该小于
|
||||
# snapshot版本之间比较
|
||||
("0.8.0-snapshot.1", "0.8.0-snapshot.2", 0), # 都标准化为0.8.0,相等
|
||||
("0.9.0-snapshot.1", "0.8.0-snapshot.1", 1), # 0.9.0 > 0.8.0
|
||||
# 边界情况
|
||||
("1.0.0-snapshot.1", "0.9.9", 1), # 主版本更高
|
||||
("0.9.0-snapshot.1", "0.8.99", 1), # 次版本更高
|
||||
]
|
||||
|
||||
for version1, version2, expected in test_cases:
|
||||
result = VersionComparator.compare_versions(version1, version2)
|
||||
status = "✅" if result == expected else "❌"
|
||||
comparison = "<" if expected < 0 else "==" if expected == 0 else ">"
|
||||
print(f" {status} {version1} {comparison} {version2} (结果: {result})")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("🔧 MaiBot插件版本兼容性检查测试")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
test_version_normalization()
|
||||
test_version_comparison()
|
||||
test_version_range_check()
|
||||
test_current_version()
|
||||
test_manifest_compatibility()
|
||||
test_additional_snapshot_formats()
|
||||
test_snapshot_version_comparison()
|
||||
|
||||
print("\n🎉 所有测试完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试过程中发生错误: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -108,55 +108,64 @@ class ExpressionSelector:
|
||||
|
||||
return selected_style, selected_grammar, selected_personality
|
||||
|
||||
def update_expression_count(self, chat_id: str, expression: Dict[str, str], increment: float = 0.1):
|
||||
"""更新表达方式的count值
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按文件分组后一次性写入"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
expression: 表达方式字典
|
||||
increment: 增量值,默认0.1
|
||||
"""
|
||||
if expression.get("type") == "style_personality":
|
||||
# personality表达方式存储在全局文件中
|
||||
updates_by_file = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id = expr.get("source_id")
|
||||
if not source_id:
|
||||
logger.warning(f"表达方式缺少source_id,无法更新: {expr}")
|
||||
continue
|
||||
|
||||
file_path = ""
|
||||
if source_id == "personality":
|
||||
file_path = os.path.join("data", "expression", "personality", "expressions.json")
|
||||
else:
|
||||
# style和grammar表达方式存储在对应chat_id目录中
|
||||
expr_type = expression.get("type", "style")
|
||||
chat_id = source_id
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "style":
|
||||
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||
elif expr_type == "grammar":
|
||||
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||
else:
|
||||
return
|
||||
|
||||
if file_path:
|
||||
if file_path not in updates_by_file:
|
||||
updates_by_file[file_path] = []
|
||||
updates_by_file[file_path].append(expr)
|
||||
|
||||
for file_path, updates in updates_by_file.items():
|
||||
if not os.path.exists(file_path):
|
||||
return
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
all_expressions = json.load(f)
|
||||
|
||||
# 找到匹配的表达方式并更新count
|
||||
for expr in expressions:
|
||||
if expr.get("situation") == expression.get("situation") and expr.get("style") == expression.get(
|
||||
"style"
|
||||
):
|
||||
current_count = expr.get("count", 1)
|
||||
# Create a dictionary for quick lookup
|
||||
expr_map = {(e.get("situation"), e.get("style")): e for e in all_expressions}
|
||||
|
||||
# 简单加0.1,但限制最高为5
|
||||
# Update counts in memory
|
||||
for expr_to_update in updates:
|
||||
key = (expr_to_update.get("situation"), expr_to_update.get("style"))
|
||||
if key in expr_map:
|
||||
expr_in_map = expr_map[key]
|
||||
current_count = expr_in_map.get("count", 1)
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr["count"] = new_count
|
||||
expr["last_active_time"] = time.time()
|
||||
expr_in_map["count"] = new_count
|
||||
expr_in_map["last_active_time"] = time.time()
|
||||
logger.info(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in {file_path}"
|
||||
)
|
||||
|
||||
logger.info(f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f}")
|
||||
break
|
||||
|
||||
# 保存更新后的文件
|
||||
# Save the updated list once for this file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(expressions, f, ensure_ascii=False, indent=2)
|
||||
json.dump(all_expressions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新表达方式count失败: {e}")
|
||||
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5
|
||||
@@ -237,8 +246,9 @@ class ExpressionSelector:
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的表达方式count数+0.1
|
||||
self.update_expression_count(chat_id, expression, 0.001)
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
self.update_expressions_count_batch(valid_expressions, 0.003)
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions
|
||||
|
||||
@@ -72,77 +72,57 @@ class ExpressionLearner:
|
||||
temperature=0.2,
|
||||
request_type="expressor.learner",
|
||||
)
|
||||
self.llm_model = None
|
||||
|
||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
def get_expression_by_chat_id(
|
||||
self, chat_id: str
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
"""
|
||||
读取/data/expression/learnt/{chat_id}/expressions.json和/data/expression/personality/expressions.json
|
||||
返回(learnt_expressions, personality_expressions)
|
||||
获取指定chat_id的style和grammar表达方式, 同时获取全局的personality表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
"""
|
||||
expression_groups = global_config.expression.expression_groups
|
||||
chat_ids_to_load = [chat_id]
|
||||
|
||||
# 获取当前chat_id的类型
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if chat_stream is None:
|
||||
# 如果聊天流不在内存中,跳过互通组查找,直接使用当前chat_id
|
||||
logger.warning(f"聊天流 {chat_id} 不在内存中,跳过互通组查找")
|
||||
chat_ids_to_load = [chat_id]
|
||||
else:
|
||||
platform = chat_stream.platform
|
||||
if chat_stream.group_info:
|
||||
current_chat_type = "group"
|
||||
typed_chat_id = f"{platform}:{chat_stream.group_info.group_id}:{current_chat_type}"
|
||||
else:
|
||||
current_chat_type = "private"
|
||||
typed_chat_id = f"{platform}:{chat_stream.user_info.user_id}:{current_chat_type}"
|
||||
|
||||
logger.debug(f"正在为 {typed_chat_id} 查找互通组...")
|
||||
|
||||
found_group = None
|
||||
for group in expression_groups:
|
||||
# logger.info(f"正在检查互通组: {group}")
|
||||
# logger.info(f"当前chat_id: {typed_chat_id}")
|
||||
if typed_chat_id in group:
|
||||
found_group = group
|
||||
# logger.info(f"找到互通组: {group}")
|
||||
break
|
||||
|
||||
if not found_group:
|
||||
logger.debug(f"未找到互通组,仅加载 {chat_id} 的表达方式")
|
||||
|
||||
if found_group:
|
||||
# 从带类型的id中解析出原始id
|
||||
parsed_ids = []
|
||||
for item in found_group:
|
||||
try:
|
||||
platform, id, type = item.split(":")
|
||||
chat_id = get_chat_manager().get_stream_id(platform, id, type == "group")
|
||||
parsed_ids.append(chat_id)
|
||||
except Exception:
|
||||
logger.warning(f"无法解析互通组中的ID: {item}")
|
||||
chat_ids_to_load = parsed_ids
|
||||
logger.debug(f"将要加载以下id的表达方式: {chat_ids_to_load}")
|
||||
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
for id_to_load in chat_ids_to_load:
|
||||
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(id_to_load), "expressions.json")
|
||||
learnt_grammar_file = os.path.join(
|
||||
"data", "expression", "learnt_grammar", str(id_to_load), "expressions.json"
|
||||
)
|
||||
if os.path.exists(learnt_style_file):
|
||||
with open(learnt_style_file, "r", encoding="utf-8") as f:
|
||||
learnt_style_expressions.extend(json.load(f))
|
||||
if os.path.exists(learnt_grammar_file):
|
||||
with open(learnt_grammar_file, "r", encoding="utf-8") as f:
|
||||
learnt_grammar_expressions.extend(json.load(f))
|
||||
|
||||
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
|
||||
personality_expressions = []
|
||||
|
||||
# 获取style表达方式
|
||||
style_dir = os.path.join("data", "expression", "learnt_style", str(chat_id))
|
||||
style_file = os.path.join(style_dir, "expressions.json")
|
||||
if os.path.exists(style_file):
|
||||
try:
|
||||
with open(style_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = chat_id # 添加来源ID
|
||||
learnt_style_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取style表达方式失败: {e}")
|
||||
|
||||
# 获取grammar表达方式
|
||||
grammar_dir = os.path.join("data", "expression", "learnt_grammar", str(chat_id))
|
||||
grammar_file = os.path.join(grammar_dir, "expressions.json")
|
||||
if os.path.exists(grammar_file):
|
||||
try:
|
||||
with open(grammar_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = chat_id # 添加来源ID
|
||||
learnt_grammar_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取grammar表达方式失败: {e}")
|
||||
|
||||
# 获取personality表达方式
|
||||
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
|
||||
if os.path.exists(personality_file):
|
||||
try:
|
||||
with open(personality_file, "r", encoding="utf-8") as f:
|
||||
personality_expressions = json.load(f)
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = "personality" # 添加来源ID
|
||||
personality_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取personality表达方式失败: {e}")
|
||||
|
||||
return learnt_style_expressions, learnt_grammar_expressions, personality_expressions
|
||||
|
||||
def is_similar(self, s1: str, s2: str) -> bool:
|
||||
@@ -205,28 +185,24 @@ class ExpressionLearner:
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
"""
|
||||
计算衰减值
|
||||
当时间差为0天时,衰减值为0.001
|
||||
当时间差为7天时,衰减值为0
|
||||
当时间差为30天时,衰减值为0.001
|
||||
当时间差为0天时,衰减值为0(最近活跃的不衰减)
|
||||
当时间差为7天时,衰减值为0.002(中等衰减)
|
||||
当时间差为30天或更长时,衰减值为0.01(高衰减)
|
||||
使用二次函数进行曲线插值
|
||||
"""
|
||||
if time_diff_days <= 0 or time_diff_days >= DECAY_DAYS:
|
||||
return 0.001
|
||||
if time_diff_days <= 0:
|
||||
return 0.0 # 刚激活的表达式不衰减
|
||||
|
||||
# 使用二次函数进行插值
|
||||
# 将7天作为顶点,0天和30天作为两个端点
|
||||
# 使用顶点式:y = a(x-h)^2 + k,其中(h,k)为顶点
|
||||
h = 7.0 # 顶点x坐标
|
||||
k = 0.001 # 顶点y坐标
|
||||
if time_diff_days >= DECAY_DAYS:
|
||||
return 0.01 # 长时间未活跃的表达式大幅衰减
|
||||
|
||||
# 计算a值,使得x=0和x=30时y=0.001
|
||||
# 0.001 = a(0-7)^2 + 0.001
|
||||
# 解得a = 0
|
||||
a = 0
|
||||
# 使用二次函数插值:在0-30天之间从0衰减到0.01
|
||||
# 使用简单的二次函数:y = a * x^2
|
||||
# 当x=30时,y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
|
||||
a = 0.01 / (DECAY_DAYS ** 2)
|
||||
decay = a * (time_diff_days ** 2)
|
||||
|
||||
# 计算衰减值
|
||||
decay = a * (time_diff_days - h) ** 2 + k
|
||||
return min(0.001, decay)
|
||||
return min(0.01, decay)
|
||||
|
||||
def apply_decay_to_expressions(
|
||||
self, expressions: List[Dict[str, Any]], current_time: float
|
||||
|
||||
Reference in New Issue
Block a user