fix:修复一处关系构建错误,修复一处表达方式错误
This commit is contained in:
@@ -330,48 +330,8 @@ class ExpressionLearner:
|
|||||||
"""
|
"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# 全局衰减所有已存储的表达方式
|
# 全局衰减所有已存储的表达方式(直接操作数据库)
|
||||||
for type in ["style", "grammar"]:
|
self._apply_global_decay_to_database(current_time)
|
||||||
base_dir = os.path.join("data", "expression", f"learnt_{type}")
|
|
||||||
if not os.path.exists(base_dir):
|
|
||||||
logger.debug(f"目录不存在,跳过衰减: {base_dir}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
chat_ids = os.listdir(base_dir)
|
|
||||||
logger.debug(f"在 {base_dir} 中找到 {len(chat_ids)} 个聊天ID目录进行衰减")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取目录失败 {base_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
for chat_id in chat_ids:
|
|
||||||
file_path = os.path.join(base_dir, chat_id, "expressions.json")
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
expressions = json.load(f)
|
|
||||||
|
|
||||||
if not isinstance(expressions, list):
|
|
||||||
logger.warning(f"表达方式文件格式错误,跳过衰减: {file_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 应用全局衰减
|
|
||||||
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
|
|
||||||
|
|
||||||
# 保存衰减后的结果
|
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
logger.debug(f"已对 {file_path} 应用衰减,剩余 {len(decayed_expressions)} 个表达方式")
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"JSON解析失败,跳过衰减 {file_path}: {e}")
|
|
||||||
except PermissionError as e:
|
|
||||||
logger.error(f"权限不足,无法更新 {file_path}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"全局衰减{type}表达方式失败 {file_path}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
||||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
|
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
|
||||||
@@ -388,6 +348,42 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
return learnt_style, learnt_grammar
|
return learnt_style, learnt_grammar
|
||||||
|
|
||||||
|
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||||
|
"""
|
||||||
|
对数据库中的所有表达方式应用全局衰减
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取所有表达方式
|
||||||
|
all_expressions = Expression.select()
|
||||||
|
|
||||||
|
updated_count = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
for expr in all_expressions:
|
||||||
|
# 计算时间差
|
||||||
|
last_active = expr.last_active_time
|
||||||
|
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||||
|
|
||||||
|
# 计算衰减值
|
||||||
|
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||||
|
new_count = max(0.01, expr.count - decay_value)
|
||||||
|
|
||||||
|
if new_count <= 0.01:
|
||||||
|
# 如果count太小,删除这个表达方式
|
||||||
|
expr.delete_instance()
|
||||||
|
deleted_count += 1
|
||||||
|
else:
|
||||||
|
# 更新count
|
||||||
|
expr.count = new_count
|
||||||
|
expr.save()
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
if updated_count > 0 or deleted_count > 0:
|
||||||
|
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库全局衰减失败: {e}")
|
||||||
|
|
||||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||||
"""
|
"""
|
||||||
计算衰减值
|
计算衰减值
|
||||||
@@ -410,30 +406,6 @@ class ExpressionLearner:
|
|||||||
|
|
||||||
return min(0.01, decay)
|
return min(0.01, decay)
|
||||||
|
|
||||||
def apply_decay_to_expressions(
|
|
||||||
self, expressions: List[Dict[str, Any]], current_time: float
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
对表达式列表应用衰减
|
|
||||||
返回衰减后的表达式列表,移除count小于0的项
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
for expr in expressions:
|
|
||||||
# 确保last_active_time存在,如果不存在则使用current_time
|
|
||||||
if "last_active_time" not in expr:
|
|
||||||
expr["last_active_time"] = current_time
|
|
||||||
|
|
||||||
last_active = expr["last_active_time"]
|
|
||||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
|
||||||
|
|
||||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
|
||||||
expr["count"] = max(0.01, expr.get("count", 1) - decay_value)
|
|
||||||
|
|
||||||
if expr["count"] > 0:
|
|
||||||
result.append(expr)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||||
# sourcery skip: use-join
|
# sourcery skip: use-join
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from typing import List, Dict, Tuple, Optional
|
from typing import List, Dict, Tuple, Optional, Any
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -117,36 +117,42 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
def get_random_expressions(
|
def get_random_expressions(
|
||||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
# 支持多chat_id合并抽选
|
# 支持多chat_id合并抽选
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||||
style_exprs = []
|
|
||||||
grammar_exprs = []
|
# 优化:一次性查询所有相关chat_id的表达方式
|
||||||
for cid in related_chat_ids:
|
style_query = Expression.select().where(
|
||||||
style_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "style"))
|
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||||
grammar_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "grammar"))
|
)
|
||||||
style_exprs.extend([
|
grammar_query = Expression.select().where(
|
||||||
|
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
|
||||||
|
)
|
||||||
|
|
||||||
|
style_exprs = [
|
||||||
{
|
{
|
||||||
"situation": expr.situation,
|
"situation": expr.situation,
|
||||||
"style": expr.style,
|
"style": expr.style,
|
||||||
"count": expr.count,
|
"count": expr.count,
|
||||||
"last_active_time": expr.last_active_time,
|
"last_active_time": expr.last_active_time,
|
||||||
"source_id": cid,
|
"source_id": expr.chat_id,
|
||||||
"type": "style",
|
"type": "style",
|
||||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||||
} for expr in style_query
|
} for expr in style_query
|
||||||
])
|
]
|
||||||
grammar_exprs.extend([
|
|
||||||
|
grammar_exprs = [
|
||||||
{
|
{
|
||||||
"situation": expr.situation,
|
"situation": expr.situation,
|
||||||
"style": expr.style,
|
"style": expr.style,
|
||||||
"count": expr.count,
|
"count": expr.count,
|
||||||
"last_active_time": expr.last_active_time,
|
"last_active_time": expr.last_active_time,
|
||||||
"source_id": cid,
|
"source_id": expr.chat_id,
|
||||||
"type": "grammar",
|
"type": "grammar",
|
||||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||||
} for expr in grammar_query
|
} for expr in grammar_query
|
||||||
])
|
]
|
||||||
|
|
||||||
style_num = int(total_num * style_percentage)
|
style_num = int(total_num * style_percentage)
|
||||||
grammar_num = int(total_num * grammar_percentage)
|
grammar_num = int(total_num * grammar_percentage)
|
||||||
# 按权重抽样(使用count作为权重)
|
# 按权重抽样(使用count作为权重)
|
||||||
@@ -162,7 +168,7 @@ class ExpressionSelector:
|
|||||||
selected_grammar = []
|
selected_grammar = []
|
||||||
return selected_style, selected_grammar
|
return selected_style, selected_grammar
|
||||||
|
|
||||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
|
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||||
if not expressions_to_update:
|
if not expressions_to_update:
|
||||||
return
|
return
|
||||||
@@ -203,7 +209,7 @@ class ExpressionSelector:
|
|||||||
max_num: int = 10,
|
max_num: int = 10,
|
||||||
min_num: int = 5,
|
min_num: int = 5,
|
||||||
target_message: Optional[str] = None,
|
target_message: Optional[str] = None,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, Any]]:
|
||||||
# sourcery skip: inline-variable, list-comprehension
|
# sourcery skip: inline-variable, list-comprehension
|
||||||
"""使用LLM选择适合的表达方式"""
|
"""使用LLM选择适合的表达方式"""
|
||||||
|
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class RelationshipManager:
|
|||||||
请用json格式输出,引起了你的兴趣,或者有什么需要你记忆的点。
|
请用json格式输出,引起了你的兴趣,或者有什么需要你记忆的点。
|
||||||
并为每个点赋予1-10的权重,权重越高,表示越重要。
|
并为每个点赋予1-10的权重,权重越高,表示越重要。
|
||||||
格式如下:
|
格式如下:
|
||||||
{{
|
[
|
||||||
{{
|
{{
|
||||||
"point": "{person_name}想让我记住他的生日,我回答确认了,他的生日是11月23日",
|
"point": "{person_name}想让我记住他的生日,我回答确认了,他的生日是11月23日",
|
||||||
"weight": 10
|
"weight": 10
|
||||||
@@ -156,13 +156,10 @@ class RelationshipManager:
|
|||||||
"point": "{person_name}喜欢吃辣,具体来说,没有辣的食物ta都不喜欢吃,可能是因为ta是湖南人。",
|
"point": "{person_name}喜欢吃辣,具体来说,没有辣的食物ta都不喜欢吃,可能是因为ta是湖南人。",
|
||||||
"weight": 7
|
"weight": 7
|
||||||
}}
|
}}
|
||||||
}}
|
]
|
||||||
|
|
||||||
如果没有,就输出none,或points为空:
|
如果没有,就输出none,或返回空数组:
|
||||||
{{
|
[]
|
||||||
"point": "none",
|
|
||||||
"weight": 0
|
|
||||||
}}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 调用LLM生成印象
|
# 调用LLM生成印象
|
||||||
@@ -184,17 +181,25 @@ class RelationshipManager:
|
|||||||
try:
|
try:
|
||||||
points = repair_json(points)
|
points = repair_json(points)
|
||||||
points_data = json.loads(points)
|
points_data = json.loads(points)
|
||||||
if points_data == "none" or not points_data or points_data.get("point") == "none":
|
|
||||||
|
# 只处理正确的格式,错误格式直接跳过
|
||||||
|
if points_data == "none" or not points_data:
|
||||||
|
points_list = []
|
||||||
|
elif isinstance(points_data, str) and points_data.lower() == "none":
|
||||||
|
points_list = []
|
||||||
|
elif isinstance(points_data, list):
|
||||||
|
# 正确格式:数组格式 [{"point": "...", "weight": 10}, ...]
|
||||||
|
if not points_data: # 空数组
|
||||||
points_list = []
|
points_list = []
|
||||||
else:
|
else:
|
||||||
# logger.info(f"points_data: {points_data}")
|
|
||||||
if isinstance(points_data, dict) and "points" in points_data:
|
|
||||||
points_data = points_data["points"]
|
|
||||||
if not isinstance(points_data, list):
|
|
||||||
points_data = [points_data]
|
|
||||||
# 添加可读时间到每个point
|
|
||||||
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
|
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
|
||||||
|
else:
|
||||||
|
# 错误格式,直接跳过不解析
|
||||||
|
logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(points_data)}, 内容: {points_data}")
|
||||||
|
points_list = []
|
||||||
|
|
||||||
|
# 权重过滤逻辑
|
||||||
|
if points_list:
|
||||||
original_points_list = list(points_list)
|
original_points_list = list(points_list)
|
||||||
points_list.clear()
|
points_list.clear()
|
||||||
discarded_count = 0
|
discarded_count = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user