This commit is contained in:
春河晴
2025-06-10 17:31:05 +09:00
parent b0c553703f
commit 3e854719ee
13 changed files with 686 additions and 646 deletions

Binary file not shown.

View File

@@ -9,13 +9,15 @@ import sqlite3
import re
from datetime import datetime
def clean_group_name(name: str) -> str:
"""清理群组名称,只保留中文和英文字符"""
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
if not cleaned:
cleaned = datetime.now().strftime("%Y%m%d")
return cleaned
def get_group_name(stream_id: str) -> str:
"""从数据库中获取群组名称"""
conn = sqlite3.connect("data/maibot.db")
@@ -43,6 +45,7 @@ def get_group_name(stream_id: str) -> str:
return clean_group_name(f"{platform}{stream_id[:8]}")
return stream_id
def format_timestamp(timestamp: float) -> str:
"""将时间戳转换为可读的时间格式"""
if not timestamp:
@@ -50,9 +53,11 @@ def format_timestamp(timestamp: float) -> str:
try:
dt = datetime.fromtimestamp(timestamp)
return dt.strftime("%Y-%m-%d %H:%M:%S")
except:
except Exception as e:
print(f"时间戳格式化错误: {e}")
return "未知"
def load_expressions(chat_id: str) -> List[Dict]:
"""加载指定群聊的表达方式"""
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -65,14 +70,15 @@ def load_expressions(chat_id: str) -> List[Dict]:
return style_exprs
def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]:
"""找出每个表达方式最相似的top_k个表达方式"""
if not expressions:
return {}
# 分别准备情景和表达方式的文本数据
situations = [expr['situation'] for expr in expressions]
styles = [expr['style'] for expr in expressions]
situations = [expr["situation"] for expr in expressions]
styles = [expr["style"] for expr in expressions]
# 使用TF-IDF向量化
vectorizer = TfidfVectorizer()
@@ -85,7 +91,7 @@ def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[st
# 对每个表达方式找出最相似的top_k个
similar_expressions = {}
for i, expr in enumerate(expressions):
for i, _ in enumerate(expressions):
# 获取相似度分数
situation_scores = situation_similarity[i]
style_scores = style_similarity[i]
@@ -100,29 +106,31 @@ def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[st
# 处理相似情景
for idx in situation_indices:
if situation_scores[idx] > 0: # 只保留有相似度的
similar_situations.append((
expressions[idx]['situation'],
expressions[idx]['style'], # 添加对应的原始表达
situation_scores[idx]
))
similar_situations.append(
(
expressions[idx]["situation"],
expressions[idx]["style"], # 添加对应的原始表达
situation_scores[idx],
)
)
# 处理相似表达
for idx in style_indices:
if style_scores[idx] > 0: # 只保留有相似度的
similar_styles.append((
expressions[idx]['style'],
expressions[idx]['situation'], # 添加对应的原始情景
style_scores[idx]
))
similar_styles.append(
(
expressions[idx]["style"],
expressions[idx]["situation"], # 添加对应的原始情景
style_scores[idx],
)
)
if similar_situations or similar_styles:
similar_expressions[i] = {
'situations': similar_situations,
'styles': similar_styles
}
similar_expressions[i] = {"situations": similar_situations, "styles": similar_styles}
return similar_expressions
def main():
# 获取所有群聊ID
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
@@ -164,18 +172,21 @@ def main():
print("\n" + "-" * 20)
print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}")
if similar_styles[i]['styles']:
if similar_styles[i]["styles"]:
print("\n\033[33m相似表达\033[0m")
for similar_style, original_situation, score in similar_styles[i]['styles']:
for similar_style, original_situation, score in similar_styles[i]["styles"]:
print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m")
if similar_styles[i]['situations']:
if similar_styles[i]["situations"]:
print("\n\033[32m相似情景\033[0m")
for similar_situation, original_style, score in similar_styles[i]['situations']:
for similar_situation, original_style, score in similar_styles[i]["situations"]:
print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m")
print(f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}")
print(
f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}"
)
print("-" * 20)
if __name__ == "__main__":
main()

View File

@@ -6,15 +6,17 @@ from datetime import datetime
from typing import Dict, List, Any
import sqlite3
def clean_group_name(name: str) -> str:
"""清理群组名称,只保留中文和英文字符"""
# 提取中文和英文字符
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
# 如果清理后为空,使用当前日期
if not cleaned:
cleaned = datetime.now().strftime("%Y%m%d")
return cleaned
def get_group_name(stream_id: str) -> str:
"""从数据库中获取群组名称"""
conn = sqlite3.connect("data/maibot.db")
@@ -42,6 +44,7 @@ def get_group_name(stream_id: str) -> str:
return clean_group_name(f"{platform}{stream_id[:8]}")
return stream_id
def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
"""加载指定群组的表达方式"""
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -66,10 +69,12 @@ def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str,
return style_expressions, grammar_expressions, personality_expressions
def format_time(timestamp: float) -> str:
"""格式化时间戳为可读字符串"""
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
"""写入表达方式列表"""
if not expressions:
@@ -87,7 +92,14 @@ def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
f.write(f"最后活跃: {format_time(last_active)}\n")
f.write("-" * 40 + "\n")
def write_group_report(group_file: str, group_name: str, chat_id: str, style_exprs: List[Dict[str, Any]], grammar_exprs: List[Dict[str, Any]]):
def write_group_report(
group_file: str,
group_name: str,
chat_id: str,
style_exprs: List[Dict[str, Any]],
grammar_exprs: List[Dict[str, Any]],
):
"""写入群组详细报告"""
with open(group_file, "w", encoding="utf-8") as gf:
gf.write(f"群组: {group_name} (ID: {chat_id})\n")
@@ -104,6 +116,7 @@ def write_group_report(group_file: str, group_name: str, chat_id: str, style_exp
gf.write("=" * 40 + "\n")
write_expressions(gf, grammar_exprs, "句法特点")
def analyze_expressions():
"""分析所有群组的表达方式"""
# 获取所有群组ID
@@ -197,5 +210,6 @@ def analyze_expressions():
print(f"人格表达报告: {personality_report}")
print(f"各群组详细报告位于: {output_dir}")
if __name__ == "__main__":
analyze_expressions()

View File

@@ -3,11 +3,11 @@ import json
import random
from typing import List, Dict, Tuple
import glob
from datetime import datetime
MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量
MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值
def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
"""加载指定群聊的表达方式"""
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -26,6 +26,7 @@ def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
return style_exprs, grammar_exprs
def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None:
"""保存表达方式到文件"""
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -40,6 +41,7 @@ def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[
with open(grammar_file, "w", encoding="utf-8") as f:
json.dump(grammar_exprs, f, ensure_ascii=False, indent=2)
def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
"""清理表达方式列表"""
if not expressions:
@@ -66,6 +68,7 @@ def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
return expressions
def main():
# 获取所有群聊ID
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
@@ -113,7 +116,10 @@ def main():
print("\n清理完成!")
print(f"语言风格总数: {total_style_before} -> {total_style_after}")
print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}")
print(f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式")
print(
f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式"
)
if __name__ == "__main__":
main()

View File

@@ -1,5 +1,6 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
@@ -15,13 +16,15 @@ import random
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
def clean_group_name(name: str) -> str:
"""清理群组名称,只保留中文和英文字符"""
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
if not cleaned:
cleaned = datetime.now().strftime("%Y%m%d")
return cleaned
def get_group_name(stream_id: str) -> str:
"""从数据库中获取群组名称"""
conn = sqlite3.connect("data/maibot.db")
@@ -49,6 +52,7 @@ def get_group_name(stream_id: str) -> str:
return clean_group_name(f"{platform}{stream_id[:8]}")
return stream_id
def load_expressions(chat_id: str) -> List[Dict]:
"""加载指定群聊的表达方式"""
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -66,16 +70,19 @@ def load_expressions(chat_id: str) -> List[Dict]:
return style_exprs
def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10) -> List[Tuple[str, str, float]]:
def find_similar_expressions_tfidf(
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10
) -> List[Tuple[str, str, float]]:
"""使用TF-IDF方法找出与输入文本最相似的top_k个表达方式"""
if not expressions:
return []
# 准备文本数据
if mode == "style":
texts = [expr['style'] for expr in expressions]
texts = [expr["style"] for expr in expressions]
elif mode == "situation":
texts = [expr['situation'] for expr in expressions]
texts = [expr["situation"] for expr in expressions]
else: # both
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
@@ -98,24 +105,23 @@ def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mod
similar_exprs = []
for idx in top_indices:
if scores[idx] > 0: # 只保留有相似度的
similar_exprs.append((
expressions[idx]['style'],
expressions[idx]['situation'],
scores[idx]
))
similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], scores[idx]))
return similar_exprs
async def find_similar_expressions_embedding(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5) -> List[Tuple[str, str, float]]:
async def find_similar_expressions_embedding(
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5
) -> List[Tuple[str, str, float]]:
"""使用嵌入模型找出与输入文本最相似的top_k个表达方式"""
if not expressions:
return []
# 准备文本数据
if mode == "style":
texts = [expr['style'] for expr in expressions]
texts = [expr["style"] for expr in expressions]
elif mode == "situation":
texts = [expr['situation'] for expr in expressions]
texts = [expr["situation"] for expr in expressions]
else: # both
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
@@ -144,14 +150,11 @@ async def find_similar_expressions_embedding(input_text: str, expressions: List[
similar_exprs = []
for idx in top_indices:
if similarities[idx] > 0: # 只保留有相似度的
similar_exprs.append((
expressions[idx]['style'],
expressions[idx]['situation'],
similarities[idx]
))
similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], similarities[idx]))
return similar_exprs
async def main():
# 获取所有群聊ID
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
@@ -202,11 +205,7 @@ async def main():
except ValueError:
print("请输入有效的数字")
mode_map = {
1: "style",
2: "situation",
3: "both"
}
mode_map = {1: "style", 2: "situation", 3: "both"}
mode = mode_map[mode_choice]
# 选择匹配方法
@@ -225,7 +224,7 @@ async def main():
while True:
input_text = input("\n请输入要匹配的文本输入q退出: ")
if input_text.lower() == 'q':
if input_text.lower() == "q":
break
if not input_text.strip():
@@ -246,6 +245,8 @@ async def main():
else:
print("\n没有找到相似的表达方式")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# ruff: noqa: E402
"""
消息检索脚本
@@ -10,57 +11,37 @@
5. 应用LLM分析将结果存储到数据库person_info中
"""
import sys
import os
import asyncio
import json
import re
import random
import time
import math
from datetime import datetime, timedelta
import sys
from collections import defaultdict
from typing import Dict, List, Any, Optional
from datetime import datetime, timedelta
from difflib import SequenceMatcher
from pathlib import Path
from typing import Dict, List, Any, Optional
import jieba
from json_repair import repair_json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.common.database.database_model import Messages
from src.person_info.person_info import PersonInfoManager
from src.config.config import global_config
from src.common.database.database import db
from src.chat.utils.chat_message_builder import build_readable_messages
from src.person_info.person_info import person_info_manager
from src.llm_models.utils_model import LLMRequest
from src.individuality.individuality import individuality
from json_repair import repair_json
from difflib import SequenceMatcher
import jieba
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from src.common.database.database_model import Messages
from src.common.logger_manager import get_logger
from src.common.database.database import db
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, person_info_manager
logger = get_logger("message_retrieval")
class MessageRetrievalScript:
def __init__(self):
"""初始化脚本"""
self.person_info_manager = PersonInfoManager()
self.bot_qq = str(global_config.bot.qq_account)
# 初始化LLM请求器和relationship_manager一样
self.relationship_llm = LLMRequest(
model=global_config.model.relation,
request_type="relationship",
)
def get_person_id(self, platform: str, user_id: str) -> str:
"""根据platform和user_id计算person_id"""
return PersonInfoManager.get_person_id(platform, user_id)
def get_time_range(self, time_period: str) -> Optional[float]:
def get_time_range(time_period: str) -> Optional[float]:
"""根据时间段选择获取起始时间戳"""
now = datetime.now()
@@ -77,74 +58,21 @@ class MessageRetrievalScript:
return start_time.timestamp()
def retrieve_messages(self, user_qq: str, time_period: str) -> Dict[str, List[Dict[str, Any]]]:
"""检索消息"""
print(f"开始检索用户 {user_qq} 的消息...")
# 计算person_id
person_id = self.get_person_id("qq", user_qq)
print(f"用户person_id: {person_id}")
def get_person_id(platform: str, user_id: str) -> str:
"""根据platform和user_id计算person_id"""
return PersonInfoManager.get_person_id(platform, user_id)
# 获取时间范围
start_timestamp = self.get_time_range(time_period)
if start_timestamp:
print(f"时间范围: {datetime.fromtimestamp(start_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 至今")
else:
print("时间范围: 全部时间")
# 构建查询条件
query = Messages.select()
# 添加用户条件包含bot消息或目标用户消息
user_condition = (
(Messages.user_id == self.bot_qq) | # bot的消息
(Messages.user_id == user_qq) # 目标用户的消息
)
query = query.where(user_condition)
# 添加时间条件
if start_timestamp:
query = query.where(Messages.time >= start_timestamp)
# 按时间排序
query = query.order_by(Messages.time.asc())
print("正在执行数据库查询...")
messages = list(query)
print(f"查询到 {len(messages)} 条消息")
# 按chat_id分组
grouped_messages = defaultdict(list)
for msg in messages:
msg_dict = {
'message_id': msg.message_id,
'time': msg.time,
'datetime': datetime.fromtimestamp(msg.time).strftime('%Y-%m-%d %H:%M:%S'),
'chat_id': msg.chat_id,
'user_id': msg.user_id,
'user_nickname': msg.user_nickname,
'user_platform': msg.user_platform,
'processed_plain_text': msg.processed_plain_text,
'display_message': msg.display_message,
'chat_info_group_id': msg.chat_info_group_id,
'chat_info_group_name': msg.chat_info_group_name,
'chat_info_platform': msg.chat_info_platform,
'user_cardname': msg.user_cardname,
'is_bot_message': msg.user_id == self.bot_qq
}
grouped_messages[msg.chat_id].append(msg_dict)
print(f"消息分布在 {len(grouped_messages)} 个聊天中")
return dict(grouped_messages)
def split_messages_by_count(self, messages: List[Dict[str, Any]], count: int = 50) -> List[List[Dict[str, Any]]]:
def split_messages_by_count(messages: List[Dict[str, Any]], count: int = 50) -> List[List[Dict[str, Any]]]:
"""将消息按指定数量分段"""
chunks = []
for i in range(0, len(messages), count):
chunks.append(messages[i : i + count])
return chunks
async def build_name_mapping(self, messages: List[Dict[str, Any]], target_person_id: str, target_person_name: str) -> Dict[str, str]:
async def build_name_mapping(messages: List[Dict[str, Any]], target_person_name: str) -> Dict[str, str]:
"""构建用户名称映射和relationship_manager中的逻辑一致"""
name_mapping = {}
current_user = "A"
@@ -160,7 +88,7 @@ class MessageRetrievalScript:
)
replace_user_id = msg.get("user_id")
replace_platform = msg.get("chat_info_platform")
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
replace_person_id = get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
# 跳过机器人自己
@@ -175,22 +103,23 @@ class MessageRetrievalScript:
# 其他用户映射
if replace_person_name not in name_mapping:
if current_user > 'Z':
current_user = 'A'
if current_user > "Z":
current_user = "A"
user_count += 1
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
current_user = chr(ord(current_user) + 1)
return name_mapping
def build_focus_readable_messages(self, messages: List[Dict[str, Any]], target_person_id: str = None) -> str:
def build_focus_readable_messages(messages: List[Dict[str, Any]], target_person_id: str = None) -> str:
"""格式化消息只保留目标用户和bot消息附近的内容和relationship_manager中的逻辑一致"""
# 找到目标用户和bot的消息索引
target_indices = []
for i, msg in enumerate(messages):
user_id = msg.get("user_id")
platform = msg.get("chat_info_platform")
person_id = person_info_manager.get_person_id(platform, user_id)
person_id = get_person_id(platform, user_id)
if person_id == target_person_id:
target_indices.append(i)
@@ -231,17 +160,14 @@ class MessageRetrievalScript:
if i > 0:
result.append("...")
group_text = build_readable_messages(
messages=group,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=False
messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False
)
result.append(group_text)
return "\n".join(result)
# 添加相似度检查方法和relationship_manager一致
def tfidf_similarity(self, s1, s2):
def tfidf_similarity(s1, s2):
"""使用 TF-IDF 和余弦相似度计算两个句子的相似性"""
# 确保输入是字符串类型
if isinstance(s1, list):
@@ -274,20 +200,13 @@ class MessageRetrievalScript:
# 返回 s1 和 s2 的相似度
return similarity_matrix[0, 1]
def sequence_similarity(self, s1, s2):
def sequence_similarity(s1, s2):
"""使用 SequenceMatcher 计算两个句子的相似性"""
return SequenceMatcher(None, s1, s2).ratio()
def check_similarity(self, text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
"""使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的"""
# 计算两种相似度
tfidf_sim = self.tfidf_similarity(text1, text2)
seq_sim = self.sequence_similarity(text1, text2)
# 只要其中一种方法达到阈值就认为是相似的
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
def calculate_time_weight(point_time: str, current_time: str) -> float:
"""计算基于时间的权重系数"""
try:
point_timestamp = datetime.strptime(point_time, "%Y-%m-%d %H:%M:%S")
@@ -311,6 +230,184 @@ class MessageRetrievalScript:
logger.error(f"计算时间权重失败: {e}")
return 0.5 # 发生错误时返回中等权重
def filter_selected_chats(
grouped_messages: Dict[str, List[Dict[str, Any]]], selected_indices: List[int]
) -> Dict[str, List[Dict[str, Any]]]:
"""根据用户选择过滤群聊"""
chat_items = list(grouped_messages.items())
selected_chats = {}
for idx in selected_indices:
chat_id, messages = chat_items[idx - 1] # 转换为0基索引
selected_chats[chat_id] = messages
return selected_chats
def get_user_selection(total_count: int) -> List[int]:
"""获取用户选择的群聊编号"""
while True:
print(f"\n请选择要分析的群聊 (1-{total_count}):")
print("输入格式:")
print(" 单个: 1")
print(" 多个: 1,3,5")
print(" 范围: 1-3")
print(" 全部: all 或 a")
print(" 退出: quit 或 q")
user_input = input("请输入选择: ").strip().lower()
if user_input in ["quit", "q"]:
return []
if user_input in ["all", "a"]:
return list(range(1, total_count + 1))
try:
selected = []
# 处理逗号分隔的输入
parts = user_input.split(",")
for part in parts:
part = part.strip()
if "-" in part:
# 处理范围输入 (如: 1-3)
start, end = part.split("-")
start_num = int(start.strip())
end_num = int(end.strip())
if 1 <= start_num <= total_count and 1 <= end_num <= total_count and start_num <= end_num:
selected.extend(range(start_num, end_num + 1))
else:
raise ValueError("范围超出有效范围")
else:
# 处理单个数字
num = int(part)
if 1 <= num <= total_count:
selected.append(num)
else:
raise ValueError("数字超出有效范围")
# 去重并排序
selected = sorted(list(set(selected)))
if selected:
return selected
else:
print("错误: 请输入有效的选择")
except ValueError as e:
print(f"错误: 输入格式无效 - {e}")
print("请重新输入")
def display_chat_list(grouped_messages: Dict[str, List[Dict[str, Any]]]) -> None:
"""显示群聊列表"""
print("\n找到以下群聊:")
print("=" * 60)
for i, (chat_id, messages) in enumerate(grouped_messages.items(), 1):
first_msg = messages[0]
group_name = first_msg.get("chat_info_group_name", "私聊")
group_id = first_msg.get("chat_info_group_id", chat_id)
# 计算时间范围
start_time = datetime.fromtimestamp(messages[0]["time"]).strftime("%Y-%m-%d")
end_time = datetime.fromtimestamp(messages[-1]["time"]).strftime("%Y-%m-%d")
print(f"{i:2d}. {group_name}")
print(f" 群ID: {group_id}")
print(f" 消息数: {len(messages)}")
print(f" 时间范围: {start_time} ~ {end_time}")
print("-" * 60)
def check_similarity(text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
"""使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的"""
# 计算两种相似度
tfidf_sim = tfidf_similarity(text1, text2)
seq_sim = sequence_similarity(text1, text2)
# 只要其中一种方法达到阈值就认为是相似的
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold
class MessageRetrievalScript:
def __init__(self):
"""初始化脚本"""
self.bot_qq = str(global_config.bot.qq_account)
# 初始化LLM请求器和relationship_manager一样
self.relationship_llm = LLMRequest(
model=global_config.model.relation,
request_type="relationship",
)
def retrieve_messages(self, user_qq: str, time_period: str) -> Dict[str, List[Dict[str, Any]]]:
"""检索消息"""
print(f"开始检索用户 {user_qq} 的消息...")
# 计算person_id
person_id = get_person_id("qq", user_qq)
print(f"用户person_id: {person_id}")
# 获取时间范围
start_timestamp = get_time_range(time_period)
if start_timestamp:
print(f"时间范围: {datetime.fromtimestamp(start_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 至今")
else:
print("时间范围: 全部时间")
# 构建查询条件
query = Messages.select()
# 添加用户条件包含bot消息或目标用户消息
user_condition = (
(Messages.user_id == self.bot_qq) # bot的消息
| (Messages.user_id == user_qq) # 目标用户的消息
)
query = query.where(user_condition)
# 添加时间条件
if start_timestamp:
query = query.where(Messages.time >= start_timestamp)
# 按时间排序
query = query.order_by(Messages.time.asc())
print("正在执行数据库查询...")
messages = list(query)
print(f"查询到 {len(messages)} 条消息")
# 按chat_id分组
grouped_messages = defaultdict(list)
for msg in messages:
msg_dict = {
"message_id": msg.message_id,
"time": msg.time,
"datetime": datetime.fromtimestamp(msg.time).strftime("%Y-%m-%d %H:%M:%S"),
"chat_id": msg.chat_id,
"user_id": msg.user_id,
"user_nickname": msg.user_nickname,
"user_platform": msg.user_platform,
"processed_plain_text": msg.processed_plain_text,
"display_message": msg.display_message,
"chat_info_group_id": msg.chat_info_group_id,
"chat_info_group_name": msg.chat_info_group_name,
"chat_info_platform": msg.chat_info_platform,
"user_cardname": msg.user_cardname,
"is_bot_message": msg.user_id == self.bot_qq,
}
grouped_messages[msg.chat_id].append(msg_dict)
print(f"消息分布在 {len(grouped_messages)} 个聊天中")
return dict(grouped_messages)
# 添加相似度检查方法和relationship_manager一致
async def update_person_impression_from_segment(self, person_id: str, readable_messages: str, segment_time: float):
"""从消息段落更新用户印象使用和relationship_manager相同的流程"""
person_name = await person_info_manager.get_value(person_id, "person_name")
@@ -412,7 +509,7 @@ class MessageRetrievalScript:
# 在现有points中查找相似的点
for i, existing_point in enumerate(current_points):
# 使用组合的相似度检查方法
if self.check_similarity(new_point[0], existing_point[0]):
if check_similarity(new_point[0], existing_point[0]):
similar_points.append(existing_point)
similar_indices.append(i)
@@ -460,7 +557,7 @@ class MessageRetrievalScript:
# 计算每个点的最终权重(原始权重 * 时间权重)
weighted_points = []
for point in current_points:
time_weight = self.calculate_time_weight(point[2], current_time_str)
time_weight = calculate_time_weight(point[2], current_time_str)
final_weight = point[1] * time_weight
weighted_points.append((point, final_weight))
@@ -505,10 +602,9 @@ class MessageRetrievalScript:
forgotten_points.sort(key=lambda x: x[2])
# 构建points文本
points_text = "\n".join([
f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}"
for point in forgotten_points
])
points_text = "\n".join(
[f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points]
)
impression = await person_info_manager.get_value(person_id, "impression") or ""
@@ -543,27 +639,30 @@ class MessageRetrievalScript:
forgotten_points = []
# 更新数据库
await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None))
await person_info_manager.update_one_field(
person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)
)
# 更新数据库
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
await person_info_manager.update_one_field(
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
)
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
await person_info_manager.update_one_field(person_id, "last_know", segment_time)
logger.info(f"印象更新完成 for {person_name},新增 {len(points_list)} 个记忆点")
async def process_segments_and_update_impression(self, user_qq: str, grouped_messages: Dict[str, List[Dict[str, Any]]]):
async def process_segments_and_update_impression(
self, user_qq: str, grouped_messages: Dict[str, List[Dict[str, Any]]]
):
"""处理分段消息并更新用户印象到数据库"""
# 获取目标用户信息
target_person_id = self.get_person_id("qq", user_qq)
target_person_id = get_person_id("qq", user_qq)
target_person_name = await person_info_manager.get_value(target_person_id, "person_name")
target_nickname = await person_info_manager.get_value(target_person_id, "nickname")
if not target_person_name:
target_person_name = f"用户{user_qq}"
if not target_nickname:
target_nickname = f"用户{user_qq}"
print(f"\n开始分析用户 {target_person_name} (QQ: {user_qq}) 的消息...")
@@ -575,52 +674,53 @@ class MessageRetrievalScript:
# 为每个chat_id处理消息收集所有分段
for chat_id, messages in grouped_messages.items():
first_msg = messages[0]
group_name = first_msg.get('chat_info_group_name', '私聊')
group_name = first_msg.get("chat_info_group_name", "私聊")
print(f"准备聊天: {group_name} (共{len(messages)}条消息)")
# 将消息按50条分段
message_chunks = self.split_messages_by_count(messages, 50)
message_chunks = split_messages_by_count(messages, 50)
for i, chunk in enumerate(message_chunks):
# 将分段信息添加到列表中,包含分段时间用于排序
segment_time = chunk[-1]['time']
all_segments.append({
'chunk': chunk,
'chat_id': chat_id,
'group_name': group_name,
'segment_index': i + 1,
'total_segments': len(message_chunks),
'segment_time': segment_time
})
segment_time = chunk[-1]["time"]
all_segments.append(
{
"chunk": chunk,
"chat_id": chat_id,
"group_name": group_name,
"segment_index": i + 1,
"total_segments": len(message_chunks),
"segment_time": segment_time,
}
)
# 按时间排序所有分段
all_segments.sort(key=lambda x: x['segment_time'])
all_segments.sort(key=lambda x: x["segment_time"])
print(f"\n按时间顺序处理 {len(all_segments)} 个分段:")
# 按时间顺序处理所有分段
for segment_idx, segment_info in enumerate(all_segments, 1):
chunk = segment_info['chunk']
group_name = segment_info['group_name']
segment_index = segment_info['segment_index']
total_segments = segment_info['total_segments']
segment_time = segment_info['segment_time']
chunk = segment_info["chunk"]
group_name = segment_info["group_name"]
segment_index = segment_info["segment_index"]
total_segments = segment_info["total_segments"]
segment_time = segment_info["segment_time"]
segment_time_str = datetime.fromtimestamp(segment_time).strftime('%Y-%m-%d %H:%M:%S')
print(f" [{segment_idx}/{len(all_segments)}] {group_name}{segment_index}/{total_segments}段 ({segment_time_str}) (共{len(chunk)}条)")
# 构建名称映射
name_mapping = await self.build_name_mapping(chunk, target_person_id, target_person_name)
# 构建可读消息
readable_messages = self.build_focus_readable_messages(
messages=chunk,
target_person_id=target_person_id
segment_time_str = datetime.fromtimestamp(segment_time).strftime("%Y-%m-%d %H:%M:%S")
print(
f" [{segment_idx}/{len(all_segments)}] {group_name}{segment_index}/{total_segments}段 ({segment_time_str}) (共{len(chunk)}条)"
)
# 构建名称映射
name_mapping = await build_name_mapping(chunk, target_person_name)
# 构建可读消息
readable_messages = build_focus_readable_messages(messages=chunk, target_person_id=target_person_id)
if not readable_messages:
print(f" 跳过:该段落没有目标用户的消息")
print(" 跳过:该段落没有目标用户的消息")
continue
# 应用名称映射
@@ -633,7 +733,7 @@ class MessageRetrievalScript:
total_segments_processed += 1
except Exception as e:
logger.error(f"处理段落时出错: {e}")
print(f" 错误:处理该段落时出现异常")
print(" 错误:处理该段落时出现异常")
# 获取最终统计
final_points = await person_info_manager.get_value(target_person_id, "points") or []
@@ -645,7 +745,7 @@ class MessageRetrievalScript:
final_impression = await person_info_manager.get_value(target_person_id, "impression") or ""
print(f"\n=== 处理完成 ===")
print("\n=== 处理完成 ===")
print(f"目标用户: {target_person_name} (QQ: {user_qq})")
print(f"处理段落数: {total_segments_processed}")
print(f"当前记忆点数: {len(final_points)}")
@@ -654,95 +754,6 @@ class MessageRetrievalScript:
if final_points:
print(f"最新记忆点: {final_points[-1][0][:50]}...")
def display_chat_list(self, grouped_messages: Dict[str, List[Dict[str, Any]]]) -> None:
"""显示群聊列表"""
print("\n找到以下群聊:")
print("=" * 60)
for i, (chat_id, messages) in enumerate(grouped_messages.items(), 1):
first_msg = messages[0]
group_name = first_msg.get('chat_info_group_name', '私聊')
group_id = first_msg.get('chat_info_group_id', chat_id)
# 计算时间范围
start_time = datetime.fromtimestamp(messages[0]['time']).strftime('%Y-%m-%d')
end_time = datetime.fromtimestamp(messages[-1]['time']).strftime('%Y-%m-%d')
print(f"{i:2d}. {group_name}")
print(f" 群ID: {group_id}")
print(f" 消息数: {len(messages)}")
print(f" 时间范围: {start_time} ~ {end_time}")
print("-" * 60)
def get_user_selection(self, total_count: int) -> List[int]:
"""获取用户选择的群聊编号"""
while True:
print(f"\n请选择要分析的群聊 (1-{total_count}):")
print("输入格式:")
print(" 单个: 1")
print(" 多个: 1,3,5")
print(" 范围: 1-3")
print(" 全部: all 或 a")
print(" 退出: quit 或 q")
user_input = input("请输入选择: ").strip().lower()
if user_input in ['quit', 'q']:
return []
if user_input in ['all', 'a']:
return list(range(1, total_count + 1))
try:
selected = []
# 处理逗号分隔的输入
parts = user_input.split(',')
for part in parts:
part = part.strip()
if '-' in part:
# 处理范围输入 (如: 1-3)
start, end = part.split('-')
start_num = int(start.strip())
end_num = int(end.strip())
if 1 <= start_num <= total_count and 1 <= end_num <= total_count and start_num <= end_num:
selected.extend(range(start_num, end_num + 1))
else:
raise ValueError("范围超出有效范围")
else:
# 处理单个数字
num = int(part)
if 1 <= num <= total_count:
selected.append(num)
else:
raise ValueError("数字超出有效范围")
# 去重并排序
selected = sorted(list(set(selected)))
if selected:
return selected
else:
print("错误: 请输入有效的选择")
except ValueError as e:
print(f"错误: 输入格式无效 - {e}")
print("请重新输入")
def filter_selected_chats(self, grouped_messages: Dict[str, List[Dict[str, Any]]], selected_indices: List[int]) -> Dict[str, List[Dict[str, Any]]]:
"""根据用户选择过滤群聊"""
chat_items = list(grouped_messages.items())
selected_chats = {}
for idx in selected_indices:
chat_id, messages = chat_items[idx - 1] # 转换为0基索引
selected_chats[chat_id] = messages
return selected_chats
async def run(self):
"""运行脚本"""
print("=== 消息检索分析脚本 ===")
@@ -760,12 +771,7 @@ class MessageRetrievalScript:
print("4. 最近1周 (1week)")
choice = input("请选择时间段 (1-4): ").strip()
time_periods = {
"1": "all",
"2": "3months",
"3": "1month",
"4": "1week"
}
time_periods = {"1": "all", "2": "3months", "3": "1month", "4": "1week"}
if choice not in time_periods:
print("选择无效")
@@ -792,28 +798,28 @@ class MessageRetrievalScript:
return
# 显示群聊列表
self.display_chat_list(grouped_messages)
display_chat_list(grouped_messages)
# 获取用户选择
selected_indices = self.get_user_selection(len(grouped_messages))
selected_indices = get_user_selection(len(grouped_messages))
if not selected_indices:
print("已取消操作")
return
# 过滤选中的群聊
selected_chats = self.filter_selected_chats(grouped_messages, selected_indices)
selected_chats = filter_selected_chats(grouped_messages, selected_indices)
# 显示选中的群聊
print(f"\n已选择 {len(selected_chats)} 个群聊进行分析:")
for i, (chat_id, messages) in enumerate(selected_chats.items(), 1):
for i, (_, messages) in enumerate(selected_chats.items(), 1):
first_msg = messages[0]
group_name = first_msg.get('chat_info_group_name', '私聊')
group_name = first_msg.get("chat_info_group_name", "私聊")
print(f" {i}. {group_name} ({len(messages)}条消息)")
# 确认处理
confirm = input(f"\n确认分析这些群聊吗? (y/n): ").strip().lower()
if confirm != 'y':
confirm = input("\n确认分析这些群聊吗? (y/n): ").strip().lower()
if confirm != "y":
print("已取消操作")
return
@@ -823,16 +829,18 @@ class MessageRetrievalScript:
except Exception as e:
print(f"处理过程中出现错误: {e}")
import traceback
traceback.print_exc()
finally:
db.close()
print("数据库连接已关闭")
def main():
"""主函数"""
script = MessageRetrievalScript()
asyncio.run(script.run())
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,7 @@
from typing import Optional, Tuple, Dict
from src.common.logger_manager import get_logger
from src.chat.message_receive.chat_stream import chat_manager
from src.person_info.person_info import person_info_manager
from src.person_info.person_info import person_info_manager, PersonInfoManager
logger = get_logger("heartflow_utils")
@@ -47,7 +47,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
# Try to fetch person info
try:
# Assume get_person_id is sync (as per original code), keep using to_thread
person_id = person_info_manager.get_person_id(platform, user_id)
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_name = None
if person_id:
# get_value is async, so await it directly

View File

@@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageThinking
from src.chat.normal_chat.normal_prompt import prompt_builder
from src.chat.utils.timer_calculator import Timer
from src.common.logger_manager import get_logger
from src.person_info.person_info import person_info_manager
from src.person_info.person_info import person_info_manager, PersonInfoManager
from src.chat.utils.utils import process_llm_response
@@ -66,7 +66,7 @@ class NormalChatGenerator:
enable_planner: bool = False,
available_actions=None,
):
person_id = person_info_manager.get_person_id(
person_id = PersonInfoManager.get_person_id(
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
)

View File

@@ -96,7 +96,7 @@ class BaseWillingManager(ABC):
self.logger: LoguruLogger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
person_id = person_info_manager.get_person_id(chat.platform, chat.user_info.user_id)
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
message=message,
chat=chat,

View File

@@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间
import random
import re
from src.common.message_repository import find_messages, count_messages
from src.person_info.person_info import person_info_manager
from src.person_info.person_info import person_info_manager, PersonInfoManager
from src.chat.utils.utils import translate_timestamp_to_human_readable
from rich.traceback import install
from src.common.database.database_model import ActionRecords
@@ -219,7 +219,7 @@ def _build_readable_messages_internal(
if not all([platform, user_id, timestamp is not None]):
continue
person_id = person_info_manager.get_person_id(platform, user_id)
person_id = PersonInfoManager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
@@ -241,7 +241,7 @@ def _build_readable_messages_internal(
if match:
aaa = match.group(1)
bbb = match.group(2)
reply_person_id = person_info_manager.get_person_id(platform, bbb)
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
if not reply_person_name:
reply_person_name = aaa
@@ -258,7 +258,7 @@ def _build_readable_messages_internal(
new_content += content[last_end : m.start()]
aaa = m.group(1)
bbb = m.group(2)
at_person_id = person_info_manager.get_person_id(platform, bbb)
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
if not at_person_name:
at_person_name = aaa
@@ -572,7 +572,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# print("SELF11111111111111")
return "SELF"
try:
person_id = person_info_manager.get_person_id(platform, user_id)
person_id = PersonInfoManager.get_person_id(platform, user_id)
except Exception as _e:
person_id = None
if not person_id:
@@ -673,7 +673,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
person_id = person_info_manager.get_person_id(platform, user_id)
person_id = PersonInfoManager.get_person_id(platform, user_id)
# 只有当获取到有效 person_id 时才添加
if person_id:

View File

@@ -1,9 +1,9 @@
from src.manager.async_task_manager import AsyncTask
from src.common.logger_manager import get_logger
from src.person_info.person_info import PersonInfoManager
from src.person_info.relationship_manager import relationship_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp
from src.config.config import global_config
from src.person_info.person_info import person_info_manager
from src.chat.message_receive.chat_stream import chat_manager
import time
import random
@@ -95,7 +95,7 @@ class ImpressionUpdateTask(AsyncTask):
if msg["user_nickname"] == global_config.bot.nickname:
continue
person_id = person_info_manager.get_person_id(msg["chat_info_platform"], msg["user_id"])
person_id = PersonInfoManager.get_person_id(msg["chat_info_platform"], msg["user_id"])
if not person_id:
logger.warning(f"未找到用户 {msg['user_nickname']} 的person_id")
continue

View File

@@ -1,6 +1,6 @@
from src.common.logger_manager import get_logger
import math
from src.person_info.person_info import person_info_manager
from src.person_info.person_info import person_info_manager, PersonInfoManager
import time
import random
from src.llm_models.utils_model import LLMRequest
@@ -91,7 +91,7 @@ class RelationshipManager:
@staticmethod
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
"""判断是否认识某人"""
person_id = person_info_manager.get_person_id(platform, user_id)
person_id = PersonInfoManager.get_person_id(platform, user_id)
# 生成唯一的 person_name
unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname)
data = {
@@ -116,7 +116,7 @@ class RelationshipManager:
if is_id:
person_id = person
else:
person_id = person_info_manager.get_person_id(person[0], person[1])
person_id = PersonInfoManager.get_person_id(person[0], person[1])
person_name = await person_info_manager.get_value(person_id, "person_name")
if not person_name or person_name == "none":
@@ -198,7 +198,7 @@ class RelationshipManager:
)
replace_user_id = msg.get("user_id")
replace_platform = msg.get("chat_info_platform")
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
# 跳过机器人自己
@@ -467,7 +467,7 @@ class RelationshipManager:
for i, msg in enumerate(messages):
user_id = msg.get("user_id")
platform = msg.get("chat_info_platform")
person_id = person_info_manager.get_person_id(platform, user_id)
person_id = PersonInfoManager.get_person_id(platform, user_id)
if person_id == target_person_id:
target_indices.append(i)