feat(expression): 添加聊天ID解析功能,支持哈希值和platform:raw_id:type格式

This commit is contained in:
minecraft1024a
2025-12-13 12:05:33 +08:00
parent 7d8ce8b246
commit 1cd1454289

View File

@@ -5,6 +5,7 @@
""" """
import csv import csv
import hashlib
import io import io
import math import math
import time import time
@@ -25,6 +26,51 @@ from src.config.config import global_config
logger = get_logger("expression_api") logger = get_logger("expression_api")
# ==================== 辅助函数 ====================
def parse_chat_id_input(chat_id_input: str) -> str:
"""
解析聊天ID输入支持两种格式
1. 哈希值格式(直接返回)
2. platform:raw_id:type 格式(转换为哈希值)
Args:
chat_id_input: 输入的chat_id可以是哈希值或 platform:raw_id:type 格式
Returns:
哈希值格式的chat_id
Examples:
>>> parse_chat_id_input("abc123def456") # 哈希值
"abc123def456"
>>> parse_chat_id_input("QQ:12345:group") # platform:id:type
"..." (转换后的哈希值)
"""
# 如果包含冒号,认为是 platform:id:type 格式
if ":" in chat_id_input:
parts = chat_id_input.split(":")
if len(parts) != 3:
raise ValueError(
f"无效的chat_id格式: {chat_id_input}"
"应为 'platform:raw_id:type' 格式,例如 'QQ:12345:group''QQ:67890:private'"
)
platform, raw_id, chat_type = parts
if chat_type not in ["group", "private"]:
raise ValueError(f"无效的chat_type: {chat_type},只支持 'group''private'")
# 使用与 ChatStream.get_stream_id 相同的逻辑生成哈希值
is_group = chat_type == "group"
components = [platform, raw_id] if is_group else [platform, raw_id, "private"]
key = "_".join(components)
return hashlib.sha256(key.encode()).hexdigest()
# 否则认为已经是哈希值
return chat_id_input
# ==================== 查询接口 ==================== # ==================== 查询接口 ====================
@@ -97,8 +143,25 @@ async def get_expression_list(
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
for expr in expressions: for expr in expressions:
# 获取聊天流名称 # 获取聊天流名称和详细信息
chat_name = await chat_manager.get_stream_name(expr.chat_id) chat_name = await chat_manager.get_stream_name(expr.chat_id)
chat_stream = await chat_manager.get_stream(expr.chat_id)
# 构建格式化的chat_id信息
chat_id_display = expr.chat_id # 默认使用哈希值
platform = "未知"
raw_id = "未知"
chat_type = "未知"
if chat_stream:
platform = chat_stream.platform
if chat_stream.group_info:
raw_id = chat_stream.group_info.group_id
chat_type = "group"
elif chat_stream.user_info:
raw_id = chat_stream.user_info.user_id
chat_type = "private"
chat_id_display = f"{platform}:{raw_id}:{chat_type}"
expression_list.append( expression_list.append(
{ {
@@ -107,7 +170,11 @@ async def get_expression_list(
"style": expr.style, "style": expr.style,
"count": expr.count, "count": expr.count,
"last_active_time": expr.last_active_time, "last_active_time": expr.last_active_time,
"chat_id": expr.chat_id, "chat_id": expr.chat_id, # 保留哈希值用于后端操作
"chat_id_display": chat_id_display, # 显示用的格式化ID
"chat_platform": platform,
"chat_raw_id": raw_id,
"chat_type": chat_type,
"chat_name": chat_name or expr.chat_id, "chat_name": chat_name or expr.chat_id,
"type": expr.type, "type": expr.type,
"create_date": expr.create_date if expr.create_date else expr.last_active_time, "create_date": expr.create_date if expr.create_date else expr.last_active_time,
@@ -155,9 +222,26 @@ async def get_expression_detail(expression_id: int) -> dict[str, Any] | None:
if not expr: if not expr:
return None return None
# 获取聊天流名称 # 获取聊天流名称和详细信息
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_name = await chat_manager.get_stream_name(expr.chat_id) chat_name = await chat_manager.get_stream_name(expr.chat_id)
chat_stream = await chat_manager.get_stream(expr.chat_id)
# 构建格式化的chat_id信息
chat_id_display = expr.chat_id
platform = "未知"
raw_id = "未知"
chat_type = "未知"
if chat_stream:
platform = chat_stream.platform
if chat_stream.group_info:
raw_id = chat_stream.group_info.group_id
chat_type = "group"
elif chat_stream.user_info:
raw_id = chat_stream.user_info.user_id
chat_type = "private"
chat_id_display = f"{platform}:{raw_id}:{chat_type}"
# 计算使用统计 # 计算使用统计
days_since_create = (time.time() - (expr.create_date or expr.last_active_time)) / 86400 days_since_create = (time.time() - (expr.create_date or expr.last_active_time)) / 86400
@@ -170,6 +254,10 @@ async def get_expression_detail(expression_id: int) -> dict[str, Any] | None:
"count": expr.count, "count": expr.count,
"last_active_time": expr.last_active_time, "last_active_time": expr.last_active_time,
"chat_id": expr.chat_id, "chat_id": expr.chat_id,
"chat_id_display": chat_id_display,
"chat_platform": platform,
"chat_raw_id": raw_id,
"chat_type": chat_type,
"chat_name": chat_name or expr.chat_id, "chat_name": chat_name or expr.chat_id,
"type": expr.type, "type": expr.type,
"create_date": expr.create_date if expr.create_date else expr.last_active_time, "create_date": expr.create_date if expr.create_date else expr.last_active_time,
@@ -360,10 +448,21 @@ async def create_expression(
""" """
手动创建表达方式 手动创建表达方式
Args:
situation: 情境描述
style: 表达风格
chat_id: 聊天流ID支持两种格式
- 哈希值格式(如: "abc123def456..."
- platform:raw_id:type 格式(如: "QQ:12345:group""QQ:67890:private"
type: 表达类型
count: 权重
Returns: Returns:
创建的表达方式详情 创建的表达方式详情
""" """
try: try:
# 解析并转换chat_id
chat_id_hash = parse_chat_id_input(chat_id)
current_time = time.time() current_time = time.time()
async with get_db_session() as session: async with get_db_session() as session:
@@ -371,7 +470,7 @@ async def create_expression(
existing_query = await session.execute( existing_query = await session.execute(
select(Expression).where( select(Expression).where(
and_( and_(
Expression.chat_id == chat_id, Expression.chat_id == chat_id_hash,
Expression.type == type, Expression.type == type,
Expression.situation == situation, Expression.situation == situation,
Expression.style == style, Expression.style == style,
@@ -389,7 +488,7 @@ async def create_expression(
style=style, style=style,
count=count, count=count,
last_active_time=current_time, last_active_time=current_time,
chat_id=chat_id, chat_id=chat_id_hash,
type=type, type=type,
create_date=current_time, create_date=current_time,
) )
@@ -400,9 +499,9 @@ async def create_expression(
# 清除缓存 # 清除缓存
cache = await get_cache() cache = await get_cache()
await cache.delete(generate_cache_key("chat_expressions", chat_id)) await cache.delete(generate_cache_key("chat_expressions", chat_id_hash))
logger.info(f"创建表达方式成功: {situation} -> {style}") logger.info(f"创建表达方式成功: {situation} -> {style} (chat_id={chat_id_hash})")
return await get_expression_detail(new_expression.id) # type: ignore return await get_expression_detail(new_expression.id) # type: ignore
@@ -564,7 +663,9 @@ async def trigger_learning(
手动触发学习 手动触发学习
Args: Args:
chat_id: 聊天流ID chat_id: 聊天流ID,支持两种格式:
- 哈希值格式(如: "abc123def456..."
- platform:raw_id:type 格式(如: "QQ:12345:group""QQ:67890:private"
type: 学习类型 type: 学习类型
force: 是否强制学习(忽略时间和消息数量限制) force: 是否强制学习(忽略时间和消息数量限制)
@@ -577,7 +678,10 @@ async def trigger_learning(
} }
""" """
try: try:
learner = ExpressionLearner(chat_id) # 解析并转换chat_id
chat_id_hash = parse_chat_id_input(chat_id)
learner = ExpressionLearner(chat_id_hash)
await learner._initialize_chat_name() await learner._initialize_chat_name()
# 检查是否允许学习 # 检查是否允许学习
@@ -621,6 +725,11 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
""" """
获取学习状态 获取学习状态
Args:
chat_id: 聊天流ID支持两种格式
- 哈希值格式(如: "abc123def456..."
- platform:raw_id:type 格式(如: "QQ:12345:group""QQ:67890:private"
Returns: Returns:
{ {
"can_learn": true, "can_learn": true,
@@ -632,7 +741,10 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
} }
""" """
try: try:
learner = ExpressionLearner(chat_id) # 解析并转换chat_id
chat_id_hash = parse_chat_id_input(chat_id)
learner = ExpressionLearner(chat_id_hash)
await learner._initialize_chat_name() await learner._initialize_chat_name()
# 获取配置 # 获取配置
@@ -640,7 +752,7 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
raise RuntimeError("Global config is not initialized") raise RuntimeError("Global config is not initialized")
_use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat( _use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
chat_id chat_id_hash
) )
can_learn = learner.can_learn_for_chat() can_learn = learner.can_learn_for_chat()
@@ -655,7 +767,7 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
recent_messages = await get_raw_msg_by_timestamp_with_chat_inclusive( recent_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=chat_id, chat_id=chat_id_hash,
timestamp_start=learner.last_learning_time, timestamp_start=learner.last_learning_time,
timestamp_end=time.time(), timestamp_end=time.time(),
filter_bot=True, filter_bot=True,