This commit is contained in:
春河晴
2025-06-10 17:31:05 +09:00
parent b0c553703f
commit 3e854719ee
13 changed files with 686 additions and 646 deletions

Binary file not shown.

View File

@@ -9,13 +9,15 @@ import sqlite3
import re import re
from datetime import datetime from datetime import datetime
def clean_group_name(name: str) -> str: def clean_group_name(name: str) -> str:
"""清理群组名称,只保留中文和英文字符""" """清理群组名称,只保留中文和英文字符"""
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name) cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
if not cleaned: if not cleaned:
cleaned = datetime.now().strftime("%Y%m%d") cleaned = datetime.now().strftime("%Y%m%d")
return cleaned return cleaned
def get_group_name(stream_id: str) -> str: def get_group_name(stream_id: str) -> str:
"""从数据库中获取群组名称""" """从数据库中获取群组名称"""
conn = sqlite3.connect("data/maibot.db") conn = sqlite3.connect("data/maibot.db")
@@ -43,6 +45,7 @@ def get_group_name(stream_id: str) -> str:
return clean_group_name(f"{platform}{stream_id[:8]}") return clean_group_name(f"{platform}{stream_id[:8]}")
return stream_id return stream_id
def format_timestamp(timestamp: float) -> str: def format_timestamp(timestamp: float) -> str:
"""将时间戳转换为可读的时间格式""" """将时间戳转换为可读的时间格式"""
if not timestamp: if not timestamp:
@@ -50,9 +53,11 @@ def format_timestamp(timestamp: float) -> str:
try: try:
dt = datetime.fromtimestamp(timestamp) dt = datetime.fromtimestamp(timestamp)
return dt.strftime("%Y-%m-%d %H:%M:%S") return dt.strftime("%Y-%m-%d %H:%M:%S")
except: except Exception as e:
print(f"时间戳格式化错误: {e}")
return "未知" return "未知"
def load_expressions(chat_id: str) -> List[Dict]: def load_expressions(chat_id: str) -> List[Dict]:
"""加载指定群聊的表达方式""" """加载指定群聊的表达方式"""
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -65,14 +70,15 @@ def load_expressions(chat_id: str) -> List[Dict]:
return style_exprs return style_exprs
def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]: def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]:
"""找出每个表达方式最相似的top_k个表达方式""" """找出每个表达方式最相似的top_k个表达方式"""
if not expressions: if not expressions:
return {} return {}
# 分别准备情景和表达方式的文本数据 # 分别准备情景和表达方式的文本数据
situations = [expr['situation'] for expr in expressions] situations = [expr["situation"] for expr in expressions]
styles = [expr['style'] for expr in expressions] styles = [expr["style"] for expr in expressions]
# 使用TF-IDF向量化 # 使用TF-IDF向量化
vectorizer = TfidfVectorizer() vectorizer = TfidfVectorizer()
@@ -85,14 +91,14 @@ def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[st
# 对每个表达方式找出最相似的top_k个 # 对每个表达方式找出最相似的top_k个
similar_expressions = {} similar_expressions = {}
for i, expr in enumerate(expressions): for i, _ in enumerate(expressions):
# 获取相似度分数 # 获取相似度分数
situation_scores = situation_similarity[i] situation_scores = situation_similarity[i]
style_scores = style_similarity[i] style_scores = style_similarity[i]
# 获取top_k的索引排除自己 # 获取top_k的索引排除自己
situation_indices = np.argsort(situation_scores)[::-1][1:top_k+1] situation_indices = np.argsort(situation_scores)[::-1][1 : top_k + 1]
style_indices = np.argsort(style_scores)[::-1][1:top_k+1] style_indices = np.argsort(style_scores)[::-1][1 : top_k + 1]
similar_situations = [] similar_situations = []
similar_styles = [] similar_styles = []
@@ -100,29 +106,31 @@ def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[st
# 处理相似情景 # 处理相似情景
for idx in situation_indices: for idx in situation_indices:
if situation_scores[idx] > 0: # 只保留有相似度的 if situation_scores[idx] > 0: # 只保留有相似度的
similar_situations.append(( similar_situations.append(
expressions[idx]['situation'], (
expressions[idx]['style'], # 添加对应的原始表达 expressions[idx]["situation"],
situation_scores[idx] expressions[idx]["style"], # 添加对应的原始表达
)) situation_scores[idx],
)
)
# 处理相似表达 # 处理相似表达
for idx in style_indices: for idx in style_indices:
if style_scores[idx] > 0: # 只保留有相似度的 if style_scores[idx] > 0: # 只保留有相似度的
similar_styles.append(( similar_styles.append(
expressions[idx]['style'], (
expressions[idx]['situation'], # 添加对应的原始情景 expressions[idx]["style"],
style_scores[idx] expressions[idx]["situation"], # 添加对应的原始情景
)) style_scores[idx],
)
)
if similar_situations or similar_styles: if similar_situations or similar_styles:
similar_expressions[i] = { similar_expressions[i] = {"situations": similar_situations, "styles": similar_styles}
'situations': similar_situations,
'styles': similar_styles
}
return similar_expressions return similar_expressions
def main(): def main():
# 获取所有群聊ID # 获取所有群聊ID
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
@@ -143,7 +151,7 @@ def main():
if choice == 0: if choice == 0:
break break
if 1 <= choice <= len(chat_ids): if 1 <= choice <= len(chat_ids):
chat_id = chat_ids[choice-1] chat_id = chat_ids[choice - 1]
break break
print("无效的选择,请重试") print("无效的选择,请重试")
except ValueError: except ValueError:
@@ -164,18 +172,21 @@ def main():
print("\n" + "-" * 20) print("\n" + "-" * 20)
print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}") print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}")
if similar_styles[i]['styles']: if similar_styles[i]["styles"]:
print("\n\033[33m相似表达\033[0m") print("\n\033[33m相似表达\033[0m")
for similar_style, original_situation, score in similar_styles[i]['styles']: for similar_style, original_situation, score in similar_styles[i]["styles"]:
print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m") print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m")
if similar_styles[i]['situations']: if similar_styles[i]["situations"]:
print("\n\033[32m相似情景\033[0m") print("\n\033[32m相似情景\033[0m")
for similar_situation, original_style, score in similar_styles[i]['situations']: for similar_situation, original_style, score in similar_styles[i]["situations"]:
print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m") print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m")
print(f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}") print(
f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}"
)
print("-" * 20) print("-" * 20)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -6,15 +6,17 @@ from datetime import datetime
from typing import Dict, List, Any from typing import Dict, List, Any
import sqlite3 import sqlite3
def clean_group_name(name: str) -> str: def clean_group_name(name: str) -> str:
"""清理群组名称,只保留中文和英文字符""" """清理群组名称,只保留中文和英文字符"""
# 提取中文和英文字符 # 提取中文和英文字符
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name) cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
# 如果清理后为空,使用当前日期 # 如果清理后为空,使用当前日期
if not cleaned: if not cleaned:
cleaned = datetime.now().strftime("%Y%m%d") cleaned = datetime.now().strftime("%Y%m%d")
return cleaned return cleaned
def get_group_name(stream_id: str) -> str: def get_group_name(stream_id: str) -> str:
"""从数据库中获取群组名称""" """从数据库中获取群组名称"""
conn = sqlite3.connect("data/maibot.db") conn = sqlite3.connect("data/maibot.db")
@@ -42,6 +44,7 @@ def get_group_name(stream_id: str) -> str:
return clean_group_name(f"{platform}{stream_id[:8]}") return clean_group_name(f"{platform}{stream_id[:8]}")
return stream_id return stream_id
def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
"""加载指定群组的表达方式""" """加载指定群组的表达方式"""
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -66,10 +69,12 @@ def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str,
return style_expressions, grammar_expressions, personality_expressions return style_expressions, grammar_expressions, personality_expressions
def format_time(timestamp: float) -> str: def format_time(timestamp: float) -> str:
"""格式化时间戳为可读字符串""" """格式化时间戳为可读字符串"""
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
def write_expressions(f, expressions: List[Dict[str, Any]], title: str): def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
"""写入表达方式列表""" """写入表达方式列表"""
if not expressions: if not expressions:
@@ -87,7 +92,14 @@ def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
f.write(f"最后活跃: {format_time(last_active)}\n") f.write(f"最后活跃: {format_time(last_active)}\n")
f.write("-" * 40 + "\n") f.write("-" * 40 + "\n")
def write_group_report(group_file: str, group_name: str, chat_id: str, style_exprs: List[Dict[str, Any]], grammar_exprs: List[Dict[str, Any]]):
def write_group_report(
group_file: str,
group_name: str,
chat_id: str,
style_exprs: List[Dict[str, Any]],
grammar_exprs: List[Dict[str, Any]],
):
"""写入群组详细报告""" """写入群组详细报告"""
with open(group_file, "w", encoding="utf-8") as gf: with open(group_file, "w", encoding="utf-8") as gf:
gf.write(f"群组: {group_name} (ID: {chat_id})\n") gf.write(f"群组: {group_name} (ID: {chat_id})\n")
@@ -104,6 +116,7 @@ def write_group_report(group_file: str, group_name: str, chat_id: str, style_exp
gf.write("=" * 40 + "\n") gf.write("=" * 40 + "\n")
write_expressions(gf, grammar_exprs, "句法特点") write_expressions(gf, grammar_exprs, "句法特点")
def analyze_expressions(): def analyze_expressions():
"""分析所有群组的表达方式""" """分析所有群组的表达方式"""
# 获取所有群组ID # 获取所有群组ID
@@ -197,5 +210,6 @@ def analyze_expressions():
print(f"人格表达报告: {personality_report}") print(f"人格表达报告: {personality_report}")
print(f"各群组详细报告位于: {output_dir}") print(f"各群组详细报告位于: {output_dir}")
if __name__ == "__main__": if __name__ == "__main__":
analyze_expressions() analyze_expressions()

View File

@@ -3,11 +3,11 @@ import json
import random import random
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
import glob import glob
from datetime import datetime
MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量 MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量
MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值 MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值
def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]: def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
"""加载指定群聊的表达方式""" """加载指定群聊的表达方式"""
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -26,6 +26,7 @@ def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
return style_exprs, grammar_exprs return style_exprs, grammar_exprs
def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None: def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None:
"""保存表达方式到文件""" """保存表达方式到文件"""
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -40,6 +41,7 @@ def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[
with open(grammar_file, "w", encoding="utf-8") as f: with open(grammar_file, "w", encoding="utf-8") as f:
json.dump(grammar_exprs, f, ensure_ascii=False, indent=2) json.dump(grammar_exprs, f, ensure_ascii=False, indent=2)
def cleanup_expressions(expressions: List[Dict]) -> List[Dict]: def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
"""清理表达方式列表""" """清理表达方式列表"""
if not expressions: if not expressions:
@@ -60,12 +62,13 @@ def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
# 从剩余的表达方式中随机选择 # 从剩余的表达方式中随机选择
remaining_exprs = expressions[keep_count:] remaining_exprs = expressions[keep_count:]
random.shuffle(remaining_exprs) random.shuffle(remaining_exprs)
keep_exprs.extend(remaining_exprs[:MAX_EXPRESSION_COUNT - keep_count]) keep_exprs.extend(remaining_exprs[: MAX_EXPRESSION_COUNT - keep_count])
expressions = keep_exprs expressions = keep_exprs
return expressions return expressions
def main(): def main():
# 获取所有群聊ID # 获取所有群聊ID
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
@@ -113,7 +116,10 @@ def main():
print("\n清理完成!") print("\n清理完成!")
print(f"语言风格总数: {total_style_before} -> {total_style_after}") print(f"语言风格总数: {total_style_before} -> {total_style_after}")
print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}") print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}")
print(f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式") print(
f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式"
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,5 +1,6 @@
import os import os
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json import json
@@ -15,13 +16,15 @@ import random
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
def clean_group_name(name: str) -> str: def clean_group_name(name: str) -> str:
"""清理群组名称,只保留中文和英文字符""" """清理群组名称,只保留中文和英文字符"""
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name) cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
if not cleaned: if not cleaned:
cleaned = datetime.now().strftime("%Y%m%d") cleaned = datetime.now().strftime("%Y%m%d")
return cleaned return cleaned
def get_group_name(stream_id: str) -> str: def get_group_name(stream_id: str) -> str:
"""从数据库中获取群组名称""" """从数据库中获取群组名称"""
conn = sqlite3.connect("data/maibot.db") conn = sqlite3.connect("data/maibot.db")
@@ -49,6 +52,7 @@ def get_group_name(stream_id: str) -> str:
return clean_group_name(f"{platform}{stream_id[:8]}") return clean_group_name(f"{platform}{stream_id[:8]}")
return stream_id return stream_id
def load_expressions(chat_id: str) -> List[Dict]: def load_expressions(chat_id: str) -> List[Dict]:
"""加载指定群聊的表达方式""" """加载指定群聊的表达方式"""
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
@@ -66,16 +70,19 @@ def load_expressions(chat_id: str) -> List[Dict]:
return style_exprs return style_exprs
def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10) -> List[Tuple[str, str, float]]:
def find_similar_expressions_tfidf(
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10
) -> List[Tuple[str, str, float]]:
"""使用TF-IDF方法找出与输入文本最相似的top_k个表达方式""" """使用TF-IDF方法找出与输入文本最相似的top_k个表达方式"""
if not expressions: if not expressions:
return [] return []
# 准备文本数据 # 准备文本数据
if mode == "style": if mode == "style":
texts = [expr['style'] for expr in expressions] texts = [expr["style"] for expr in expressions]
elif mode == "situation": elif mode == "situation":
texts = [expr['situation'] for expr in expressions] texts = [expr["situation"] for expr in expressions]
else: # both else: # both
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions] texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
@@ -98,24 +105,23 @@ def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mod
similar_exprs = [] similar_exprs = []
for idx in top_indices: for idx in top_indices:
if scores[idx] > 0: # 只保留有相似度的 if scores[idx] > 0: # 只保留有相似度的
similar_exprs.append(( similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], scores[idx]))
expressions[idx]['style'],
expressions[idx]['situation'],
scores[idx]
))
return similar_exprs return similar_exprs
async def find_similar_expressions_embedding(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5) -> List[Tuple[str, str, float]]:
async def find_similar_expressions_embedding(
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5
) -> List[Tuple[str, str, float]]:
"""使用嵌入模型找出与输入文本最相似的top_k个表达方式""" """使用嵌入模型找出与输入文本最相似的top_k个表达方式"""
if not expressions: if not expressions:
return [] return []
# 准备文本数据 # 准备文本数据
if mode == "style": if mode == "style":
texts = [expr['style'] for expr in expressions] texts = [expr["style"] for expr in expressions]
elif mode == "situation": elif mode == "situation":
texts = [expr['situation'] for expr in expressions] texts = [expr["situation"] for expr in expressions]
else: # both else: # both
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions] texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
@@ -144,14 +150,11 @@ async def find_similar_expressions_embedding(input_text: str, expressions: List[
similar_exprs = [] similar_exprs = []
for idx in top_indices: for idx in top_indices:
if similarities[idx] > 0: # 只保留有相似度的 if similarities[idx] > 0: # 只保留有相似度的
similar_exprs.append(( similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], similarities[idx]))
expressions[idx]['style'],
expressions[idx]['situation'],
similarities[idx]
))
return similar_exprs return similar_exprs
async def main(): async def main():
# 获取所有群聊ID # 获取所有群聊ID
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
@@ -172,7 +175,7 @@ async def main():
if choice == 0: if choice == 0:
break break
if 1 <= choice <= len(chat_ids): if 1 <= choice <= len(chat_ids):
chat_id = chat_ids[choice-1] chat_id = chat_ids[choice - 1]
break break
print("无效的选择,请重试") print("无效的选择,请重试")
except ValueError: except ValueError:
@@ -202,11 +205,7 @@ async def main():
except ValueError: except ValueError:
print("请输入有效的数字") print("请输入有效的数字")
mode_map = { mode_map = {1: "style", 2: "situation", 3: "both"}
1: "style",
2: "situation",
3: "both"
}
mode = mode_map[mode_choice] mode = mode_map[mode_choice]
# 选择匹配方法 # 选择匹配方法
@@ -225,7 +224,7 @@ async def main():
while True: while True:
input_text = input("\n请输入要匹配的文本输入q退出: ") input_text = input("\n请输入要匹配的文本输入q退出: ")
if input_text.lower() == 'q': if input_text.lower() == "q":
break break
if not input_text.strip(): if not input_text.strip():
@@ -246,6 +245,8 @@ async def main():
else: else:
print("\n没有找到相似的表达方式") print("\n没有找到相似的表达方式")
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# ruff: noqa: E402
""" """
消息检索脚本 消息检索脚本
@@ -10,44 +11,333 @@
5. 应用LLM分析将结果存储到数据库person_info中 5. 应用LLM分析将结果存储到数据库person_info中
""" """
import sys
import os
import asyncio import asyncio
import json import json
import re
import random import random
import time import sys
import math
from datetime import datetime, timedelta
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Any, Optional from datetime import datetime, timedelta
from difflib import SequenceMatcher
from pathlib import Path from pathlib import Path
from typing import Dict, List, Any, Optional
import jieba
from json_repair import repair_json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# 添加项目根目录到Python路径 # 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
from src.common.database.database_model import Messages
from src.person_info.person_info import PersonInfoManager
from src.config.config import global_config
from src.common.database.database import db
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
from src.person_info.person_info import person_info_manager from src.common.database.database_model import Messages
from src.llm_models.utils_model import LLMRequest
from src.individuality.individuality import individuality
from json_repair import repair_json
from difflib import SequenceMatcher
import jieba
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.common.database.database import db
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, person_info_manager
logger = get_logger("message_retrieval") logger = get_logger("message_retrieval")
def get_time_range(time_period: str) -> Optional[float]:
"""根据时间段选择获取起始时间戳"""
now = datetime.now()
if time_period == "all":
return None
elif time_period == "3months":
start_time = now - timedelta(days=90)
elif time_period == "1month":
start_time = now - timedelta(days=30)
elif time_period == "1week":
start_time = now - timedelta(days=7)
else:
raise ValueError(f"不支持的时间段: {time_period}")
return start_time.timestamp()
def get_person_id(platform: str, user_id: str) -> str:
"""根据platform和user_id计算person_id"""
return PersonInfoManager.get_person_id(platform, user_id)
def split_messages_by_count(messages: List[Dict[str, Any]], count: int = 50) -> List[List[Dict[str, Any]]]:
"""将消息按指定数量分段"""
chunks = []
for i in range(0, len(messages), count):
chunks.append(messages[i : i + count])
return chunks
async def build_name_mapping(messages: List[Dict[str, Any]], target_person_name: str) -> Dict[str, str]:
"""构建用户名称映射和relationship_manager中的逻辑一致"""
name_mapping = {}
current_user = "A"
user_count = 1
# 遍历消息,构建映射
for msg in messages:
await person_info_manager.get_or_create_person(
platform=msg.get("chat_info_platform"),
user_id=msg.get("user_id"),
nickname=msg.get("user_nickname"),
user_cardname=msg.get("user_cardname"),
)
replace_user_id = msg.get("user_id")
replace_platform = msg.get("chat_info_platform")
replace_person_id = get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
# 跳过机器人自己
if replace_user_id == global_config.bot.qq_account:
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
continue
# 跳过目标用户
if replace_person_name == target_person_name:
name_mapping[replace_person_name] = f"{target_person_name}"
continue
# 其他用户映射
if replace_person_name not in name_mapping:
if current_user > "Z":
current_user = "A"
user_count += 1
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
current_user = chr(ord(current_user) + 1)
return name_mapping
def build_focus_readable_messages(messages: List[Dict[str, Any]], target_person_id: str = None) -> str:
"""格式化消息只保留目标用户和bot消息附近的内容和relationship_manager中的逻辑一致"""
# 找到目标用户和bot的消息索引
target_indices = []
for i, msg in enumerate(messages):
user_id = msg.get("user_id")
platform = msg.get("chat_info_platform")
person_id = get_person_id(platform, user_id)
if person_id == target_person_id:
target_indices.append(i)
if not target_indices:
return ""
# 获取需要保留的消息索引
keep_indices = set()
for idx in target_indices:
# 获取前后5条消息的索引
start_idx = max(0, idx - 5)
end_idx = min(len(messages), idx + 6)
keep_indices.update(range(start_idx, end_idx))
# 将索引排序
keep_indices = sorted(list(keep_indices))
# 按顺序构建消息组
message_groups = []
current_group = []
for i in range(len(messages)):
if i in keep_indices:
current_group.append(messages[i])
elif current_group:
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
if current_group:
message_groups.append(current_group)
current_group = []
# 添加最后一组
if current_group:
message_groups.append(current_group)
# 构建最终的消息文本
result = []
for i, group in enumerate(message_groups):
if i > 0:
result.append("...")
group_text = build_readable_messages(
messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False
)
result.append(group_text)
return "\n".join(result)
def tfidf_similarity(s1, s2):
"""使用 TF-IDF 和余弦相似度计算两个句子的相似性"""
# 确保输入是字符串类型
if isinstance(s1, list):
s1 = " ".join(str(x) for x in s1)
if isinstance(s2, list):
s2 = " ".join(str(x) for x in s2)
# 转换为字符串类型
s1 = str(s1)
s2 = str(s2)
# 1. 使用 jieba 进行分词
s1_words = " ".join(jieba.cut(s1))
s2_words = " ".join(jieba.cut(s2))
# 2. 将两句话放入一个列表中
corpus = [s1_words, s2_words]
# 3. 创建 TF-IDF 向量化器并进行计算
try:
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(corpus)
except ValueError:
# 如果句子完全由停用词组成,或者为空,可能会报错
return 0.0
# 4. 计算余弦相似度
similarity_matrix = cosine_similarity(tfidf_matrix)
# 返回 s1 和 s2 的相似度
return similarity_matrix[0, 1]
def sequence_similarity(s1, s2):
"""使用 SequenceMatcher 计算两个句子的相似性"""
return SequenceMatcher(None, s1, s2).ratio()
def calculate_time_weight(point_time: str, current_time: str) -> float:
"""计算基于时间的权重系数"""
try:
point_timestamp = datetime.strptime(point_time, "%Y-%m-%d %H:%M:%S")
current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S")
time_diff = current_timestamp - point_timestamp
hours_diff = time_diff.total_seconds() / 3600
if hours_diff <= 1: # 1小时内
return 1.0
elif hours_diff <= 24: # 1-24小时
# 从1.0快速递减到0.7
return 1.0 - (hours_diff - 1) * (0.3 / 23)
elif hours_diff <= 24 * 7: # 24小时-7天
# 从0.7缓慢回升到0.95
return 0.7 + (hours_diff - 24) * (0.25 / (24 * 6))
else: # 7-30天
# 从0.95缓慢递减到0.1
days_diff = hours_diff / 24 - 7
return max(0.1, 0.95 - days_diff * (0.85 / 23))
except Exception as e:
logger.error(f"计算时间权重失败: {e}")
return 0.5 # 发生错误时返回中等权重
def filter_selected_chats(
grouped_messages: Dict[str, List[Dict[str, Any]]], selected_indices: List[int]
) -> Dict[str, List[Dict[str, Any]]]:
"""根据用户选择过滤群聊"""
chat_items = list(grouped_messages.items())
selected_chats = {}
for idx in selected_indices:
chat_id, messages = chat_items[idx - 1] # 转换为0基索引
selected_chats[chat_id] = messages
return selected_chats
def get_user_selection(total_count: int) -> List[int]:
"""获取用户选择的群聊编号"""
while True:
print(f"\n请选择要分析的群聊 (1-{total_count}):")
print("输入格式:")
print(" 单个: 1")
print(" 多个: 1,3,5")
print(" 范围: 1-3")
print(" 全部: all 或 a")
print(" 退出: quit 或 q")
user_input = input("请输入选择: ").strip().lower()
if user_input in ["quit", "q"]:
return []
if user_input in ["all", "a"]:
return list(range(1, total_count + 1))
try:
selected = []
# 处理逗号分隔的输入
parts = user_input.split(",")
for part in parts:
part = part.strip()
if "-" in part:
# 处理范围输入 (如: 1-3)
start, end = part.split("-")
start_num = int(start.strip())
end_num = int(end.strip())
if 1 <= start_num <= total_count and 1 <= end_num <= total_count and start_num <= end_num:
selected.extend(range(start_num, end_num + 1))
else:
raise ValueError("范围超出有效范围")
else:
# 处理单个数字
num = int(part)
if 1 <= num <= total_count:
selected.append(num)
else:
raise ValueError("数字超出有效范围")
# 去重并排序
selected = sorted(list(set(selected)))
if selected:
return selected
else:
print("错误: 请输入有效的选择")
except ValueError as e:
print(f"错误: 输入格式无效 - {e}")
print("请重新输入")
def display_chat_list(grouped_messages: Dict[str, List[Dict[str, Any]]]) -> None:
"""显示群聊列表"""
print("\n找到以下群聊:")
print("=" * 60)
for i, (chat_id, messages) in enumerate(grouped_messages.items(), 1):
first_msg = messages[0]
group_name = first_msg.get("chat_info_group_name", "私聊")
group_id = first_msg.get("chat_info_group_id", chat_id)
# 计算时间范围
start_time = datetime.fromtimestamp(messages[0]["time"]).strftime("%Y-%m-%d")
end_time = datetime.fromtimestamp(messages[-1]["time"]).strftime("%Y-%m-%d")
print(f"{i:2d}. {group_name}")
print(f" 群ID: {group_id}")
print(f" 消息数: {len(messages)}")
print(f" 时间范围: {start_time} ~ {end_time}")
print("-" * 60)
def check_similarity(text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
"""使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的"""
# 计算两种相似度
tfidf_sim = tfidf_similarity(text1, text2)
seq_sim = sequence_similarity(text1, text2)
# 只要其中一种方法达到阈值就认为是相似的
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold
class MessageRetrievalScript: class MessageRetrievalScript:
def __init__(self): def __init__(self):
"""初始化脚本""" """初始化脚本"""
self.person_info_manager = PersonInfoManager()
self.bot_qq = str(global_config.bot.qq_account) self.bot_qq = str(global_config.bot.qq_account)
# 初始化LLM请求器和relationship_manager一样 # 初始化LLM请求器和relationship_manager一样
@@ -56,37 +346,16 @@ class MessageRetrievalScript:
request_type="relationship", request_type="relationship",
) )
def get_person_id(self, platform: str, user_id: str) -> str:
"""根据platform和user_id计算person_id"""
return PersonInfoManager.get_person_id(platform, user_id)
def get_time_range(self, time_period: str) -> Optional[float]:
"""根据时间段选择获取起始时间戳"""
now = datetime.now()
if time_period == "all":
return None
elif time_period == "3months":
start_time = now - timedelta(days=90)
elif time_period == "1month":
start_time = now - timedelta(days=30)
elif time_period == "1week":
start_time = now - timedelta(days=7)
else:
raise ValueError(f"不支持的时间段: {time_period}")
return start_time.timestamp()
def retrieve_messages(self, user_qq: str, time_period: str) -> Dict[str, List[Dict[str, Any]]]: def retrieve_messages(self, user_qq: str, time_period: str) -> Dict[str, List[Dict[str, Any]]]:
"""检索消息""" """检索消息"""
print(f"开始检索用户 {user_qq} 的消息...") print(f"开始检索用户 {user_qq} 的消息...")
# 计算person_id # 计算person_id
person_id = self.get_person_id("qq", user_qq) person_id = get_person_id("qq", user_qq)
print(f"用户person_id: {person_id}") print(f"用户person_id: {person_id}")
# 获取时间范围 # 获取时间范围
start_timestamp = self.get_time_range(time_period) start_timestamp = get_time_range(time_period)
if start_timestamp: if start_timestamp:
print(f"时间范围: {datetime.fromtimestamp(start_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 至今") print(f"时间范围: {datetime.fromtimestamp(start_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 至今")
else: else:
@@ -97,8 +366,8 @@ class MessageRetrievalScript:
# 添加用户条件包含bot消息或目标用户消息 # 添加用户条件包含bot消息或目标用户消息
user_condition = ( user_condition = (
(Messages.user_id == self.bot_qq) | # bot的消息 (Messages.user_id == self.bot_qq) # bot的消息
(Messages.user_id == user_qq) # 目标用户的消息 | (Messages.user_id == user_qq) # 目标用户的消息
) )
query = query.where(user_condition) query = query.where(user_condition)
@@ -117,199 +386,27 @@ class MessageRetrievalScript:
grouped_messages = defaultdict(list) grouped_messages = defaultdict(list)
for msg in messages: for msg in messages:
msg_dict = { msg_dict = {
'message_id': msg.message_id, "message_id": msg.message_id,
'time': msg.time, "time": msg.time,
'datetime': datetime.fromtimestamp(msg.time).strftime('%Y-%m-%d %H:%M:%S'), "datetime": datetime.fromtimestamp(msg.time).strftime("%Y-%m-%d %H:%M:%S"),
'chat_id': msg.chat_id, "chat_id": msg.chat_id,
'user_id': msg.user_id, "user_id": msg.user_id,
'user_nickname': msg.user_nickname, "user_nickname": msg.user_nickname,
'user_platform': msg.user_platform, "user_platform": msg.user_platform,
'processed_plain_text': msg.processed_plain_text, "processed_plain_text": msg.processed_plain_text,
'display_message': msg.display_message, "display_message": msg.display_message,
'chat_info_group_id': msg.chat_info_group_id, "chat_info_group_id": msg.chat_info_group_id,
'chat_info_group_name': msg.chat_info_group_name, "chat_info_group_name": msg.chat_info_group_name,
'chat_info_platform': msg.chat_info_platform, "chat_info_platform": msg.chat_info_platform,
'user_cardname': msg.user_cardname, "user_cardname": msg.user_cardname,
'is_bot_message': msg.user_id == self.bot_qq "is_bot_message": msg.user_id == self.bot_qq,
} }
grouped_messages[msg.chat_id].append(msg_dict) grouped_messages[msg.chat_id].append(msg_dict)
print(f"消息分布在 {len(grouped_messages)} 个聊天中") print(f"消息分布在 {len(grouped_messages)} 个聊天中")
return dict(grouped_messages) return dict(grouped_messages)
def split_messages_by_count(self, messages: List[Dict[str, Any]], count: int = 50) -> List[List[Dict[str, Any]]]:
"""将消息按指定数量分段"""
chunks = []
for i in range(0, len(messages), count):
chunks.append(messages[i:i + count])
return chunks
async def build_name_mapping(self, messages: List[Dict[str, Any]], target_person_id: str, target_person_name: str) -> Dict[str, str]:
"""构建用户名称映射和relationship_manager中的逻辑一致"""
name_mapping = {}
current_user = "A"
user_count = 1
# 遍历消息,构建映射
for msg in messages:
await person_info_manager.get_or_create_person(
platform=msg.get("chat_info_platform"),
user_id=msg.get("user_id"),
nickname=msg.get("user_nickname"),
user_cardname=msg.get("user_cardname"),
)
replace_user_id = msg.get("user_id")
replace_platform = msg.get("chat_info_platform")
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
# 跳过机器人自己
if replace_user_id == global_config.bot.qq_account:
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
continue
# 跳过目标用户
if replace_person_name == target_person_name:
name_mapping[replace_person_name] = f"{target_person_name}"
continue
# 其他用户映射
if replace_person_name not in name_mapping:
if current_user > 'Z':
current_user = 'A'
user_count += 1
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
current_user = chr(ord(current_user) + 1)
return name_mapping
def build_focus_readable_messages(self, messages: List[Dict[str, Any]], target_person_id: str = None) -> str:
"""格式化消息只保留目标用户和bot消息附近的内容和relationship_manager中的逻辑一致"""
# 找到目标用户和bot的消息索引
target_indices = []
for i, msg in enumerate(messages):
user_id = msg.get("user_id")
platform = msg.get("chat_info_platform")
person_id = person_info_manager.get_person_id(platform, user_id)
if person_id == target_person_id:
target_indices.append(i)
if not target_indices:
return ""
# 获取需要保留的消息索引
keep_indices = set()
for idx in target_indices:
# 获取前后5条消息的索引
start_idx = max(0, idx - 5)
end_idx = min(len(messages), idx + 6)
keep_indices.update(range(start_idx, end_idx))
# 将索引排序
keep_indices = sorted(list(keep_indices))
# 按顺序构建消息组
message_groups = []
current_group = []
for i in range(len(messages)):
if i in keep_indices:
current_group.append(messages[i])
elif current_group:
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
if current_group:
message_groups.append(current_group)
current_group = []
# 添加最后一组
if current_group:
message_groups.append(current_group)
# 构建最终的消息文本
result = []
for i, group in enumerate(message_groups):
if i > 0:
result.append("...")
group_text = build_readable_messages(
messages=group,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=False
)
result.append(group_text)
return "\n".join(result)
# 添加相似度检查方法和relationship_manager一致 # 添加相似度检查方法和relationship_manager一致
def tfidf_similarity(self, s1, s2):
"""使用 TF-IDF 和余弦相似度计算两个句子的相似性"""
# 确保输入是字符串类型
if isinstance(s1, list):
s1 = " ".join(str(x) for x in s1)
if isinstance(s2, list):
s2 = " ".join(str(x) for x in s2)
# 转换为字符串类型
s1 = str(s1)
s2 = str(s2)
# 1. 使用 jieba 进行分词
s1_words = " ".join(jieba.cut(s1))
s2_words = " ".join(jieba.cut(s2))
# 2. 将两句话放入一个列表中
corpus = [s1_words, s2_words]
# 3. 创建 TF-IDF 向量化器并进行计算
try:
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(corpus)
except ValueError:
# 如果句子完全由停用词组成,或者为空,可能会报错
return 0.0
# 4. 计算余弦相似度
similarity_matrix = cosine_similarity(tfidf_matrix)
# 返回 s1 和 s2 的相似度
return similarity_matrix[0, 1]
def sequence_similarity(self, s1, s2):
"""使用 SequenceMatcher 计算两个句子的相似性"""
return SequenceMatcher(None, s1, s2).ratio()
def check_similarity(self, text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
"""使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的"""
# 计算两种相似度
tfidf_sim = self.tfidf_similarity(text1, text2)
seq_sim = self.sequence_similarity(text1, text2)
# 只要其中一种方法达到阈值就认为是相似的
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
"""计算基于时间的权重系数"""
try:
point_timestamp = datetime.strptime(point_time, "%Y-%m-%d %H:%M:%S")
current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S")
time_diff = current_timestamp - point_timestamp
hours_diff = time_diff.total_seconds() / 3600
if hours_diff <= 1: # 1小时内
return 1.0
elif hours_diff <= 24: # 1-24小时
# 从1.0快速递减到0.7
return 1.0 - (hours_diff - 1) * (0.3 / 23)
elif hours_diff <= 24 * 7: # 24小时-7天
# 从0.7缓慢回升到0.95
return 0.7 + (hours_diff - 24) * (0.25 / (24 * 6))
else: # 7-30天
# 从0.95缓慢递减到0.1
days_diff = hours_diff / 24 - 7
return max(0.1, 0.95 - days_diff * (0.85 / 23))
except Exception as e:
logger.error(f"计算时间权重失败: {e}")
return 0.5 # 发生错误时返回中等权重
async def update_person_impression_from_segment(self, person_id: str, readable_messages: str, segment_time: float): async def update_person_impression_from_segment(self, person_id: str, readable_messages: str, segment_time: float):
"""从消息段落更新用户印象使用和relationship_manager相同的流程""" """从消息段落更新用户印象使用和relationship_manager相同的流程"""
@@ -412,7 +509,7 @@ class MessageRetrievalScript:
# 在现有points中查找相似的点 # 在现有points中查找相似的点
for i, existing_point in enumerate(current_points): for i, existing_point in enumerate(current_points):
# 使用组合的相似度检查方法 # 使用组合的相似度检查方法
if self.check_similarity(new_point[0], existing_point[0]): if check_similarity(new_point[0], existing_point[0]):
similar_points.append(existing_point) similar_points.append(existing_point)
similar_indices.append(i) similar_indices.append(i)
@@ -460,7 +557,7 @@ class MessageRetrievalScript:
# 计算每个点的最终权重(原始权重 * 时间权重) # 计算每个点的最终权重(原始权重 * 时间权重)
weighted_points = [] weighted_points = []
for point in current_points: for point in current_points:
time_weight = self.calculate_time_weight(point[2], current_time_str) time_weight = calculate_time_weight(point[2], current_time_str)
final_weight = point[1] * time_weight final_weight = point[1] * time_weight
weighted_points.append((point, final_weight)) weighted_points.append((point, final_weight))
@@ -505,10 +602,9 @@ class MessageRetrievalScript:
forgotten_points.sort(key=lambda x: x[2]) forgotten_points.sort(key=lambda x: x[2])
# 构建points文本 # 构建points文本
points_text = "\n".join([ points_text = "\n".join(
f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" [f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points]
for point in forgotten_points )
])
impression = await person_info_manager.get_value(person_id, "impression") or "" impression = await person_info_manager.get_value(person_id, "impression") or ""
@@ -543,27 +639,30 @@ class MessageRetrievalScript:
forgotten_points = [] forgotten_points = []
# 更新数据库 # 更新数据库
await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)) await person_info_manager.update_one_field(
person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)
)
# 更新数据库 # 更新数据库
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)) await person_info_manager.update_one_field(
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
)
know_times = await person_info_manager.get_value(person_id, "know_times") or 0 know_times = await person_info_manager.get_value(person_id, "know_times") or 0
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1) await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
await person_info_manager.update_one_field(person_id, "last_know", segment_time) await person_info_manager.update_one_field(person_id, "last_know", segment_time)
logger.info(f"印象更新完成 for {person_name},新增 {len(points_list)} 个记忆点") logger.info(f"印象更新完成 for {person_name},新增 {len(points_list)} 个记忆点")
async def process_segments_and_update_impression(self, user_qq: str, grouped_messages: Dict[str, List[Dict[str, Any]]]): async def process_segments_and_update_impression(
self, user_qq: str, grouped_messages: Dict[str, List[Dict[str, Any]]]
):
"""处理分段消息并更新用户印象到数据库""" """处理分段消息并更新用户印象到数据库"""
# 获取目标用户信息 # 获取目标用户信息
target_person_id = self.get_person_id("qq", user_qq) target_person_id = get_person_id("qq", user_qq)
target_person_name = await person_info_manager.get_value(target_person_id, "person_name") target_person_name = await person_info_manager.get_value(target_person_id, "person_name")
target_nickname = await person_info_manager.get_value(target_person_id, "nickname")
if not target_person_name: if not target_person_name:
target_person_name = f"用户{user_qq}" target_person_name = f"用户{user_qq}"
if not target_nickname:
target_nickname = f"用户{user_qq}"
print(f"\n开始分析用户 {target_person_name} (QQ: {user_qq}) 的消息...") print(f"\n开始分析用户 {target_person_name} (QQ: {user_qq}) 的消息...")
@@ -575,52 +674,53 @@ class MessageRetrievalScript:
# 为每个chat_id处理消息收集所有分段 # 为每个chat_id处理消息收集所有分段
for chat_id, messages in grouped_messages.items(): for chat_id, messages in grouped_messages.items():
first_msg = messages[0] first_msg = messages[0]
group_name = first_msg.get('chat_info_group_name', '私聊') group_name = first_msg.get("chat_info_group_name", "私聊")
print(f"准备聊天: {group_name} (共{len(messages)}条消息)") print(f"准备聊天: {group_name} (共{len(messages)}条消息)")
# 将消息按50条分段 # 将消息按50条分段
message_chunks = self.split_messages_by_count(messages, 50) message_chunks = split_messages_by_count(messages, 50)
for i, chunk in enumerate(message_chunks): for i, chunk in enumerate(message_chunks):
# 将分段信息添加到列表中,包含分段时间用于排序 # 将分段信息添加到列表中,包含分段时间用于排序
segment_time = chunk[-1]['time'] segment_time = chunk[-1]["time"]
all_segments.append({ all_segments.append(
'chunk': chunk, {
'chat_id': chat_id, "chunk": chunk,
'group_name': group_name, "chat_id": chat_id,
'segment_index': i + 1, "group_name": group_name,
'total_segments': len(message_chunks), "segment_index": i + 1,
'segment_time': segment_time "total_segments": len(message_chunks),
}) "segment_time": segment_time,
}
)
# 按时间排序所有分段 # 按时间排序所有分段
all_segments.sort(key=lambda x: x['segment_time']) all_segments.sort(key=lambda x: x["segment_time"])
print(f"\n按时间顺序处理 {len(all_segments)} 个分段:") print(f"\n按时间顺序处理 {len(all_segments)} 个分段:")
# 按时间顺序处理所有分段 # 按时间顺序处理所有分段
for segment_idx, segment_info in enumerate(all_segments, 1): for segment_idx, segment_info in enumerate(all_segments, 1):
chunk = segment_info['chunk'] chunk = segment_info["chunk"]
group_name = segment_info['group_name'] group_name = segment_info["group_name"]
segment_index = segment_info['segment_index'] segment_index = segment_info["segment_index"]
total_segments = segment_info['total_segments'] total_segments = segment_info["total_segments"]
segment_time = segment_info['segment_time'] segment_time = segment_info["segment_time"]
segment_time_str = datetime.fromtimestamp(segment_time).strftime('%Y-%m-%d %H:%M:%S') segment_time_str = datetime.fromtimestamp(segment_time).strftime("%Y-%m-%d %H:%M:%S")
print(f" [{segment_idx}/{len(all_segments)}] {group_name}{segment_index}/{total_segments}段 ({segment_time_str}) (共{len(chunk)}条)") print(
f" [{segment_idx}/{len(all_segments)}] {group_name}{segment_index}/{total_segments}段 ({segment_time_str}) (共{len(chunk)}条)"
# 构建名称映射
name_mapping = await self.build_name_mapping(chunk, target_person_id, target_person_name)
# 构建可读消息
readable_messages = self.build_focus_readable_messages(
messages=chunk,
target_person_id=target_person_id
) )
# 构建名称映射
name_mapping = await build_name_mapping(chunk, target_person_name)
# 构建可读消息
readable_messages = build_focus_readable_messages(messages=chunk, target_person_id=target_person_id)
if not readable_messages: if not readable_messages:
print(f" 跳过:该段落没有目标用户的消息") print(" 跳过:该段落没有目标用户的消息")
continue continue
# 应用名称映射 # 应用名称映射
@@ -633,7 +733,7 @@ class MessageRetrievalScript:
total_segments_processed += 1 total_segments_processed += 1
except Exception as e: except Exception as e:
logger.error(f"处理段落时出错: {e}") logger.error(f"处理段落时出错: {e}")
print(f" 错误:处理该段落时出现异常") print(" 错误:处理该段落时出现异常")
# 获取最终统计 # 获取最终统计
final_points = await person_info_manager.get_value(target_person_id, "points") or [] final_points = await person_info_manager.get_value(target_person_id, "points") or []
@@ -645,7 +745,7 @@ class MessageRetrievalScript:
final_impression = await person_info_manager.get_value(target_person_id, "impression") or "" final_impression = await person_info_manager.get_value(target_person_id, "impression") or ""
print(f"\n=== 处理完成 ===") print("\n=== 处理完成 ===")
print(f"目标用户: {target_person_name} (QQ: {user_qq})") print(f"目标用户: {target_person_name} (QQ: {user_qq})")
print(f"处理段落数: {total_segments_processed}") print(f"处理段落数: {total_segments_processed}")
print(f"当前记忆点数: {len(final_points)}") print(f"当前记忆点数: {len(final_points)}")
@@ -654,95 +754,6 @@ class MessageRetrievalScript:
if final_points: if final_points:
print(f"最新记忆点: {final_points[-1][0][:50]}...") print(f"最新记忆点: {final_points[-1][0][:50]}...")
def display_chat_list(self, grouped_messages: Dict[str, List[Dict[str, Any]]]) -> None:
"""显示群聊列表"""
print("\n找到以下群聊:")
print("=" * 60)
for i, (chat_id, messages) in enumerate(grouped_messages.items(), 1):
first_msg = messages[0]
group_name = first_msg.get('chat_info_group_name', '私聊')
group_id = first_msg.get('chat_info_group_id', chat_id)
# 计算时间范围
start_time = datetime.fromtimestamp(messages[0]['time']).strftime('%Y-%m-%d')
end_time = datetime.fromtimestamp(messages[-1]['time']).strftime('%Y-%m-%d')
print(f"{i:2d}. {group_name}")
print(f" 群ID: {group_id}")
print(f" 消息数: {len(messages)}")
print(f" 时间范围: {start_time} ~ {end_time}")
print("-" * 60)
def get_user_selection(self, total_count: int) -> List[int]:
"""获取用户选择的群聊编号"""
while True:
print(f"\n请选择要分析的群聊 (1-{total_count}):")
print("输入格式:")
print(" 单个: 1")
print(" 多个: 1,3,5")
print(" 范围: 1-3")
print(" 全部: all 或 a")
print(" 退出: quit 或 q")
user_input = input("请输入选择: ").strip().lower()
if user_input in ['quit', 'q']:
return []
if user_input in ['all', 'a']:
return list(range(1, total_count + 1))
try:
selected = []
# 处理逗号分隔的输入
parts = user_input.split(',')
for part in parts:
part = part.strip()
if '-' in part:
# 处理范围输入 (如: 1-3)
start, end = part.split('-')
start_num = int(start.strip())
end_num = int(end.strip())
if 1 <= start_num <= total_count and 1 <= end_num <= total_count and start_num <= end_num:
selected.extend(range(start_num, end_num + 1))
else:
raise ValueError("范围超出有效范围")
else:
# 处理单个数字
num = int(part)
if 1 <= num <= total_count:
selected.append(num)
else:
raise ValueError("数字超出有效范围")
# 去重并排序
selected = sorted(list(set(selected)))
if selected:
return selected
else:
print("错误: 请输入有效的选择")
except ValueError as e:
print(f"错误: 输入格式无效 - {e}")
print("请重新输入")
def filter_selected_chats(self, grouped_messages: Dict[str, List[Dict[str, Any]]], selected_indices: List[int]) -> Dict[str, List[Dict[str, Any]]]:
"""根据用户选择过滤群聊"""
chat_items = list(grouped_messages.items())
selected_chats = {}
for idx in selected_indices:
chat_id, messages = chat_items[idx - 1] # 转换为0基索引
selected_chats[chat_id] = messages
return selected_chats
async def run(self): async def run(self):
"""运行脚本""" """运行脚本"""
print("=== 消息检索分析脚本 ===") print("=== 消息检索分析脚本 ===")
@@ -760,12 +771,7 @@ class MessageRetrievalScript:
print("4. 最近1周 (1week)") print("4. 最近1周 (1week)")
choice = input("请选择时间段 (1-4): ").strip() choice = input("请选择时间段 (1-4): ").strip()
time_periods = { time_periods = {"1": "all", "2": "3months", "3": "1month", "4": "1week"}
"1": "all",
"2": "3months",
"3": "1month",
"4": "1week"
}
if choice not in time_periods: if choice not in time_periods:
print("选择无效") print("选择无效")
@@ -792,28 +798,28 @@ class MessageRetrievalScript:
return return
# 显示群聊列表 # 显示群聊列表
self.display_chat_list(grouped_messages) display_chat_list(grouped_messages)
# 获取用户选择 # 获取用户选择
selected_indices = self.get_user_selection(len(grouped_messages)) selected_indices = get_user_selection(len(grouped_messages))
if not selected_indices: if not selected_indices:
print("已取消操作") print("已取消操作")
return return
# 过滤选中的群聊 # 过滤选中的群聊
selected_chats = self.filter_selected_chats(grouped_messages, selected_indices) selected_chats = filter_selected_chats(grouped_messages, selected_indices)
# 显示选中的群聊 # 显示选中的群聊
print(f"\n已选择 {len(selected_chats)} 个群聊进行分析:") print(f"\n已选择 {len(selected_chats)} 个群聊进行分析:")
for i, (chat_id, messages) in enumerate(selected_chats.items(), 1): for i, (_, messages) in enumerate(selected_chats.items(), 1):
first_msg = messages[0] first_msg = messages[0]
group_name = first_msg.get('chat_info_group_name', '私聊') group_name = first_msg.get("chat_info_group_name", "私聊")
print(f" {i}. {group_name} ({len(messages)}条消息)") print(f" {i}. {group_name} ({len(messages)}条消息)")
# 确认处理 # 确认处理
confirm = input(f"\n确认分析这些群聊吗? (y/n): ").strip().lower() confirm = input("\n确认分析这些群聊吗? (y/n): ").strip().lower()
if confirm != 'y': if confirm != "y":
print("已取消操作") print("已取消操作")
return return
@@ -823,16 +829,18 @@ class MessageRetrievalScript:
except Exception as e: except Exception as e:
print(f"处理过程中出现错误: {e}") print(f"处理过程中出现错误: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally: finally:
db.close() db.close()
print("数据库连接已关闭") print("数据库连接已关闭")
def main(): def main():
"""主函数""" """主函数"""
script = MessageRetrievalScript() script = MessageRetrievalScript()
asyncio.run(script.run()) asyncio.run(script.run())
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,7 +1,7 @@
from typing import Optional, Tuple, Dict from typing import Optional, Tuple, Dict
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
from src.person_info.person_info import person_info_manager from src.person_info.person_info import person_info_manager, PersonInfoManager
logger = get_logger("heartflow_utils") logger = get_logger("heartflow_utils")
@@ -47,7 +47,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
# Try to fetch person info # Try to fetch person info
try: try:
# Assume get_person_id is sync (as per original code), keep using to_thread # Assume get_person_id is sync (as per original code), keep using to_thread
person_id = person_info_manager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
person_name = None person_name = None
if person_id: if person_id:
# get_value is async, so await it directly # get_value is async, so await it directly

View File

@@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageThinking
from src.chat.normal_chat.normal_prompt import prompt_builder from src.chat.normal_chat.normal_prompt import prompt_builder
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.person_info.person_info import person_info_manager from src.person_info.person_info import person_info_manager, PersonInfoManager
from src.chat.utils.utils import process_llm_response from src.chat.utils.utils import process_llm_response
@@ -66,7 +66,7 @@ class NormalChatGenerator:
enable_planner: bool = False, enable_planner: bool = False,
available_actions=None, available_actions=None,
): ):
person_id = person_info_manager.get_person_id( person_id = PersonInfoManager.get_person_id(
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
) )

View File

@@ -96,7 +96,7 @@ class BaseWillingManager(ABC):
self.logger: LoguruLogger = logger self.logger: LoguruLogger = logger
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float): def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
person_id = person_info_manager.get_person_id(chat.platform, chat.user_info.user_id) person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
self.ongoing_messages[message.message_info.message_id] = WillingInfo( self.ongoing_messages[message.message_info.message_id] = WillingInfo(
message=message, message=message,
chat=chat, chat=chat,

View File

@@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间
import random import random
import re import re
from src.common.message_repository import find_messages, count_messages from src.common.message_repository import find_messages, count_messages
from src.person_info.person_info import person_info_manager from src.person_info.person_info import person_info_manager, PersonInfoManager
from src.chat.utils.utils import translate_timestamp_to_human_readable from src.chat.utils.utils import translate_timestamp_to_human_readable
from rich.traceback import install from rich.traceback import install
from src.common.database.database_model import ActionRecords from src.common.database.database_model import ActionRecords
@@ -219,7 +219,7 @@ def _build_readable_messages_internal(
if not all([platform, user_id, timestamp is not None]): if not all([platform, user_id, timestamp is not None]):
continue continue
person_id = person_info_manager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称 # 根据 replace_bot_name 参数决定是否替换机器人名称
if replace_bot_name and user_id == global_config.bot.qq_account: if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)" person_name = f"{global_config.bot.nickname}(你)"
@@ -241,7 +241,7 @@ def _build_readable_messages_internal(
if match: if match:
aaa = match.group(1) aaa = match.group(1)
bbb = match.group(2) bbb = match.group(2)
reply_person_id = person_info_manager.get_person_id(platform, bbb) reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
if not reply_person_name: if not reply_person_name:
reply_person_name = aaa reply_person_name = aaa
@@ -258,7 +258,7 @@ def _build_readable_messages_internal(
new_content += content[last_end : m.start()] new_content += content[last_end : m.start()]
aaa = m.group(1) aaa = m.group(1)
bbb = m.group(2) bbb = m.group(2)
at_person_id = person_info_manager.get_person_id(platform, bbb) at_person_id = PersonInfoManager.get_person_id(platform, bbb)
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
if not at_person_name: if not at_person_name:
at_person_name = aaa at_person_name = aaa
@@ -572,7 +572,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# print("SELF11111111111111") # print("SELF11111111111111")
return "SELF" return "SELF"
try: try:
person_id = person_info_manager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
except Exception as _e: except Exception as _e:
person_id = None person_id = None
if not person_id: if not person_id:
@@ -673,7 +673,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
if not all([platform, user_id]) or user_id == global_config.bot.qq_account: if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue continue
person_id = person_info_manager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
# 只有当获取到有效 person_id 时才添加 # 只有当获取到有效 person_id 时才添加
if person_id: if person_id:

View File

@@ -1,9 +1,9 @@
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.person_info.person_info import PersonInfoManager
from src.person_info.relationship_manager import relationship_manager from src.person_info.relationship_manager import relationship_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp
from src.config.config import global_config from src.config.config import global_config
from src.person_info.person_info import person_info_manager
from src.chat.message_receive.chat_stream import chat_manager from src.chat.message_receive.chat_stream import chat_manager
import time import time
import random import random
@@ -95,7 +95,7 @@ class ImpressionUpdateTask(AsyncTask):
if msg["user_nickname"] == global_config.bot.nickname: if msg["user_nickname"] == global_config.bot.nickname:
continue continue
person_id = person_info_manager.get_person_id(msg["chat_info_platform"], msg["user_id"]) person_id = PersonInfoManager.get_person_id(msg["chat_info_platform"], msg["user_id"])
if not person_id: if not person_id:
logger.warning(f"未找到用户 {msg['user_nickname']} 的person_id") logger.warning(f"未找到用户 {msg['user_nickname']} 的person_id")
continue continue

View File

@@ -1,6 +1,6 @@
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
import math import math
from src.person_info.person_info import person_info_manager from src.person_info.person_info import person_info_manager, PersonInfoManager
import time import time
import random import random
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@@ -91,7 +91,7 @@ class RelationshipManager:
@staticmethod @staticmethod
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
"""判断是否认识某人""" """判断是否认识某人"""
person_id = person_info_manager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
# 生成唯一的 person_name # 生成唯一的 person_name
unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname) unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname)
data = { data = {
@@ -116,7 +116,7 @@ class RelationshipManager:
if is_id: if is_id:
person_id = person person_id = person
else: else:
person_id = person_info_manager.get_person_id(person[0], person[1]) person_id = PersonInfoManager.get_person_id(person[0], person[1])
person_name = await person_info_manager.get_value(person_id, "person_name") person_name = await person_info_manager.get_value(person_id, "person_name")
if not person_name or person_name == "none": if not person_name or person_name == "none":
@@ -198,7 +198,7 @@ class RelationshipManager:
) )
replace_user_id = msg.get("user_id") replace_user_id = msg.get("user_id")
replace_platform = msg.get("chat_info_platform") replace_platform = msg.get("chat_info_platform")
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id) replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name") replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
# 跳过机器人自己 # 跳过机器人自己
@@ -467,7 +467,7 @@ class RelationshipManager:
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
user_id = msg.get("user_id") user_id = msg.get("user_id")
platform = msg.get("chat_info_platform") platform = msg.get("chat_info_platform")
person_id = person_info_manager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
if person_id == target_person_id: if person_id == target_person_id:
target_indices.append(i) target_indices.append(i)