From 1cd1454289c95a650e4d81ea65987867a8bb0b8b Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sat, 13 Dec 2025 12:05:33 +0800 Subject: [PATCH] =?UTF-8?q?feat(expression):=20=E6=B7=BB=E5=8A=A0=E8=81=8A?= =?UTF-8?q?=E5=A4=A9ID=E8=A7=A3=E6=9E=90=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=93=88=E5=B8=8C=E5=80=BC=E5=92=8Cplatform:?= =?UTF-8?q?raw=5Fid:type=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_system/apis/expression_api.py | 136 +++++++++++++++++++++-- 1 file changed, 124 insertions(+), 12 deletions(-) diff --git a/src/plugin_system/apis/expression_api.py b/src/plugin_system/apis/expression_api.py index c50cf5f1b..b95990629 100644 --- a/src/plugin_system/apis/expression_api.py +++ b/src/plugin_system/apis/expression_api.py @@ -5,6 +5,7 @@ """ import csv +import hashlib import io import math import time @@ -25,6 +26,51 @@ from src.config.config import global_config logger = get_logger("expression_api") +# ==================== 辅助函数 ==================== + + +def parse_chat_id_input(chat_id_input: str) -> str: + """ + 解析聊天ID输入,支持两种格式: + 1. 哈希值格式(直接返回) + 2. platform:raw_id:type 格式(转换为哈希值) + + Args: + chat_id_input: 输入的chat_id,可以是哈希值或 platform:raw_id:type 格式 + + Returns: + 哈希值格式的chat_id + + Examples: + >>> parse_chat_id_input("abc123def456") # 哈希值 + "abc123def456" + >>> parse_chat_id_input("QQ:12345:group") # platform:id:type + "..." (转换后的哈希值) + """ + # 如果包含冒号,认为是 platform:id:type 格式 + if ":" in chat_id_input: + parts = chat_id_input.split(":") + if len(parts) != 3: + raise ValueError( + f"无效的chat_id格式: {chat_id_input}," + "应为 'platform:raw_id:type' 格式,例如 'QQ:12345:group' 或 'QQ:67890:private'" + ) + + platform, raw_id, chat_type = parts + + if chat_type not in ["group", "private"]: + raise ValueError(f"无效的chat_type: {chat_type},只支持 'group' 或 'private'") + + # 使用与 ChatStream.get_stream_id 相同的逻辑生成哈希值 + is_group = chat_type == "group" + components = [platform, raw_id] if is_group else [platform, raw_id, "private"] + key = "_".join(components) + return hashlib.sha256(key.encode()).hexdigest() + + # 否则认为已经是哈希值 + return chat_id_input + + # ==================== 查询接口 ==================== @@ -97,8 +143,25 @@ async def get_expression_list( chat_manager = get_chat_manager() for expr in expressions: - # 获取聊天流名称 + # 获取聊天流名称和详细信息 chat_name = await chat_manager.get_stream_name(expr.chat_id) + chat_stream = await chat_manager.get_stream(expr.chat_id) + + # 构建格式化的chat_id信息 + chat_id_display = expr.chat_id # 默认使用哈希值 + platform = "未知" + raw_id = "未知" + chat_type = "未知" + + if chat_stream: + platform = chat_stream.platform + if chat_stream.group_info: + raw_id = chat_stream.group_info.group_id + chat_type = "group" + elif chat_stream.user_info: + raw_id = chat_stream.user_info.user_id + chat_type = "private" + chat_id_display = f"{platform}:{raw_id}:{chat_type}" expression_list.append( { @@ -107,7 +170,11 @@ async def get_expression_list( "style": expr.style, "count": expr.count, "last_active_time": expr.last_active_time, - "chat_id": expr.chat_id, + "chat_id": expr.chat_id, # 保留哈希值用于后端操作 + "chat_id_display": chat_id_display, # 显示用的格式化ID + "chat_platform": platform, + "chat_raw_id": raw_id, + "chat_type": chat_type, "chat_name": chat_name or expr.chat_id, "type": expr.type, "create_date": expr.create_date if expr.create_date else expr.last_active_time, @@ -155,9 +222,26 @@ async def get_expression_detail(expression_id: int) -> dict[str, Any] | None: if not expr: return None - # 获取聊天流名称 + # 获取聊天流名称和详细信息 chat_manager = get_chat_manager() chat_name = await chat_manager.get_stream_name(expr.chat_id) + chat_stream = await chat_manager.get_stream(expr.chat_id) + + # 构建格式化的chat_id信息 + chat_id_display = expr.chat_id + platform = "未知" + raw_id = "未知" + chat_type = "未知" + + if chat_stream: + platform = chat_stream.platform + if chat_stream.group_info: + raw_id = chat_stream.group_info.group_id + chat_type = "group" + elif chat_stream.user_info: + raw_id = chat_stream.user_info.user_id + chat_type = "private" + chat_id_display = f"{platform}:{raw_id}:{chat_type}" # 计算使用统计 days_since_create = (time.time() - (expr.create_date or expr.last_active_time)) / 86400 @@ -170,6 +254,10 @@ async def get_expression_detail(expression_id: int) -> dict[str, Any] | None: "count": expr.count, "last_active_time": expr.last_active_time, "chat_id": expr.chat_id, + "chat_id_display": chat_id_display, + "chat_platform": platform, + "chat_raw_id": raw_id, + "chat_type": chat_type, "chat_name": chat_name or expr.chat_id, "type": expr.type, "create_date": expr.create_date if expr.create_date else expr.last_active_time, @@ -360,10 +448,21 @@ async def create_expression( """ 手动创建表达方式 + Args: + situation: 情境描述 + style: 表达风格 + chat_id: 聊天流ID,支持两种格式: + - 哈希值格式(如: "abc123def456...") + - platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private") + type: 表达类型 + count: 权重 + Returns: 创建的表达方式详情 """ try: + # 解析并转换chat_id + chat_id_hash = parse_chat_id_input(chat_id) current_time = time.time() async with get_db_session() as session: @@ -371,7 +470,7 @@ async def create_expression( existing_query = await session.execute( select(Expression).where( and_( - Expression.chat_id == chat_id, + Expression.chat_id == chat_id_hash, Expression.type == type, Expression.situation == situation, Expression.style == style, @@ -389,7 +488,7 @@ async def create_expression( style=style, count=count, last_active_time=current_time, - chat_id=chat_id, + chat_id=chat_id_hash, type=type, create_date=current_time, ) @@ -400,9 +499,9 @@ async def create_expression( # 清除缓存 cache = await get_cache() - await cache.delete(generate_cache_key("chat_expressions", chat_id)) + await cache.delete(generate_cache_key("chat_expressions", chat_id_hash)) - logger.info(f"创建表达方式成功: {situation} -> {style}") + logger.info(f"创建表达方式成功: {situation} -> {style} (chat_id={chat_id_hash})") return await get_expression_detail(new_expression.id) # type: ignore @@ -564,7 +663,9 @@ async def trigger_learning( 手动触发学习 Args: - chat_id: 聊天流ID + chat_id: 聊天流ID,支持两种格式: + - 哈希值格式(如: "abc123def456...") + - platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private") type: 学习类型 force: 是否强制学习(忽略时间和消息数量限制) @@ -577,7 +678,10 @@ async def trigger_learning( } """ try: - learner = ExpressionLearner(chat_id) + # 解析并转换chat_id + chat_id_hash = parse_chat_id_input(chat_id) + + learner = ExpressionLearner(chat_id_hash) await learner._initialize_chat_name() # 检查是否允许学习 @@ -621,6 +725,11 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]: """ 获取学习状态 + Args: + chat_id: 聊天流ID,支持两种格式: + - 哈希值格式(如: "abc123def456...") + - platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private") + Returns: { "can_learn": true, @@ -632,7 +741,10 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]: } """ try: - learner = ExpressionLearner(chat_id) + # 解析并转换chat_id + chat_id_hash = parse_chat_id_input(chat_id) + + learner = ExpressionLearner(chat_id_hash) await learner._initialize_chat_name() # 获取配置 @@ -640,7 +752,7 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]: raise RuntimeError("Global config is not initialized") _use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat( - chat_id + chat_id_hash ) can_learn = learner.can_learn_for_chat() @@ -655,7 +767,7 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]: from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive recent_messages = await get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=chat_id, + chat_id=chat_id_hash, timestamp_start=learner.last_learning_time, timestamp_end=time.time(), filter_bot=True,