style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 19:38:39 +08:00
committed by Windpicker-owo
parent e7aaafde2f
commit 00ba07e0e1
111 changed files with 2343 additions and 2316 deletions

View File

@@ -13,6 +13,7 @@ from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_m
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, and_
from src.common.logger import get_logger
logger = get_logger("chat_message_builder")
install(extra_lines=3)
@@ -274,21 +275,52 @@ async def get_actions_by_timestamp_with_chat(
async with get_db_session() as session:
if limit > 0:
result = await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end,
)
)
.order_by(ActionRecords.time.desc())
.limit(limit)
)
actions = list(result.scalars())
actions_result = []
for action in reversed(actions):
action_dict = {
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
actions_result.append(action_dict)
else: # earliest
result = await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time >= timestamp_start,
ActionRecords.time <= timestamp_end,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
.order_by(ActionRecords.time.desc())
.order_by(ActionRecords.time.asc())
.limit(limit)
)
actions = list(result.scalars())
actions_result = []
for action in reversed(actions):
for action in actions:
action_dict = {
"id": action.id,
"action_id": action.action_id,
@@ -303,37 +335,6 @@ async def get_actions_by_timestamp_with_chat(
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
actions_result.append(action_dict)
else: # earliest
result = await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.chat_id == chat_id,
ActionRecords.time > timestamp_start,
ActionRecords.time < timestamp_end,
)
)
.order_by(ActionRecords.time.asc())
.limit(limit)
)
actions = list(result.scalars())
actions_result = []
for action in actions:
action_dict = {
"id": action.id,
"action_id": action.action_id,
"time": action.time,
"action_name": action.action_name,
"action_data": action.action_data,
"action_done": action.action_done,
"action_build_into_prompt": action.action_build_into_prompt,
"action_prompt_display": action.action_prompt_display,
"chat_id": action.chat_id,
"chat_info_stream_id": action.chat_info_stream_id,
"chat_info_platform": action.chat_info_platform,
}
actions_result.append(action_dict)
else:
result = await session.execute(
select(ActionRecords)
@@ -457,7 +458,9 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_chat(
chat_id: str, timestamp: float, limit: int = 0
) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -466,7 +469,9 @@ async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float,
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -475,7 +480,9 @@ async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids:
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
async def num_new_messages_since(
chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None
) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -830,7 +837,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
async with get_db_session() as session:
result = await session.execute(select(Images).where(Images.image_id == pic_id))
image = result.scalar_one_or_none()
if image and hasattr(image, 'description') and image.description:
if image and hasattr(image, "description") and image.description:
description = image.description
except Exception as e:
# 如果查询失败,保持默认描述
@@ -1017,24 +1024,29 @@ async def build_readable_messages(
async with get_db_session() as session:
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
actions_in_range = (
await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time,
ActionRecords.time <= max_time,
ActionRecords.chat_id == chat_id,
)
)
.order_by(ActionRecords.time)
)
.order_by(ActionRecords.time)
)).scalars()
).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = (await session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)).scalars()
action_after_latest = (
await session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
).scalars()
# 合并两部分动作记录,并转为 dict避免 DetachedInstanceError
actions = [
@@ -1225,9 +1237,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
except Exception:
return "?"
content = await replace_user_references_async(
content, platform, anon_name_resolver, replace_bot_name=False
)
content = await replace_user_references_async(content, platform, anon_name_resolver, replace_bot_name=False)
header = f"{anon_name}"
output_lines.append(header)

View File

@@ -17,7 +17,7 @@ MEMORY_TYPE_CHINESE_MAPPING = {
"goal": "目标计划",
"experience": "经验教训",
"contextual": "上下文信息",
"unknown": "未知"
"unknown": "未知",
}
# 置信度等级到中文标签的映射表
@@ -30,7 +30,7 @@ CONFIDENCE_LEVEL_CHINESE_MAPPING = {
"MEDIUM": "中等置信度",
"HIGH": "高置信度",
"VERIFIED": "已验证",
"unknown": "未知"
"unknown": "未知",
}
# 重要性等级到中文标签的映射表
@@ -43,7 +43,7 @@ IMPORTANCE_LEVEL_CHINESE_MAPPING = {
"NORMAL": "一般重要性",
"HIGH": "高重要性",
"CRITICAL": "关键重要性",
"unknown": "未知"
"unknown": "未知",
}
@@ -69,7 +69,7 @@ def get_confidence_level_chinese_label(level) -> str:
str: 对应的中文标签,如果找不到则返回"未知"
"""
# 处理枚举实例
if hasattr(level, 'value'):
if hasattr(level, "value"):
level = level.value
# 处理数字
@@ -94,7 +94,7 @@ def get_importance_level_chinese_label(level) -> str:
str: 对应的中文标签,如果找不到则返回"未知"
"""
# 处理枚举实例
if hasattr(level, 'value'):
if hasattr(level, "value"):
level = level.value
# 处理数字
@@ -106,4 +106,4 @@ def get_importance_level_chinese_label(level) -> str:
level_upper = level.upper()
return IMPORTANCE_LEVEL_CHINESE_MAPPING.get(level_upper, "未知")
return "未知"
return "未知"

View File

@@ -381,12 +381,12 @@ class Prompt:
# 性能优化 - 为不同任务设置不同的超时时间
task_timeouts = {
"memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建
"tool_info": 15.0, # 工具信息
"relation_info": 10.0, # 关系信息
"knowledge_info": 10.0, # 知识库查询
"cross_context": 10.0, # 上下文处理
"expression_habits": 10.0, # 表达习惯
"memory_block": 15.0, # 记忆系统 - 降低超时时间,鼓励预构建
"tool_info": 15.0, # 工具信息
"relation_info": 10.0, # 关系信息
"knowledge_info": 10.0, # 知识库查询
"cross_context": 10.0, # 上下文处理
"expression_habits": 10.0, # 表达习惯
}
# 分别处理每个任务,避免慢任务影响快任务
@@ -563,7 +563,7 @@ class Prompt:
),
enhanced_memory_activator.get_instant_memory(
target_message=self.parameters.target, chat_id=self.parameters.chat_id
)
),
]
try:
@@ -606,26 +606,27 @@ class Prompt:
"opinion": "opinion",
"personal_fact": "personal_fact",
"preference": "preference",
"event": "event"
"event": "event",
}
mapped_type = memory_type_mapping.get(topic, "personal_fact")
formatted_memories.append({
"display": display_text,
"memory_type": mapped_type,
"metadata": {
"confidence": memory.get("confidence", "未知"),
"importance": memory.get("importance", "一般"),
"timestamp": memory.get("timestamp", ""),
"source": memory.get("source", "unknown"),
"relevance_score": memory.get("relevance_score", 0.0)
formatted_memories.append(
{
"display": display_text,
"memory_type": mapped_type,
"metadata": {
"confidence": memory.get("confidence", "未知"),
"importance": memory.get("importance", "一般"),
"timestamp": memory.get("timestamp", ""),
"source": memory.get("source", "unknown"),
"relevance_score": memory.get("relevance_score", 0.0),
},
}
})
)
# 使用方括号格式格式化记忆
memory_block = format_memories_bracket_style(
formatted_memories,
query_context=self.parameters.target
formatted_memories, query_context=self.parameters.target
)
except Exception as e:
logger.warning(f"记忆格式化失败,使用简化格式: {e}")
@@ -833,7 +834,8 @@ class Prompt:
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
"chat_scene": self.parameters.chat_scene
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
@@ -860,7 +862,8 @@ class Prompt:
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
"chat_scene": self.parameters.chat_scene
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -305,11 +305,14 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录
query_start_time = collect_period[-1][1]
records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp",
) or []
records = (
await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": query_start_time}},
order_by="-timestamp",
)
or []
)
for record in records:
if not isinstance(record, dict):
@@ -401,7 +404,9 @@ class StatisticOutputTask(AsyncTask):
return stats
@staticmethod
async def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]:
async def _collect_online_time_for_period(
collect_period: List[Tuple[str, datetime]], now: datetime
) -> Dict[str, Any]:
"""
收集指定时间段的在线时间统计数据
@@ -420,11 +425,14 @@ class StatisticOutputTask(AsyncTask):
}
query_start_time = collect_period[-1][1]
records = await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp",
) or []
records = (
await db_get(
model_class=OnlineTime,
filters={"end_timestamp": {"$gte": query_start_time}},
order_by="-end_timestamp",
)
or []
)
for record in records:
if not isinstance(record, dict):
@@ -476,11 +484,14 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
records = await db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time",
) or []
records = (
await db_get(
model_class=Messages,
filters={"time": {"$gte": query_start_timestamp}},
order_by="-time",
)
or []
)
for message in records:
if not isinstance(message, dict):
@@ -1038,11 +1049,14 @@ class StatisticOutputTask(AsyncTask):
interval_seconds = interval_minutes * 60
# 单次查询 LLMUsage
llm_records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time}},
order_by="-timestamp",
) or []
llm_records = (
await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time}},
order_by="-timestamp",
)
or []
)
for record in llm_records:
if not isinstance(record, dict) or not record.get("timestamp"):
continue
@@ -1068,11 +1082,14 @@ class StatisticOutputTask(AsyncTask):
cost_by_module[module_name][idx] += cost
# 单次查询 Messages
msg_records = await db_get(
model_class=Messages,
filters={"time": {"$gte": start_time.timestamp()}},
order_by="-time",
) or []
msg_records = (
await db_get(
model_class=Messages,
filters={"time": {"$gte": start_time.timestamp()}},
order_by="-time",
)
or []
)
for msg in msg_records:
if not isinstance(msg, dict) or not msg.get("time"):
continue
@@ -1375,4 +1392,4 @@ class StatisticOutputTask(AsyncTask):
}});
</script>
</div>
"""
"""

View File

@@ -675,7 +675,6 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
if loop.is_running():
# 如果事件循环在运行,从其他线程提交并等待结果
try:
from concurrent.futures import TimeoutError
fut = asyncio.run_coroutine_threadsafe(
person_info_manager.get_value(person_id, "person_name"), loop

View File

@@ -81,14 +81,16 @@ class ImageManager:
"""
try:
async with get_db_session() as session:
record = (await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
record = (
await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
)
)).scalar()
).scalar()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
@@ -107,14 +109,16 @@ class ImageManager:
current_timestamp = time.time()
async with get_db_session() as session:
# 查找现有记录
existing = (await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
existing = (
await session.execute(
select(ImageDescriptions).where(
and_(
ImageDescriptions.image_description_hash == image_hash,
ImageDescriptions.type == description_type,
)
)
)
)).scalar()
).scalar()
if existing:
# 更新现有记录
@@ -262,9 +266,11 @@ class ImageManager:
from src.common.database.sqlalchemy_models import get_db_session
async with get_db_session() as session:
existing_img = (await session.execute(
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
)).scalar()
existing_img = (
await session.execute(
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
)
).scalar()
if existing_img:
existing_img.path = file_path

View File

@@ -35,7 +35,7 @@ logger = get_logger("utils_video")
# Rust模块可用性检测
RUST_VIDEO_AVAILABLE = False
try:
import rust_video # pyright: ignore[reportMissingImports]
import rust_video # pyright: ignore[reportMissingImports]
RUST_VIDEO_AVAILABLE = True
logger.info("✅ Rust 视频处理模块加载成功")
@@ -222,7 +222,7 @@ class VideoAnalyzer:
return None
async def _store_video_result(
self, video_hash: str, description: str, metadata: Optional[Dict] = None
self, video_hash: str, description: str, metadata: Optional[Dict] = None
) -> Optional[Videos]:
"""存储视频分析结果到数据库"""
# 检查描述是否为错误信息,如果是则不保存