ruff
This commit is contained in:
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
@@ -9,13 +9,15 @@ import sqlite3
|
|||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
def clean_group_name(name: str) -> str:
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
"""清理群组名称,只保留中文和英文字符"""
|
||||||
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
|
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
||||||
if not cleaned:
|
if not cleaned:
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
cleaned = datetime.now().strftime("%Y%m%d")
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
def get_group_name(stream_id: str) -> str:
|
||||||
"""从数据库中获取群组名称"""
|
"""从数据库中获取群组名称"""
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
conn = sqlite3.connect("data/maibot.db")
|
||||||
@@ -43,6 +45,7 @@ def get_group_name(stream_id: str) -> str:
|
|||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
return clean_group_name(f"{platform}{stream_id[:8]}")
|
||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
|
|
||||||
def format_timestamp(timestamp: float) -> str:
|
def format_timestamp(timestamp: float) -> str:
|
||||||
"""将时间戳转换为可读的时间格式"""
|
"""将时间戳转换为可读的时间格式"""
|
||||||
if not timestamp:
|
if not timestamp:
|
||||||
@@ -50,9 +53,11 @@ def format_timestamp(timestamp: float) -> str:
|
|||||||
try:
|
try:
|
||||||
dt = datetime.fromtimestamp(timestamp)
|
dt = datetime.fromtimestamp(timestamp)
|
||||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(f"时间戳格式化错误: {e}")
|
||||||
return "未知"
|
return "未知"
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> List[Dict]:
|
def load_expressions(chat_id: str) -> List[Dict]:
|
||||||
"""加载指定群聊的表达方式"""
|
"""加载指定群聊的表达方式"""
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
@@ -65,14 +70,15 @@ def load_expressions(chat_id: str) -> List[Dict]:
|
|||||||
|
|
||||||
return style_exprs
|
return style_exprs
|
||||||
|
|
||||||
|
|
||||||
def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]:
|
def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]:
|
||||||
"""找出每个表达方式最相似的top_k个表达方式"""
|
"""找出每个表达方式最相似的top_k个表达方式"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# 分别准备情景和表达方式的文本数据
|
# 分别准备情景和表达方式的文本数据
|
||||||
situations = [expr['situation'] for expr in expressions]
|
situations = [expr["situation"] for expr in expressions]
|
||||||
styles = [expr['style'] for expr in expressions]
|
styles = [expr["style"] for expr in expressions]
|
||||||
|
|
||||||
# 使用TF-IDF向量化
|
# 使用TF-IDF向量化
|
||||||
vectorizer = TfidfVectorizer()
|
vectorizer = TfidfVectorizer()
|
||||||
@@ -85,14 +91,14 @@ def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[st
|
|||||||
|
|
||||||
# 对每个表达方式找出最相似的top_k个
|
# 对每个表达方式找出最相似的top_k个
|
||||||
similar_expressions = {}
|
similar_expressions = {}
|
||||||
for i, expr in enumerate(expressions):
|
for i, _ in enumerate(expressions):
|
||||||
# 获取相似度分数
|
# 获取相似度分数
|
||||||
situation_scores = situation_similarity[i]
|
situation_scores = situation_similarity[i]
|
||||||
style_scores = style_similarity[i]
|
style_scores = style_similarity[i]
|
||||||
|
|
||||||
# 获取top_k的索引(排除自己)
|
# 获取top_k的索引(排除自己)
|
||||||
situation_indices = np.argsort(situation_scores)[::-1][1:top_k+1]
|
situation_indices = np.argsort(situation_scores)[::-1][1 : top_k + 1]
|
||||||
style_indices = np.argsort(style_scores)[::-1][1:top_k+1]
|
style_indices = np.argsort(style_scores)[::-1][1 : top_k + 1]
|
||||||
|
|
||||||
similar_situations = []
|
similar_situations = []
|
||||||
similar_styles = []
|
similar_styles = []
|
||||||
@@ -100,29 +106,31 @@ def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[st
|
|||||||
# 处理相似情景
|
# 处理相似情景
|
||||||
for idx in situation_indices:
|
for idx in situation_indices:
|
||||||
if situation_scores[idx] > 0: # 只保留有相似度的
|
if situation_scores[idx] > 0: # 只保留有相似度的
|
||||||
similar_situations.append((
|
similar_situations.append(
|
||||||
expressions[idx]['situation'],
|
(
|
||||||
expressions[idx]['style'], # 添加对应的原始表达
|
expressions[idx]["situation"],
|
||||||
situation_scores[idx]
|
expressions[idx]["style"], # 添加对应的原始表达
|
||||||
))
|
situation_scores[idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 处理相似表达
|
# 处理相似表达
|
||||||
for idx in style_indices:
|
for idx in style_indices:
|
||||||
if style_scores[idx] > 0: # 只保留有相似度的
|
if style_scores[idx] > 0: # 只保留有相似度的
|
||||||
similar_styles.append((
|
similar_styles.append(
|
||||||
expressions[idx]['style'],
|
(
|
||||||
expressions[idx]['situation'], # 添加对应的原始情景
|
expressions[idx]["style"],
|
||||||
style_scores[idx]
|
expressions[idx]["situation"], # 添加对应的原始情景
|
||||||
))
|
style_scores[idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if similar_situations or similar_styles:
|
if similar_situations or similar_styles:
|
||||||
similar_expressions[i] = {
|
similar_expressions[i] = {"situations": similar_situations, "styles": similar_styles}
|
||||||
'situations': similar_situations,
|
|
||||||
'styles': similar_styles
|
|
||||||
}
|
|
||||||
|
|
||||||
return similar_expressions
|
return similar_expressions
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 获取所有群聊ID
|
# 获取所有群聊ID
|
||||||
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
||||||
@@ -143,7 +151,7 @@ def main():
|
|||||||
if choice == 0:
|
if choice == 0:
|
||||||
break
|
break
|
||||||
if 1 <= choice <= len(chat_ids):
|
if 1 <= choice <= len(chat_ids):
|
||||||
chat_id = chat_ids[choice-1]
|
chat_id = chat_ids[choice - 1]
|
||||||
break
|
break
|
||||||
print("无效的选择,请重试")
|
print("无效的选择,请重试")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -164,18 +172,21 @@ def main():
|
|||||||
print("\n" + "-" * 20)
|
print("\n" + "-" * 20)
|
||||||
print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}")
|
print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}")
|
||||||
|
|
||||||
if similar_styles[i]['styles']:
|
if similar_styles[i]["styles"]:
|
||||||
print("\n\033[33m相似表达:\033[0m")
|
print("\n\033[33m相似表达:\033[0m")
|
||||||
for similar_style, original_situation, score in similar_styles[i]['styles']:
|
for similar_style, original_situation, score in similar_styles[i]["styles"]:
|
||||||
print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m")
|
print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m")
|
||||||
|
|
||||||
if similar_styles[i]['situations']:
|
if similar_styles[i]["situations"]:
|
||||||
print("\n\033[32m相似情景:\033[0m")
|
print("\n\033[32m相似情景:\033[0m")
|
||||||
for similar_situation, original_style, score in similar_styles[i]['situations']:
|
for similar_situation, original_style, score in similar_styles[i]["situations"]:
|
||||||
print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m")
|
print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m")
|
||||||
|
|
||||||
print(f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}")
|
print(
|
||||||
|
f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}"
|
||||||
|
)
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -6,15 +6,17 @@ from datetime import datetime
|
|||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
def clean_group_name(name: str) -> str:
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
"""清理群组名称,只保留中文和英文字符"""
|
||||||
# 提取中文和英文字符
|
# 提取中文和英文字符
|
||||||
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
|
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
||||||
# 如果清理后为空,使用当前日期
|
# 如果清理后为空,使用当前日期
|
||||||
if not cleaned:
|
if not cleaned:
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
cleaned = datetime.now().strftime("%Y%m%d")
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
def get_group_name(stream_id: str) -> str:
|
||||||
"""从数据库中获取群组名称"""
|
"""从数据库中获取群组名称"""
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
conn = sqlite3.connect("data/maibot.db")
|
||||||
@@ -42,6 +44,7 @@ def get_group_name(stream_id: str) -> str:
|
|||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
return clean_group_name(f"{platform}{stream_id[:8]}")
|
||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
"""加载指定群组的表达方式"""
|
"""加载指定群组的表达方式"""
|
||||||
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
@@ -66,10 +69,12 @@ def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str,
|
|||||||
|
|
||||||
return style_expressions, grammar_expressions, personality_expressions
|
return style_expressions, grammar_expressions, personality_expressions
|
||||||
|
|
||||||
|
|
||||||
def format_time(timestamp: float) -> str:
|
def format_time(timestamp: float) -> str:
|
||||||
"""格式化时间戳为可读字符串"""
|
"""格式化时间戳为可读字符串"""
|
||||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
|
def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
|
||||||
"""写入表达方式列表"""
|
"""写入表达方式列表"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
@@ -87,7 +92,14 @@ def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
|
|||||||
f.write(f"最后活跃: {format_time(last_active)}\n")
|
f.write(f"最后活跃: {format_time(last_active)}\n")
|
||||||
f.write("-" * 40 + "\n")
|
f.write("-" * 40 + "\n")
|
||||||
|
|
||||||
def write_group_report(group_file: str, group_name: str, chat_id: str, style_exprs: List[Dict[str, Any]], grammar_exprs: List[Dict[str, Any]]):
|
|
||||||
|
def write_group_report(
|
||||||
|
group_file: str,
|
||||||
|
group_name: str,
|
||||||
|
chat_id: str,
|
||||||
|
style_exprs: List[Dict[str, Any]],
|
||||||
|
grammar_exprs: List[Dict[str, Any]],
|
||||||
|
):
|
||||||
"""写入群组详细报告"""
|
"""写入群组详细报告"""
|
||||||
with open(group_file, "w", encoding="utf-8") as gf:
|
with open(group_file, "w", encoding="utf-8") as gf:
|
||||||
gf.write(f"群组: {group_name} (ID: {chat_id})\n")
|
gf.write(f"群组: {group_name} (ID: {chat_id})\n")
|
||||||
@@ -104,6 +116,7 @@ def write_group_report(group_file: str, group_name: str, chat_id: str, style_exp
|
|||||||
gf.write("=" * 40 + "\n")
|
gf.write("=" * 40 + "\n")
|
||||||
write_expressions(gf, grammar_exprs, "句法特点")
|
write_expressions(gf, grammar_exprs, "句法特点")
|
||||||
|
|
||||||
|
|
||||||
def analyze_expressions():
|
def analyze_expressions():
|
||||||
"""分析所有群组的表达方式"""
|
"""分析所有群组的表达方式"""
|
||||||
# 获取所有群组ID
|
# 获取所有群组ID
|
||||||
@@ -197,5 +210,6 @@ def analyze_expressions():
|
|||||||
print(f"人格表达报告: {personality_report}")
|
print(f"人格表达报告: {personality_report}")
|
||||||
print(f"各群组详细报告位于: {output_dir}")
|
print(f"各群组详细报告位于: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
analyze_expressions()
|
analyze_expressions()
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ import json
|
|||||||
import random
|
import random
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
import glob
|
import glob
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量
|
MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量
|
||||||
MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值
|
MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
|
def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
|
||||||
"""加载指定群聊的表达方式"""
|
"""加载指定群聊的表达方式"""
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
@@ -26,6 +26,7 @@ def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
|
|||||||
|
|
||||||
return style_exprs, grammar_exprs
|
return style_exprs, grammar_exprs
|
||||||
|
|
||||||
|
|
||||||
def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None:
|
def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None:
|
||||||
"""保存表达方式到文件"""
|
"""保存表达方式到文件"""
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
@@ -40,6 +41,7 @@ def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[
|
|||||||
with open(grammar_file, "w", encoding="utf-8") as f:
|
with open(grammar_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(grammar_exprs, f, ensure_ascii=False, indent=2)
|
json.dump(grammar_exprs, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
|
def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
|
||||||
"""清理表达方式列表"""
|
"""清理表达方式列表"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
@@ -60,12 +62,13 @@ def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
|
|||||||
# 从剩余的表达方式中随机选择
|
# 从剩余的表达方式中随机选择
|
||||||
remaining_exprs = expressions[keep_count:]
|
remaining_exprs = expressions[keep_count:]
|
||||||
random.shuffle(remaining_exprs)
|
random.shuffle(remaining_exprs)
|
||||||
keep_exprs.extend(remaining_exprs[:MAX_EXPRESSION_COUNT - keep_count])
|
keep_exprs.extend(remaining_exprs[: MAX_EXPRESSION_COUNT - keep_count])
|
||||||
|
|
||||||
expressions = keep_exprs
|
expressions = keep_exprs
|
||||||
|
|
||||||
return expressions
|
return expressions
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 获取所有群聊ID
|
# 获取所有群聊ID
|
||||||
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
||||||
@@ -113,7 +116,10 @@ def main():
|
|||||||
print("\n清理完成!")
|
print("\n清理完成!")
|
||||||
print(f"语言风格总数: {total_style_before} -> {total_style_after}")
|
print(f"语言风格总数: {total_style_before} -> {total_style_after}")
|
||||||
print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}")
|
print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}")
|
||||||
print(f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式")
|
print(
|
||||||
|
f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -15,13 +16,15 @@ import random
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
def clean_group_name(name: str) -> str:
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
"""清理群组名称,只保留中文和英文字符"""
|
||||||
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
|
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
||||||
if not cleaned:
|
if not cleaned:
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
cleaned = datetime.now().strftime("%Y%m%d")
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
def get_group_name(stream_id: str) -> str:
|
||||||
"""从数据库中获取群组名称"""
|
"""从数据库中获取群组名称"""
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
conn = sqlite3.connect("data/maibot.db")
|
||||||
@@ -49,6 +52,7 @@ def get_group_name(stream_id: str) -> str:
|
|||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
return clean_group_name(f"{platform}{stream_id[:8]}")
|
||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> List[Dict]:
|
def load_expressions(chat_id: str) -> List[Dict]:
|
||||||
"""加载指定群聊的表达方式"""
|
"""加载指定群聊的表达方式"""
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
@@ -66,16 +70,19 @@ def load_expressions(chat_id: str) -> List[Dict]:
|
|||||||
|
|
||||||
return style_exprs
|
return style_exprs
|
||||||
|
|
||||||
def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10) -> List[Tuple[str, str, float]]:
|
|
||||||
|
def find_similar_expressions_tfidf(
|
||||||
|
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10
|
||||||
|
) -> List[Tuple[str, str, float]]:
|
||||||
"""使用TF-IDF方法找出与输入文本最相似的top_k个表达方式"""
|
"""使用TF-IDF方法找出与输入文本最相似的top_k个表达方式"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 准备文本数据
|
# 准备文本数据
|
||||||
if mode == "style":
|
if mode == "style":
|
||||||
texts = [expr['style'] for expr in expressions]
|
texts = [expr["style"] for expr in expressions]
|
||||||
elif mode == "situation":
|
elif mode == "situation":
|
||||||
texts = [expr['situation'] for expr in expressions]
|
texts = [expr["situation"] for expr in expressions]
|
||||||
else: # both
|
else: # both
|
||||||
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
||||||
|
|
||||||
@@ -98,24 +105,23 @@ def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mod
|
|||||||
similar_exprs = []
|
similar_exprs = []
|
||||||
for idx in top_indices:
|
for idx in top_indices:
|
||||||
if scores[idx] > 0: # 只保留有相似度的
|
if scores[idx] > 0: # 只保留有相似度的
|
||||||
similar_exprs.append((
|
similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], scores[idx]))
|
||||||
expressions[idx]['style'],
|
|
||||||
expressions[idx]['situation'],
|
|
||||||
scores[idx]
|
|
||||||
))
|
|
||||||
|
|
||||||
return similar_exprs
|
return similar_exprs
|
||||||
|
|
||||||
async def find_similar_expressions_embedding(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5) -> List[Tuple[str, str, float]]:
|
|
||||||
|
async def find_similar_expressions_embedding(
|
||||||
|
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5
|
||||||
|
) -> List[Tuple[str, str, float]]:
|
||||||
"""使用嵌入模型找出与输入文本最相似的top_k个表达方式"""
|
"""使用嵌入模型找出与输入文本最相似的top_k个表达方式"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 准备文本数据
|
# 准备文本数据
|
||||||
if mode == "style":
|
if mode == "style":
|
||||||
texts = [expr['style'] for expr in expressions]
|
texts = [expr["style"] for expr in expressions]
|
||||||
elif mode == "situation":
|
elif mode == "situation":
|
||||||
texts = [expr['situation'] for expr in expressions]
|
texts = [expr["situation"] for expr in expressions]
|
||||||
else: # both
|
else: # both
|
||||||
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
||||||
|
|
||||||
@@ -144,14 +150,11 @@ async def find_similar_expressions_embedding(input_text: str, expressions: List[
|
|||||||
similar_exprs = []
|
similar_exprs = []
|
||||||
for idx in top_indices:
|
for idx in top_indices:
|
||||||
if similarities[idx] > 0: # 只保留有相似度的
|
if similarities[idx] > 0: # 只保留有相似度的
|
||||||
similar_exprs.append((
|
similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], similarities[idx]))
|
||||||
expressions[idx]['style'],
|
|
||||||
expressions[idx]['situation'],
|
|
||||||
similarities[idx]
|
|
||||||
))
|
|
||||||
|
|
||||||
return similar_exprs
|
return similar_exprs
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# 获取所有群聊ID
|
# 获取所有群聊ID
|
||||||
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
||||||
@@ -172,7 +175,7 @@ async def main():
|
|||||||
if choice == 0:
|
if choice == 0:
|
||||||
break
|
break
|
||||||
if 1 <= choice <= len(chat_ids):
|
if 1 <= choice <= len(chat_ids):
|
||||||
chat_id = chat_ids[choice-1]
|
chat_id = chat_ids[choice - 1]
|
||||||
break
|
break
|
||||||
print("无效的选择,请重试")
|
print("无效的选择,请重试")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -202,11 +205,7 @@ async def main():
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
print("请输入有效的数字")
|
print("请输入有效的数字")
|
||||||
|
|
||||||
mode_map = {
|
mode_map = {1: "style", 2: "situation", 3: "both"}
|
||||||
1: "style",
|
|
||||||
2: "situation",
|
|
||||||
3: "both"
|
|
||||||
}
|
|
||||||
mode = mode_map[mode_choice]
|
mode = mode_map[mode_choice]
|
||||||
|
|
||||||
# 选择匹配方法
|
# 选择匹配方法
|
||||||
@@ -225,7 +224,7 @@ async def main():
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
input_text = input("\n请输入要匹配的文本(输入q退出): ")
|
input_text = input("\n请输入要匹配的文本(输入q退出): ")
|
||||||
if input_text.lower() == 'q':
|
if input_text.lower() == "q":
|
||||||
break
|
break
|
||||||
|
|
||||||
if not input_text.strip():
|
if not input_text.strip():
|
||||||
@@ -246,6 +245,8 @@ async def main():
|
|||||||
else:
|
else:
|
||||||
print("\n没有找到相似的表达方式")
|
print("\n没有找到相似的表达方式")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# ruff: noqa: E402
|
||||||
"""
|
"""
|
||||||
消息检索脚本
|
消息检索脚本
|
||||||
|
|
||||||
@@ -10,44 +11,333 @@
|
|||||||
5. 应用LLM分析,将结果存储到数据库person_info中
|
5. 应用LLM分析,将结果存储到数据库person_info中
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import random
|
import random
|
||||||
import time
|
import sys
|
||||||
import math
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Any, Optional
|
from datetime import datetime, timedelta
|
||||||
|
from difflib import SequenceMatcher
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
from json_repair import repair_json
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
# 添加项目根目录到Python路径
|
# 添加项目根目录到Python路径
|
||||||
project_root = Path(__file__).parent.parent
|
project_root = Path(__file__).parent.parent
|
||||||
sys.path.insert(0, str(project_root))
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
from src.common.database.database_model import Messages
|
|
||||||
from src.person_info.person_info import PersonInfoManager
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.common.database.database import db
|
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.common.database.database_model import Messages
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.individuality.individuality import individuality
|
|
||||||
from json_repair import repair_json
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
import jieba
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from src.common.database.database import db
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.person_info.person_info import PersonInfoManager, person_info_manager
|
||||||
|
|
||||||
logger = get_logger("message_retrieval")
|
logger = get_logger("message_retrieval")
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_range(time_period: str) -> Optional[float]:
|
||||||
|
"""根据时间段选择获取起始时间戳"""
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
if time_period == "all":
|
||||||
|
return None
|
||||||
|
elif time_period == "3months":
|
||||||
|
start_time = now - timedelta(days=90)
|
||||||
|
elif time_period == "1month":
|
||||||
|
start_time = now - timedelta(days=30)
|
||||||
|
elif time_period == "1week":
|
||||||
|
start_time = now - timedelta(days=7)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的时间段: {time_period}")
|
||||||
|
|
||||||
|
return start_time.timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
def get_person_id(platform: str, user_id: str) -> str:
|
||||||
|
"""根据platform和user_id计算person_id"""
|
||||||
|
return PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def split_messages_by_count(messages: List[Dict[str, Any]], count: int = 50) -> List[List[Dict[str, Any]]]:
|
||||||
|
"""将消息按指定数量分段"""
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(messages), count):
|
||||||
|
chunks.append(messages[i : i + count])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def build_name_mapping(messages: List[Dict[str, Any]], target_person_name: str) -> Dict[str, str]:
|
||||||
|
"""构建用户名称映射,和relationship_manager中的逻辑一致"""
|
||||||
|
name_mapping = {}
|
||||||
|
current_user = "A"
|
||||||
|
user_count = 1
|
||||||
|
|
||||||
|
# 遍历消息,构建映射
|
||||||
|
for msg in messages:
|
||||||
|
await person_info_manager.get_or_create_person(
|
||||||
|
platform=msg.get("chat_info_platform"),
|
||||||
|
user_id=msg.get("user_id"),
|
||||||
|
nickname=msg.get("user_nickname"),
|
||||||
|
user_cardname=msg.get("user_cardname"),
|
||||||
|
)
|
||||||
|
replace_user_id = msg.get("user_id")
|
||||||
|
replace_platform = msg.get("chat_info_platform")
|
||||||
|
replace_person_id = get_person_id(replace_platform, replace_user_id)
|
||||||
|
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
||||||
|
|
||||||
|
# 跳过机器人自己
|
||||||
|
if replace_user_id == global_config.bot.qq_account:
|
||||||
|
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 跳过目标用户
|
||||||
|
if replace_person_name == target_person_name:
|
||||||
|
name_mapping[replace_person_name] = f"{target_person_name}"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 其他用户映射
|
||||||
|
if replace_person_name not in name_mapping:
|
||||||
|
if current_user > "Z":
|
||||||
|
current_user = "A"
|
||||||
|
user_count += 1
|
||||||
|
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
|
||||||
|
current_user = chr(ord(current_user) + 1)
|
||||||
|
|
||||||
|
return name_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def build_focus_readable_messages(messages: List[Dict[str, Any]], target_person_id: str = None) -> str:
|
||||||
|
"""格式化消息,只保留目标用户和bot消息附近的内容,和relationship_manager中的逻辑一致"""
|
||||||
|
# 找到目标用户和bot的消息索引
|
||||||
|
target_indices = []
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
user_id = msg.get("user_id")
|
||||||
|
platform = msg.get("chat_info_platform")
|
||||||
|
person_id = get_person_id(platform, user_id)
|
||||||
|
if person_id == target_person_id:
|
||||||
|
target_indices.append(i)
|
||||||
|
|
||||||
|
if not target_indices:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取需要保留的消息索引
|
||||||
|
keep_indices = set()
|
||||||
|
for idx in target_indices:
|
||||||
|
# 获取前后5条消息的索引
|
||||||
|
start_idx = max(0, idx - 5)
|
||||||
|
end_idx = min(len(messages), idx + 6)
|
||||||
|
keep_indices.update(range(start_idx, end_idx))
|
||||||
|
|
||||||
|
# 将索引排序
|
||||||
|
keep_indices = sorted(list(keep_indices))
|
||||||
|
|
||||||
|
# 按顺序构建消息组
|
||||||
|
message_groups = []
|
||||||
|
current_group = []
|
||||||
|
|
||||||
|
for i in range(len(messages)):
|
||||||
|
if i in keep_indices:
|
||||||
|
current_group.append(messages[i])
|
||||||
|
elif current_group:
|
||||||
|
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
|
||||||
|
if current_group:
|
||||||
|
message_groups.append(current_group)
|
||||||
|
current_group = []
|
||||||
|
|
||||||
|
# 添加最后一组
|
||||||
|
if current_group:
|
||||||
|
message_groups.append(current_group)
|
||||||
|
|
||||||
|
# 构建最终的消息文本
|
||||||
|
result = []
|
||||||
|
for i, group in enumerate(message_groups):
|
||||||
|
if i > 0:
|
||||||
|
result.append("...")
|
||||||
|
group_text = build_readable_messages(
|
||||||
|
messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False
|
||||||
|
)
|
||||||
|
result.append(group_text)
|
||||||
|
|
||||||
|
return "\n".join(result)
|
||||||
|
|
||||||
|
|
||||||
|
def tfidf_similarity(s1, s2):
|
||||||
|
"""使用 TF-IDF 和余弦相似度计算两个句子的相似性"""
|
||||||
|
# 确保输入是字符串类型
|
||||||
|
if isinstance(s1, list):
|
||||||
|
s1 = " ".join(str(x) for x in s1)
|
||||||
|
if isinstance(s2, list):
|
||||||
|
s2 = " ".join(str(x) for x in s2)
|
||||||
|
|
||||||
|
# 转换为字符串类型
|
||||||
|
s1 = str(s1)
|
||||||
|
s2 = str(s2)
|
||||||
|
|
||||||
|
# 1. 使用 jieba 进行分词
|
||||||
|
s1_words = " ".join(jieba.cut(s1))
|
||||||
|
s2_words = " ".join(jieba.cut(s2))
|
||||||
|
|
||||||
|
# 2. 将两句话放入一个列表中
|
||||||
|
corpus = [s1_words, s2_words]
|
||||||
|
|
||||||
|
# 3. 创建 TF-IDF 向量化器并进行计算
|
||||||
|
try:
|
||||||
|
vectorizer = TfidfVectorizer()
|
||||||
|
tfidf_matrix = vectorizer.fit_transform(corpus)
|
||||||
|
except ValueError:
|
||||||
|
# 如果句子完全由停用词组成,或者为空,可能会报错
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 4. 计算余弦相似度
|
||||||
|
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||||
|
|
||||||
|
# 返回 s1 和 s2 的相似度
|
||||||
|
return similarity_matrix[0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_similarity(s1, s2):
|
||||||
|
"""使用 SequenceMatcher 计算两个句子的相似性"""
|
||||||
|
return SequenceMatcher(None, s1, s2).ratio()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_time_weight(point_time: str, current_time: str) -> float:
|
||||||
|
"""计算基于时间的权重系数"""
|
||||||
|
try:
|
||||||
|
point_timestamp = datetime.strptime(point_time, "%Y-%m-%d %H:%M:%S")
|
||||||
|
current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S")
|
||||||
|
time_diff = current_timestamp - point_timestamp
|
||||||
|
hours_diff = time_diff.total_seconds() / 3600
|
||||||
|
|
||||||
|
if hours_diff <= 1: # 1小时内
|
||||||
|
return 1.0
|
||||||
|
elif hours_diff <= 24: # 1-24小时
|
||||||
|
# 从1.0快速递减到0.7
|
||||||
|
return 1.0 - (hours_diff - 1) * (0.3 / 23)
|
||||||
|
elif hours_diff <= 24 * 7: # 24小时-7天
|
||||||
|
# 从0.7缓慢回升到0.95
|
||||||
|
return 0.7 + (hours_diff - 24) * (0.25 / (24 * 6))
|
||||||
|
else: # 7-30天
|
||||||
|
# 从0.95缓慢递减到0.1
|
||||||
|
days_diff = hours_diff / 24 - 7
|
||||||
|
return max(0.1, 0.95 - days_diff * (0.85 / 23))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"计算时间权重失败: {e}")
|
||||||
|
return 0.5 # 发生错误时返回中等权重
|
||||||
|
|
||||||
|
|
||||||
|
def filter_selected_chats(
|
||||||
|
grouped_messages: Dict[str, List[Dict[str, Any]]], selected_indices: List[int]
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""根据用户选择过滤群聊"""
|
||||||
|
chat_items = list(grouped_messages.items())
|
||||||
|
selected_chats = {}
|
||||||
|
|
||||||
|
for idx in selected_indices:
|
||||||
|
chat_id, messages = chat_items[idx - 1] # 转换为0基索引
|
||||||
|
selected_chats[chat_id] = messages
|
||||||
|
|
||||||
|
return selected_chats
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_selection(total_count: int) -> List[int]:
|
||||||
|
"""获取用户选择的群聊编号"""
|
||||||
|
while True:
|
||||||
|
print(f"\n请选择要分析的群聊 (1-{total_count}):")
|
||||||
|
print("输入格式:")
|
||||||
|
print(" 单个: 1")
|
||||||
|
print(" 多个: 1,3,5")
|
||||||
|
print(" 范围: 1-3")
|
||||||
|
print(" 全部: all 或 a")
|
||||||
|
print(" 退出: quit 或 q")
|
||||||
|
|
||||||
|
user_input = input("请输入选择: ").strip().lower()
|
||||||
|
|
||||||
|
if user_input in ["quit", "q"]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if user_input in ["all", "a"]:
|
||||||
|
return list(range(1, total_count + 1))
|
||||||
|
|
||||||
|
try:
|
||||||
|
selected = []
|
||||||
|
|
||||||
|
# 处理逗号分隔的输入
|
||||||
|
parts = user_input.split(",")
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
part = part.strip()
|
||||||
|
|
||||||
|
if "-" in part:
|
||||||
|
# 处理范围输入 (如: 1-3)
|
||||||
|
start, end = part.split("-")
|
||||||
|
start_num = int(start.strip())
|
||||||
|
end_num = int(end.strip())
|
||||||
|
|
||||||
|
if 1 <= start_num <= total_count and 1 <= end_num <= total_count and start_num <= end_num:
|
||||||
|
selected.extend(range(start_num, end_num + 1))
|
||||||
|
else:
|
||||||
|
raise ValueError("范围超出有效范围")
|
||||||
|
else:
|
||||||
|
# 处理单个数字
|
||||||
|
num = int(part)
|
||||||
|
if 1 <= num <= total_count:
|
||||||
|
selected.append(num)
|
||||||
|
else:
|
||||||
|
raise ValueError("数字超出有效范围")
|
||||||
|
|
||||||
|
# 去重并排序
|
||||||
|
selected = sorted(list(set(selected)))
|
||||||
|
|
||||||
|
if selected:
|
||||||
|
return selected
|
||||||
|
else:
|
||||||
|
print("错误: 请输入有效的选择")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"错误: 输入格式无效 - {e}")
|
||||||
|
print("请重新输入")
|
||||||
|
|
||||||
|
|
||||||
|
def display_chat_list(grouped_messages: Dict[str, List[Dict[str, Any]]]) -> None:
|
||||||
|
"""显示群聊列表"""
|
||||||
|
print("\n找到以下群聊:")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for i, (chat_id, messages) in enumerate(grouped_messages.items(), 1):
|
||||||
|
first_msg = messages[0]
|
||||||
|
group_name = first_msg.get("chat_info_group_name", "私聊")
|
||||||
|
group_id = first_msg.get("chat_info_group_id", chat_id)
|
||||||
|
|
||||||
|
# 计算时间范围
|
||||||
|
start_time = datetime.fromtimestamp(messages[0]["time"]).strftime("%Y-%m-%d")
|
||||||
|
end_time = datetime.fromtimestamp(messages[-1]["time"]).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
print(f"{i:2d}. {group_name}")
|
||||||
|
print(f" 群ID: {group_id}")
|
||||||
|
print(f" 消息数: {len(messages)}")
|
||||||
|
print(f" 时间范围: {start_time} ~ {end_time}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def check_similarity(text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
|
||||||
|
"""使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的"""
|
||||||
|
# 计算两种相似度
|
||||||
|
tfidf_sim = tfidf_similarity(text1, text2)
|
||||||
|
seq_sim = sequence_similarity(text1, text2)
|
||||||
|
|
||||||
|
# 只要其中一种方法达到阈值就认为是相似的
|
||||||
|
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold
|
||||||
|
|
||||||
|
|
||||||
class MessageRetrievalScript:
|
class MessageRetrievalScript:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化脚本"""
|
"""初始化脚本"""
|
||||||
self.person_info_manager = PersonInfoManager()
|
|
||||||
self.bot_qq = str(global_config.bot.qq_account)
|
self.bot_qq = str(global_config.bot.qq_account)
|
||||||
|
|
||||||
# 初始化LLM请求器,和relationship_manager一样
|
# 初始化LLM请求器,和relationship_manager一样
|
||||||
@@ -56,37 +346,16 @@ class MessageRetrievalScript:
|
|||||||
request_type="relationship",
|
request_type="relationship",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_person_id(self, platform: str, user_id: str) -> str:
|
|
||||||
"""根据platform和user_id计算person_id"""
|
|
||||||
return PersonInfoManager.get_person_id(platform, user_id)
|
|
||||||
|
|
||||||
def get_time_range(self, time_period: str) -> Optional[float]:
|
|
||||||
"""根据时间段选择获取起始时间戳"""
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
if time_period == "all":
|
|
||||||
return None
|
|
||||||
elif time_period == "3months":
|
|
||||||
start_time = now - timedelta(days=90)
|
|
||||||
elif time_period == "1month":
|
|
||||||
start_time = now - timedelta(days=30)
|
|
||||||
elif time_period == "1week":
|
|
||||||
start_time = now - timedelta(days=7)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的时间段: {time_period}")
|
|
||||||
|
|
||||||
return start_time.timestamp()
|
|
||||||
|
|
||||||
def retrieve_messages(self, user_qq: str, time_period: str) -> Dict[str, List[Dict[str, Any]]]:
|
def retrieve_messages(self, user_qq: str, time_period: str) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""检索消息"""
|
"""检索消息"""
|
||||||
print(f"开始检索用户 {user_qq} 的消息...")
|
print(f"开始检索用户 {user_qq} 的消息...")
|
||||||
|
|
||||||
# 计算person_id
|
# 计算person_id
|
||||||
person_id = self.get_person_id("qq", user_qq)
|
person_id = get_person_id("qq", user_qq)
|
||||||
print(f"用户person_id: {person_id}")
|
print(f"用户person_id: {person_id}")
|
||||||
|
|
||||||
# 获取时间范围
|
# 获取时间范围
|
||||||
start_timestamp = self.get_time_range(time_period)
|
start_timestamp = get_time_range(time_period)
|
||||||
if start_timestamp:
|
if start_timestamp:
|
||||||
print(f"时间范围: {datetime.fromtimestamp(start_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 至今")
|
print(f"时间范围: {datetime.fromtimestamp(start_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 至今")
|
||||||
else:
|
else:
|
||||||
@@ -97,8 +366,8 @@ class MessageRetrievalScript:
|
|||||||
|
|
||||||
# 添加用户条件:包含bot消息或目标用户消息
|
# 添加用户条件:包含bot消息或目标用户消息
|
||||||
user_condition = (
|
user_condition = (
|
||||||
(Messages.user_id == self.bot_qq) | # bot的消息
|
(Messages.user_id == self.bot_qq) # bot的消息
|
||||||
(Messages.user_id == user_qq) # 目标用户的消息
|
| (Messages.user_id == user_qq) # 目标用户的消息
|
||||||
)
|
)
|
||||||
query = query.where(user_condition)
|
query = query.where(user_condition)
|
||||||
|
|
||||||
@@ -117,199 +386,27 @@ class MessageRetrievalScript:
|
|||||||
grouped_messages = defaultdict(list)
|
grouped_messages = defaultdict(list)
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
msg_dict = {
|
msg_dict = {
|
||||||
'message_id': msg.message_id,
|
"message_id": msg.message_id,
|
||||||
'time': msg.time,
|
"time": msg.time,
|
||||||
'datetime': datetime.fromtimestamp(msg.time).strftime('%Y-%m-%d %H:%M:%S'),
|
"datetime": datetime.fromtimestamp(msg.time).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
'chat_id': msg.chat_id,
|
"chat_id": msg.chat_id,
|
||||||
'user_id': msg.user_id,
|
"user_id": msg.user_id,
|
||||||
'user_nickname': msg.user_nickname,
|
"user_nickname": msg.user_nickname,
|
||||||
'user_platform': msg.user_platform,
|
"user_platform": msg.user_platform,
|
||||||
'processed_plain_text': msg.processed_plain_text,
|
"processed_plain_text": msg.processed_plain_text,
|
||||||
'display_message': msg.display_message,
|
"display_message": msg.display_message,
|
||||||
'chat_info_group_id': msg.chat_info_group_id,
|
"chat_info_group_id": msg.chat_info_group_id,
|
||||||
'chat_info_group_name': msg.chat_info_group_name,
|
"chat_info_group_name": msg.chat_info_group_name,
|
||||||
'chat_info_platform': msg.chat_info_platform,
|
"chat_info_platform": msg.chat_info_platform,
|
||||||
'user_cardname': msg.user_cardname,
|
"user_cardname": msg.user_cardname,
|
||||||
'is_bot_message': msg.user_id == self.bot_qq
|
"is_bot_message": msg.user_id == self.bot_qq,
|
||||||
}
|
}
|
||||||
grouped_messages[msg.chat_id].append(msg_dict)
|
grouped_messages[msg.chat_id].append(msg_dict)
|
||||||
|
|
||||||
print(f"消息分布在 {len(grouped_messages)} 个聊天中")
|
print(f"消息分布在 {len(grouped_messages)} 个聊天中")
|
||||||
return dict(grouped_messages)
|
return dict(grouped_messages)
|
||||||
|
|
||||||
def split_messages_by_count(self, messages: List[Dict[str, Any]], count: int = 50) -> List[List[Dict[str, Any]]]:
|
|
||||||
"""将消息按指定数量分段"""
|
|
||||||
chunks = []
|
|
||||||
for i in range(0, len(messages), count):
|
|
||||||
chunks.append(messages[i:i + count])
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
async def build_name_mapping(self, messages: List[Dict[str, Any]], target_person_id: str, target_person_name: str) -> Dict[str, str]:
|
|
||||||
"""构建用户名称映射,和relationship_manager中的逻辑一致"""
|
|
||||||
name_mapping = {}
|
|
||||||
current_user = "A"
|
|
||||||
user_count = 1
|
|
||||||
|
|
||||||
# 遍历消息,构建映射
|
|
||||||
for msg in messages:
|
|
||||||
await person_info_manager.get_or_create_person(
|
|
||||||
platform=msg.get("chat_info_platform"),
|
|
||||||
user_id=msg.get("user_id"),
|
|
||||||
nickname=msg.get("user_nickname"),
|
|
||||||
user_cardname=msg.get("user_cardname"),
|
|
||||||
)
|
|
||||||
replace_user_id = msg.get("user_id")
|
|
||||||
replace_platform = msg.get("chat_info_platform")
|
|
||||||
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
|
|
||||||
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
|
||||||
|
|
||||||
# 跳过机器人自己
|
|
||||||
if replace_user_id == global_config.bot.qq_account:
|
|
||||||
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 跳过目标用户
|
|
||||||
if replace_person_name == target_person_name:
|
|
||||||
name_mapping[replace_person_name] = f"{target_person_name}"
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 其他用户映射
|
|
||||||
if replace_person_name not in name_mapping:
|
|
||||||
if current_user > 'Z':
|
|
||||||
current_user = 'A'
|
|
||||||
user_count += 1
|
|
||||||
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
|
|
||||||
current_user = chr(ord(current_user) + 1)
|
|
||||||
|
|
||||||
return name_mapping
|
|
||||||
|
|
||||||
def build_focus_readable_messages(self, messages: List[Dict[str, Any]], target_person_id: str = None) -> str:
|
|
||||||
"""格式化消息,只保留目标用户和bot消息附近的内容,和relationship_manager中的逻辑一致"""
|
|
||||||
# 找到目标用户和bot的消息索引
|
|
||||||
target_indices = []
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
user_id = msg.get("user_id")
|
|
||||||
platform = msg.get("chat_info_platform")
|
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
|
||||||
if person_id == target_person_id:
|
|
||||||
target_indices.append(i)
|
|
||||||
|
|
||||||
if not target_indices:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# 获取需要保留的消息索引
|
|
||||||
keep_indices = set()
|
|
||||||
for idx in target_indices:
|
|
||||||
# 获取前后5条消息的索引
|
|
||||||
start_idx = max(0, idx - 5)
|
|
||||||
end_idx = min(len(messages), idx + 6)
|
|
||||||
keep_indices.update(range(start_idx, end_idx))
|
|
||||||
|
|
||||||
# 将索引排序
|
|
||||||
keep_indices = sorted(list(keep_indices))
|
|
||||||
|
|
||||||
# 按顺序构建消息组
|
|
||||||
message_groups = []
|
|
||||||
current_group = []
|
|
||||||
|
|
||||||
for i in range(len(messages)):
|
|
||||||
if i in keep_indices:
|
|
||||||
current_group.append(messages[i])
|
|
||||||
elif current_group:
|
|
||||||
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
|
|
||||||
if current_group:
|
|
||||||
message_groups.append(current_group)
|
|
||||||
current_group = []
|
|
||||||
|
|
||||||
# 添加最后一组
|
|
||||||
if current_group:
|
|
||||||
message_groups.append(current_group)
|
|
||||||
|
|
||||||
# 构建最终的消息文本
|
|
||||||
result = []
|
|
||||||
for i, group in enumerate(message_groups):
|
|
||||||
if i > 0:
|
|
||||||
result.append("...")
|
|
||||||
group_text = build_readable_messages(
|
|
||||||
messages=group,
|
|
||||||
replace_bot_name=True,
|
|
||||||
timestamp_mode="normal_no_YMD",
|
|
||||||
truncate=False
|
|
||||||
)
|
|
||||||
result.append(group_text)
|
|
||||||
|
|
||||||
return "\n".join(result)
|
|
||||||
|
|
||||||
# 添加相似度检查方法,和relationship_manager一致
|
# 添加相似度检查方法,和relationship_manager一致
|
||||||
def tfidf_similarity(self, s1, s2):
|
|
||||||
"""使用 TF-IDF 和余弦相似度计算两个句子的相似性"""
|
|
||||||
# 确保输入是字符串类型
|
|
||||||
if isinstance(s1, list):
|
|
||||||
s1 = " ".join(str(x) for x in s1)
|
|
||||||
if isinstance(s2, list):
|
|
||||||
s2 = " ".join(str(x) for x in s2)
|
|
||||||
|
|
||||||
# 转换为字符串类型
|
|
||||||
s1 = str(s1)
|
|
||||||
s2 = str(s2)
|
|
||||||
|
|
||||||
# 1. 使用 jieba 进行分词
|
|
||||||
s1_words = " ".join(jieba.cut(s1))
|
|
||||||
s2_words = " ".join(jieba.cut(s2))
|
|
||||||
|
|
||||||
# 2. 将两句话放入一个列表中
|
|
||||||
corpus = [s1_words, s2_words]
|
|
||||||
|
|
||||||
# 3. 创建 TF-IDF 向量化器并进行计算
|
|
||||||
try:
|
|
||||||
vectorizer = TfidfVectorizer()
|
|
||||||
tfidf_matrix = vectorizer.fit_transform(corpus)
|
|
||||||
except ValueError:
|
|
||||||
# 如果句子完全由停用词组成,或者为空,可能会报错
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# 4. 计算余弦相似度
|
|
||||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
|
||||||
|
|
||||||
# 返回 s1 和 s2 的相似度
|
|
||||||
return similarity_matrix[0, 1]
|
|
||||||
|
|
||||||
def sequence_similarity(self, s1, s2):
|
|
||||||
"""使用 SequenceMatcher 计算两个句子的相似性"""
|
|
||||||
return SequenceMatcher(None, s1, s2).ratio()
|
|
||||||
|
|
||||||
def check_similarity(self, text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
|
|
||||||
"""使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的"""
|
|
||||||
# 计算两种相似度
|
|
||||||
tfidf_sim = self.tfidf_similarity(text1, text2)
|
|
||||||
seq_sim = self.sequence_similarity(text1, text2)
|
|
||||||
|
|
||||||
# 只要其中一种方法达到阈值就认为是相似的
|
|
||||||
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold
|
|
||||||
|
|
||||||
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
|
|
||||||
"""计算基于时间的权重系数"""
|
|
||||||
try:
|
|
||||||
point_timestamp = datetime.strptime(point_time, "%Y-%m-%d %H:%M:%S")
|
|
||||||
current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S")
|
|
||||||
time_diff = current_timestamp - point_timestamp
|
|
||||||
hours_diff = time_diff.total_seconds() / 3600
|
|
||||||
|
|
||||||
if hours_diff <= 1: # 1小时内
|
|
||||||
return 1.0
|
|
||||||
elif hours_diff <= 24: # 1-24小时
|
|
||||||
# 从1.0快速递减到0.7
|
|
||||||
return 1.0 - (hours_diff - 1) * (0.3 / 23)
|
|
||||||
elif hours_diff <= 24 * 7: # 24小时-7天
|
|
||||||
# 从0.7缓慢回升到0.95
|
|
||||||
return 0.7 + (hours_diff - 24) * (0.25 / (24 * 6))
|
|
||||||
else: # 7-30天
|
|
||||||
# 从0.95缓慢递减到0.1
|
|
||||||
days_diff = hours_diff / 24 - 7
|
|
||||||
return max(0.1, 0.95 - days_diff * (0.85 / 23))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"计算时间权重失败: {e}")
|
|
||||||
return 0.5 # 发生错误时返回中等权重
|
|
||||||
|
|
||||||
async def update_person_impression_from_segment(self, person_id: str, readable_messages: str, segment_time: float):
|
async def update_person_impression_from_segment(self, person_id: str, readable_messages: str, segment_time: float):
|
||||||
"""从消息段落更新用户印象,使用和relationship_manager相同的流程"""
|
"""从消息段落更新用户印象,使用和relationship_manager相同的流程"""
|
||||||
@@ -412,7 +509,7 @@ class MessageRetrievalScript:
|
|||||||
# 在现有points中查找相似的点
|
# 在现有points中查找相似的点
|
||||||
for i, existing_point in enumerate(current_points):
|
for i, existing_point in enumerate(current_points):
|
||||||
# 使用组合的相似度检查方法
|
# 使用组合的相似度检查方法
|
||||||
if self.check_similarity(new_point[0], existing_point[0]):
|
if check_similarity(new_point[0], existing_point[0]):
|
||||||
similar_points.append(existing_point)
|
similar_points.append(existing_point)
|
||||||
similar_indices.append(i)
|
similar_indices.append(i)
|
||||||
|
|
||||||
@@ -460,7 +557,7 @@ class MessageRetrievalScript:
|
|||||||
# 计算每个点的最终权重(原始权重 * 时间权重)
|
# 计算每个点的最终权重(原始权重 * 时间权重)
|
||||||
weighted_points = []
|
weighted_points = []
|
||||||
for point in current_points:
|
for point in current_points:
|
||||||
time_weight = self.calculate_time_weight(point[2], current_time_str)
|
time_weight = calculate_time_weight(point[2], current_time_str)
|
||||||
final_weight = point[1] * time_weight
|
final_weight = point[1] * time_weight
|
||||||
weighted_points.append((point, final_weight))
|
weighted_points.append((point, final_weight))
|
||||||
|
|
||||||
@@ -505,10 +602,9 @@ class MessageRetrievalScript:
|
|||||||
forgotten_points.sort(key=lambda x: x[2])
|
forgotten_points.sort(key=lambda x: x[2])
|
||||||
|
|
||||||
# 构建points文本
|
# 构建points文本
|
||||||
points_text = "\n".join([
|
points_text = "\n".join(
|
||||||
f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}"
|
[f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points]
|
||||||
for point in forgotten_points
|
)
|
||||||
])
|
|
||||||
|
|
||||||
impression = await person_info_manager.get_value(person_id, "impression") or ""
|
impression = await person_info_manager.get_value(person_id, "impression") or ""
|
||||||
|
|
||||||
@@ -543,27 +639,30 @@ class MessageRetrievalScript:
|
|||||||
forgotten_points = []
|
forgotten_points = []
|
||||||
|
|
||||||
# 更新数据库
|
# 更新数据库
|
||||||
await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None))
|
await person_info_manager.update_one_field(
|
||||||
|
person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)
|
||||||
|
)
|
||||||
|
|
||||||
# 更新数据库
|
# 更新数据库
|
||||||
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
|
await person_info_manager.update_one_field(
|
||||||
|
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
|
||||||
|
)
|
||||||
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
||||||
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
|
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
|
||||||
await person_info_manager.update_one_field(person_id, "last_know", segment_time)
|
await person_info_manager.update_one_field(person_id, "last_know", segment_time)
|
||||||
|
|
||||||
logger.info(f"印象更新完成 for {person_name},新增 {len(points_list)} 个记忆点")
|
logger.info(f"印象更新完成 for {person_name},新增 {len(points_list)} 个记忆点")
|
||||||
|
|
||||||
async def process_segments_and_update_impression(self, user_qq: str, grouped_messages: Dict[str, List[Dict[str, Any]]]):
|
async def process_segments_and_update_impression(
|
||||||
|
self, user_qq: str, grouped_messages: Dict[str, List[Dict[str, Any]]]
|
||||||
|
):
|
||||||
"""处理分段消息并更新用户印象到数据库"""
|
"""处理分段消息并更新用户印象到数据库"""
|
||||||
# 获取目标用户信息
|
# 获取目标用户信息
|
||||||
target_person_id = self.get_person_id("qq", user_qq)
|
target_person_id = get_person_id("qq", user_qq)
|
||||||
target_person_name = await person_info_manager.get_value(target_person_id, "person_name")
|
target_person_name = await person_info_manager.get_value(target_person_id, "person_name")
|
||||||
target_nickname = await person_info_manager.get_value(target_person_id, "nickname")
|
|
||||||
|
|
||||||
if not target_person_name:
|
if not target_person_name:
|
||||||
target_person_name = f"用户{user_qq}"
|
target_person_name = f"用户{user_qq}"
|
||||||
if not target_nickname:
|
|
||||||
target_nickname = f"用户{user_qq}"
|
|
||||||
|
|
||||||
print(f"\n开始分析用户 {target_person_name} (QQ: {user_qq}) 的消息...")
|
print(f"\n开始分析用户 {target_person_name} (QQ: {user_qq}) 的消息...")
|
||||||
|
|
||||||
@@ -575,52 +674,53 @@ class MessageRetrievalScript:
|
|||||||
# 为每个chat_id处理消息,收集所有分段
|
# 为每个chat_id处理消息,收集所有分段
|
||||||
for chat_id, messages in grouped_messages.items():
|
for chat_id, messages in grouped_messages.items():
|
||||||
first_msg = messages[0]
|
first_msg = messages[0]
|
||||||
group_name = first_msg.get('chat_info_group_name', '私聊')
|
group_name = first_msg.get("chat_info_group_name", "私聊")
|
||||||
|
|
||||||
print(f"准备聊天: {group_name} (共{len(messages)}条消息)")
|
print(f"准备聊天: {group_name} (共{len(messages)}条消息)")
|
||||||
|
|
||||||
# 将消息按50条分段
|
# 将消息按50条分段
|
||||||
message_chunks = self.split_messages_by_count(messages, 50)
|
message_chunks = split_messages_by_count(messages, 50)
|
||||||
|
|
||||||
for i, chunk in enumerate(message_chunks):
|
for i, chunk in enumerate(message_chunks):
|
||||||
# 将分段信息添加到列表中,包含分段时间用于排序
|
# 将分段信息添加到列表中,包含分段时间用于排序
|
||||||
segment_time = chunk[-1]['time']
|
segment_time = chunk[-1]["time"]
|
||||||
all_segments.append({
|
all_segments.append(
|
||||||
'chunk': chunk,
|
{
|
||||||
'chat_id': chat_id,
|
"chunk": chunk,
|
||||||
'group_name': group_name,
|
"chat_id": chat_id,
|
||||||
'segment_index': i + 1,
|
"group_name": group_name,
|
||||||
'total_segments': len(message_chunks),
|
"segment_index": i + 1,
|
||||||
'segment_time': segment_time
|
"total_segments": len(message_chunks),
|
||||||
})
|
"segment_time": segment_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 按时间排序所有分段
|
# 按时间排序所有分段
|
||||||
all_segments.sort(key=lambda x: x['segment_time'])
|
all_segments.sort(key=lambda x: x["segment_time"])
|
||||||
|
|
||||||
print(f"\n按时间顺序处理 {len(all_segments)} 个分段:")
|
print(f"\n按时间顺序处理 {len(all_segments)} 个分段:")
|
||||||
|
|
||||||
# 按时间顺序处理所有分段
|
# 按时间顺序处理所有分段
|
||||||
for segment_idx, segment_info in enumerate(all_segments, 1):
|
for segment_idx, segment_info in enumerate(all_segments, 1):
|
||||||
chunk = segment_info['chunk']
|
chunk = segment_info["chunk"]
|
||||||
group_name = segment_info['group_name']
|
group_name = segment_info["group_name"]
|
||||||
segment_index = segment_info['segment_index']
|
segment_index = segment_info["segment_index"]
|
||||||
total_segments = segment_info['total_segments']
|
total_segments = segment_info["total_segments"]
|
||||||
segment_time = segment_info['segment_time']
|
segment_time = segment_info["segment_time"]
|
||||||
|
|
||||||
segment_time_str = datetime.fromtimestamp(segment_time).strftime('%Y-%m-%d %H:%M:%S')
|
segment_time_str = datetime.fromtimestamp(segment_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
print(f" [{segment_idx}/{len(all_segments)}] {group_name} 第{segment_index}/{total_segments}段 ({segment_time_str}) (共{len(chunk)}条)")
|
print(
|
||||||
|
f" [{segment_idx}/{len(all_segments)}] {group_name} 第{segment_index}/{total_segments}段 ({segment_time_str}) (共{len(chunk)}条)"
|
||||||
# 构建名称映射
|
|
||||||
name_mapping = await self.build_name_mapping(chunk, target_person_id, target_person_name)
|
|
||||||
|
|
||||||
# 构建可读消息
|
|
||||||
readable_messages = self.build_focus_readable_messages(
|
|
||||||
messages=chunk,
|
|
||||||
target_person_id=target_person_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 构建名称映射
|
||||||
|
name_mapping = await build_name_mapping(chunk, target_person_name)
|
||||||
|
|
||||||
|
# 构建可读消息
|
||||||
|
readable_messages = build_focus_readable_messages(messages=chunk, target_person_id=target_person_id)
|
||||||
|
|
||||||
if not readable_messages:
|
if not readable_messages:
|
||||||
print(f" 跳过:该段落没有目标用户的消息")
|
print(" 跳过:该段落没有目标用户的消息")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 应用名称映射
|
# 应用名称映射
|
||||||
@@ -633,7 +733,7 @@ class MessageRetrievalScript:
|
|||||||
total_segments_processed += 1
|
total_segments_processed += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理段落时出错: {e}")
|
logger.error(f"处理段落时出错: {e}")
|
||||||
print(f" 错误:处理该段落时出现异常")
|
print(" 错误:处理该段落时出现异常")
|
||||||
|
|
||||||
# 获取最终统计
|
# 获取最终统计
|
||||||
final_points = await person_info_manager.get_value(target_person_id, "points") or []
|
final_points = await person_info_manager.get_value(target_person_id, "points") or []
|
||||||
@@ -645,7 +745,7 @@ class MessageRetrievalScript:
|
|||||||
|
|
||||||
final_impression = await person_info_manager.get_value(target_person_id, "impression") or ""
|
final_impression = await person_info_manager.get_value(target_person_id, "impression") or ""
|
||||||
|
|
||||||
print(f"\n=== 处理完成 ===")
|
print("\n=== 处理完成 ===")
|
||||||
print(f"目标用户: {target_person_name} (QQ: {user_qq})")
|
print(f"目标用户: {target_person_name} (QQ: {user_qq})")
|
||||||
print(f"处理段落数: {total_segments_processed}")
|
print(f"处理段落数: {total_segments_processed}")
|
||||||
print(f"当前记忆点数: {len(final_points)}")
|
print(f"当前记忆点数: {len(final_points)}")
|
||||||
@@ -654,95 +754,6 @@ class MessageRetrievalScript:
|
|||||||
if final_points:
|
if final_points:
|
||||||
print(f"最新记忆点: {final_points[-1][0][:50]}...")
|
print(f"最新记忆点: {final_points[-1][0][:50]}...")
|
||||||
|
|
||||||
def display_chat_list(self, grouped_messages: Dict[str, List[Dict[str, Any]]]) -> None:
|
|
||||||
"""显示群聊列表"""
|
|
||||||
print("\n找到以下群聊:")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
for i, (chat_id, messages) in enumerate(grouped_messages.items(), 1):
|
|
||||||
first_msg = messages[0]
|
|
||||||
group_name = first_msg.get('chat_info_group_name', '私聊')
|
|
||||||
group_id = first_msg.get('chat_info_group_id', chat_id)
|
|
||||||
|
|
||||||
# 计算时间范围
|
|
||||||
start_time = datetime.fromtimestamp(messages[0]['time']).strftime('%Y-%m-%d')
|
|
||||||
end_time = datetime.fromtimestamp(messages[-1]['time']).strftime('%Y-%m-%d')
|
|
||||||
|
|
||||||
print(f"{i:2d}. {group_name}")
|
|
||||||
print(f" 群ID: {group_id}")
|
|
||||||
print(f" 消息数: {len(messages)}")
|
|
||||||
print(f" 时间范围: {start_time} ~ {end_time}")
|
|
||||||
print("-" * 60)
|
|
||||||
|
|
||||||
def get_user_selection(self, total_count: int) -> List[int]:
|
|
||||||
"""获取用户选择的群聊编号"""
|
|
||||||
while True:
|
|
||||||
print(f"\n请选择要分析的群聊 (1-{total_count}):")
|
|
||||||
print("输入格式:")
|
|
||||||
print(" 单个: 1")
|
|
||||||
print(" 多个: 1,3,5")
|
|
||||||
print(" 范围: 1-3")
|
|
||||||
print(" 全部: all 或 a")
|
|
||||||
print(" 退出: quit 或 q")
|
|
||||||
|
|
||||||
user_input = input("请输入选择: ").strip().lower()
|
|
||||||
|
|
||||||
if user_input in ['quit', 'q']:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if user_input in ['all', 'a']:
|
|
||||||
return list(range(1, total_count + 1))
|
|
||||||
|
|
||||||
try:
|
|
||||||
selected = []
|
|
||||||
|
|
||||||
# 处理逗号分隔的输入
|
|
||||||
parts = user_input.split(',')
|
|
||||||
|
|
||||||
for part in parts:
|
|
||||||
part = part.strip()
|
|
||||||
|
|
||||||
if '-' in part:
|
|
||||||
# 处理范围输入 (如: 1-3)
|
|
||||||
start, end = part.split('-')
|
|
||||||
start_num = int(start.strip())
|
|
||||||
end_num = int(end.strip())
|
|
||||||
|
|
||||||
if 1 <= start_num <= total_count and 1 <= end_num <= total_count and start_num <= end_num:
|
|
||||||
selected.extend(range(start_num, end_num + 1))
|
|
||||||
else:
|
|
||||||
raise ValueError("范围超出有效范围")
|
|
||||||
else:
|
|
||||||
# 处理单个数字
|
|
||||||
num = int(part)
|
|
||||||
if 1 <= num <= total_count:
|
|
||||||
selected.append(num)
|
|
||||||
else:
|
|
||||||
raise ValueError("数字超出有效范围")
|
|
||||||
|
|
||||||
# 去重并排序
|
|
||||||
selected = sorted(list(set(selected)))
|
|
||||||
|
|
||||||
if selected:
|
|
||||||
return selected
|
|
||||||
else:
|
|
||||||
print("错误: 请输入有效的选择")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
print(f"错误: 输入格式无效 - {e}")
|
|
||||||
print("请重新输入")
|
|
||||||
|
|
||||||
def filter_selected_chats(self, grouped_messages: Dict[str, List[Dict[str, Any]]], selected_indices: List[int]) -> Dict[str, List[Dict[str, Any]]]:
|
|
||||||
"""根据用户选择过滤群聊"""
|
|
||||||
chat_items = list(grouped_messages.items())
|
|
||||||
selected_chats = {}
|
|
||||||
|
|
||||||
for idx in selected_indices:
|
|
||||||
chat_id, messages = chat_items[idx - 1] # 转换为0基索引
|
|
||||||
selected_chats[chat_id] = messages
|
|
||||||
|
|
||||||
return selected_chats
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""运行脚本"""
|
"""运行脚本"""
|
||||||
print("=== 消息检索分析脚本 ===")
|
print("=== 消息检索分析脚本 ===")
|
||||||
@@ -760,12 +771,7 @@ class MessageRetrievalScript:
|
|||||||
print("4. 最近1周 (1week)")
|
print("4. 最近1周 (1week)")
|
||||||
|
|
||||||
choice = input("请选择时间段 (1-4): ").strip()
|
choice = input("请选择时间段 (1-4): ").strip()
|
||||||
time_periods = {
|
time_periods = {"1": "all", "2": "3months", "3": "1month", "4": "1week"}
|
||||||
"1": "all",
|
|
||||||
"2": "3months",
|
|
||||||
"3": "1month",
|
|
||||||
"4": "1week"
|
|
||||||
}
|
|
||||||
|
|
||||||
if choice not in time_periods:
|
if choice not in time_periods:
|
||||||
print("选择无效")
|
print("选择无效")
|
||||||
@@ -792,28 +798,28 @@ class MessageRetrievalScript:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 显示群聊列表
|
# 显示群聊列表
|
||||||
self.display_chat_list(grouped_messages)
|
display_chat_list(grouped_messages)
|
||||||
|
|
||||||
# 获取用户选择
|
# 获取用户选择
|
||||||
selected_indices = self.get_user_selection(len(grouped_messages))
|
selected_indices = get_user_selection(len(grouped_messages))
|
||||||
|
|
||||||
if not selected_indices:
|
if not selected_indices:
|
||||||
print("已取消操作")
|
print("已取消操作")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 过滤选中的群聊
|
# 过滤选中的群聊
|
||||||
selected_chats = self.filter_selected_chats(grouped_messages, selected_indices)
|
selected_chats = filter_selected_chats(grouped_messages, selected_indices)
|
||||||
|
|
||||||
# 显示选中的群聊
|
# 显示选中的群聊
|
||||||
print(f"\n已选择 {len(selected_chats)} 个群聊进行分析:")
|
print(f"\n已选择 {len(selected_chats)} 个群聊进行分析:")
|
||||||
for i, (chat_id, messages) in enumerate(selected_chats.items(), 1):
|
for i, (_, messages) in enumerate(selected_chats.items(), 1):
|
||||||
first_msg = messages[0]
|
first_msg = messages[0]
|
||||||
group_name = first_msg.get('chat_info_group_name', '私聊')
|
group_name = first_msg.get("chat_info_group_name", "私聊")
|
||||||
print(f" {i}. {group_name} ({len(messages)}条消息)")
|
print(f" {i}. {group_name} ({len(messages)}条消息)")
|
||||||
|
|
||||||
# 确认处理
|
# 确认处理
|
||||||
confirm = input(f"\n确认分析这些群聊吗? (y/n): ").strip().lower()
|
confirm = input("\n确认分析这些群聊吗? (y/n): ").strip().lower()
|
||||||
if confirm != 'y':
|
if confirm != "y":
|
||||||
print("已取消操作")
|
print("已取消操作")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -823,16 +829,18 @@ class MessageRetrievalScript:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"处理过程中出现错误: {e}")
|
print(f"处理过程中出现错误: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
print("数据库连接已关闭")
|
print("数据库连接已关闭")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
script = MessageRetrievalScript()
|
script = MessageRetrievalScript()
|
||||||
asyncio.run(script.run())
|
asyncio.run(script.run())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import Optional, Tuple, Dict
|
from typing import Optional, Tuple, Dict
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
|
|
||||||
logger = get_logger("heartflow_utils")
|
logger = get_logger("heartflow_utils")
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
# Try to fetch person info
|
# Try to fetch person info
|
||||||
try:
|
try:
|
||||||
# Assume get_person_id is sync (as per original code), keep using to_thread
|
# Assume get_person_id is sync (as per original code), keep using to_thread
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
person_name = None
|
person_name = None
|
||||||
if person_id:
|
if person_id:
|
||||||
# get_value is async, so await it directly
|
# get_value is async, so await it directly
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageThinking
|
|||||||
from src.chat.normal_chat.normal_prompt import prompt_builder
|
from src.chat.normal_chat.normal_prompt import prompt_builder
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
from src.chat.utils.utils import process_llm_response
|
from src.chat.utils.utils import process_llm_response
|
||||||
|
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ class NormalChatGenerator:
|
|||||||
enable_planner: bool = False,
|
enable_planner: bool = False,
|
||||||
available_actions=None,
|
available_actions=None,
|
||||||
):
|
):
|
||||||
person_id = person_info_manager.get_person_id(
|
person_id = PersonInfoManager.get_person_id(
|
||||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class BaseWillingManager(ABC):
|
|||||||
self.logger: LoguruLogger = logger
|
self.logger: LoguruLogger = logger
|
||||||
|
|
||||||
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
||||||
person_id = person_info_manager.get_person_id(chat.platform, chat.user_info.user_id)
|
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
|
||||||
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
|
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
|
||||||
message=message,
|
message=message,
|
||||||
chat=chat,
|
chat=chat,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.common.database.database_model import ActionRecords
|
from src.common.database.database_model import ActionRecords
|
||||||
@@ -219,7 +219,7 @@ def _build_readable_messages_internal(
|
|||||||
if not all([platform, user_id, timestamp is not None]):
|
if not all([platform, user_id, timestamp is not None]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
person_name = f"{global_config.bot.nickname}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
@@ -241,7 +241,7 @@ def _build_readable_messages_internal(
|
|||||||
if match:
|
if match:
|
||||||
aaa = match.group(1)
|
aaa = match.group(1)
|
||||||
bbb = match.group(2)
|
bbb = match.group(2)
|
||||||
reply_person_id = person_info_manager.get_person_id(platform, bbb)
|
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||||
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
|
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
|
||||||
if not reply_person_name:
|
if not reply_person_name:
|
||||||
reply_person_name = aaa
|
reply_person_name = aaa
|
||||||
@@ -258,7 +258,7 @@ def _build_readable_messages_internal(
|
|||||||
new_content += content[last_end : m.start()]
|
new_content += content[last_end : m.start()]
|
||||||
aaa = m.group(1)
|
aaa = m.group(1)
|
||||||
bbb = m.group(2)
|
bbb = m.group(2)
|
||||||
at_person_id = person_info_manager.get_person_id(platform, bbb)
|
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||||
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
|
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
|
||||||
if not at_person_name:
|
if not at_person_name:
|
||||||
at_person_name = aaa
|
at_person_name = aaa
|
||||||
@@ -572,7 +572,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
# print("SELF11111111111111")
|
# print("SELF11111111111111")
|
||||||
return "SELF"
|
return "SELF"
|
||||||
try:
|
try:
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
person_id = None
|
person_id = None
|
||||||
if not person_id:
|
if not person_id:
|
||||||
@@ -673,7 +673,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
|
||||||
# 只有当获取到有效 person_id 时才添加
|
# 只有当获取到有效 person_id 时才添加
|
||||||
if person_id:
|
if person_id:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from src.manager.async_task_manager import AsyncTask
|
from src.manager.async_task_manager import AsyncTask
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from src.person_info.person_info import PersonInfoManager
|
||||||
from src.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.person_info.person_info import person_info_manager
|
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
@@ -95,7 +95,7 @@ class ImpressionUpdateTask(AsyncTask):
|
|||||||
if msg["user_nickname"] == global_config.bot.nickname:
|
if msg["user_nickname"] == global_config.bot.nickname:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = person_info_manager.get_person_id(msg["chat_info_platform"], msg["user_id"])
|
person_id = PersonInfoManager.get_person_id(msg["chat_info_platform"], msg["user_id"])
|
||||||
if not person_id:
|
if not person_id:
|
||||||
logger.warning(f"未找到用户 {msg['user_nickname']} 的person_id")
|
logger.warning(f"未找到用户 {msg['user_nickname']} 的person_id")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
import math
|
import math
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -91,7 +91,7 @@ class RelationshipManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
# 生成唯一的 person_name
|
# 生成唯一的 person_name
|
||||||
unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname)
|
unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname)
|
||||||
data = {
|
data = {
|
||||||
@@ -116,7 +116,7 @@ class RelationshipManager:
|
|||||||
if is_id:
|
if is_id:
|
||||||
person_id = person
|
person_id = person
|
||||||
else:
|
else:
|
||||||
person_id = person_info_manager.get_person_id(person[0], person[1])
|
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
||||||
|
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
if not person_name or person_name == "none":
|
if not person_name or person_name == "none":
|
||||||
@@ -198,7 +198,7 @@ class RelationshipManager:
|
|||||||
)
|
)
|
||||||
replace_user_id = msg.get("user_id")
|
replace_user_id = msg.get("user_id")
|
||||||
replace_platform = msg.get("chat_info_platform")
|
replace_platform = msg.get("chat_info_platform")
|
||||||
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
|
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
|
||||||
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
||||||
|
|
||||||
# 跳过机器人自己
|
# 跳过机器人自己
|
||||||
@@ -467,7 +467,7 @@ class RelationshipManager:
|
|||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
user_id = msg.get("user_id")
|
user_id = msg.get("user_id")
|
||||||
platform = msg.get("chat_info_platform")
|
platform = msg.get("chat_info_platform")
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
if person_id == target_person_id:
|
if person_id == target_person_id:
|
||||||
target_indices.append(i)
|
target_indices.append(i)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user