diff --git a/src/chat/focus_chat/expressors/exprssion_learner.py b/src/chat/focus_chat/expressors/exprssion_learner.py index 573801719..57f441a42 100644 --- a/src/chat/focus_chat/expressors/exprssion_learner.py +++ b/src/chat/focus_chat/expressors/exprssion_learner.py @@ -283,13 +283,31 @@ class ExpressionLearner: if len(old_data) > MAX_EXPRESSION_COUNT: # 计算每个表达方式的权重(count的倒数,这样count越小的越容易被选中) weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data] - # 归一化权重 - total_weight = sum(weights) - weights = [w / total_weight for w in weights] - # 随机选择要移除的表达方式 + # 随机选择要移除的表达方式,避免重复索引 remove_count = len(old_data) - MAX_EXPRESSION_COUNT - remove_indices = random.choices(range(len(old_data)), weights=weights, k=remove_count) + + # 使用一种不会选到重复索引的方法 + indices = list(range(len(old_data))) + + # 方法1:使用numpy.random.choice + # 把列表转成一个映射字典,保证不会有重复 + remove_set = set() + total_attempts = 0 + + # 尝试按权重随机选择,直到选够数量 + while len(remove_set) < remove_count and total_attempts < len(old_data) * 2: + idx = random.choices(indices, weights=weights, k=1)[0] + remove_set.add(idx) + total_attempts += 1 + + # 如果没选够,随机补充 + if len(remove_set) < remove_count: + remaining = set(indices) - remove_set + remove_set.update(random.sample(remaining, remove_count - len(remove_set))) + + remove_indices = list(remove_set) + # 从后往前删除,避免索引变化 for idx in sorted(remove_indices, reverse=True): old_data.pop(idx) diff --git a/src/chat/focus_chat/info_processors/relationship_processor.py b/src/chat/focus_chat/info_processors/relationship_processor.py index 656f01a0f..257594711 100644 --- a/src/chat/focus_chat/info_processors/relationship_processor.py +++ b/src/chat/focus_chat/info_processors/relationship_processor.py @@ -146,9 +146,9 @@ class RelationshipProcessor(BaseProcessor): time_elapsed = current_time - record["start_time"] message_count = len(get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time)) - if (record["rounds"] > 20 or + if (record["rounds"] > 50 or time_elapsed > 1800 or # 30分钟 - message_count > 50): + message_count > 75): logger.info(f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。") asyncio.create_task( self.update_impression_on_cache_expiry( diff --git a/src/chat/focus_chat/planners/actions/no_reply_complex_action.py b/src/chat/focus_chat/planners/actions/no_reply_complex_action.py deleted file mode 100644 index 120ebe981..000000000 --- a/src/chat/focus_chat/planners/actions/no_reply_complex_action.py +++ /dev/null @@ -1,134 +0,0 @@ -import asyncio -import traceback -from src.common.logger_manager import get_logger -from src.chat.utils.timer_calculator import Timer -from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action -from typing import Tuple, List -from src.chat.heart_flow.observation.observation import Observation -from src.chat.heart_flow.observation.chatting_observation import ChattingObservation -from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp - -logger = get_logger("action_taken") - -# 常量定义 -WAITING_TIME_THRESHOLD = 1200 # 等待新消息时间阈值,单位秒 - - -@register_action -class NoReplyAction(BaseAction): - """不回复动作处理类 - - 处理决定不回复的动作。 - """ - - action_name = "no_reply" - action_description = "不回复" - action_parameters = {} - action_require = [ - "话题无关/无聊/不感兴趣/不懂", - "聊天记录中最新一条消息是你自己发的且无人回应你", - "你连续发送了太多消息,且无人回复", - ] - default = True - - def __init__( - self, - action_data: dict, - reasoning: str, - cycle_timers: dict, - thinking_id: str, - observations: List[Observation], - log_prefix: str, - shutting_down: bool = False, - **kwargs, - ): - """初始化不回复动作处理器 - - Args: - action_name: 动作名称 - action_data: 动作数据 - reasoning: 执行该动作的理由 - cycle_timers: 计时器字典 - thinking_id: 思考ID - observations: 观察列表 - log_prefix: 日志前缀 - shutting_down: 是否正在关闭 - """ - super().__init__(action_data, reasoning, cycle_timers, thinking_id) - self.observations = observations - self.log_prefix = log_prefix - self._shutting_down = shutting_down - - async def handle_action(self) -> Tuple[bool, str]: - """ - 处理不回复的情况 - - 工作流程: - 1. 等待新消息、超时或关闭信号 - 2. 根据等待结果更新连续不回复计数 - 3. 如果达到阈值,触发回调 - - Returns: - Tuple[bool, str]: (是否执行成功, 空字符串) - """ - logger.info(f"{self.log_prefix} 决定不回复: {self.reasoning}") - - observation = self.observations[0] if self.observations else None - - try: - with Timer("等待新消息", self.cycle_timers): - # 等待新消息、超时或关闭信号,并获取结果 - await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix) - - return True, "" # 不回复动作没有回复文本 - - except asyncio.CancelledError: - logger.info(f"{self.log_prefix} 处理 'no_reply' 时等待被中断 (CancelledError)") - raise - except Exception as e: # 捕获调用管理器或其他地方可能发生的错误 - logger.error(f"{self.log_prefix} 处理 'no_reply' 时发生错误: {e}") - logger.error(traceback.format_exc()) - return False, "" - - async def _wait_for_new_message(self, observation: ChattingObservation, thinking_id: str, log_prefix: str) -> bool: - """ - 等待新消息 或 检测到关闭信号 - - 参数: - observation: 观察实例 - thinking_id: 思考ID - log_prefix: 日志前缀 - - 返回: - bool: 是否检测到新消息 (如果因关闭信号退出则返回 False) - """ - wait_start_time = asyncio.get_event_loop().time() - while True: - # --- 在每次循环开始时检查关闭标志 --- - if self._shutting_down: - logger.info(f"{log_prefix} 等待新消息时检测到关闭信号,中断等待。") - return False # 表示因为关闭而退出 - # ----------------------------------- - - thinking_id_timestamp = parse_thinking_id_to_timestamp(thinking_id) - - # 检查新消息 - if await observation.has_new_messages_since(thinking_id_timestamp): - logger.info(f"{log_prefix} 检测到新消息") - return True - - # 检查超时 (放在检查新消息和关闭之后) - if asyncio.get_event_loop().time() - wait_start_time > WAITING_TIME_THRESHOLD: - logger.warning(f"{log_prefix} 等待新消息超时({WAITING_TIME_THRESHOLD}秒)") - return False - - try: - # 短暂休眠,让其他任务有机会运行,并能更快响应取消或关闭 - await asyncio.sleep(0.5) # 缩短休眠时间 - except asyncio.CancelledError: - # 如果在休眠时被取消,再次检查关闭标志 - # 如果是正常关闭,则不需要警告 - if not self._shutting_down: - logger.warning(f"{log_prefix} _wait_for_new_message 的休眠被意外取消") - # 无论如何,重新抛出异常,让上层处理 - raise diff --git a/src/chat/focus_chat/planners/actions/plugin_action.py b/src/chat/focus_chat/planners/actions/plugin_action.py index d0c345718..bacd143d4 100644 --- a/src/chat/focus_chat/planners/actions/plugin_action.py +++ b/src/chat/focus_chat/planners/actions/plugin_action.py @@ -1,5 +1,5 @@ import traceback -from typing import Tuple, Dict, List, Any, Optional +from typing import Tuple, Dict, List, Any, Optional, Union, Type from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action # noqa F401 from src.chat.heart_flow.observation.chatting_observation import ChattingObservation from src.chat.focus_chat.hfc_utils import create_empty_anchor_message @@ -12,6 +12,9 @@ import os import inspect import toml # 导入 toml 库 from src.common.database.database_model import ActionRecords +from src.common.database.database import db +from peewee import Model, DoesNotExist +import json import time # 以下为类型注解需要 @@ -434,3 +437,332 @@ class PluginAction(BaseAction): except Exception as e: logger.error(f"{self.log_prefix} 存储action信息时出错: {e}") traceback.print_exc() + + async def db_query( + self, + model_class: Type[Model], + query_type: str = "get", + filters: Dict[str, Any] = None, + data: Dict[str, Any] = None, + limit: int = None, + order_by: List[str] = None, + single_result: bool = False + ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + """执行数据库查询操作 + + 这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。 + + Args: + model_class: Peewee 模型类,例如 ActionRecords, Messages 等 + query_type: 查询类型,可选值: "get", "create", "update", "delete", "count" + filters: 过滤条件字典,键为字段名,值为要匹配的值 + data: 用于创建或更新的数据字典 + limit: 限制结果数量 + order_by: 排序字段列表,使用字段名,前缀'-'表示降序 + single_result: 是否只返回单个结果 + + Returns: + 根据查询类型返回不同的结果: + - "get": 返回查询结果列表或单个结果(如果 single_result=True) + - "create": 返回创建的记录 + - "update": 返回受影响的行数 + - "delete": 返回受影响的行数 + - "count": 返回记录数量 + + 示例: + # 查询最近10条消息 + messages = await self.db_query( + Messages, + query_type="get", + filters={"chat_id": chat_stream.stream_id}, + limit=10, + order_by=["-time"] + ) + + # 创建一条记录 + new_record = await self.db_query( + ActionRecords, + query_type="create", + data={"action_id": "123", "time": time.time(), "action_name": "TestAction"} + ) + + # 更新记录 + updated_count = await self.db_query( + ActionRecords, + query_type="update", + filters={"action_id": "123"}, + data={"action_done": True} + ) + + # 删除记录 + deleted_count = await self.db_query( + ActionRecords, + query_type="delete", + filters={"action_id": "123"} + ) + + # 计数 + count = await self.db_query( + Messages, + query_type="count", + filters={"chat_id": chat_stream.stream_id} + ) + """ + try: + # 构建基本查询 + if query_type in ["get", "update", "delete", "count"]: + query = model_class.select() + + # 应用过滤条件 + if filters: + for field, value in filters.items(): + query = query.where(getattr(model_class, field) == value) + + # 执行查询 + if query_type == "get": + # 应用排序 + if order_by: + for field in order_by: + if field.startswith("-"): + query = query.order_by(getattr(model_class, field[1:]).desc()) + else: + query = query.order_by(getattr(model_class, field)) + + # 应用限制 + if limit: + query = query.limit(limit) + + # 执行查询 + results = list(query.dicts()) + + # 返回结果 + if single_result: + return results[0] if results else None + return results + + elif query_type == "create": + if not data: + raise ValueError("创建记录需要提供data参数") + + # 创建记录 + record = model_class.create(**data) + # 返回创建的记录 + return model_class.select().where(model_class.id == record.id).dicts().get() + + elif query_type == "update": + if not data: + raise ValueError("更新记录需要提供data参数") + + # 更新记录 + return query.update(**data).execute() + + elif query_type == "delete": + # 删除记录 + return query.delete().execute() + + elif query_type == "count": + # 计数 + return query.count() + + else: + raise ValueError(f"不支持的查询类型: {query_type}") + + except DoesNotExist: + # 记录不存在 + if query_type == "get" and single_result: + return None + return [] + + except Exception as e: + logger.error(f"{self.log_prefix} 数据库操作出错: {e}") + traceback.print_exc() + + # 根据查询类型返回合适的默认值 + if query_type == "get": + return None if single_result else [] + elif query_type in ["create", "update", "delete", "count"]: + return None + + async def db_raw_query( + self, + sql: str, + params: List[Any] = None, + fetch_results: bool = True + ) -> Union[List[Dict[str, Any]], int, None]: + """执行原始SQL查询 + + 警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。 + + Args: + sql: 原始SQL查询字符串 + params: 查询参数列表,用于替换SQL中的占位符 + fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于 + UPDATE/INSERT/DELETE等操作设为False + + Returns: + 如果fetch_results为True,返回查询结果列表; + 如果fetch_results为False,返回受影响的行数; + 如果出错,返回None + """ + try: + cursor = db.execute_sql(sql, params or []) + + if fetch_results: + # 获取列名 + columns = [col[0] for col in cursor.description] + + # 构建结果字典列表 + results = [] + for row in cursor.fetchall(): + results.append(dict(zip(columns, row))) + + return results + else: + # 返回受影响的行数 + return cursor.rowcount + + except Exception as e: + logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}") + traceback.print_exc() + return None + + async def db_save( + self, + model_class: Type[Model], + data: Dict[str, Any], + key_field: str = None, + key_value: Any = None + ) -> Union[Dict[str, Any], None]: + """保存数据到数据库(创建或更新) + + 如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; + 如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。 + + Args: + model_class: Peewee模型类,如ActionRecords, Messages等 + data: 要保存的数据字典 + key_field: 用于查找现有记录的字段名,例如"action_id" + key_value: 用于查找现有记录的字段值 + + Returns: + Dict[str, Any]: 保存后的记录数据 + None: 如果操作失败 + + 示例: + # 创建或更新一条记录 + record = await self.db_save( + ActionRecords, + { + "action_id": "123", + "time": time.time(), + "action_name": "TestAction", + "action_done": True + }, + key_field="action_id", + key_value="123" + ) + """ + try: + # 如果提供了key_field和key_value,尝试更新现有记录 + if key_field and key_value is not None: + # 查找现有记录 + existing_records = list(model_class.select().where( + getattr(model_class, key_field) == key_value + ).limit(1)) + + if existing_records: + # 更新现有记录 + existing_record = existing_records[0] + for field, value in data.items(): + setattr(existing_record, field, value) + existing_record.save() + + # 返回更新后的记录 + updated_record = model_class.select().where( + model_class.id == existing_record.id + ).dicts().get() + return updated_record + + # 如果没有找到现有记录或未提供key_field和key_value,创建新记录 + new_record = model_class.create(**data) + + # 返回创建的记录 + created_record = model_class.select().where( + model_class.id == new_record.id + ).dicts().get() + return created_record + + except Exception as e: + logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}") + traceback.print_exc() + return None + + async def db_get( + self, + model_class: Type[Model], + filters: Dict[str, Any] = None, + order_by: str = None, + limit: int = None + ) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: + """从数据库获取记录 + + 这是db_query方法的简化版本,专注于数据检索操作。 + + Args: + model_class: Peewee模型类 + filters: 过滤条件,字段名和值的字典 + order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序 + limit: 结果数量限制,如果为1则返回单个记录而不是列表 + + Returns: + 如果limit=1,返回单个记录字典或None; + 否则返回记录字典列表或空列表。 + + 示例: + # 获取单个记录 + record = await self.db_get( + ActionRecords, + filters={"action_id": "123"}, + limit=1 + ) + + # 获取最近10条记录 + records = await self.db_get( + Messages, + filters={"chat_id": chat_stream.stream_id}, + order_by="-time", + limit=10 + ) + """ + try: + # 构建查询 + query = model_class.select() + + # 应用过滤条件 + if filters: + for field, value in filters.items(): + query = query.where(getattr(model_class, field) == value) + + # 应用排序 + if order_by: + if order_by.startswith("-"): + query = query.order_by(getattr(model_class, order_by[1:]).desc()) + else: + query = query.order_by(getattr(model_class, order_by)) + + # 应用限制 + if limit: + query = query.limit(limit) + + # 执行查询 + results = list(query.dicts()) + + # 返回结果 + if limit == 1: + return results[0] if results else None + return results + + except Exception as e: + logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}") + traceback.print_exc() + return None if limit == 1 else [] diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index 1045902a5..dafbca42d 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -32,6 +32,7 @@ class ReplyAction(BaseAction): action_require: list[str] = [ "你想要闲聊或者随便附和", "有人提到你", + "如果你刚刚回复,不要对同一个话题重复回应" ] associated_types: list[str] = ["text", "emoji"] diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 8d6e95730..a3958b95e 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -241,7 +241,8 @@ class RelationshipManager: readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}") prompt = f""" -你的名字是{global_config.bot.nickname},别名是{alias_str}。 +你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。 +请不要混淆你自己和{global_config.bot.nickname}和{person_name}。 请你基于用户 {person_name}(昵称:{nickname}) 的最近发言,总结出其中是否有有关{person_name}的内容引起了你的兴趣,或者有什么需要你记忆的点。 如果没有,就输出none @@ -432,8 +433,10 @@ class RelationshipManager: impression = await person_info_manager.get_value(person_id, "impression") or "" compress_prompt = f""" -你的名字是{global_config.bot.nickname},别名是{alias_str}。 -请根据以下历史记录,添加,修改,整合,原有的印象和关系,总结出对{person_name}(昵称:{nickname})的信息。 +你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。 +请不要混淆你自己和{global_config.bot.nickname}和{person_name}。 + +请根据以下历史记录,添加,修改,整合,原有的印象和关系,总结出对用户 {person_name}(昵称:{nickname})的信息。 你之前对他的印象和关系是: 印象impression:{impression}