diff --git a/scripts/analyze_expressions.py b/scripts/analyze_expressions.py index 79490e6e5..87d91fa3b 100644 --- a/scripts/analyze_expressions.py +++ b/scripts/analyze_expressions.py @@ -198,4 +198,4 @@ def analyze_expressions(): print(f"各群组详细报告位于: {output_dir}") if __name__ == "__main__": - analyze_expressions() \ No newline at end of file + analyze_expressions() diff --git a/src/chat/focus_chat/expressors/exprssion_learner.py b/src/chat/focus_chat/expressors/exprssion_learner.py index f4107a459..8332fac09 100644 --- a/src/chat/focus_chat/expressors/exprssion_learner.py +++ b/src/chat/focus_chat/expressors/exprssion_learner.py @@ -29,13 +29,13 @@ def init_prompt() -> None: 4. 思考有没有特殊的梗,一并总结成语言风格 5. 例子仅供参考,请严格根据群聊内容总结!!! 注意:总结成如下格式的规律,总结的内容要详细,但具有概括性: -当"xxx"时,可以"xxx", xxx不超过10个字 +当"xxxxxx"时,可以"xxxxxx", xxxxxx不超过20个字 例如: -当"表示十分惊叹,有些意外"时,使用"我嘞个xxxx" +当"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx" 当"表示讽刺的赞同,不想讲道理"时,使用"对对对" -当"想说明某个观点,但懒得明说,或者不便明说",使用"懂的都懂" -当"表示意外的夸赞,略带戏谑意味"时,使用"这么强!" +当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂" +当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!" 注意不要总结你自己(SELF)的发言 现在请你概括 diff --git a/src/chat/focus_chat/info_processors/relationship_processor.py b/src/chat/focus_chat/info_processors/relationship_processor.py index c96e5a645..e6eaf32b6 100644 --- a/src/chat/focus_chat/info_processors/relationship_processor.py +++ b/src/chat/focus_chat/info_processors/relationship_processor.py @@ -13,27 +13,39 @@ from typing import List, Optional from typing import Dict from src.chat.focus_chat.info.info_base import InfoBase from src.chat.focus_chat.info.relation_info import RelationInfo +from json_repair import repair_json +from src.person_info.person_info import person_info_manager +import json logger = get_logger("processor") def init_prompt(): relationship_prompt = """ -{name_block} - -你和别人的关系信息是,请从这些信息中提取出你和别人的关系的原文: -{relation_prompt} -请只从上面这些信息中提取出内容。 - +<聊天记录> {chat_observe_info} + -现在请你根据现有的信息,总结你和群里的人的关系 -1. 根据聊天记录的需要,精简你和其他人的关系并输出 -2. 根据聊天记录,如果需要提及你和某个人的关系,请输出你和这个人之间的关系 -3. 如果没有特别需要提及的关系,就不用输出这个人的关系 +<人物信息> +{relation_prompt} + -输出内容平淡一些,说中文。 -请注意不要输出多余内容(包括前后缀,括号(),表情包,at或 @等 )。只输出关系内容,记得明确说明这是你的关系。 +请区分聊天记录的内容和你之前对人的了解,聊天记录是现在发生的事情,人物信息是之前对某个人的持久的了解。 + +{name_block} +现在请你总结提取某人的信息,提取成一串文本 +1. 根据聊天记录的需求,如果需要你和某个人的信息,请输出你和这个人之间精简的信息 +2. 如果没有特别需要提及的信息,就不用输出这个人的信息 +3. 如果有人问你对他的看法或者关系,请输出你和这个人之间的信息 + +请从这些信息中提取出你对某人的了解信息,信息提取成一串文本: + +请严格按照以下输出格式,不要输出多余内容,person_name可以有多个: +{{ + "person_name": "信息", + "person_name2": "信息", + "person_name3": "信息", +}} """ Prompt(relationship_prompt, "relationship_prompt") @@ -122,8 +134,10 @@ class RelationshipProcessor(BaseProcessor): relation_prompt_init = "你对对方的印象是:\n" relation_prompt = "" + person_name_list = [] for person in person_list: - relation_prompt += f"{await relationship_manager.build_relationship_info(person, is_id=True)}\n" + relation_prompt += f"{await relationship_manager.build_relationship_info(person, is_id=True)}\n\n" + person_name_list.append(await person_info_manager.get_value(person, "person_name")) if relation_prompt: relation_prompt = relation_prompt_init + relation_prompt @@ -141,22 +155,41 @@ class RelationshipProcessor(BaseProcessor): content = "" try: + logger.info(f"{self.log_prefix} 关系识别prompt: \n{prompt}\n") content, _ = await self.llm_model.generate_response_async(prompt=prompt) if not content: logger.warning(f"{self.log_prefix} LLM返回空结果,关系识别失败。") + + print(f"content: {content}") + + content = repair_json(content) + content = json.loads(content) + + person_info_str = "" + + for person_name, person_info in content.items(): + # print(f"person_name: {person_name}, person_info: {person_info}") + # print(f"person_list: {person_name_list}") + if person_name not in person_name_list: + continue + person_str = f"你对 {person_name} 的了解:{person_info}\n" + person_info_str += person_str + + except Exception as e: # 处理总体异常 logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") logger.error(traceback.format_exc()) - content = "关系识别过程中出现错误" + person_info_str = "关系识别过程中出现错误" - if content == "None": - content = "" + if person_info_str == "None": + person_info_str = "" + # 记录初步思考结果 - logger.info(f"{self.log_prefix} 关系识别prompt: \n{prompt}\n") - logger.info(f"{self.log_prefix} 关系识别: {content}") + + logger.info(f"{self.log_prefix} 关系识别: {person_info_str}") - return content + return person_info_str init_prompt() diff --git a/src/chat/focus_chat/planners/planner_simple.py b/src/chat/focus_chat/planners/planner_simple.py index 48bb21d3a..518d21126 100644 --- a/src/chat/focus_chat/planners/planner_simple.py +++ b/src/chat/focus_chat/planners/planner_simple.py @@ -31,8 +31,6 @@ def init_prompt(): {self_info_block} 请记住你的性格,身份和特点。 -{relation_info_block} - {extra_info_block} {memory_str} @@ -42,6 +40,8 @@ def init_prompt(): {chat_content_block} +{relation_info_block} + {cycle_info_block} {moderation_prompt} @@ -181,7 +181,7 @@ class ActionPlanner(BasePlanner): prompt = f"{prompt}" llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt) - logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") + logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}") logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}") logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}") @@ -225,7 +225,10 @@ class ActionPlanner(BasePlanner): extra_info_block = "" action_data["extra_info_block"] = extra_info_block - + + if relation_info: + action_data["relation_info_block"] = relation_info + # 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data if extracted_action not in current_available_actions: diff --git a/src/chat/focus_chat/replyer/default_replyer.py b/src/chat/focus_chat/replyer/default_replyer.py index 07c41070c..e205e1fd4 100644 --- a/src/chat/focus_chat/replyer/default_replyer.py +++ b/src/chat/focus_chat/replyer/default_replyer.py @@ -41,6 +41,8 @@ def init_prompt(): 你现在正在群里聊天,以下是群里正在进行的聊天内容: {chat_info} +{relation_info_block} + 以上是聊天内容,你需要了解聊天记录中的内容 {chat_target} @@ -262,6 +264,7 @@ class DefaultReplyer: target_message = action_data.get("target", "") identity = action_data.get("identity", "") extra_info_block = action_data.get("extra_info_block", "") + relation_info_block = action_data.get("relation_info_block", "") # 3. 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 @@ -270,6 +273,7 @@ class DefaultReplyer: # in_mind_reply=in_mind_reply, identity=identity, extra_info_block=extra_info_block, + relation_info_block=relation_info_block, reason=reason, sender_name=sender_name_for_prompt, # Pass determined name target_message=target_message, @@ -286,8 +290,7 @@ class DefaultReplyer: try: with Timer("LLM生成", {}): # 内部计时器,可选保留 - # TODO: API-Adapter修改标记 - # logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n") + logger.info(f"{self.log_prefix}Prompt:\n{prompt}\n") content, (reasoning_content, model_name) = await self.express_model.generate_response_async(prompt) # logger.info(f"prompt: {prompt}") @@ -331,6 +334,7 @@ class DefaultReplyer: sender_name, # in_mind_reply, extra_info_block, + relation_info_block, identity, target_message, config_expression_style, @@ -428,6 +432,7 @@ class DefaultReplyer: chat_target=chat_target_1, chat_info=chat_talking_prompt, extra_info_block=extra_info_block, + relation_info_block=relation_info_block, time_block=time_block, # bot_name=global_config.bot.nickname, # prompt_personality="", @@ -448,6 +453,7 @@ class DefaultReplyer: chat_target=chat_target_1, chat_info=chat_talking_prompt, extra_info_block=extra_info_block, + relation_info_block=relation_info_block, time_block=time_block, # bot_name=global_config.bot.nickname, # prompt_personality="", diff --git a/src/person_info/fix_points_format.py b/src/person_info/fix_points_format.py new file mode 100644 index 000000000..96134555d --- /dev/null +++ b/src/person_info/fix_points_format.py @@ -0,0 +1,70 @@ +import os +import sys +# 添加项目根目录到Python路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(os.path.dirname(current_dir)) +sys.path.append(project_root) + +from loguru import logger +import json +from src.common.database.database_model import PersonInfo + +def fix_points_format(): + """修复数据库中的points和forgotten_points格式""" + fixed_count = 0 + error_count = 0 + + try: + # 获取所有用户 + all_persons = PersonInfo.select() + + for person in all_persons: + try: + # 修复points + if person.points: + try: + # 尝试解析JSON + points_data = json.loads(person.points) + except json.JSONDecodeError: + logger.error(f"无法解析points数据: {person.points}") + points_data = [] + + # 确保数据是列表格式 + if not isinstance(points_data, list): + points_data = [] + + # 直接更新数据库 + person.points = json.dumps(points_data, ensure_ascii=False) + person.save() + fixed_count += 1 + + # 修复forgotten_points + if person.forgotten_points: + try: + # 尝试解析JSON + forgotten_data = json.loads(person.forgotten_points) + except json.JSONDecodeError: + logger.error(f"无法解析forgotten_points数据: {person.forgotten_points}") + forgotten_data = [] + + # 确保数据是列表格式 + if not isinstance(forgotten_data, list): + forgotten_data = [] + + # 直接更新数据库 + person.forgotten_points = json.dumps(forgotten_data, ensure_ascii=False) + person.save() + fixed_count += 1 + + except Exception as e: + logger.error(f"处理用户 {person.person_id} 时出错: {str(e)}") + error_count += 1 + continue + + logger.info(f"修复完成!成功修复 {fixed_count} 条记录,失败 {error_count} 条记录") + + except Exception as e: + logger.error(f"数据库操作出错: {str(e)}") + +if __name__ == "__main__": + fix_points_format() \ No newline at end of file diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 89ed7470e..8f5b6e2f7 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -8,7 +8,6 @@ import datetime import asyncio from src.llm_models.utils_model import LLMRequest from src.config.config import global_config -from src.individuality.individuality import individuality import json # 新增导入 from json_repair import repair_json @@ -29,7 +28,7 @@ PersonInfoManager 类方法功能摘要: logger = get_logger("person_info") -JSON_SERIALIZED_FIELDS = ["hobby", "hates", "meme", "relationship_others", "interaction"] +JSON_SERIALIZED_FIELDS = ["points", "forgotten_points"] person_info_default = { "person_id": None, @@ -169,9 +168,9 @@ class PersonInfoManager: processed_value = value if field_name in JSON_SERIALIZED_FIELDS: if isinstance(value, (list, dict)): - processed_value = json.dumps(value, ensure_ascii=False) + processed_value = json.dumps(value, ensure_ascii=False, indent=None) elif value is None: # Store None as "[]" for JSON list fields - processed_value = json.dumps([], ensure_ascii=False) + processed_value = json.dumps([], ensure_ascii=False, indent=None) # If value is already a string, assume it's pre-serialized or a non-JSON string. def _db_update_sync(p_id: str, f_name: str, val_to_set): @@ -280,7 +279,7 @@ class PersonInfoManager: while current_try < max_retries: # prompt_personality = individuality.get_prompt(x_person=2, level=1) - bot_name = individuality.personality.bot_nickname + bot_name = global_config.bot.nickname qv_name_prompt = f"你是{bot_name},一个伪装成人类的AI,你不能让别人发现这一点," qv_name_prompt += f"现在你想给一个用户取一个昵称,用户的qq昵称是{user_nickname}," diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index fce928231..54391d15b 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -134,7 +134,7 @@ class RelationshipManager: except (SyntaxError, ValueError): points = [] - random_points = random.sample(points, min(3, len(points))) if points else [] + random_points = random.sample(points, min(5, len(points))) if points else [] nickname_str = await person_info_manager.get_value(person_id, "nickname") platform = await person_info_manager.get_value(person_id, "platform") @@ -312,13 +312,14 @@ class RelationshipManager: current_points = await person_info_manager.get_value(person_id, "points") or [] if isinstance(current_points, str): try: - current_points = ast.literal_eval(current_points) - except (SyntaxError, ValueError): + current_points = json.loads(current_points) + except json.JSONDecodeError: + logger.error(f"解析points JSON失败: {current_points}") current_points = [] elif not isinstance(current_points, list): current_points = [] current_points.extend(points_list) - await person_info_manager.update_one_field(person_id, "points", str(current_points).replace("(", "[").replace(")", "]")) + await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)) # 将新记录添加到现有记录中 if isinstance(current_points, list): @@ -365,8 +366,9 @@ class RelationshipManager: forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or [] if isinstance(forgotten_points, str): try: - forgotten_points = ast.literal_eval(forgotten_points) - except (SyntaxError, ValueError): + forgotten_points = json.loads(forgotten_points) + except json.JSONDecodeError: + logger.error(f"解析forgotten_points JSON失败: {forgotten_points}") forgotten_points = [] elif not isinstance(forgotten_points, list): forgotten_points = [] @@ -487,10 +489,12 @@ class RelationshipManager: return # 更新数据库 - await person_info_manager.update_one_field(person_id, "forgotten_points", str(forgotten_points).replace("(", "[").replace(")", "]")) + await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)) # 更新数据库 - await person_info_manager.update_one_field(person_id, "points", str(current_points).replace("(", "[").replace(")", "]")) + await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)) + know_times = await person_info_manager.get_value(person_id, "know_times") or 0 + await person_info_manager.update_one_field(person_id, "know_times", know_times + 1) await person_info_manager.update_one_field(person_id, "last_know", timestamp) diff --git a/tests/test_relationship_processor.py b/tests/test_relationship_processor.py new file mode 100644 index 000000000..b87d7832e --- /dev/null +++ b/tests/test_relationship_processor.py @@ -0,0 +1,609 @@ +import os +import sys +import asyncio +import random +import time +import traceback +from typing import List, Dict, Any, Tuple, Optional +from datetime import datetime + +# 添加项目根目录到Python路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +sys.path.append(project_root) + +from src.common.message_repository import find_messages +from src.common.database.database_model import ActionRecords, ChatStreams +from src.config.config import global_config +from src.person_info.person_info import person_info_manager +from src.chat.utils.utils import translate_timestamp_to_human_readable +from src.chat.heart_flow.observation.observation import Observation +from src.llm_models.utils_model import LLMRequest +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.person_info.relationship_manager import relationship_manager +from src.common.logger_manager import get_logger +from src.chat.focus_chat.info.info_base import InfoBase +from src.chat.focus_chat.info.relation_info import RelationInfo + +logger = get_logger("processor") + +async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: + """ + 从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。 + + Args: + messages: 消息字典列表。 + + Returns: + 一个包含唯一 person_id 的列表。 + """ + person_ids_set = set() # 使用集合来自动去重 + + for msg in messages: + platform = msg.get("user_platform") + user_id = msg.get("user_id") + + # 检查必要信息是否存在 且 不是机器人自己 + if not all([platform, user_id]) or user_id == global_config.bot.qq_account: + continue + + person_id = person_info_manager.get_person_id(platform, user_id) + + # 只有当获取到有效 person_id 时才添加 + if person_id: + person_ids_set.add(person_id) + + return list(person_ids_set) # 将集合转换为列表返回 + +class ChattingObservation(Observation): + def __init__(self, chat_id): + super().__init__(chat_id) + self.chat_id = chat_id + self.platform = "qq" + + # 从数据库获取聊天类型和目标信息 + chat_info = ChatStreams.select().where(ChatStreams.stream_id == chat_id).first() + self.is_group_chat = True + self.chat_target_info = { + "person_name": chat_info.group_name if chat_info else None, + "user_nickname": chat_info.group_name if chat_info else None + } + + # 初始化其他属性 + self.talking_message = [] + self.talking_message_str = "" + self.talking_message_str_truncate = "" + self.name = global_config.bot.nickname + self.nick_name = global_config.bot.alias_names + self.max_now_obs_len = global_config.focus_chat.observation_context_size + self.overlap_len = global_config.focus_chat.compressed_length + self.mid_memories = [] + self.max_mid_memory_len = global_config.focus_chat.compress_length_limit + self.mid_memory_info = "" + self.person_list = [] + self.oldest_messages = [] + self.oldest_messages_str = "" + self.compressor_prompt = "" + self.last_observe_time = 0 + + def get_observe_info(self, ids=None): + """获取观察信息""" + return self.talking_message_str + +def init_prompt(): + relationship_prompt = """ +<聊天记录> +{chat_observe_info} + + +<人物信息> +{relation_prompt} + + +请区分聊天记录的内容和你之前对人的了解,聊天记录是现在发生的事情,人物信息是之前对某个人的持久的了解。 + +{name_block} +现在请你总结提取某人的信息,提取成一串文本 +1. 根据聊天记录的需求,如果需要你和某个人的信息,请输出你和这个人之间精简的信息 +2. 如果没有特别需要提及的信息,就不用输出这个人的信息 +3. 如果有人问你对他的看法或者关系,请输出你和这个人之间的信息 + +请从这些信息中提取出你对某人的了解信息,信息提取成一串文本: + +请严格按照以下输出格式,不要输出多余内容,person_name可以有多个: +{{ + "person_name": "信息", + "person_name2": "信息", + "person_name3": "信息", +}} + +""" + Prompt(relationship_prompt, "relationship_prompt") + +class RelationshipProcessor: + log_prefix = "关系" + + def __init__(self, subheartflow_id: str): + self.subheartflow_id = subheartflow_id + + self.llm_model = LLMRequest( + model=global_config.model.relation, + max_tokens=800, + request_type="relation", + ) + + # 直接从数据库获取名称 + chat_info = ChatStreams.select().where(ChatStreams.stream_id == subheartflow_id).first() + name = chat_info.group_name if chat_info else "未知" + self.log_prefix = f"[{name}] " + + async def process_info( + self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos + ) -> List[InfoBase]: + """处理信息对象 + + Args: + *infos: 可变数量的InfoBase类型的信息对象 + + Returns: + List[InfoBase]: 处理后的结构化信息列表 + """ + relation_info_str = await self.relation_identify(observations) + + if relation_info_str: + relation_info = RelationInfo() + relation_info.set_relation_info(relation_info_str) + else: + relation_info = None + return None + + return [relation_info] + + async def relation_identify( + self, observations: Optional[List[Observation]] = None, + ): + """ + 在回复前进行思考,生成内心想法并收集工具调用结果 + + 参数: + observations: 观察信息 + + 返回: + 如果return_prompt为False: + tuple: (current_mind, past_mind) 当前想法和过去的想法列表 + 如果return_prompt为True: + tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt + """ + + if observations is None: + observations = [] + for observation in observations: + if isinstance(observation, ChattingObservation): + # 获取聊天元信息 + is_group_chat = observation.is_group_chat + chat_target_info = observation.chat_target_info + chat_target_name = "对方" # 私聊默认名称 + if not is_group_chat and chat_target_info: + # 优先使用person_name,其次user_nickname,最后回退到默认值 + chat_target_name = ( + chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or chat_target_name + ) + # 获取聊天内容 + chat_observe_info = observation.get_observe_info() + person_list = observation.person_list + + nickname_str = "" + for nicknames in global_config.bot.alias_names: + nickname_str += f"{nicknames}," + name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。" + + if is_group_chat: + relation_prompt_init = "你对群聊里的人的印象是:\n" + else: + relation_prompt_init = "你对对方的印象是:\n" + + relation_prompt = "" + for person in person_list: + relation_prompt += f"{await relationship_manager.build_relationship_info(person, is_id=True)}\n" + + if relation_prompt: + relation_prompt = relation_prompt_init + relation_prompt + else: + relation_prompt = relation_prompt_init + "没有特别在意的人\n" + + prompt = (await global_prompt_manager.get_prompt_async("relationship_prompt")).format( + name_block=name_block, + relation_prompt=relation_prompt, + time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + chat_observe_info=chat_observe_info, + ) + # The above code is a Python script that is attempting to print the variable `prompt`. + # However, the code is not complete as the content of the `prompt` variable is missing. + # print(prompt) + + content = "" + try: + content, _ = await self.llm_model.generate_response_async(prompt=prompt) + if not content: + logger.warning(f"{self.log_prefix} LLM返回空结果,关系识别失败。") + except Exception as e: + # 处理总体异常 + logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}") + logger.error(traceback.format_exc()) + content = "关系识别过程中出现错误" + + if content == "None": + content = "" + # 记录初步思考结果 + logger.info(f"{self.log_prefix} 关系识别prompt: \n{prompt}\n") + logger.info(f"{self.log_prefix} 关系识别: {content}") + + return content + +init_prompt() + +# ==== 只复制最小依赖的relationship_manager ==== +class SimpleRelationshipManager: + async def build_relationship_info(self, person, is_id: bool = False) -> str: + if is_id: + person_id = person + else: + person_id = person_info_manager.get_person_id(person[0], person[1]) + + person_name = await person_info_manager.get_value(person_id, "person_name") + if not person_name or person_name == "none": + return "" + impression = await person_info_manager.get_value(person_id, "impression") + interaction = await person_info_manager.get_value(person_id, "interaction") + points = await person_info_manager.get_value(person_id, "points") or [] + + if isinstance(points, str): + try: + import ast + points = ast.literal_eval(points) + except (SyntaxError, ValueError): + points = [] + + import random + random_points = random.sample(points, min(3, len(points))) if points else [] + + nickname_str = await person_info_manager.get_value(person_id, "nickname") + platform = await person_info_manager.get_value(person_id, "platform") + relation_prompt = f"'{person_name}' ,ta在{platform}上的昵称是{nickname_str}。" + + if impression: + relation_prompt += f"你对ta的印象是:{impression}。" + if interaction: + relation_prompt += f"你与ta的关系是:{interaction}。" + if random_points: + for point in random_points: + point_str = f"时间:{point[2]}。内容:{point[0]}" + relation_prompt += f"你记得{person_name}最近的点是:{point_str}。" + return relation_prompt + +# 用于替换原有的relationship_manager +relationship_manager = SimpleRelationshipManager() + +def get_raw_msg_by_timestamp_random( + timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" +) -> List[Dict[str, Any]]: + """先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息""" + # 获取所有消息,只取chat_id字段 + filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}} + all_msgs = find_messages(message_filter=filter_query) + if not all_msgs: + return [] + # 随机选一条 + msg = random.choice(all_msgs) + chat_id = msg["chat_id"] + timestamp_start = msg["time"] + # 用 chat_id 获取该聊天在指定时间戳范围内的消息 + filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": timestamp_end}} + sort_order = [("time", 1)] if limit == 0 else None + return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode="earliest") + +def _build_readable_messages_internal( + messages: List[Dict[str, Any]], + replace_bot_name: bool = True, + merge_messages: bool = False, + timestamp_mode: str = "relative", + truncate: bool = False, +) -> Tuple[str, List[Tuple[float, str, str]]]: + """内部辅助函数,构建可读消息字符串和原始消息详情列表""" + if not messages: + return "", [] + + message_details_raw: List[Tuple[float, str, str]] = [] + + # 1 & 2: 获取发送者信息并提取消息组件 + for msg in messages: + # 检查是否是动作记录 + if msg.get("is_action_record", False): + is_action = True + timestamp = msg.get("time") + content = msg.get("display_message", "") + message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action)) + continue + + # 检查并修复缺少的user_info字段 + if "user_info" not in msg: + msg["user_info"] = { + "platform": msg.get("user_platform", ""), + "user_id": msg.get("user_id", ""), + "user_nickname": msg.get("user_nickname", ""), + "user_cardname": msg.get("user_cardname", ""), + } + + user_info = msg.get("user_info", {}) + platform = user_info.get("platform") + user_id = user_info.get("user_id") + user_nickname = user_info.get("user_nickname") + user_cardname = user_info.get("user_cardname") + timestamp = msg.get("time") + + if msg.get("display_message"): + content = msg.get("display_message") + else: + content = msg.get("processed_plain_text", "") + + if "ᶠ" in content: + content = content.replace("ᶠ", "") + if "ⁿ" in content: + content = content.replace("ⁿ", "") + + if not all([platform, user_id, timestamp is not None]): + continue + + person_id = person_info_manager.get_person_id(platform, user_id) + if replace_bot_name and user_id == global_config.bot.qq_account: + person_name = f"{global_config.bot.nickname}(你)" + else: + person_name = person_info_manager.get_value_sync(person_id, "person_name") + + if not person_name: + if user_cardname: + person_name = f"昵称:{user_cardname}" + elif user_nickname: + person_name = f"{user_nickname}" + else: + person_name = "某人" + + if content != "": + message_details_raw.append((timestamp, person_name, content, False)) + + if not message_details_raw: + return "", [] + + message_details_raw.sort(key=lambda x: x[0]) + + # 为每条消息添加一个标记,指示它是否是动作记录 + message_details_with_flags = [] + for timestamp, name, content, is_action in message_details_raw: + message_details_with_flags.append((timestamp, name, content, is_action)) + + # 应用截断逻辑 + message_details: List[Tuple[float, str, str, bool]] = [] + n_messages = len(message_details_with_flags) + if truncate and n_messages > 0: + for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags): + if is_action: + message_details.append((timestamp, name, content, is_action)) + continue + + percentile = i / n_messages + original_len = len(content) + limit = -1 + + if percentile < 0.2: + limit = 50 + replace_content = "......(记不清了)" + elif percentile < 0.5: + limit = 100 + replace_content = "......(有点记不清了)" + elif percentile < 0.7: + limit = 200 + replace_content = "......(内容太长了)" + elif percentile < 1.0: + limit = 300 + replace_content = "......(太长了)" + + truncated_content = content + if 0 < limit < original_len: + truncated_content = f"{content[:limit]}{replace_content}" + + message_details.append((timestamp, name, truncated_content, is_action)) + else: + message_details = message_details_with_flags + + # 合并连续消息 + merged_messages = [] + if merge_messages and message_details: + current_merge = { + "name": message_details[0][1], + "start_time": message_details[0][0], + "end_time": message_details[0][0], + "content": [message_details[0][2]], + "is_action": message_details[0][3] + } + + for i in range(1, len(message_details)): + timestamp, name, content, is_action = message_details[i] + + if is_action or current_merge["is_action"]: + merged_messages.append(current_merge) + current_merge = { + "name": name, + "start_time": timestamp, + "end_time": timestamp, + "content": [content], + "is_action": is_action + } + continue + + if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60): + current_merge["content"].append(content) + current_merge["end_time"] = timestamp + else: + merged_messages.append(current_merge) + current_merge = { + "name": name, + "start_time": timestamp, + "end_time": timestamp, + "content": [content], + "is_action": is_action + } + merged_messages.append(current_merge) + elif message_details: + for timestamp, name, content, is_action in message_details: + merged_messages.append( + { + "name": name, + "start_time": timestamp, + "end_time": timestamp, + "content": [content], + "is_action": is_action + } + ) + + # 格式化为字符串 + output_lines = [] + for merged in merged_messages: + readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode) + + if merged["is_action"]: + output_lines.append(f"{readable_time}, {merged['content'][0]}") + else: + header = f"{readable_time}, {merged['name']} :" + output_lines.append(header) + for line in merged["content"]: + stripped_line = line.strip() + if stripped_line: + if stripped_line.endswith("。"): + stripped_line = stripped_line[:-1] + if not stripped_line.endswith("(内容太长)"): + output_lines.append(f"{stripped_line}") + else: + output_lines.append(stripped_line) + output_lines.append("\n") + + formatted_string = "".join(output_lines).strip() + return formatted_string, [(t, n, c) for t, n, c, is_action in message_details if not is_action] + +def build_readable_messages( + messages: List[Dict[str, Any]], + replace_bot_name: bool = True, + merge_messages: bool = False, + timestamp_mode: str = "relative", + read_mark: float = 0.0, + truncate: bool = False, + show_actions: bool = False, +) -> str: + """将消息列表转换为可读的文本格式""" + copy_messages = [msg.copy() for msg in messages] + + if show_actions and copy_messages: + min_time = min(msg.get("time", 0) for msg in copy_messages) + max_time = max(msg.get("time", 0) for msg in copy_messages) + chat_id = copy_messages[0].get("chat_id") if copy_messages else None + + actions = ActionRecords.select().where( + (ActionRecords.time >= min_time) & + (ActionRecords.time <= max_time) & + (ActionRecords.chat_id == chat_id) + ).order_by(ActionRecords.time) + + for action in actions: + if action.action_build_into_prompt: + action_msg = { + "time": action.time, + "user_id": global_config.bot.qq_account, + "user_nickname": global_config.bot.nickname, + "user_cardname": "", + "processed_plain_text": f"{action.action_prompt_display}", + "display_message": f"{action.action_prompt_display}", + "chat_info_platform": action.chat_info_platform, + "is_action_record": True, + "action_name": action.action_name, + } + copy_messages.append(action_msg) + + copy_messages.sort(key=lambda x: x.get("time", 0)) + + if read_mark <= 0: + formatted_string, _ = _build_readable_messages_internal( + copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate + ) + return formatted_string + else: + messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark] + messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark] + + formatted_before, _ = _build_readable_messages_internal( + messages_before_mark, replace_bot_name, merge_messages, timestamp_mode, truncate + ) + formatted_after, _ = _build_readable_messages_internal( + messages_after_mark, + replace_bot_name, + merge_messages, + timestamp_mode, + ) + + read_mark_line = "\n--- 以上消息是你已经看过---\n--- 请关注以下未读的新消息---\n" + + if formatted_before and formatted_after: + return f"{formatted_before}{read_mark_line}{formatted_after}" + elif formatted_before: + return f"{formatted_before}{read_mark_line}" + elif formatted_after: + return f"{read_mark_line}{formatted_after}" + else: + return read_mark_line.strip() + +async def test_relationship_processor(): + """测试关系处理器的功能""" + + # 测试10次 + for i in range(10): + print(f"\n=== 测试 {i+1} ===") + + # 获取随机消息 + current_time = time.time() + start_time = current_time - 864000 # 10天前 + messages = get_raw_msg_by_timestamp_random(start_time, current_time, limit=25) + + if not messages: + print("没有找到消息,跳过此次测试") + continue + + chat_id = messages[0]["chat_id"] + + # 构建可读消息 + chat_observe_info = build_readable_messages( + messages, + replace_bot_name=True, + timestamp_mode="normal_no_YMD", + truncate=True, + show_actions=True, + ) + # print(chat_observe_info) + # 创建观察对象 + processor = RelationshipProcessor(chat_id) + observation = ChattingObservation(chat_id) + observation.talking_message_str = chat_observe_info + observation.talking_message = messages # 设置消息列表 + observation.person_list = await get_person_id_list(messages) # 使用get_person_id_list获取person_list + + # 处理关系 + result = await processor.process_info([observation]) + + if result: + print("\n关系识别结果:") + print(result[0].get_processed_info()) + else: + print("关系识别失败") + + # 等待一下,避免请求过快 + await asyncio.sleep(1) + +if __name__ == "__main__": + asyncio.run(test_relationship_processor()) \ No newline at end of file diff --git a/tests/test_sentence_similarity.py b/tests/test_sentence_similarity.py deleted file mode 100644 index c5e46751f..000000000 --- a/tests/test_sentence_similarity.py +++ /dev/null @@ -1,156 +0,0 @@ -import time -import unittest -import jieba -from difflib import SequenceMatcher -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity - -def tfidf_similarity(s1, s2): - """ - 使用 TF-IDF 和余弦相似度计算两个句子的相似性。 - """ - # 1. 使用 jieba 进行分词 - s1_words = " ".join(jieba.cut(s1)) - s2_words = " ".join(jieba.cut(s2)) - - # 2. 将两句话放入一个列表中 - corpus = [s1_words, s2_words] - - # 3. 创建 TF-IDF 向量化器并进行计算 - try: - vectorizer = TfidfVectorizer() - tfidf_matrix = vectorizer.fit_transform(corpus) - except ValueError: - # 如果句子完全由停用词组成,或者为空,可能会报错 - return 0.0 - - # 4. 计算余弦相似度 - similarity_matrix = cosine_similarity(tfidf_matrix) - - # 返回 s1 和 s2 的相似度 - return similarity_matrix[0, 1] - -def sequence_similarity(s1, s2): - """ - 使用 SequenceMatcher 计算两个句子的相似性。 - """ - return SequenceMatcher(None, s1, s2).ratio() - -class TestSentenceSimilarity(unittest.TestCase): - def test_similarity_comparison(self): - """比较不同相似度计算方法的结果""" - test_cases = [ - { - "sentence1": "今天天气怎么样", - "sentence2": "今天气候如何", - "expected_similar": True - }, - { - "sentence1": "今天天气怎么样", - "sentence2": "我今天要去吃麦当劳", - "expected_similar": False - }, - { - "sentence1": "我今天要去吃麦当劳", - "sentence2": "肯德基和麦当劳哪家好吃", - "expected_similar": True - }, - { - "sentence1": "Vindemiatrix提到昨天三个无赖杀穿交界地", - "sentence2": "Vindemiatrix昨天用三个无赖角色杀穿了游戏中的交界地", - "expected_similar": True - }, - { - "sentence1": "tc_魔法士解释了之前templateinfo的with用法和现在的单独逻辑发送的区别", - "sentence2": "tc_魔法士解释了templateinfo的用法,包括它是一个字典,key是prompt的名字,value是prompt的内容,格式是只支持大括号的fstring", - "expected_similar": False - }, - { - "sentence1": "YXH_XianYu分享了一张舰娘街机游戏的图片,并提到'玩舰娘街机的董不懂'", - "sentence2": "YXH_XianYu对街机游戏表现出兴趣,并分享了玩舰娘街机的经历", - "expected_similar": True - }, - { - "sentence1": "YXH_XianYu在考虑入坑明日方舟,犹豫是否要从零开荒或使用初始号", - "sentence2": "YXH_XianYu考虑入坑明日方舟,倾向于从零开荒或初始号开荒", - "expected_similar": True - }, - { - "sentence1": "YXH_XianYu提到秋叶原好多人在玩maimai", - "sentence2": "YXH_XianYu对学园偶像的付费石头机制表示惊讶", - "expected_similar": False - } - ] - - print("\n相似度计算方法比较:") - for i, case in enumerate(test_cases, 1): - print(f"\n测试用例 {i}:") - print(f"句子1: {case['sentence1']}") - print(f"句子2: {case['sentence2']}") - - # TF-IDF 相似度 - start_time = time.time() - tfidf_sim = tfidf_similarity(case['sentence1'], case['sentence2']) - tfidf_time = time.time() - start_time - - # SequenceMatcher 相似度 - start_time = time.time() - seq_sim = sequence_similarity(case['sentence1'], case['sentence2']) - seq_time = time.time() - start_time - - print(f"TF-IDF相似度: {tfidf_sim:.4f} (耗时: {tfidf_time:.4f}秒)") - print(f"SequenceMatcher相似度: {seq_sim:.4f} (耗时: {seq_time:.4f}秒)") - - def test_batch_processing(self): - """测试批量处理性能""" - sentences = [ - "人工智能正在改变世界", - "AI技术发展迅速", - "机器学习是人工智能的一个分支", - "深度学习在图像识别领域取得了突破", - "自然语言处理技术越来越成熟" - ] - - print("\n批量处理测试:") - - # TF-IDF 批量处理 - start_time = time.time() - tfidf_matrix = [] - for i in range(len(sentences)): - row = [] - for j in range(len(sentences)): - similarity = tfidf_similarity(sentences[i], sentences[j]) - row.append(similarity) - tfidf_matrix.append(row) - tfidf_time = time.time() - start_time - - # SequenceMatcher 批量处理 - start_time = time.time() - seq_matrix = [] - for i in range(len(sentences)): - row = [] - for j in range(len(sentences)): - similarity = sequence_similarity(sentences[i], sentences[j]) - row.append(similarity) - seq_matrix.append(row) - seq_time = time.time() - start_time - - print(f"TF-IDF批量处理 {len(sentences)} 个句子耗时: {tfidf_time:.4f}秒") - print(f"SequenceMatcher批量处理 {len(sentences)} 个句子耗时: {seq_time:.4f}秒") - - # 打印TF-IDF相似度矩阵 - print("\nTF-IDF相似度矩阵:") - for row in tfidf_matrix: - for similarity in row: - print(f"{similarity:.4f}", end="\t") - print() - - # 打印SequenceMatcher相似度矩阵 - print("\nSequenceMatcher相似度矩阵:") - for row in seq_matrix: - for similarity in row: - print(f"{similarity:.4f}", end="\t") - print() - -if __name__ == '__main__': - unittest.main(verbosity=2) \ No newline at end of file