diff --git a/scripts/analyze_expression_similarity.py b/scripts/analyze_expression_similarity.py new file mode 100644 index 000000000..1cdda3dd5 --- /dev/null +++ b/scripts/analyze_expression_similarity.py @@ -0,0 +1,181 @@ +import os +import json +from typing import List, Dict, Tuple +import numpy as np +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +import glob +import sqlite3 +import re +from datetime import datetime + +def clean_group_name(name: str) -> str: + """清理群组名称,只保留中文和英文字符""" + cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name) + if not cleaned: + cleaned = datetime.now().strftime("%Y%m%d") + return cleaned + +def get_group_name(stream_id: str) -> str: + """从数据库中获取群组名称""" + conn = sqlite3.connect("data/maibot.db") + cursor = conn.cursor() + + cursor.execute( + """ + SELECT group_name, user_nickname, platform + FROM chat_streams + WHERE stream_id = ? + """, + (stream_id,), + ) + + result = cursor.fetchone() + conn.close() + + if result: + group_name, user_nickname, platform = result + if group_name: + return clean_group_name(group_name) + if user_nickname: + return clean_group_name(user_nickname) + if platform: + return clean_group_name(f"{platform}{stream_id[:8]}") + return stream_id + +def format_timestamp(timestamp: float) -> str: + """将时间戳转换为可读的时间格式""" + if not timestamp: + return "未知" + try: + dt = datetime.fromtimestamp(timestamp) + return dt.strftime("%Y-%m-%d %H:%M:%S") + except: + return "未知" + +def load_expressions(chat_id: str) -> List[Dict]: + """加载指定群聊的表达方式""" + style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") + + style_exprs = [] + + if os.path.exists(style_file): + with open(style_file, "r", encoding="utf-8") as f: + style_exprs = json.load(f) + + return style_exprs + +def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]: + """找出每个表达方式最相似的top_k个表达方式""" + if not expressions: + return {} + + # 分别准备情景和表达方式的文本数据 + situations = [expr['situation'] for expr in expressions] + styles = [expr['style'] for expr in expressions] + + # 使用TF-IDF向量化 + vectorizer = TfidfVectorizer() + situation_matrix = vectorizer.fit_transform(situations) + style_matrix = vectorizer.fit_transform(styles) + + # 计算余弦相似度 + situation_similarity = cosine_similarity(situation_matrix) + style_similarity = cosine_similarity(style_matrix) + + # 对每个表达方式找出最相似的top_k个 + similar_expressions = {} + for i, expr in enumerate(expressions): + # 获取相似度分数 + situation_scores = situation_similarity[i] + style_scores = style_similarity[i] + + # 获取top_k的索引(排除自己) + situation_indices = np.argsort(situation_scores)[::-1][1:top_k+1] + style_indices = np.argsort(style_scores)[::-1][1:top_k+1] + + similar_situations = [] + similar_styles = [] + + # 处理相似情景 + for idx in situation_indices: + if situation_scores[idx] > 0: # 只保留有相似度的 + similar_situations.append(( + expressions[idx]['situation'], + expressions[idx]['style'], # 添加对应的原始表达 + situation_scores[idx] + )) + + # 处理相似表达 + for idx in style_indices: + if style_scores[idx] > 0: # 只保留有相似度的 + similar_styles.append(( + expressions[idx]['style'], + expressions[idx]['situation'], # 添加对应的原始情景 + style_scores[idx] + )) + + if similar_situations or similar_styles: + similar_expressions[i] = { + 'situations': similar_situations, + 'styles': similar_styles + } + + return similar_expressions + +def main(): + # 获取所有群聊ID + style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) + chat_ids = [os.path.basename(d) for d in style_dirs] + + if not chat_ids: + print("没有找到任何群聊的表达方式数据") + return + + print("可用的群聊:") + for i, chat_id in enumerate(chat_ids, 1): + group_name = get_group_name(chat_id) + print(f"{i}. {group_name}") + + while True: + try: + choice = int(input("\n请选择要分析的群聊编号 (输入0退出): ")) + if choice == 0: + break + if 1 <= choice <= len(chat_ids): + chat_id = chat_ids[choice-1] + break + print("无效的选择,请重试") + except ValueError: + print("请输入有效的数字") + + if choice == 0: + return + + # 加载表达方式 + style_exprs = load_expressions(chat_id) + + group_name = get_group_name(chat_id) + print(f"\n分析群聊 {group_name} 的表达方式:") + + similar_styles = find_similar_expressions(style_exprs) + for i, expr in enumerate(style_exprs): + if i in similar_styles: + print("\n" + "-" * 20) + print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}") + + if similar_styles[i]['styles']: + print("\n\033[33m相似表达:\033[0m") + for similar_style, original_situation, score in similar_styles[i]['styles']: + print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m") + + if similar_styles[i]['situations']: + print("\n\033[32m相似情景:\033[0m") + for similar_situation, original_style, score in similar_styles[i]['situations']: + print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m") + + print(f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}") + print("-" * 20) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/cleanup_expressions.py b/scripts/cleanup_expressions.py new file mode 100644 index 000000000..c5e66133a --- /dev/null +++ b/scripts/cleanup_expressions.py @@ -0,0 +1,119 @@ +import os +import json +import random +from typing import List, Dict, Tuple +import glob +from datetime import datetime + +MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量 +MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值 + +def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]: + """加载指定群聊的表达方式""" + style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") + grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json") + + style_exprs = [] + grammar_exprs = [] + + if os.path.exists(style_file): + with open(style_file, "r", encoding="utf-8") as f: + style_exprs = json.load(f) + + if os.path.exists(grammar_file): + with open(grammar_file, "r", encoding="utf-8") as f: + grammar_exprs = json.load(f) + + return style_exprs, grammar_exprs + +def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None: + """保存表达方式到文件""" + style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") + grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json") + + os.makedirs(os.path.dirname(style_file), exist_ok=True) + os.makedirs(os.path.dirname(grammar_file), exist_ok=True) + + with open(style_file, "w", encoding="utf-8") as f: + json.dump(style_exprs, f, ensure_ascii=False, indent=2) + + with open(grammar_file, "w", encoding="utf-8") as f: + json.dump(grammar_exprs, f, ensure_ascii=False, indent=2) + +def cleanup_expressions(expressions: List[Dict]) -> List[Dict]: + """清理表达方式列表""" + if not expressions: + return [] + + # 1. 移除使用次数过低的表达方式 + expressions = [expr for expr in expressions if expr.get("count", 0) > MIN_COUNT_THRESHOLD] + + # 2. 如果数量超过限制,随机删除多余的 + if len(expressions) > MAX_EXPRESSION_COUNT: + # 按使用次数排序 + expressions.sort(key=lambda x: x.get("count", 0), reverse=True) + + # 保留前50%的高频表达方式 + keep_count = MAX_EXPRESSION_COUNT // 2 + keep_exprs = expressions[:keep_count] + + # 从剩余的表达方式中随机选择 + remaining_exprs = expressions[keep_count:] + random.shuffle(remaining_exprs) + keep_exprs.extend(remaining_exprs[:MAX_EXPRESSION_COUNT - keep_count]) + + expressions = keep_exprs + + return expressions + +def main(): + # 获取所有群聊ID + style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) + chat_ids = [os.path.basename(d) for d in style_dirs] + + if not chat_ids: + print("没有找到任何群聊的表达方式数据") + return + + print(f"开始清理 {len(chat_ids)} 个群聊的表达方式数据...") + + total_style_before = 0 + total_style_after = 0 + total_grammar_before = 0 + total_grammar_after = 0 + + for chat_id in chat_ids: + print(f"\n处理群聊 {chat_id}:") + + # 加载表达方式 + style_exprs, grammar_exprs = load_expressions(chat_id) + + # 记录清理前的数量 + style_count_before = len(style_exprs) + grammar_count_before = len(grammar_exprs) + total_style_before += style_count_before + total_grammar_before += grammar_count_before + + # 清理表达方式 + style_exprs = cleanup_expressions(style_exprs) + grammar_exprs = cleanup_expressions(grammar_exprs) + + # 记录清理后的数量 + style_count_after = len(style_exprs) + grammar_count_after = len(grammar_exprs) + total_style_after += style_count_after + total_grammar_after += grammar_count_after + + # 保存清理后的表达方式 + save_expressions(chat_id, style_exprs, grammar_exprs) + + print(f"语言风格: {style_count_before} -> {style_count_after}") + print(f"句法特点: {grammar_count_before} -> {grammar_count_after}") + + print("\n清理完成!") + print(f"语言风格总数: {total_style_before} -> {total_style_after}") + print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}") + print(f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/find_similar_expression.py b/scripts/find_similar_expression.py new file mode 100644 index 000000000..21d34e1a8 --- /dev/null +++ b/scripts/find_similar_expression.py @@ -0,0 +1,251 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import json +from typing import List, Dict, Tuple +import numpy as np +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +import glob +import sqlite3 +import re +from datetime import datetime +import random +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config + +def clean_group_name(name: str) -> str: + """清理群组名称,只保留中文和英文字符""" + cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name) + if not cleaned: + cleaned = datetime.now().strftime("%Y%m%d") + return cleaned + +def get_group_name(stream_id: str) -> str: + """从数据库中获取群组名称""" + conn = sqlite3.connect("data/maibot.db") + cursor = conn.cursor() + + cursor.execute( + """ + SELECT group_name, user_nickname, platform + FROM chat_streams + WHERE stream_id = ? + """, + (stream_id,), + ) + + result = cursor.fetchone() + conn.close() + + if result: + group_name, user_nickname, platform = result + if group_name: + return clean_group_name(group_name) + if user_nickname: + return clean_group_name(user_nickname) + if platform: + return clean_group_name(f"{platform}{stream_id[:8]}") + return stream_id + +def load_expressions(chat_id: str) -> List[Dict]: + """加载指定群聊的表达方式""" + style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") + + style_exprs = [] + + if os.path.exists(style_file): + with open(style_file, "r", encoding="utf-8") as f: + style_exprs = json.load(f) + + # 如果表达方式超过10个,随机选择10个 + if len(style_exprs) > 50: + style_exprs = random.sample(style_exprs, 50) + print(f"\n从 {len(style_exprs)} 个表达方式中随机选择了 10 个进行匹配") + + return style_exprs + +def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10) -> List[Tuple[str, str, float]]: + """使用TF-IDF方法找出与输入文本最相似的top_k个表达方式""" + if not expressions: + return [] + + # 准备文本数据 + if mode == "style": + texts = [expr['style'] for expr in expressions] + elif mode == "situation": + texts = [expr['situation'] for expr in expressions] + else: # both + texts = [f"{expr['situation']} {expr['style']}" for expr in expressions] + + texts.append(input_text) # 添加输入文本 + + # 使用TF-IDF向量化 + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform(texts) + + # 计算余弦相似度 + similarity_matrix = cosine_similarity(tfidf_matrix) + + # 获取输入文本的相似度分数(最后一行) + scores = similarity_matrix[-1][:-1] # 排除与自身的相似度 + + # 获取top_k的索引 + top_indices = np.argsort(scores)[::-1][:top_k] + + # 获取相似表达 + similar_exprs = [] + for idx in top_indices: + if scores[idx] > 0: # 只保留有相似度的 + similar_exprs.append(( + expressions[idx]['style'], + expressions[idx]['situation'], + scores[idx] + )) + + return similar_exprs + +async def find_similar_expressions_embedding(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5) -> List[Tuple[str, str, float]]: + """使用嵌入模型找出与输入文本最相似的top_k个表达方式""" + if not expressions: + return [] + + # 准备文本数据 + if mode == "style": + texts = [expr['style'] for expr in expressions] + elif mode == "situation": + texts = [expr['situation'] for expr in expressions] + else: # both + texts = [f"{expr['situation']} {expr['style']}" for expr in expressions] + + # 获取嵌入向量 + llm_request = LLMRequest(global_config.model.embedding) + text_embeddings = [] + for text in texts: + embedding = await llm_request.get_embedding(text) + if embedding: + text_embeddings.append(embedding) + + input_embedding = await llm_request.get_embedding(input_text) + if not input_embedding or not text_embeddings: + return [] + + # 计算余弦相似度 + text_embeddings = np.array(text_embeddings) + similarities = np.dot(text_embeddings, input_embedding) / ( + np.linalg.norm(text_embeddings, axis=1) * np.linalg.norm(input_embedding) + ) + + # 获取top_k的索引 + top_indices = np.argsort(similarities)[::-1][:top_k] + + # 获取相似表达 + similar_exprs = [] + for idx in top_indices: + if similarities[idx] > 0: # 只保留有相似度的 + similar_exprs.append(( + expressions[idx]['style'], + expressions[idx]['situation'], + similarities[idx] + )) + + return similar_exprs + +async def main(): + # 获取所有群聊ID + style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) + chat_ids = [os.path.basename(d) for d in style_dirs] + + if not chat_ids: + print("没有找到任何群聊的表达方式数据") + return + + print("可用的群聊:") + for i, chat_id in enumerate(chat_ids, 1): + group_name = get_group_name(chat_id) + print(f"{i}. {group_name}") + + while True: + try: + choice = int(input("\n请选择要分析的群聊编号 (输入0退出): ")) + if choice == 0: + break + if 1 <= choice <= len(chat_ids): + chat_id = chat_ids[choice-1] + break + print("无效的选择,请重试") + except ValueError: + print("请输入有效的数字") + + if choice == 0: + return + + # 加载表达方式 + style_exprs = load_expressions(chat_id) + + group_name = get_group_name(chat_id) + print(f"\n已选择群聊:{group_name}") + + # 选择匹配模式 + print("\n请选择匹配模式:") + print("1. 匹配表达方式") + print("2. 匹配情景") + print("3. 两者都考虑") + + while True: + try: + mode_choice = int(input("\n请选择匹配模式 (1-3): ")) + if 1 <= mode_choice <= 3: + break + print("无效的选择,请重试") + except ValueError: + print("请输入有效的数字") + + mode_map = { + 1: "style", + 2: "situation", + 3: "both" + } + mode = mode_map[mode_choice] + + # 选择匹配方法 + print("\n请选择匹配方法:") + print("1. TF-IDF方法") + print("2. 嵌入模型方法") + + while True: + try: + method_choice = int(input("\n请选择匹配方法 (1-2): ")) + if 1 <= method_choice <= 2: + break + print("无效的选择,请重试") + except ValueError: + print("请输入有效的数字") + + while True: + input_text = input("\n请输入要匹配的文本(输入q退出): ") + if input_text.lower() == 'q': + break + + if not input_text.strip(): + continue + + if method_choice == 1: + similar_exprs = find_similar_expressions_tfidf(input_text, style_exprs, mode) + else: + similar_exprs = await find_similar_expressions_embedding(input_text, style_exprs, mode) + + if similar_exprs: + print("\n找到以下相似表达:") + for style, situation, score in similar_exprs: + print(f"\n\033[33m表达方式:{style}\033[0m") + print(f"\033[32m对应情景:{situation}\033[0m") + print(f"相似度:{score:.3f}") + print("-" * 20) + else: + print("\n没有找到相似的表达方式") + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/src/chat/focus_chat/planners/actions/reply_action.py b/src/chat/focus_chat/planners/actions/reply_action.py index b6ed69be0..20444ebc7 100644 --- a/src/chat/focus_chat/planners/actions/reply_action.py +++ b/src/chat/focus_chat/planners/actions/reply_action.py @@ -25,7 +25,8 @@ class ReplyAction(BaseAction): action_name: str = "reply" action_description: str = "当你想要参与回复或者聊天" action_parameters: dict[str:str] = { - "target": "如果你要明确回复特定某人的某句话,请在target参数中中指定那句话的原始文本(非必须,仅文本,不包含发送者)(可选)", + "reply_to": "如果是明确回复某个人的发言,请在reply_to参数中指定,格式:(用户名:发言内容),如果不是,reply_to的值设为none", + "emoji": "如果你想用表情包辅助你的回答,请在emoji参数中用文字描述你想要发送的表情包内容,如果没有,值设为空", } action_require: list[str] = [ "你想要闲聊或者随便附和", diff --git a/src/chat/focus_chat/planners/planner_simple.py b/src/chat/focus_chat/planners/planner_simple.py index 518d21126..9fea4cebe 100644 --- a/src/chat/focus_chat/planners/planner_simple.py +++ b/src/chat/focus_chat/planners/planner_simple.py @@ -181,7 +181,7 @@ class ActionPlanner(BasePlanner): prompt = f"{prompt}" llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt) - logger.debug(f"{self.log_prefix}规划器原始提示词: {prompt}") + logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}") logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}") diff --git a/src/chat/focus_chat/replyer/default_replyer.py b/src/chat/focus_chat/replyer/default_replyer.py index a5b4592ad..8c477bed4 100644 --- a/src/chat/focus_chat/replyer/default_replyer.py +++ b/src/chat/focus_chat/replyer/default_replyer.py @@ -23,6 +23,9 @@ from src.chat.focus_chat.expressors.exprssion_learner import expression_learner import random from datetime import datetime import re +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +import numpy as np logger = get_logger("replyer") @@ -32,6 +35,7 @@ def init_prompt(): """ 你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: {style_habbits} + 请你根据情景使用以下句法: {grammar_habbits} @@ -40,15 +44,10 @@ def init_prompt(): {relation_info_block} {time_block} -你现在正在群里聊天,以下是群里正在进行的聊天内容: -{chat_info} - - - -以上是聊天内容,你需要了解聊天记录中的内容 - {chat_target} -{identity},在这聊天中,"{target_message}"引起了你的注意,你想要在群里发言或者回复这条消息。 +{chat_info} +{reply_target_block} +{identity} 你需要使用合适的语言习惯和句法,参考聊天内容,组织一条日常且口语化的回复。注意不要复读你说过的话。 {config_expression_style},请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。 {keywords_reaction_prompt} @@ -61,20 +60,17 @@ def init_prompt(): Prompt( """ -{extra_info_block} - -{time_block} -你现在正在聊天,以下是你和对方正在进行的聊天内容: -{chat_info} - -以上是聊天内容,你需要了解聊天记录中的内容 - -{chat_target} -{identity},在这聊天中,"{target_message}"引起了你的注意,你想要发言或者回复这条消息。 -你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。注意不要复读你说过的话。 -你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: {style_habbits} {grammar_habbits} +{extra_info_block} +{time_block} +{chat_target} +{chat_info} +现在"{sender_name}"说的:{target_message}。引起了你的注意,你想要发言或者回复这条消息。 +{identity}, +你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。注意不要复读你说过的话。 +你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中: + {config_expression_style},请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。 {keywords_reaction_prompt} @@ -155,11 +151,12 @@ class DefaultReplyer: action_data=action_data, ) - # with Timer("选择表情", cycle_timers): - # emoji_keyword = action_data.get("emojis", []) - # emoji_base64 = await self._choose_emoji(emoji_keyword) - # if emoji_base64: - # reply.append(("emoji", emoji_base64)) + with Timer("选择表情", cycle_timers): + emoji_keyword = action_data.get("emoji", "") + if emoji_keyword: + emoji_base64 = await self._choose_emoji(emoji_keyword) + if emoji_base64: + reply.append(("emoji", emoji_base64)) if reply: with Timer("发送消息", cycle_timers): @@ -251,23 +248,22 @@ class DefaultReplyer: # 2. 获取信息捕捉器 info_catcher = info_catcher_manager.get_info_catcher(thinking_id) - - # --- Determine sender_name for private chat --- - sender_name_for_prompt = "某人" # Default for group or if info unavailable - if not self.is_group_chat and self.chat_target_info: - # Prioritize person_name, then nickname - sender_name_for_prompt = ( - self.chat_target_info.get("person_name") - or self.chat_target_info.get("user_nickname") - or sender_name_for_prompt - ) - # --- End determining sender_name --- - - target_message = action_data.get("target", "") + + reply_to = action_data.get("reply_to", "none") + + sender = "" + targer = "" + if ":" in reply_to or ":" in reply_to: + # 使用正则表达式匹配中文或英文冒号 + parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1) + if len(parts) == 2: + sender = parts[0].strip() + targer = parts[1].strip() + identity = action_data.get("identity", "") extra_info_block = action_data.get("extra_info_block", "") relation_info_block = action_data.get("relation_info_block", "") - + # 3. 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 prompt = await self.build_prompt_focus( @@ -277,8 +273,8 @@ class DefaultReplyer: extra_info_block=extra_info_block, relation_info_block=relation_info_block, reason=reason, - sender_name=sender_name_for_prompt, # Pass determined name - target_message=target_message, + sender_name=sender, # Pass determined name + target_message=targer, config_expression_style=global_config.expression.expression_style, ) @@ -340,6 +336,7 @@ class DefaultReplyer: identity, target_message, config_expression_style, + # stuation, ) -> str: is_group_chat = bool(chat_stream.group_info) @@ -368,15 +365,16 @@ class DefaultReplyer: grammar_habbits = [] # 1. learnt_expressions加权随机选3条 if learnt_style_expressions: - weights = [expr["count"] for expr in learnt_style_expressions] - selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 4) - for expr in selected_learnt: + # 使用相似度匹配选择最相似的表达 + similar_exprs = find_similar_expressions(target_message, learnt_style_expressions, 3) + for expr in similar_exprs: + print(f"expr: {expr}") if isinstance(expr, dict) and "situation" in expr and "style" in expr: style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") - # 2. learnt_grammar_expressions加权随机选3条 + # 2. learnt_grammar_expressions加权随机选2条 if learnt_grammar_expressions: weights = [expr["count"] for expr in learnt_grammar_expressions] - selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 4) + selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 2) for expr in selected_learnt: if isinstance(expr, dict) and "situation" in expr and "style" in expr: grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}") @@ -419,6 +417,16 @@ class DefaultReplyer: time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" # logger.debug("开始构建 focus prompt") + + if sender_name: + reply_target_block = f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。" + elif target_message: + reply_target_block = f"现在{target_message}引起了你的注意,你想要在群里发言或者回复这条消息。" + else: + reply_target_block = "现在,你想要在群里发言或者回复消息。" + + + # --- Choose template based on chat type --- if is_group_chat: @@ -436,6 +444,7 @@ class DefaultReplyer: extra_info_block=extra_info_block, relation_info_block=relation_info_block, time_block=time_block, + reply_target_block=reply_target_block, # bot_name=global_config.bot.nickname, # prompt_personality="", # reason=reason, @@ -443,6 +452,7 @@ class DefaultReplyer: keywords_reaction_prompt=keywords_reaction_prompt, identity=identity, target_message=target_message, + sender_name=sender_name, config_expression_style=config_expression_style, ) else: # Private chat @@ -457,6 +467,7 @@ class DefaultReplyer: extra_info_block=extra_info_block, relation_info_block=relation_info_block, time_block=time_block, + reply_target_block=reply_target_block, # bot_name=global_config.bot.nickname, # prompt_personality="", # reason=reason, @@ -464,6 +475,7 @@ class DefaultReplyer: keywords_reaction_prompt=keywords_reaction_prompt, identity=identity, target_message=target_message, + sender_name=sender_name, config_expression_style=config_expression_style, ) @@ -659,4 +671,35 @@ def weighted_sample_no_replacement(items, weights, k) -> list: return selected +def find_similar_expressions(input_text: str, expressions: List[Dict], top_k: int = 3) -> List[Dict]: + """使用TF-IDF和余弦相似度找出与输入文本最相似的top_k个表达方式""" + if not expressions: + return [] + + # 准备文本数据 + texts = [expr['situation'] for expr in expressions] + texts.append(input_text) # 添加输入文本 + + # 使用TF-IDF向量化 + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform(texts) + + # 计算余弦相似度 + similarity_matrix = cosine_similarity(tfidf_matrix) + + # 获取输入文本的相似度分数(最后一行) + scores = similarity_matrix[-1][:-1] # 排除与自身的相似度 + + # 获取top_k的索引 + top_indices = np.argsort(scores)[::-1][:top_k] + + # 获取相似表达 + similar_exprs = [] + for idx in top_indices: + if scores[idx] > 0: # 只保留有相似度的 + similar_exprs.append(expressions[idx]) + + return similar_exprs + + init_prompt()