feat(expression): 添加聊天ID解析功能,支持哈希值和platform:raw_id:type格式
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
@@ -25,6 +26,51 @@ from src.config.config import global_config
|
||||
logger = get_logger("expression_api")
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def parse_chat_id_input(chat_id_input: str) -> str:
|
||||
"""
|
||||
解析聊天ID输入,支持两种格式:
|
||||
1. 哈希值格式(直接返回)
|
||||
2. platform:raw_id:type 格式(转换为哈希值)
|
||||
|
||||
Args:
|
||||
chat_id_input: 输入的chat_id,可以是哈希值或 platform:raw_id:type 格式
|
||||
|
||||
Returns:
|
||||
哈希值格式的chat_id
|
||||
|
||||
Examples:
|
||||
>>> parse_chat_id_input("abc123def456") # 哈希值
|
||||
"abc123def456"
|
||||
>>> parse_chat_id_input("QQ:12345:group") # platform:id:type
|
||||
"..." (转换后的哈希值)
|
||||
"""
|
||||
# 如果包含冒号,认为是 platform:id:type 格式
|
||||
if ":" in chat_id_input:
|
||||
parts = chat_id_input.split(":")
|
||||
if len(parts) != 3:
|
||||
raise ValueError(
|
||||
f"无效的chat_id格式: {chat_id_input},"
|
||||
"应为 'platform:raw_id:type' 格式,例如 'QQ:12345:group' 或 'QQ:67890:private'"
|
||||
)
|
||||
|
||||
platform, raw_id, chat_type = parts
|
||||
|
||||
if chat_type not in ["group", "private"]:
|
||||
raise ValueError(f"无效的chat_type: {chat_type},只支持 'group' 或 'private'")
|
||||
|
||||
# 使用与 ChatStream.get_stream_id 相同的逻辑生成哈希值
|
||||
is_group = chat_type == "group"
|
||||
components = [platform, raw_id] if is_group else [platform, raw_id, "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
# 否则认为已经是哈希值
|
||||
return chat_id_input
|
||||
|
||||
|
||||
# ==================== 查询接口 ====================
|
||||
|
||||
|
||||
@@ -97,8 +143,25 @@ async def get_expression_list(
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
for expr in expressions:
|
||||
# 获取聊天流名称
|
||||
# 获取聊天流名称和详细信息
|
||||
chat_name = await chat_manager.get_stream_name(expr.chat_id)
|
||||
chat_stream = await chat_manager.get_stream(expr.chat_id)
|
||||
|
||||
# 构建格式化的chat_id信息
|
||||
chat_id_display = expr.chat_id # 默认使用哈希值
|
||||
platform = "未知"
|
||||
raw_id = "未知"
|
||||
chat_type = "未知"
|
||||
|
||||
if chat_stream:
|
||||
platform = chat_stream.platform
|
||||
if chat_stream.group_info:
|
||||
raw_id = chat_stream.group_info.group_id
|
||||
chat_type = "group"
|
||||
elif chat_stream.user_info:
|
||||
raw_id = chat_stream.user_info.user_id
|
||||
chat_type = "private"
|
||||
chat_id_display = f"{platform}:{raw_id}:{chat_type}"
|
||||
|
||||
expression_list.append(
|
||||
{
|
||||
@@ -107,7 +170,11 @@ async def get_expression_list(
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"chat_id": expr.chat_id,
|
||||
"chat_id": expr.chat_id, # 保留哈希值用于后端操作
|
||||
"chat_id_display": chat_id_display, # 显示用的格式化ID
|
||||
"chat_platform": platform,
|
||||
"chat_raw_id": raw_id,
|
||||
"chat_type": chat_type,
|
||||
"chat_name": chat_name or expr.chat_id,
|
||||
"type": expr.type,
|
||||
"create_date": expr.create_date if expr.create_date else expr.last_active_time,
|
||||
@@ -155,9 +222,26 @@ async def get_expression_detail(expression_id: int) -> dict[str, Any] | None:
|
||||
if not expr:
|
||||
return None
|
||||
|
||||
# 获取聊天流名称
|
||||
# 获取聊天流名称和详细信息
|
||||
chat_manager = get_chat_manager()
|
||||
chat_name = await chat_manager.get_stream_name(expr.chat_id)
|
||||
chat_stream = await chat_manager.get_stream(expr.chat_id)
|
||||
|
||||
# 构建格式化的chat_id信息
|
||||
chat_id_display = expr.chat_id
|
||||
platform = "未知"
|
||||
raw_id = "未知"
|
||||
chat_type = "未知"
|
||||
|
||||
if chat_stream:
|
||||
platform = chat_stream.platform
|
||||
if chat_stream.group_info:
|
||||
raw_id = chat_stream.group_info.group_id
|
||||
chat_type = "group"
|
||||
elif chat_stream.user_info:
|
||||
raw_id = chat_stream.user_info.user_id
|
||||
chat_type = "private"
|
||||
chat_id_display = f"{platform}:{raw_id}:{chat_type}"
|
||||
|
||||
# 计算使用统计
|
||||
days_since_create = (time.time() - (expr.create_date or expr.last_active_time)) / 86400
|
||||
@@ -170,6 +254,10 @@ async def get_expression_detail(expression_id: int) -> dict[str, Any] | None:
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"chat_id": expr.chat_id,
|
||||
"chat_id_display": chat_id_display,
|
||||
"chat_platform": platform,
|
||||
"chat_raw_id": raw_id,
|
||||
"chat_type": chat_type,
|
||||
"chat_name": chat_name or expr.chat_id,
|
||||
"type": expr.type,
|
||||
"create_date": expr.create_date if expr.create_date else expr.last_active_time,
|
||||
@@ -360,10 +448,21 @@ async def create_expression(
|
||||
"""
|
||||
手动创建表达方式
|
||||
|
||||
Args:
|
||||
situation: 情境描述
|
||||
style: 表达风格
|
||||
chat_id: 聊天流ID,支持两种格式:
|
||||
- 哈希值格式(如: "abc123def456...")
|
||||
- platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private")
|
||||
type: 表达类型
|
||||
count: 权重
|
||||
|
||||
Returns:
|
||||
创建的表达方式详情
|
||||
"""
|
||||
try:
|
||||
# 解析并转换chat_id
|
||||
chat_id_hash = parse_chat_id_input(chat_id)
|
||||
current_time = time.time()
|
||||
|
||||
async with get_db_session() as session:
|
||||
@@ -371,7 +470,7 @@ async def create_expression(
|
||||
existing_query = await session.execute(
|
||||
select(Expression).where(
|
||||
and_(
|
||||
Expression.chat_id == chat_id,
|
||||
Expression.chat_id == chat_id_hash,
|
||||
Expression.type == type,
|
||||
Expression.situation == situation,
|
||||
Expression.style == style,
|
||||
@@ -389,7 +488,7 @@ async def create_expression(
|
||||
style=style,
|
||||
count=count,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
chat_id=chat_id_hash,
|
||||
type=type,
|
||||
create_date=current_time,
|
||||
)
|
||||
@@ -400,9 +499,9 @@ async def create_expression(
|
||||
|
||||
# 清除缓存
|
||||
cache = await get_cache()
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
||||
await cache.delete(generate_cache_key("chat_expressions", chat_id_hash))
|
||||
|
||||
logger.info(f"创建表达方式成功: {situation} -> {style}")
|
||||
logger.info(f"创建表达方式成功: {situation} -> {style} (chat_id={chat_id_hash})")
|
||||
|
||||
return await get_expression_detail(new_expression.id) # type: ignore
|
||||
|
||||
@@ -564,7 +663,9 @@ async def trigger_learning(
|
||||
手动触发学习
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_id: 聊天流ID,支持两种格式:
|
||||
- 哈希值格式(如: "abc123def456...")
|
||||
- platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private")
|
||||
type: 学习类型
|
||||
force: 是否强制学习(忽略时间和消息数量限制)
|
||||
|
||||
@@ -577,7 +678,10 @@ async def trigger_learning(
|
||||
}
|
||||
"""
|
||||
try:
|
||||
learner = ExpressionLearner(chat_id)
|
||||
# 解析并转换chat_id
|
||||
chat_id_hash = parse_chat_id_input(chat_id)
|
||||
|
||||
learner = ExpressionLearner(chat_id_hash)
|
||||
await learner._initialize_chat_name()
|
||||
|
||||
# 检查是否允许学习
|
||||
@@ -621,6 +725,11 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
获取学习状态
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID,支持两种格式:
|
||||
- 哈希值格式(如: "abc123def456...")
|
||||
- platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private")
|
||||
|
||||
Returns:
|
||||
{
|
||||
"can_learn": true,
|
||||
@@ -632,7 +741,10 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
|
||||
}
|
||||
"""
|
||||
try:
|
||||
learner = ExpressionLearner(chat_id)
|
||||
# 解析并转换chat_id
|
||||
chat_id_hash = parse_chat_id_input(chat_id)
|
||||
|
||||
learner = ExpressionLearner(chat_id_hash)
|
||||
await learner._initialize_chat_name()
|
||||
|
||||
# 获取配置
|
||||
@@ -640,7 +752,7 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
_use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||
chat_id
|
||||
chat_id_hash
|
||||
)
|
||||
|
||||
can_learn = learner.can_learn_for_chat()
|
||||
@@ -655,7 +767,7 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
|
||||
recent_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=chat_id,
|
||||
chat_id=chat_id_hash,
|
||||
timestamp_start=learner.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
filter_bot=True,
|
||||
|
||||
Reference in New Issue
Block a user