feat(expression): 添加聊天ID解析功能,支持哈希值和platform:raw_id:type格式
This commit is contained in:
@@ -5,6 +5,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
@@ -25,6 +26,51 @@ from src.config.config import global_config
|
|||||||
logger = get_logger("expression_api")
|
logger = get_logger("expression_api")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 辅助函数 ====================
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chat_id_input(chat_id_input: str) -> str:
|
||||||
|
"""
|
||||||
|
解析聊天ID输入,支持两种格式:
|
||||||
|
1. 哈希值格式(直接返回)
|
||||||
|
2. platform:raw_id:type 格式(转换为哈希值)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id_input: 输入的chat_id,可以是哈希值或 platform:raw_id:type 格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
哈希值格式的chat_id
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> parse_chat_id_input("abc123def456") # 哈希值
|
||||||
|
"abc123def456"
|
||||||
|
>>> parse_chat_id_input("QQ:12345:group") # platform:id:type
|
||||||
|
"..." (转换后的哈希值)
|
||||||
|
"""
|
||||||
|
# 如果包含冒号,认为是 platform:id:type 格式
|
||||||
|
if ":" in chat_id_input:
|
||||||
|
parts = chat_id_input.split(":")
|
||||||
|
if len(parts) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"无效的chat_id格式: {chat_id_input},"
|
||||||
|
"应为 'platform:raw_id:type' 格式,例如 'QQ:12345:group' 或 'QQ:67890:private'"
|
||||||
|
)
|
||||||
|
|
||||||
|
platform, raw_id, chat_type = parts
|
||||||
|
|
||||||
|
if chat_type not in ["group", "private"]:
|
||||||
|
raise ValueError(f"无效的chat_type: {chat_type},只支持 'group' 或 'private'")
|
||||||
|
|
||||||
|
# 使用与 ChatStream.get_stream_id 相同的逻辑生成哈希值
|
||||||
|
is_group = chat_type == "group"
|
||||||
|
components = [platform, raw_id] if is_group else [platform, raw_id, "private"]
|
||||||
|
key = "_".join(components)
|
||||||
|
return hashlib.sha256(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
# 否则认为已经是哈希值
|
||||||
|
return chat_id_input
|
||||||
|
|
||||||
|
|
||||||
# ==================== 查询接口 ====================
|
# ==================== 查询接口 ====================
|
||||||
|
|
||||||
|
|
||||||
@@ -97,8 +143,25 @@ async def get_expression_list(
|
|||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
|
|
||||||
for expr in expressions:
|
for expr in expressions:
|
||||||
# 获取聊天流名称
|
# 获取聊天流名称和详细信息
|
||||||
chat_name = await chat_manager.get_stream_name(expr.chat_id)
|
chat_name = await chat_manager.get_stream_name(expr.chat_id)
|
||||||
|
chat_stream = await chat_manager.get_stream(expr.chat_id)
|
||||||
|
|
||||||
|
# 构建格式化的chat_id信息
|
||||||
|
chat_id_display = expr.chat_id # 默认使用哈希值
|
||||||
|
platform = "未知"
|
||||||
|
raw_id = "未知"
|
||||||
|
chat_type = "未知"
|
||||||
|
|
||||||
|
if chat_stream:
|
||||||
|
platform = chat_stream.platform
|
||||||
|
if chat_stream.group_info:
|
||||||
|
raw_id = chat_stream.group_info.group_id
|
||||||
|
chat_type = "group"
|
||||||
|
elif chat_stream.user_info:
|
||||||
|
raw_id = chat_stream.user_info.user_id
|
||||||
|
chat_type = "private"
|
||||||
|
chat_id_display = f"{platform}:{raw_id}:{chat_type}"
|
||||||
|
|
||||||
expression_list.append(
|
expression_list.append(
|
||||||
{
|
{
|
||||||
@@ -107,7 +170,11 @@ async def get_expression_list(
|
|||||||
"style": expr.style,
|
"style": expr.style,
|
||||||
"count": expr.count,
|
"count": expr.count,
|
||||||
"last_active_time": expr.last_active_time,
|
"last_active_time": expr.last_active_time,
|
||||||
"chat_id": expr.chat_id,
|
"chat_id": expr.chat_id, # 保留哈希值用于后端操作
|
||||||
|
"chat_id_display": chat_id_display, # 显示用的格式化ID
|
||||||
|
"chat_platform": platform,
|
||||||
|
"chat_raw_id": raw_id,
|
||||||
|
"chat_type": chat_type,
|
||||||
"chat_name": chat_name or expr.chat_id,
|
"chat_name": chat_name or expr.chat_id,
|
||||||
"type": expr.type,
|
"type": expr.type,
|
||||||
"create_date": expr.create_date if expr.create_date else expr.last_active_time,
|
"create_date": expr.create_date if expr.create_date else expr.last_active_time,
|
||||||
@@ -155,9 +222,26 @@ async def get_expression_detail(expression_id: int) -> dict[str, Any] | None:
|
|||||||
if not expr:
|
if not expr:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取聊天流名称
|
# 获取聊天流名称和详细信息
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_name = await chat_manager.get_stream_name(expr.chat_id)
|
chat_name = await chat_manager.get_stream_name(expr.chat_id)
|
||||||
|
chat_stream = await chat_manager.get_stream(expr.chat_id)
|
||||||
|
|
||||||
|
# 构建格式化的chat_id信息
|
||||||
|
chat_id_display = expr.chat_id
|
||||||
|
platform = "未知"
|
||||||
|
raw_id = "未知"
|
||||||
|
chat_type = "未知"
|
||||||
|
|
||||||
|
if chat_stream:
|
||||||
|
platform = chat_stream.platform
|
||||||
|
if chat_stream.group_info:
|
||||||
|
raw_id = chat_stream.group_info.group_id
|
||||||
|
chat_type = "group"
|
||||||
|
elif chat_stream.user_info:
|
||||||
|
raw_id = chat_stream.user_info.user_id
|
||||||
|
chat_type = "private"
|
||||||
|
chat_id_display = f"{platform}:{raw_id}:{chat_type}"
|
||||||
|
|
||||||
# 计算使用统计
|
# 计算使用统计
|
||||||
days_since_create = (time.time() - (expr.create_date or expr.last_active_time)) / 86400
|
days_since_create = (time.time() - (expr.create_date or expr.last_active_time)) / 86400
|
||||||
@@ -170,6 +254,10 @@ async def get_expression_detail(expression_id: int) -> dict[str, Any] | None:
|
|||||||
"count": expr.count,
|
"count": expr.count,
|
||||||
"last_active_time": expr.last_active_time,
|
"last_active_time": expr.last_active_time,
|
||||||
"chat_id": expr.chat_id,
|
"chat_id": expr.chat_id,
|
||||||
|
"chat_id_display": chat_id_display,
|
||||||
|
"chat_platform": platform,
|
||||||
|
"chat_raw_id": raw_id,
|
||||||
|
"chat_type": chat_type,
|
||||||
"chat_name": chat_name or expr.chat_id,
|
"chat_name": chat_name or expr.chat_id,
|
||||||
"type": expr.type,
|
"type": expr.type,
|
||||||
"create_date": expr.create_date if expr.create_date else expr.last_active_time,
|
"create_date": expr.create_date if expr.create_date else expr.last_active_time,
|
||||||
@@ -360,10 +448,21 @@ async def create_expression(
|
|||||||
"""
|
"""
|
||||||
手动创建表达方式
|
手动创建表达方式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
situation: 情境描述
|
||||||
|
style: 表达风格
|
||||||
|
chat_id: 聊天流ID,支持两种格式:
|
||||||
|
- 哈希值格式(如: "abc123def456...")
|
||||||
|
- platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private")
|
||||||
|
type: 表达类型
|
||||||
|
count: 权重
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
创建的表达方式详情
|
创建的表达方式详情
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 解析并转换chat_id
|
||||||
|
chat_id_hash = parse_chat_id_input(chat_id)
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
@@ -371,7 +470,7 @@ async def create_expression(
|
|||||||
existing_query = await session.execute(
|
existing_query = await session.execute(
|
||||||
select(Expression).where(
|
select(Expression).where(
|
||||||
and_(
|
and_(
|
||||||
Expression.chat_id == chat_id,
|
Expression.chat_id == chat_id_hash,
|
||||||
Expression.type == type,
|
Expression.type == type,
|
||||||
Expression.situation == situation,
|
Expression.situation == situation,
|
||||||
Expression.style == style,
|
Expression.style == style,
|
||||||
@@ -389,7 +488,7 @@ async def create_expression(
|
|||||||
style=style,
|
style=style,
|
||||||
count=count,
|
count=count,
|
||||||
last_active_time=current_time,
|
last_active_time=current_time,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id_hash,
|
||||||
type=type,
|
type=type,
|
||||||
create_date=current_time,
|
create_date=current_time,
|
||||||
)
|
)
|
||||||
@@ -400,9 +499,9 @@ async def create_expression(
|
|||||||
|
|
||||||
# 清除缓存
|
# 清除缓存
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
await cache.delete(generate_cache_key("chat_expressions", chat_id))
|
await cache.delete(generate_cache_key("chat_expressions", chat_id_hash))
|
||||||
|
|
||||||
logger.info(f"创建表达方式成功: {situation} -> {style}")
|
logger.info(f"创建表达方式成功: {situation} -> {style} (chat_id={chat_id_hash})")
|
||||||
|
|
||||||
return await get_expression_detail(new_expression.id) # type: ignore
|
return await get_expression_detail(new_expression.id) # type: ignore
|
||||||
|
|
||||||
@@ -564,7 +663,9 @@ async def trigger_learning(
|
|||||||
手动触发学习
|
手动触发学习
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_id: 聊天流ID
|
chat_id: 聊天流ID,支持两种格式:
|
||||||
|
- 哈希值格式(如: "abc123def456...")
|
||||||
|
- platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private")
|
||||||
type: 学习类型
|
type: 学习类型
|
||||||
force: 是否强制学习(忽略时间和消息数量限制)
|
force: 是否强制学习(忽略时间和消息数量限制)
|
||||||
|
|
||||||
@@ -577,7 +678,10 @@ async def trigger_learning(
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
learner = ExpressionLearner(chat_id)
|
# 解析并转换chat_id
|
||||||
|
chat_id_hash = parse_chat_id_input(chat_id)
|
||||||
|
|
||||||
|
learner = ExpressionLearner(chat_id_hash)
|
||||||
await learner._initialize_chat_name()
|
await learner._initialize_chat_name()
|
||||||
|
|
||||||
# 检查是否允许学习
|
# 检查是否允许学习
|
||||||
@@ -621,6 +725,11 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
|
|||||||
"""
|
"""
|
||||||
获取学习状态
|
获取学习状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天流ID,支持两种格式:
|
||||||
|
- 哈希值格式(如: "abc123def456...")
|
||||||
|
- platform:raw_id:type 格式(如: "QQ:12345:group" 或 "QQ:67890:private")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
{
|
{
|
||||||
"can_learn": true,
|
"can_learn": true,
|
||||||
@@ -632,7 +741,10 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
learner = ExpressionLearner(chat_id)
|
# 解析并转换chat_id
|
||||||
|
chat_id_hash = parse_chat_id_input(chat_id)
|
||||||
|
|
||||||
|
learner = ExpressionLearner(chat_id_hash)
|
||||||
await learner._initialize_chat_name()
|
await learner._initialize_chat_name()
|
||||||
|
|
||||||
# 获取配置
|
# 获取配置
|
||||||
@@ -640,7 +752,7 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
|
|||||||
raise RuntimeError("Global config is not initialized")
|
raise RuntimeError("Global config is not initialized")
|
||||||
|
|
||||||
_use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
|
_use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(
|
||||||
chat_id
|
chat_id_hash
|
||||||
)
|
)
|
||||||
|
|
||||||
can_learn = learner.can_learn_for_chat()
|
can_learn = learner.can_learn_for_chat()
|
||||||
@@ -655,7 +767,7 @@ async def get_learning_status(chat_id: str) -> dict[str, Any]:
|
|||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
|
|
||||||
recent_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
recent_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id_hash,
|
||||||
timestamp_start=learner.last_learning_time,
|
timestamp_start=learner.last_learning_time,
|
||||||
timestamp_end=time.time(),
|
timestamp_end=time.time(),
|
||||||
filter_bot=True,
|
filter_bot=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user