fix:修复表达提取无法提高count的问题

This commit is contained in:
SengokuCola
2025-06-26 00:44:51 +08:00
parent 040ebf18d9
commit 29a3183ba7
5 changed files with 115 additions and 489 deletions

View File

@@ -108,55 +108,64 @@ class ExpressionSelector:
return selected_style, selected_grammar, selected_personality
def update_expression_count(self, chat_id: str, expression: Dict[str, str], increment: float = 0.1):
"""更新表达方式count值
Args:
chat_id: 聊天ID
expression: 表达方式字典
increment: 增量值默认0.1
"""
if expression.get("type") == "style_personality":
# personality表达方式存储在全局文件中
file_path = os.path.join("data", "expression", "personality", "expressions.json")
else:
# style和grammar表达方式存储在对应chat_id目录中
expr_type = expression.get("type", "style")
if expr_type == "style":
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
elif expr_type == "grammar":
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
else:
return
if not os.path.exists(file_path):
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
"""对一批表达方式更新count值,按文件分组后一次性写入"""
if not expressions_to_update:
return
try:
with open(file_path, "r", encoding="utf-8") as f:
expressions = json.load(f)
updates_by_file = {}
for expr in expressions_to_update:
source_id = expr.get("source_id")
if not source_id:
logger.warning(f"表达方式缺少source_id无法更新: {expr}")
continue
# 找到匹配的表达方式并更新count
for expr in expressions:
if expr.get("situation") == expression.get("situation") and expr.get("style") == expression.get(
"style"
):
current_count = expr.get("count", 1)
file_path = ""
if source_id == "personality":
file_path = os.path.join("data", "expression", "personality", "expressions.json")
else:
chat_id = source_id
expr_type = expr.get("type", "style")
if expr_type == "style":
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
elif expr_type == "grammar":
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
# 简单加0.1但限制最高为5
new_count = min(current_count + increment, 5.0)
expr["count"] = new_count
expr["last_active_time"] = time.time()
if file_path:
if file_path not in updates_by_file:
updates_by_file[file_path] = []
updates_by_file[file_path].append(expr)
logger.info(f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f}")
break
for file_path, updates in updates_by_file.items():
if not os.path.exists(file_path):
continue
# 保存更新后的文件
with open(file_path, "w", encoding="utf-8") as f:
json.dump(expressions, f, ensure_ascii=False, indent=2)
try:
with open(file_path, "r", encoding="utf-8") as f:
all_expressions = json.load(f)
except Exception as e:
logger.error(f"更新表达方式count失败: {e}")
# Create a dictionary for quick lookup
expr_map = {(e.get("situation"), e.get("style")): e for e in all_expressions}
# Update counts in memory
for expr_to_update in updates:
key = (expr_to_update.get("situation"), expr_to_update.get("style"))
if key in expr_map:
expr_in_map = expr_map[key]
current_count = expr_in_map.get("count", 1)
new_count = min(current_count + increment, 5.0)
expr_in_map["count"] = new_count
expr_in_map["last_active_time"] = time.time()
logger.info(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in {file_path}"
)
# Save the updated list once for this file
with open(file_path, "w", encoding="utf-8") as f:
json.dump(all_expressions, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
async def select_suitable_expressions_llm(
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5
@@ -237,8 +246,9 @@ class ExpressionSelector:
expression = all_expressions[idx - 1] # 索引从1开始
valid_expressions.append(expression)
# 对选中的表达方式count数+0.1
self.update_expression_count(chat_id, expression, 0.001)
# 对选中的所有表达方式,一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.003)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions

View File

@@ -72,77 +72,57 @@ class ExpressionLearner:
temperature=0.2,
request_type="expressor.learner",
)
self.llm_model = None
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
def get_expression_by_chat_id(
self, chat_id: str
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]]:
"""
读取/data/expression/learnt/{chat_id}/expressions.json和/data/expression/personality/expressions.json
返回(learnt_expressions, personality_expressions)
获取指定chat_id的style和grammar表达方式, 同时获取全局的personality表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
"""
expression_groups = global_config.expression.expression_groups
chat_ids_to_load = [chat_id]
# 获取当前chat_id的类型
chat_stream = get_chat_manager().get_stream(chat_id)
if chat_stream is None:
# 如果聊天流不在内存中跳过互通组查找直接使用当前chat_id
logger.warning(f"聊天流 {chat_id} 不在内存中,跳过互通组查找")
chat_ids_to_load = [chat_id]
else:
platform = chat_stream.platform
if chat_stream.group_info:
current_chat_type = "group"
typed_chat_id = f"{platform}:{chat_stream.group_info.group_id}:{current_chat_type}"
else:
current_chat_type = "private"
typed_chat_id = f"{platform}:{chat_stream.user_info.user_id}:{current_chat_type}"
logger.debug(f"正在为 {typed_chat_id} 查找互通组...")
found_group = None
for group in expression_groups:
# logger.info(f"正在检查互通组: {group}")
# logger.info(f"当前chat_id: {typed_chat_id}")
if typed_chat_id in group:
found_group = group
# logger.info(f"找到互通组: {group}")
break
if not found_group:
logger.debug(f"未找到互通组,仅加载 {chat_id} 的表达方式")
if found_group:
# 从带类型的id中解析出原始id
parsed_ids = []
for item in found_group:
try:
platform, id, type = item.split(":")
chat_id = get_chat_manager().get_stream_id(platform, id, type == "group")
parsed_ids.append(chat_id)
except Exception:
logger.warning(f"无法解析互通组中的ID: {item}")
chat_ids_to_load = parsed_ids
logger.debug(f"将要加载以下id的表达方式: {chat_ids_to_load}")
learnt_style_expressions = []
learnt_grammar_expressions = []
for id_to_load in chat_ids_to_load:
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(id_to_load), "expressions.json")
learnt_grammar_file = os.path.join(
"data", "expression", "learnt_grammar", str(id_to_load), "expressions.json"
)
if os.path.exists(learnt_style_file):
with open(learnt_style_file, "r", encoding="utf-8") as f:
learnt_style_expressions.extend(json.load(f))
if os.path.exists(learnt_grammar_file):
with open(learnt_grammar_file, "r", encoding="utf-8") as f:
learnt_grammar_expressions.extend(json.load(f))
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
personality_expressions = []
# 获取style表达方式
style_dir = os.path.join("data", "expression", "learnt_style", str(chat_id))
style_file = os.path.join(style_dir, "expressions.json")
if os.path.exists(style_file):
try:
with open(style_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
for expr in expressions:
expr["source_id"] = chat_id # 添加来源ID
learnt_style_expressions.append(expr)
except Exception as e:
logger.error(f"读取style表达方式失败: {e}")
# 获取grammar表达方式
grammar_dir = os.path.join("data", "expression", "learnt_grammar", str(chat_id))
grammar_file = os.path.join(grammar_dir, "expressions.json")
if os.path.exists(grammar_file):
try:
with open(grammar_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
for expr in expressions:
expr["source_id"] = chat_id # 添加来源ID
learnt_grammar_expressions.append(expr)
except Exception as e:
logger.error(f"读取grammar表达方式失败: {e}")
# 获取personality表达方式
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
if os.path.exists(personality_file):
with open(personality_file, "r", encoding="utf-8") as f:
personality_expressions = json.load(f)
try:
with open(personality_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
for expr in expressions:
expr["source_id"] = "personality" # 添加来源ID
personality_expressions.append(expr)
except Exception as e:
logger.error(f"读取personality表达方式失败: {e}")
return learnt_style_expressions, learnt_grammar_expressions, personality_expressions
def is_similar(self, s1: str, s2: str) -> bool:
@@ -205,28 +185,24 @@ class ExpressionLearner:
def calculate_decay_factor(self, time_diff_days: float) -> float:
"""
计算衰减值
当时间差为0天时衰减值为0.001
当时间差为7天时衰减值为0
当时间差为30天时衰减值为0.001
当时间差为0天时衰减值为0(最近活跃的不衰减)
当时间差为7天时衰减值为0.002(中等衰减)
当时间差为30天或更长衰减值为0.01(高衰减)
使用二次函数进行曲线插值
"""
if time_diff_days <= 0 or time_diff_days >= DECAY_DAYS:
return 0.001
# 使用二次函数进行插值
# 将7天作为顶点0天和30天作为两个端点
# 使用顶点式y = a(x-h)^2 + k其中(h,k)为顶点
h = 7.0 # 顶点x坐标
k = 0.001 # 顶点y坐标
# 计算a值使得x=0和x=30时y=0.001
# 0.001 = a(0-7)^2 + 0.001
# 解得a = 0
a = 0
# 计算衰减值
decay = a * (time_diff_days - h) ** 2 + k
return min(0.001, decay)
if time_diff_days <= 0:
return 0.0 # 刚激活的表达式不衰减
if time_diff_days >= DECAY_DAYS:
return 0.01 # 长时间未活跃的表达式大幅衰减
# 使用二次函数插值在0-30天之间从0衰减到0.01
# 使用简单的二次函数y = a * x^2
# 当x=30时y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
a = 0.01 / (DECAY_DAYS ** 2)
decay = a * (time_diff_days ** 2)
return min(0.01, decay)
def apply_decay_to_expressions(
self, expressions: List[Dict[str, Any]], current_time: float