diff --git a/scripts/text_length_analysis.py b/scripts/text_length_analysis.py new file mode 100644 index 000000000..2ca596e2f --- /dev/null +++ b/scripts/text_length_analysis.py @@ -0,0 +1,394 @@ +import time +import sys +import os +import re +from typing import Dict, List, Tuple, Optional +from datetime import datetime +# Add project root to Python path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) +from src.common.database.database_model import Messages, ChatStreams #noqa + + +def contains_emoji_or_image_tags(text: str) -> bool: + """Check if text contains [表情包xxxxx] or [图片xxxxx] tags""" + if not text: + return False + + # 检查是否包含 [表情包] 或 [图片] 标记 + emoji_pattern = r'\[表情包[^\]]*\]' + image_pattern = r'\[图片[^\]]*\]' + + return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text)) + + +def clean_reply_text(text: str) -> str: + """Remove reply references like [回复 xxxx...] from text""" + if not text: + return text + + # 匹配 [回复 xxxx...] 格式的内容 + # 使用非贪婪匹配,匹配到第一个 ] 就停止 + cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text) + + # 去除多余的空白字符 + cleaned_text = cleaned_text.strip() + + return cleaned_text + + +def get_chat_name(chat_id: str) -> str: + """Get chat name from chat_id by querying ChatStreams table directly""" + try: + chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) + if chat_stream is None: + return f"未知聊天 ({chat_id})" + + if chat_stream.group_name: + return f"{chat_stream.group_name} ({chat_id})" + elif chat_stream.user_nickname: + return f"{chat_stream.user_nickname}的私聊 ({chat_id})" + else: + return f"未知聊天 ({chat_id})" + except Exception: + return f"查询失败 ({chat_id})" + + +def format_timestamp(timestamp: float) -> str: + """Format timestamp to readable date string""" + try: + return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return "未知时间" + + +def calculate_text_length_distribution(messages) -> Dict[str, int]: + """Calculate distribution of processed_plain_text length""" + distribution = { + '0': 0, # 空文本 + '1-5': 0, # 极短文本 + '6-10': 0, # 很短文本 + '11-20': 0, # 短文本 + '21-30': 0, # 较短文本 + '31-50': 0, # 中短文本 + '51-70': 0, # 中等文本 + '71-100': 0, # 较长文本 + '101-150': 0, # 长文本 + '151-200': 0, # 很长文本 + '201-300': 0, # 超长文本 + '301-500': 0, # 极长文本 + '501-1000': 0, # 巨长文本 + '1000+': 0 # 超巨长文本 + } + + for msg in messages: + if msg.processed_plain_text is None: + continue + + # 排除包含表情包或图片标记的消息 + if contains_emoji_or_image_tags(msg.processed_plain_text): + continue + + # 清理文本中的回复引用 + cleaned_text = clean_reply_text(msg.processed_plain_text) + length = len(cleaned_text) + + if length == 0: + distribution['0'] += 1 + elif length <= 5: + distribution['1-5'] += 1 + elif length <= 10: + distribution['6-10'] += 1 + elif length <= 20: + distribution['11-20'] += 1 + elif length <= 30: + distribution['21-30'] += 1 + elif length <= 50: + distribution['31-50'] += 1 + elif length <= 70: + distribution['51-70'] += 1 + elif length <= 100: + distribution['71-100'] += 1 + elif length <= 150: + distribution['101-150'] += 1 + elif length <= 200: + distribution['151-200'] += 1 + elif length <= 300: + distribution['201-300'] += 1 + elif length <= 500: + distribution['301-500'] += 1 + elif length <= 1000: + distribution['501-1000'] += 1 + else: + distribution['1000+'] += 1 + + return distribution + + +def get_text_length_stats(messages) -> Dict[str, float]: + """Calculate basic statistics for processed_plain_text length""" + lengths = [] + null_count = 0 + excluded_count = 0 # 被排除的消息数量 + + for msg in messages: + if msg.processed_plain_text is None: + null_count += 1 + elif contains_emoji_or_image_tags(msg.processed_plain_text): + # 排除包含表情包或图片标记的消息 + excluded_count += 1 + else: + # 清理文本中的回复引用 + cleaned_text = clean_reply_text(msg.processed_plain_text) + lengths.append(len(cleaned_text)) + + if not lengths: + return { + 'count': 0, + 'null_count': null_count, + 'excluded_count': excluded_count, + 'min': 0, + 'max': 0, + 'avg': 0, + 'median': 0 + } + + lengths.sort() + count = len(lengths) + + return { + 'count': count, + 'null_count': null_count, + 'excluded_count': excluded_count, + 'min': min(lengths), + 'max': max(lengths), + 'avg': sum(lengths) / count, + 'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2 + } + + +def get_available_chats() -> List[Tuple[str, str, int]]: + """Get all available chats with message counts""" + try: + # 获取所有有消息的chat_id,排除特殊类型消息 + chat_counts = {} + for msg in Messages.select(Messages.chat_id).distinct(): + chat_id = msg.chat_id + count = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.is_emoji != 1) & + (Messages.is_picid != 1) & + (Messages.is_command != 1) + ).count() + if count > 0: + chat_counts[chat_id] = count + + # 获取聊天名称 + result = [] + for chat_id, count in chat_counts.items(): + chat_name = get_chat_name(chat_id) + result.append((chat_id, chat_name, count)) + + # 按消息数量排序 + result.sort(key=lambda x: x[2], reverse=True) + return result + except Exception as e: + print(f"获取聊天列表失败: {e}") + return [] + + +def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: + """Get time range input from user""" + print("\n时间范围选择:") + print("1. 最近1天") + print("2. 最近3天") + print("3. 最近7天") + print("4. 最近30天") + print("5. 自定义时间范围") + print("6. 不限制时间") + + choice = input("请选择时间范围 (1-6): ").strip() + + now = time.time() + + if choice == "1": + return now - 24*3600, now + elif choice == "2": + return now - 3*24*3600, now + elif choice == "3": + return now - 7*24*3600, now + elif choice == "4": + return now - 30*24*3600, now + elif choice == "5": + print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") + start_str = input().strip() + print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") + end_str = input().strip() + + try: + start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() + end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() + return start_time, end_time + except ValueError: + print("时间格式错误,将不限制时间范围") + return None, None + else: + return None, None + + +def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]: + """Get top N longest messages""" + message_lengths = [] + + for msg in messages: + if msg.processed_plain_text is not None: + # 排除包含表情包或图片标记的消息 + if contains_emoji_or_image_tags(msg.processed_plain_text): + continue + + # 清理文本中的回复引用 + cleaned_text = clean_reply_text(msg.processed_plain_text) + length = len(cleaned_text) + chat_name = get_chat_name(msg.chat_id) + time_str = format_timestamp(msg.time) + # 截取前100个字符作为预览 + preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text + message_lengths.append((chat_name, length, time_str, preview)) + + # 按长度排序,取前N个 + message_lengths.sort(key=lambda x: x[1], reverse=True) + return message_lengths[:top_n] + + +def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: + """Analyze processed_plain_text lengths with optional filters""" + + # 构建查询条件,排除特殊类型的消息 + query = Messages.select().where( + (Messages.is_emoji != 1) & + (Messages.is_picid != 1) & + (Messages.is_command != 1) + ) + + if chat_id: + query = query.where(Messages.chat_id == chat_id) + + if start_time: + query = query.where(Messages.time >= start_time) + + if end_time: + query = query.where(Messages.time <= end_time) + + messages = list(query) + + if not messages: + print("没有找到符合条件的消息") + return + + # 计算统计信息 + distribution = calculate_text_length_distribution(messages) + stats = get_text_length_stats(messages) + top_longest = get_top_longest_messages(messages, 10) + + # 显示结果 + print("\n=== Processed Plain Text 长度分析结果 ===") + print("(已排除表情、图片ID、命令类型消息,已排除[表情包]和[图片]标记消息,已清理回复引用)") + if chat_id: + print(f"聊天: {get_chat_name(chat_id)}") + else: + print("聊天: 全部聊天") + + if start_time and end_time: + print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}") + elif start_time: + print(f"时间范围: {format_timestamp(start_time)} 之后") + elif end_time: + print(f"时间范围: {format_timestamp(end_time)} 之前") + else: + print("时间范围: 不限制") + + print("\n基本统计:") + print(f"总消息数量: {len(messages)}") + print(f"有文本消息数量: {stats['count']}") + print(f"空文本消息数量: {stats['null_count']}") + print(f"被排除的消息数量: {stats['excluded_count']}") + if stats['count'] > 0: + print(f"最短长度: {stats['min']} 字符") + print(f"最长长度: {stats['max']} 字符") + print(f"平均长度: {stats['avg']:.2f} 字符") + print(f"中位数长度: {stats['median']:.2f} 字符") + + print("\n文本长度分布:") + total = stats['count'] + if total > 0: + for range_name, count in distribution.items(): + if count > 0: + percentage = count / total * 100 + print(f"{range_name} 字符: {count} ({percentage:.2f}%)") + + # 显示最长的消息 + if top_longest: + print(f"\n最长的 {len(top_longest)} 条消息:") + for i, (chat_name, length, time_str, preview) in enumerate(top_longest, 1): + print(f"{i}. [{chat_name}] {time_str}") + print(f" 长度: {length} 字符") + print(f" 预览: {preview}") + print() + + +def interactive_menu() -> None: + """Interactive menu for text length analysis""" + + while True: + print("\n" + "="*50) + print("Processed Plain Text 长度分析工具") + print("="*50) + print("1. 分析全部聊天") + print("2. 选择特定聊天分析") + print("q. 退出") + + choice = input("\n请选择分析模式 (1-2, q): ").strip() + + if choice.lower() == 'q': + print("再见!") + break + + chat_id = None + + if choice == "2": + # 显示可用的聊天列表 + chats = get_available_chats() + if not chats: + print("没有找到聊天数据") + continue + + print(f"\n可用的聊天 (共{len(chats)}个):") + for i, (_cid, name, count) in enumerate(chats, 1): + print(f"{i}. {name} ({count}条消息)") + + try: + chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) + if 1 <= chat_choice <= len(chats): + chat_id = chats[chat_choice - 1][0] + else: + print("无效选择") + continue + except ValueError: + print("请输入有效数字") + continue + + elif choice != "1": + print("无效选择") + continue + + # 获取时间范围 + start_time, end_time = get_time_range_input() + + # 执行分析 + analyze_text_lengths(chat_id, start_time, end_time) + + input("\n按回车键继续...") + + +if __name__ == "__main__": + interactive_menu() \ No newline at end of file diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 32a365688..57b52ae69 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -61,11 +61,35 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: logger.debug(f"记忆激活率: {interested_rate:.2f}") text_len = len(message.processed_plain_text) - # 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05 - # 采用对数函数实现递减增长 - - base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1)) - base_interest = min(max(base_interest, 0.01), 0.05) + # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 + # 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%) + + if text_len == 0: + base_interest = 0.01 # 空消息最低兴趣度 + elif text_len <= 5: + # 1-5字符:线性增长 0.01 -> 0.03 + base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4 + elif text_len <= 10: + # 6-10字符:线性增长 0.03 -> 0.06 + base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5 + elif text_len <= 20: + # 11-20字符:线性增长 0.06 -> 0.12 + base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10 + elif text_len <= 30: + # 21-30字符:线性增长 0.12 -> 0.18 + base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10 + elif text_len <= 50: + # 31-50字符:线性增长 0.18 -> 0.22 + base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20 + elif text_len <= 100: + # 51-100字符:线性增长 0.22 -> 0.26 + base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50 + else: + # 100+字符:对数增长 0.26 -> 0.3,增长率递减 + base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901 + + # 确保在范围内 + base_interest = min(max(base_interest, 0.01), 0.3) interested_rate += base_interest diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index ad0384160..c1cf61797 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -299,6 +299,63 @@ class Hippocampus: # 按相似度降序排序 memories.sort(key=lambda x: x[2], reverse=True) return memories + + async def get_keywords_from_text(self, text: str, fast_retrieval: bool = False) -> list: + """从文本中提取关键词。 + + Args: + text (str): 输入文本 + fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 + 如果为True,使用jieba分词提取关键词,速度更快但可能不够准确。 + 如果为False,使用LLM提取关键词,速度较慢但更准确。 + """ + if not text: + return [] + + if fast_retrieval: + # 使用jieba分词提取关键词 + words = jieba.cut(text) + # 过滤掉停用词和单字词 + keywords = [word for word in words if len(word) > 1] + # 去重 + keywords = list(set(keywords)) + # 限制关键词数量 + logger.debug(f"提取关键词: {keywords}") + + else: + # 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算 + text_length = len(text) + if text_length <= 5: + topic_num = 1 # 1-5字符: 1个关键词 (26.57%的文本) + elif text_length <= 10: + topic_num = 1 # 6-10字符: 1个关键词 (27.18%的文本) + elif text_length <= 20: + topic_num = 2 # 11-20字符: 2个关键词 (22.76%的文本) + elif text_length <= 30: + topic_num = 3 # 21-30字符: 3个关键词 (10.33%的文本) + elif text_length <= 50: + topic_num = 4 # 31-50字符: 4个关键词 (9.79%的文本) + else: + topic_num = 5 # 51+字符: 5个关键词 (其余长文本) + + # logger.info(f"提取关键词数量: {topic_num}") + topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async( + self.find_topic_llm(text, topic_num) + ) + + # 提取关键词 + keywords = re.findall(r"<([^>]+)>", topics_response) + if not keywords: + keywords = [] + else: + keywords = [ + keyword.strip() + for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + if keyword.strip() + ] + + return keywords + async def get_memory_from_text( self, @@ -325,39 +382,7 @@ class Hippocampus: - memory_items: list, 该主题下的记忆项列表 - similarity: float, 与文本的相似度 """ - if not text: - return [] - - if fast_retrieval: - # 使用jieba分词提取关键词 - words = jieba.cut(text) - # 过滤掉停用词和单字词 - keywords = [word for word in words if len(word) > 1] - # 去重 - keywords = list(set(keywords)) - # 限制关键词数量 - logger.debug(f"提取关键词: {keywords}") - - else: - # 使用LLM提取关键词 - topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 - # logger.info(f"提取关键词数量: {topic_num}") - topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async( - self.find_topic_llm(text, topic_num) - ) - - # 提取关键词 - keywords = re.findall(r"<([^>]+)>", topics_response) - if not keywords: - keywords = [] - else: - keywords = [ - keyword.strip() - for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - if keyword.strip() - ] - - # logger.info(f"提取的关键词: {', '.join(keywords)}") + keywords = await self.get_keywords_from_text(text, fast_retrieval) # 过滤掉不存在于记忆图中的关键词 valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] @@ -679,38 +704,7 @@ class Hippocampus: Returns: float: 激活节点数与总节点数的比值 """ - if not text: - return 0 - - if fast_retrieval: - # 使用jieba分词提取关键词 - words = jieba.cut(text) - # 过滤掉停用词和单字词 - keywords = [word for word in words if len(word) > 1] - # 去重 - keywords = list(set(keywords)) - # 限制关键词数量 - keywords = keywords[:5] - else: - # 使用LLM提取关键词 - topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 - # logger.info(f"提取关键词数量: {topic_num}") - topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async( - self.find_topic_llm(text, topic_num) - ) - - # 提取关键词 - keywords = re.findall(r"<([^>]+)>", topics_response) - if not keywords: - keywords = [] - else: - keywords = [ - keyword.strip() - for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - if keyword.strip() - ] - - # logger.info(f"提取的关键词: {', '.join(keywords)}") + keywords = await self.get_keywords_from_text(text, fast_retrieval) # 过滤掉不存在于记忆图中的关键词 valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G] diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index cbc7d3fac..c5ad9ca1f 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -47,11 +47,35 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]: logger.debug(f"记忆激活率: {interested_rate:.2f}") text_len = len(message.processed_plain_text) - # 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05 - # 采用对数函数实现递减增长 - - base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1)) - base_interest = min(max(base_interest, 0.01), 0.05) + # 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算 + # 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%) + + if text_len == 0: + base_interest = 0.01 # 空消息最低兴趣度 + elif text_len <= 5: + # 1-5字符:线性增长 0.01 -> 0.03 + base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4 + elif text_len <= 10: + # 6-10字符:线性增长 0.03 -> 0.06 + base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5 + elif text_len <= 20: + # 11-20字符:线性增长 0.06 -> 0.12 + base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10 + elif text_len <= 30: + # 21-30字符:线性增长 0.12 -> 0.18 + base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10 + elif text_len <= 50: + # 31-50字符:线性增长 0.18 -> 0.22 + base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20 + elif text_len <= 100: + # 51-100字符:线性增长 0.22 -> 0.26 + base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50 + else: + # 100+字符:对数增长 0.26 -> 0.3,增长率递减 + base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901 + + # 确保在范围内 + base_interest = min(max(base_interest, 0.01), 0.3) interested_rate += base_interest