diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index ac41b12a3..4ee2f2cbb 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -330,48 +330,8 @@ class ExpressionLearner: """ current_time = time.time() - # 全局衰减所有已存储的表达方式 - for type in ["style", "grammar"]: - base_dir = os.path.join("data", "expression", f"learnt_{type}") - if not os.path.exists(base_dir): - logger.debug(f"目录不存在,跳过衰减: {base_dir}") - continue - - try: - chat_ids = os.listdir(base_dir) - logger.debug(f"在 {base_dir} 中找到 {len(chat_ids)} 个聊天ID目录进行衰减") - except Exception as e: - logger.error(f"读取目录失败 {base_dir}: {e}") - continue - - for chat_id in chat_ids: - file_path = os.path.join(base_dir, chat_id, "expressions.json") - if not os.path.exists(file_path): - continue - - try: - with open(file_path, "r", encoding="utf-8") as f: - expressions = json.load(f) - - if not isinstance(expressions, list): - logger.warning(f"表达方式文件格式错误,跳过衰减: {file_path}") - continue - - # 应用全局衰减 - decayed_expressions = self.apply_decay_to_expressions(expressions, current_time) - - # 保存衰减后的结果 - with open(file_path, "w", encoding="utf-8") as f: - json.dump(decayed_expressions, f, ensure_ascii=False, indent=2) - - logger.debug(f"已对 {file_path} 应用衰减,剩余 {len(decayed_expressions)} 个表达方式") - except json.JSONDecodeError as e: - logger.error(f"JSON解析失败,跳过衰减 {file_path}: {e}") - except PermissionError as e: - logger.error(f"权限不足,无法更新 {file_path}: {e}") - except Exception as e: - logger.error(f"全局衰减{type}表达方式失败 {file_path}: {e}") - continue + # 全局衰减所有已存储的表达方式(直接操作数据库) + self._apply_global_decay_to_database(current_time) learnt_style: Optional[List[Tuple[str, str, str]]] = [] learnt_grammar: Optional[List[Tuple[str, str, str]]] = [] @@ -388,6 +348,42 @@ class ExpressionLearner: return learnt_style, learnt_grammar + def _apply_global_decay_to_database(self, current_time: float) -> None: + """ + 对数据库中的所有表达方式应用全局衰减 + """ + try: + # 获取所有表达方式 + all_expressions = Expression.select() + + updated_count = 0 + deleted_count = 0 + + for expr in all_expressions: + # 计算时间差 + last_active = expr.last_active_time + time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 + + # 计算衰减值 + decay_value = self.calculate_decay_factor(time_diff_days) + new_count = max(0.01, expr.count - decay_value) + + if new_count <= 0.01: + # 如果count太小,删除这个表达方式 + expr.delete_instance() + deleted_count += 1 + else: + # 更新count + expr.count = new_count + expr.save() + updated_count += 1 + + if updated_count > 0 or deleted_count > 0: + logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") + + except Exception as e: + logger.error(f"数据库全局衰减失败: {e}") + def calculate_decay_factor(self, time_diff_days: float) -> float: """ 计算衰减值 @@ -410,30 +406,6 @@ class ExpressionLearner: return min(0.01, decay) - def apply_decay_to_expressions( - self, expressions: List[Dict[str, Any]], current_time: float - ) -> List[Dict[str, Any]]: - """ - 对表达式列表应用衰减 - 返回衰减后的表达式列表,移除count小于0的项 - """ - result = [] - for expr in expressions: - # 确保last_active_time存在,如果不存在则使用current_time - if "last_active_time" not in expr: - expr["last_active_time"] = current_time - - last_active = expr["last_active_time"] - time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - - decay_value = self.calculate_decay_factor(time_diff_days) - expr["count"] = max(0.01, expr.get("count", 1) - decay_value) - - if expr["count"] > 0: - result.append(expr) - - return result - async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]: # sourcery skip: use-join """ diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index d83d3a472..8358c7a2f 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -2,7 +2,7 @@ import json import time import random -from typing import List, Dict, Tuple, Optional +from typing import List, Dict, Tuple, Optional, Any from json_repair import repair_json from src.llm_models.utils_model import LLMRequest @@ -117,36 +117,42 @@ class ExpressionSelector: def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float - ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) - style_exprs = [] - grammar_exprs = [] - for cid in related_chat_ids: - style_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "style")) - grammar_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "grammar")) - style_exprs.extend([ - { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": cid, - "type": "style", - "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } for expr in style_query - ]) - grammar_exprs.extend([ - { - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": cid, - "type": "grammar", - "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } for expr in grammar_query - ]) + + # 优化:一次性查询所有相关chat_id的表达方式 + style_query = Expression.select().where( + (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") + ) + grammar_query = Expression.select().where( + (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar") + ) + + style_exprs = [ + { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": expr.chat_id, + "type": "style", + "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, + } for expr in style_query + ] + + grammar_exprs = [ + { + "situation": expr.situation, + "style": expr.style, + "count": expr.count, + "last_active_time": expr.last_active_time, + "source_id": expr.chat_id, + "type": "grammar", + "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, + } for expr in grammar_query + ] + style_num = int(total_num * style_percentage) grammar_num = int(total_num * grammar_percentage) # 按权重抽样(使用count作为权重) @@ -162,7 +168,7 @@ class ExpressionSelector: selected_grammar = [] return selected_style, selected_grammar - def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1): + def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" if not expressions_to_update: return @@ -203,7 +209,7 @@ class ExpressionSelector: max_num: int = 10, min_num: int = 5, target_message: Optional[str] = None, - ) -> List[Dict[str, str]]: + ) -> List[Dict[str, Any]]: # sourcery skip: inline-variable, list-comprehension """使用LLM选择适合的表达方式""" diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 01cc89e9a..6c2693572 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -139,7 +139,7 @@ class RelationshipManager: 请用json格式输出,引起了你的兴趣,或者有什么需要你记忆的点。 并为每个点赋予1-10的权重,权重越高,表示越重要。 格式如下: -{{ +[ {{ "point": "{person_name}想让我记住他的生日,我回答确认了,他的生日是11月23日", "weight": 10 @@ -156,13 +156,10 @@ class RelationshipManager: "point": "{person_name}喜欢吃辣,具体来说,没有辣的食物ta都不喜欢吃,可能是因为ta是湖南人。", "weight": 7 }} -}} +] -如果没有,就输出none,或points为空: -{{ - "point": "none", - "weight": 0 -}} +如果没有,就输出none,或返回空数组: +[] """ # 调用LLM生成印象 @@ -184,17 +181,25 @@ class RelationshipManager: try: points = repair_json(points) points_data = json.loads(points) - if points_data == "none" or not points_data or points_data.get("point") == "none": + + # 只处理正确的格式,错误格式直接跳过 + if points_data == "none" or not points_data: points_list = [] + elif isinstance(points_data, str) and points_data.lower() == "none": + points_list = [] + elif isinstance(points_data, list): + # 正确格式:数组格式 [{"point": "...", "weight": 10}, ...] + if not points_data: # 空数组 + points_list = [] + else: + points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data] else: - # logger.info(f"points_data: {points_data}") - if isinstance(points_data, dict) and "points" in points_data: - points_data = points_data["points"] - if not isinstance(points_data, list): - points_data = [points_data] - # 添加可读时间到每个point - points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data] + # 错误格式,直接跳过不解析 + logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(points_data)}, 内容: {points_data}") + points_list = [] + # 权重过滤逻辑 + if points_list: original_points_list = list(points_list) points_list.clear() discarded_count = 0