三次修改

This commit is contained in:
tt-P607
2025-09-20 02:21:53 +08:00
committed by Windpicker-owo
parent 635311bc80
commit aba4f1a947
20 changed files with 923 additions and 479 deletions

View File

@@ -361,7 +361,7 @@ class HeartFChatting:
# 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收
filter_command_flag = not (is_sleeping or is_in_insomnia)
recent_messages = message_api.get_messages_by_time_in_chat(
recent_messages = await message_api.get_messages_by_time_in_chat(
chat_id=self.context.stream_id,
start_time=self.context.last_read_time,
end_time=time.time(),

View File

@@ -149,7 +149,7 @@ class MaiEmoji:
# --- 数据库操作 ---
try:
# 准备数据库记录 for emoji collection
with get_db_session() as session:
async with get_db_session() as session:
emotion_str = ",".join(self.emotion) if self.emotion else ""
emoji = Emoji(
@@ -167,7 +167,7 @@ class MaiEmoji:
last_used_time=self.last_used_time,
)
session.add(emoji)
session.commit()
await session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -203,17 +203,17 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
with get_db_session() as session:
will_delete_emoji = session.execute(
select(Emoji).where(Emoji.emoji_hash == self.hash)
async with get_db_session() as session:
will_delete_emoji = (
await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
).scalar_one_or_none()
if will_delete_emoji is None:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
result = 0
else:
session.delete(will_delete_emoji)
result = 1 # Successfully deleted one record
session.commit()
await session.delete(will_delete_emoji)
result = 1
await session.commit()
except Exception as e:
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
result = 0
@@ -424,17 +424,19 @@ class EmojiManager:
# if not self._initialized:
# raise RuntimeError("EmojiManager not initialized")
def record_usage(self, emoji_hash: str) -> None:
async def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
with get_db_session() as session:
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
async with get_db_session() as session:
emoji_update = (
await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
).scalar_one_or_none()
if emoji_update is None:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
else:
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
session.commit()
emoji_update.last_used_time = time.time()
await session.commit()
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -658,10 +660,11 @@ class EmojiManager:
async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
with get_db_session() as session:
async with get_db_session() as session:
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = session.execute(select(Emoji)).scalars().all()
result = await session.execute(select(Emoji))
emoji_instances = result.scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量
@@ -687,14 +690,16 @@ class EmojiManager:
list[MaiEmoji]: 表情包对象列表
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
if emoji_hash:
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
query = result.scalars().all()
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = session.execute(select(Emoji)).scalars().all()
result = await session.execute(select(Emoji))
query = result.scalars().all()
emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
@@ -771,10 +776,11 @@ class EmojiManager:
# 如果内存中没有,从数据库查找
try:
with get_db_session() as session:
emoji_record = session.execute(
async with get_db_session() as session:
result = await session.execute(
select(Emoji).where(Emoji.emoji_hash == emoji_hash)
).scalar_one_or_none()
)
emoji_record = result.scalar_one_or_none()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description
@@ -937,10 +943,13 @@ class EmojiManager:
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
existing_description = None
try:
with get_db_session() as session:
existing_image = session.query(Images).filter(
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
).one_or_none()
async with get_db_session() as session:
result = await session.execute(
select(Images).filter(
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
)
)
existing_image = result.scalar_one_or_none()
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -136,18 +136,18 @@ class ExpressionSelector:
return related_chat_ids if related_chat_ids else [chat_id]
def get_random_expressions(
self, chat_id: str, total_num: int
) -> List[Dict[str, Any]]:
async def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
with get_db_session() as session:
async with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式
style_query = session.execute(
style_query = await session.execute(
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style"))
)
grammar_query = session.execute(
grammar_query = await session.execute(
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar"))
)
@@ -193,7 +193,7 @@ class ExpressionSelector:
return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
async def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
@@ -210,26 +210,27 @@ class ExpressionSelector:
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
with get_db_session() as session:
query = session.execute(
async with get_db_session() as session:
query = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
).scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
session.commit()
query = query.scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
await session.commit()
async def select_suitable_expressions_llm(
self,
@@ -246,12 +247,8 @@ class ExpressionSelector:
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 1. 获取20个随机表达方式(现在按权重抽取)
style_exprs = self.get_random_expressions(chat_id, 10)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 1. 获取35个随机表达方式(现在按权重抽取)
style_exprs, grammar_exprs = await self.get_random_expressions(chat_id, 30, 0.5, 0.5)
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
@@ -326,7 +323,7 @@ class ExpressionSelector:
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.006)
await self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions, selected_ids

View File

@@ -267,8 +267,8 @@ class Hippocampus:
self.entorhinal_cortex = EntorhinalCortex(self)
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify")
# self.entorhinal_cortex.sync_memory_from_db() # 改为异步启动
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.small")
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
@@ -877,7 +877,7 @@ class EntorhinalCortex:
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
def get_memory_sample(self):
async def get_memory_sample(self):
"""从数据库获取记忆样本"""
# 硬编码:每条消息最大记忆次数
max_memorized_time_per_msg = 2
@@ -900,7 +900,7 @@ class EntorhinalCortex:
logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = []
for timestamp in timestamps:
if messages := self.random_get_msg_snippet(
if messages := await self.random_get_msg_snippet(
timestamp,
global_config.memory.memory_build_sample_length,
max_memorized_time_per_msg,
@@ -914,7 +914,9 @@ class EntorhinalCortex:
return chat_samples
@staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
async def random_get_msg_snippet(
target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int
) -> list | None:
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟
@@ -952,13 +954,13 @@ class EntorhinalCortex:
for message in messages:
# 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
with get_db_session() as session:
session.execute(
async with get_db_session() as session:
await session.execute(
update(Messages)
.where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1)
)
session.commit()
await session.commit()
return messages # 直接返回原始的消息列表
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
@@ -972,8 +974,8 @@ class EntorhinalCortex:
current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点
with get_db_session() as session:
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
async with get_db_session() as session:
db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据
@@ -1043,24 +1045,24 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
await session.execute(insert(GraphNodes), batch)
if nodes_to_update:
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i : i + batch_size]
for node_data in batch:
session.execute(
await session.execute(
update(GraphNodes)
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
)
if nodes_to_delete:
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
# 处理边的信息
db_edges = list(session.execute(select(GraphEdges)).scalars())
db_edges = list((await session.execute(select(GraphEdges))).scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -1112,14 +1114,14 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
await session.execute(insert(GraphEdges), batch)
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i : i + batch_size]
for edge_data in batch:
session.execute(
await session.execute(
update(GraphEdges)
.where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
@@ -1129,12 +1131,12 @@ class EntorhinalCortex:
if edges_to_delete:
for source, target in edges_to_delete:
session.execute(
await session.execute(
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
)
# 提交事务
session.commit()
await session.commit()
end_time = time.time()
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}")
@@ -1146,10 +1148,10 @@ class EntorhinalCortex:
logger.info("[数据库] 开始重新同步所有记忆数据...")
# 清空数据库
with get_db_session() as session:
async with get_db_session() as session:
clear_start = time.time()
session.execute(delete(GraphNodes))
session.execute(delete(GraphEdges))
await session.execute(delete(GraphNodes))
await session.execute(delete(GraphEdges))
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1208,7 +1210,7 @@ class EntorhinalCortex:
batch_size = 500 # 增加批量大小
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
await session.execute(insert(GraphNodes), batch)
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
@@ -1219,8 +1221,8 @@ class EntorhinalCortex:
batch_size = 500 # 增加批量大小
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
session.commit()
await session.execute(insert(GraphEdges), batch)
await session.commit()
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")
@@ -1229,7 +1231,7 @@ class EntorhinalCortex:
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}")
logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边")
def sync_memory_from_db(self):
async def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
current_time = datetime.datetime.now().timestamp()
need_update = False
@@ -1243,8 +1245,8 @@ class EntorhinalCortex:
skipped_nodes = 0
# 从数据库加载所有节点
with get_db_session() as session:
nodes = list(session.execute(select(GraphNodes)).scalars())
async with get_db_session() as session:
nodes = list((await session.execute(select(GraphNodes))).scalars())
for node in nodes:
concept = node.concept
try:
@@ -1262,7 +1264,9 @@ class EntorhinalCortex:
if not node.last_modified:
update_data["last_modified"] = current_time
session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data))
await session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
)
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time
@@ -1277,7 +1281,7 @@ class EntorhinalCortex:
continue
# 从数据库加载所有边
edges = list(session.execute(select(GraphEdges)).scalars())
edges = list((await session.execute(select(GraphEdges))).scalars())
for edge in edges:
source = edge.source
target = edge.target
@@ -1293,7 +1297,7 @@ class EntorhinalCortex:
if not edge.last_modified:
update_data["last_modified"] = current_time
session.execute(
await session.execute(
update(GraphEdges)
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
@@ -1308,7 +1312,7 @@ class EntorhinalCortex:
self.memory_graph.G.add_edge(
source, target, strength=strength, created_time=created_time, last_modified=last_modified
)
session.commit()
await session.commit()
if need_update:
logger.info("[数据库] 已为缺失的时间字段进行补充")
@@ -1348,7 +1352,7 @@ class ParahippocampalGyrus:
# 1. 使用 build_readable_messages 生成格式化文本
# build_readable_messages 只返回一个字符串,不需要解包
input_text = build_readable_messages(
input_text = await build_readable_messages(
messages,
merge_messages=True, # 合并连续消息
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
@@ -1436,7 +1440,7 @@ class ParahippocampalGyrus:
# sourcery skip: merge-list-appends-into-extend
logger.info("------------------------------------开始构建记忆--------------------------------------")
start_time = time.time()
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample()
all_added_nodes = []
all_connected_nodes = []
all_added_edges = []
@@ -1697,7 +1701,7 @@ class HippocampusManager:
return self._hippocampus
self._hippocampus = Hippocampus()
self._hippocampus.initialize()
# self._hippocampus.initialize() # 改为异步启动
self._initialized = True
# 输出记忆图统计信息
@@ -1716,6 +1720,13 @@ class HippocampusManager:
return self._hippocampus
async def initialize_async(self):
"""异步初始化海马体实例"""
if not self._initialized:
self.initialize() # 先进行同步部分的初始化
self._hippocampus.initialize()
await self._hippocampus.entorhinal_cortex.sync_memory_from_db()
def get_hippocampus(self):
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")

View File

@@ -246,11 +246,11 @@ class ChatManager:
return stream
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
with get_db_session() as session:
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
async def _db_find_stream_async(s_id: str):
async with get_db_session() as session:
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
model_instance = await _db_find_stream_async(stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
@@ -344,11 +344,10 @@ class ChatManager:
return
stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
with get_db_session() as session:
async def _db_save_stream_async(s_data_dict: dict):
async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
@@ -364,8 +363,6 @@ class ChatManager:
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
"focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value),
}
# 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
@@ -375,15 +372,13 @@ class ChatManager:
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
else:
# 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
session.execute(stmt)
session.commit()
await session.execute(stmt)
await session.commit()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
await _db_save_stream_async(stream_data_dict)
stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
@@ -397,10 +392,10 @@ class ChatManager:
"""从数据库加载所有聊天流"""
logger.info("正在从数据库加载所有聊天流")
def _db_load_all_streams_sync():
async def _db_load_all_streams_async():
loaded_streams_data = []
with get_db_session() as session:
for model_instance in session.execute(select(ChatStreams)).scalars():
async with get_db_session() as session:
for model_instance in (await session.execute(select(ChatStreams))).scalars():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
@@ -414,7 +409,6 @@ class ChatManager:
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
@@ -427,11 +421,11 @@ class ChatManager:
"focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value),
}
loaded_streams_data.append(data_for_from_dict)
session.commit()
await session.commit()
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
all_streams_data_list = await _db_load_all_streams_async()
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data)

View File

@@ -42,7 +42,7 @@ class MessageStorage:
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
@@ -129,9 +129,9 @@ class MessageStorage:
key_words=key_words,
key_words_lite=key_words_lite,
)
with get_db_session() as session:
async with get_db_session() as session:
session.add(new_message)
session.commit()
await session.commit()
except Exception:
logger.exception("存储消息失败")
@@ -174,16 +174,18 @@ class MessageStorage:
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
async with get_db_session() as session:
matched_message = (
await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
)
).scalar()
if matched_message:
session.execute(
await session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
session.commit()
await session.commit()
# 会在上下文管理器中自动调用
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
@@ -197,28 +199,36 @@ class MessageStorage:
)
@staticmethod
def replace_image_descriptions(text: str) -> str:
async def replace_image_descriptions(text: str) -> str:
"""将[图片:描述]替换为[picid:image_id]"""
# 先检查文本中是否有图片标记
pattern = r"\[图片:([^\]]+)\]"
matches = re.findall(pattern, text)
matches = list(re.finditer(pattern, text))
if not matches:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match):
new_text = ""
last_end = 0
for match in matches:
new_text += text[last_end : match.start()]
description = match.group(1).strip()
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
async with get_db_session() as session:
image_record = (
await session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
)
).scalar()
session.commit()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
if image_record:
new_text += f"[picid:{image_record.image_id}]"
else:
new_text += match.group(0)
except Exception:
return match.group(0)
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
new_text += match.group(0)
last_end = match.end()
new_text += text[last_end:]
return new_text

View File

@@ -97,12 +97,12 @@ class ActionModifier:
for action_name, reason in chat_type_removals:
logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}")
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
chat_id=self.chat_stream.stream_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
)
chat_content = build_readable_messages(
chat_content = await build_readable_messages(
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,

View File

@@ -152,7 +152,7 @@ class PlanFilter:
)
return prompt, message_id_list
chat_content_block, message_id_list = build_readable_messages_with_id(
chat_content_block, message_id_list = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal",
read_mark=self.last_obs_time_mark,
@@ -167,7 +167,7 @@ class PlanFilter:
limit=5,
)
actions_before_now_block = build_readable_actions(actions=actions_before_now)
actions_before_now_block = build_readable_actions(actions=await actions_before_now)
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
self.last_obs_time_mark = time.time()

View File

@@ -63,7 +63,7 @@ class PlanGenerator:
timestamp=time.time(),
limit=int(global_config.chat.max_context_size),
)
chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw]
chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw]
plan = Plan(

View File

@@ -873,7 +873,8 @@ class DefaultReplyer:
platform, # type: ignore
reply_message.get("user_id"), # type: ignore
)
person_name = await person_info_manager.get_value(person_id, "person_name")
person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"])
person_name = person_info.get("person_name")
# 如果person_name为None使用fallback值
if person_name is None:
@@ -884,7 +885,7 @@ class DefaultReplyer:
# 检查是否是bot自己的名字如果是则替换为"(你)"
bot_user_id = str(global_config.bot.qq_account)
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
current_user_id = person_info.get("user_id")
current_platform = reply_message.get("chat_info_platform")
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
@@ -909,18 +910,18 @@ class DefaultReplyer:
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size * 1,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
chat_talking_prompt_short = build_readable_messages(
chat_talking_prompt_short = await build_readable_messages(
message_list_before_short,
replace_bot_name=True,
merge_messages=False,
@@ -932,7 +933,7 @@ class DefaultReplyer:
# 获取目标用户信息用于s4u模式
target_user_info = None
if sender:
target_user_info = await person_info_manager.get_person_info_by_name(sender)
target_user_info = person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务
@@ -1150,12 +1151,12 @@ class DefaultReplyer:
else:
mood_prompt = ""
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
)
chat_talking_prompt_half = build_readable_messages(
chat_talking_prompt_half = await build_readable_messages(
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,

View File

@@ -116,8 +116,9 @@ async def replace_user_references_async(
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
person = Person(platform=platform, user_id=user_id)
return person.person_name or user_id # type: ignore
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info = await person_info_manager.get_values(person_id, ["person_name"])
return person_info.get("person_name") or user_id
name_resolver = default_resolver
@@ -165,7 +166,7 @@ async def replace_user_references_async(
return content
def get_raw_msg_by_timestamp(
async def get_raw_msg_by_timestamp(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
@@ -176,10 +177,10 @@ def get_raw_msg_by_timestamp(
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_by_timestamp_with_chat(
async def get_raw_msg_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -196,7 +197,7 @@ def get_raw_msg_by_timestamp_with_chat(
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(
return await find_messages(
message_filter=filter_query,
sort=sort_order,
limit=limit,
@@ -206,7 +207,7 @@ def get_raw_msg_by_timestamp_with_chat(
)
def get_raw_msg_by_timestamp_with_chat_inclusive(
async def get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -223,12 +224,12 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(
return await find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
)
def get_raw_msg_by_timestamp_with_chat_users(
async def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -247,10 +248,10 @@ def get_raw_msg_by_timestamp_with_chat_users(
}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_actions_by_timestamp_with_chat(
async def get_actions_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float = 0,
timestamp_end: float = time.time(),
@@ -269,10 +270,10 @@ def get_actions_by_timestamp_with_chat(
f"limit={limit}, limit_mode={limit_mode}"
)
with get_db_session() as session:
async with get_db_session() as session:
if limit > 0:
if limit_mode == "latest":
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -302,7 +303,7 @@ def get_actions_by_timestamp_with_chat(
}
actions_result.append(action_dict)
else: # earliest
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -332,7 +333,7 @@ def get_actions_by_timestamp_with_chat(
}
actions_result.append(action_dict)
else:
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -363,14 +364,14 @@ def get_actions_by_timestamp_with_chat(
return actions_result
def get_actions_by_timestamp_with_chat_inclusive(
async def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
with get_db_session() as session:
async with get_db_session() as session:
if limit > 0:
if limit_mode == "latest":
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -385,7 +386,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -398,7 +399,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
.limit(limit)
)
else:
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -414,14 +415,14 @@ def get_actions_by_timestamp_with_chat_inclusive(
return [action.__dict__ for action in actions]
def get_raw_msg_by_timestamp_random(
async def get_raw_msg_by_timestamp_random(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
"""
# 获取所有消息只取chat_id字段
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
if not all_msgs:
return []
# 随机选一条
@@ -429,10 +430,10 @@ def get_raw_msg_by_timestamp_random(
chat_id = msg["chat_id"]
timestamp_start = msg["time"]
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
def get_raw_msg_by_timestamp_with_users(
async def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
@@ -442,37 +443,39 @@ def get_raw_msg_by_timestamp_with_users(
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -486,10 +489,10 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
return 0 # 起始时间大于等于结束时间,没有新消息
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
return count_messages(message_filter=filter_query)
return await count_messages(message_filter=filter_query)
def num_new_messages_since_with_users(
async def num_new_messages_since_with_users(
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
) -> int:
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
@@ -500,10 +503,10 @@ def num_new_messages_since_with_users(
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
"user_id": {"$in": person_ids},
}
return count_messages(message_filter=filter_query)
return await count_messages(message_filter=filter_query)
def _build_readable_messages_internal(
async def _build_readable_messages_internal(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -622,7 +625,8 @@ def _build_readable_messages_internal(
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = person.person_name or user_id # type: ignore
person_info = await person_info_manager.get_values(person_id, ["person_name"])
person_name = person_info.get("person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
@@ -791,7 +795,7 @@ def _build_readable_messages_internal(
)
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
"""
构建图片映射信息字符串,显示图片的具体描述内容
@@ -814,9 +818,9 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "[图片内容未知]" # 默认描述
try:
with get_db_session() as session:
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
if image and image.description: # type: ignore
async with get_db_session() as session:
image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none()
if image and image.description: # type: ignore
description = image.description
except Exception:
# 如果查询失败,保持默认描述
@@ -912,17 +916,17 @@ async def build_readable_messages_with_list(
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, details_list
def build_readable_messages_with_id(
async def build_readable_messages_with_id(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -938,7 +942,7 @@ def build_readable_messages_with_id(
"""
message_id_list = assign_message_ids(messages)
formatted_string = build_readable_messages(
formatted_string = await build_readable_messages(
messages=messages,
replace_bot_name=replace_bot_name,
merge_messages=merge_messages,
@@ -953,7 +957,7 @@ def build_readable_messages_with_id(
return formatted_string, message_id_list
def build_readable_messages(
async def build_readable_messages(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -994,24 +998,28 @@ def build_readable_messages(
from src.common.database.sqlalchemy_database_api import get_db_session
with get_db_session() as session:
async with get_db_session() as session:
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
actions_in_range = (
await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
)
)
.order_by(ActionRecords.time)
)
.order_by(ActionRecords.time)
).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
action_after_latest = (
await session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
).scalars()
# 合并两部分动作记录,并转为 dict避免 DetachedInstanceError
@@ -1043,7 +1051,7 @@ def build_readable_messages(
if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
copy_messages,
replace_bot_name,
merge_messages,
@@ -1054,7 +1062,7 @@ def build_readable_messages(
)
# 生成图片映射信息并添加到最前面
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info:
return f"{pic_mapping_info}\n\n{formatted_string}"
else:
@@ -1069,7 +1077,7 @@ def build_readable_messages(
pic_counter = 1
# 分别格式化,但使用共享的图片映射
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
messages_before_mark,
replace_bot_name,
merge_messages,
@@ -1080,7 +1088,7 @@ def build_readable_messages(
show_pic=show_pic,
message_id_list=message_id_list,
)
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
messages_after_mark,
replace_bot_name,
merge_messages,
@@ -1096,7 +1104,7 @@ def build_readable_messages(
# 生成图片映射信息
if pic_id_mapping:
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
else:
pic_mapping_info = "聊天记录信息:\n"

View File

@@ -25,7 +25,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
def find_messages(
async def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
limit: int = 0,
@@ -46,7 +46,7 @@ def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
query = select(Messages)
# 应用过滤器
@@ -96,7 +96,7 @@ def find_messages(
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
try:
results = session.execute(query).scalars().all()
results = (await session.execute(query)).scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
@@ -104,7 +104,7 @@ def find_messages(
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
latest_results = session.execute(query).scalars().all()
latest_results = (await session.execute(query)).scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
@@ -128,12 +128,12 @@ def find_messages(
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = session.execute(query).scalars().all()
results = (await session.execute(query)).scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
return [_model_to_dict(msg) for msg in results]
return [_model_to_dict(msg) for msg in results]
except Exception as e:
log_message = (
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
@@ -143,7 +143,7 @@ def find_messages(
return []
def count_messages(message_filter: dict[str, Any]) -> int:
async def count_messages(message_filter: dict[str, Any]) -> int:
"""
根据提供的过滤器计算消息数量。
@@ -154,7 +154,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
query = select(func.count(Messages.id))
# 应用过滤器
@@ -192,7 +192,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if conditions:
query = query.where(*conditions)
count = session.execute(query).scalar()
count = (await session.execute(query)).scalar()
return count or 0
except Exception as e:
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"

View File

@@ -40,6 +40,9 @@ if not global_config.memory.enable_memory:
def initialize(self):
pass
async def initialize_async(self):
pass
def get_hippocampus(self):
return None
@@ -248,7 +251,7 @@ MoFox_Bot(第三方修改版)
logger.info("聊天管理器初始化成功")
# 初始化记忆系统
self.hippocampus_manager.initialize()
await self.hippocampus_manager.initialize_async()
logger.info("记忆系统初始化成功")
# 初始化LPMM知识库
@@ -283,7 +286,7 @@ MoFox_Bot(第三方修改版)
if global_config.planning_system.monthly_plan_enable:
logger.info("正在初始化月度计划管理器...")
try:
await monthly_plan_manager.start_monthly_plan_generation()
await monthly_plan_manager.initialize()
logger.info("月度计划管理器初始化成功")
except Exception as e:
logger.error(f"月度计划管理器初始化失败: {e}")
@@ -291,8 +294,7 @@ MoFox_Bot(第三方修改版)
# 初始化日程管理器
if global_config.planning_system.schedule_enable:
logger.info("日程表功能已启用,正在初始化管理器...")
await schedule_manager.load_or_generate_today_schedule()
await schedule_manager.start_daily_schedule_generation()
await schedule_manager.initialize()
logger.info("日程表管理器初始化成功。")
try:

View File

@@ -401,14 +401,15 @@ class PersonInfoManager:
# # 初始化时读取所有person_name
try:
pass
# 在这里获取会话
with get_db_session() as session:
for record in session.execute(
select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
).fetchall():
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
# with get_db_session() as session:
# for record in session.execute(
# select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
# ).fetchall():
# if record.person_name:
# self.person_name_list[record.person_id] = record.person_name
# logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
except Exception as e:
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
@@ -430,23 +431,25 @@ class PersonInfoManager:
"""判断是否认识某人"""
person_id = self.get_person_id(platform, user_id)
def _db_check_known_sync(p_id: str):
async def _db_check_known_async(p_id: str):
# 在需要时获取会话
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
async with get_db_session() as session:
return (
await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
).scalar() is not None
try:
return await asyncio.to_thread(_db_check_known_sync, person_id)
return await _db_check_known_async(person_id)
except Exception as e:
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
return False
def get_person_id_by_person_name(self, person_name: str) -> str:
async def get_person_id_by_person_name(self, person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
# 在需要时获取会话
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
@@ -500,19 +503,18 @@ class PersonInfoManager:
final_data[key] = orjson.dumps([]).decode("utf-8")
# If it's already a string, assume it's valid JSON or a non-JSON string field
def _db_create_sync(p_data: dict):
with get_db_session() as session:
async def _db_create_async(p_data: dict):
async with get_db_session() as session:
try:
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
await session.commit()
return True
except Exception as e:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
await _db_create_async(final_data)
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
"""安全地创建用户信息,处理竞态条件"""
@@ -557,11 +559,11 @@ class PersonInfoManager:
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = orjson.dumps([]).decode("utf-8")
def _db_safe_create_sync(p_data: dict):
with get_db_session() as session:
async def _db_safe_create_async(p_data: dict):
async with get_db_session() as session:
try:
existing = session.execute(
select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])
existing = (
await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"]))
).scalar()
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
@@ -570,18 +572,17 @@ class PersonInfoManager:
# 尝试创建
new_person = PersonInfo(**p_data)
session.add(new_person)
session.commit()
await session.commit()
return True
except Exception as e:
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
return True # 其他协程已创建,视为成功
return True
else:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
return False
await asyncio.to_thread(_db_safe_create_sync, final_data)
await _db_safe_create_async(final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
@@ -598,37 +599,33 @@ class PersonInfoManager:
elif value is None: # Store None as "[]" for JSON list fields
processed_value = orjson.dumps([]).decode("utf-8")
def _db_update_sync(p_id: str, f_name: str, val_to_set):
async def _db_update_async(p_id: str, f_name: str, val_to_set):
start_time = time.time()
with get_db_session() as session:
async with get_db_session() as session:
try:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
query_time = time.time()
if record:
setattr(record, f_name, val_to_set)
save_time = time.time()
total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志
if total_time > 0.5:
logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
)
session.commit()
return True, False # Found and updated, no creation needed
await session.commit()
return True, False
else:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
return False, True # Not found, needs creation
return False, True
except Exception as e:
total_time = time.time() - start_time
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
found, needs_creation = await _db_update_async(person_id, field_name, processed_value)
if needs_creation:
logger.info(f"{person_id} 不存在,将新建。")
@@ -666,13 +663,13 @@ class PersonInfoManager:
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。")
return False
def _db_has_field_sync(p_id: str, f_name: str):
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
async def _db_has_field_async(p_id: str, f_name: str):
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
return bool(record)
try:
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
return await _db_has_field_async(person_id, field_name)
except Exception as e:
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}")
return False
@@ -778,14 +775,14 @@ class PersonInfoManager:
logger.info(f"尝试给用户{user_nickname} {person_id} 取名,但是 {generated_nickname} 已存在,重试中...")
else:
def _db_check_name_exists_sync(name_to_check):
with get_db_session() as session:
async def _db_check_name_exists_async(name_to_check):
async with get_db_session() as session:
return (
session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar()
(await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar()
is not None
)
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
if await _db_check_name_exists_async(generated_nickname):
is_duplicate = True
current_name_set.add(generated_nickname)
@@ -824,91 +821,26 @@ class PersonInfoManager:
logger.debug("删除失败person_id 不能为空")
return
def _db_delete_sync(p_id: str):
async def _db_delete_async(p_id: str):
try:
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
async with get_db_session() as session:
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
if record:
session.delete(record)
session.commit()
return 1
await session.delete(record)
await session.commit()
return 1
return 0
except Exception as e:
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
deleted_count = await _db_delete_async(person_id)
if deleted_count > 0:
logger.debug(f"删除成功person_id={person_id} (Peewee)")
logger.debug(f"删除成功person_id={person_id}")
else:
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行")
@staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(p_id: str, f_name: str):
with get_db_session() as session:
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
val = getattr(record, f_name, None)
if f_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return orjson.loads(val)
except orjson.JSONDecodeError:
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
return [] # Default for JSON fields
# If val is already a list/dict (e.g. if somehow set without serialization)
return val # Should ideally not happen if update_one_field is always used
return val
return None # Record not found
try:
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if value_from_db is not None:
return value_from_db
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None # Ultimate fallback
except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
return default_value_for_field if field_name in person_info_default else None
@staticmethod
def get_value_sync(person_id: str, field_name: str):
"""同步获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
with get_db_session() as session:
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
val = getattr(record, field_name, None)
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return orjson.loads(val)
except orjson.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
return []
elif val is None:
return []
return val
return val
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None
@staticmethod
async def get_values(person_id: str, field_names: list) -> dict:
@@ -919,11 +851,11 @@ class PersonInfoManager:
result = {}
def _db_get_record_sync(p_id: str):
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
async def _db_get_record_async(p_id: str):
async with get_db_session() as session:
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
record = await asyncio.to_thread(_db_get_record_sync, person_id)
record = await _db_get_record_async(person_id)
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
@@ -960,14 +892,15 @@ class PersonInfoManager:
# 获取 SQLAlchemy 模型的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
if field_name not in model_fields:
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义")
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模中定义")
return {}
def _db_get_specific_sync(f_name: str):
async def _db_get_specific_async(f_name: str):
found_results = {}
try:
with get_db_session() as session:
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
async with get_db_session() as session:
result = await session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name)))
for record in result.fetchall():
value = getattr(record, f_name)
if way(value):
found_results[record.person_id] = value
@@ -978,9 +911,9 @@ class PersonInfoManager:
return found_results
try:
return await asyncio.to_thread(_db_get_specific_sync, field_name)
return await _db_get_specific_async(field_name)
except Exception as e:
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True)
return {}
async def get_or_create_person(
@@ -993,40 +926,38 @@ class PersonInfoManager:
"""
person_id = self.get_person_id(platform, user_id)
def _db_get_or_create_sync(p_id: str, init_data: dict):
async def _db_get_or_create_async(p_id: str, init_data: dict):
"""原子性的获取或创建操作"""
with get_db_session() as session:
async with get_db_session() as session:
# 首先尝试获取现有记录
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建
try:
new_person = PersonInfo(**init_data)
session.add(new_person)
session.commit()
return session.execute(
select(PersonInfo).where(PersonInfo.person_id == p_id)
).scalar(), True # 创建成功
except Exception as e:
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
# 记录不存在,尝试创建
try:
new_person = PersonInfo(**init_data)
session.add(new_person)
await session.commit()
await session.refresh(new_person)
return new_person, True # 创建成功
except Exception as e:
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
unique_nickname = await self._generate_unique_person_name(nickname)
initial_data = {
"person_id": person_id,
"platform": platform,
"user_id": str(user_id),
"nickname": nickname,
"person_name": unique_nickname, # 使用群昵称作为person_name
"person_name": unique_nickname,
"name_reason": "从群昵称获取",
"know_times": 0,
"know_since": int(datetime.datetime.now().timestamp()),
@@ -1036,7 +967,6 @@ class PersonInfoManager:
"forgotten_points": [],
}
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
if isinstance(initial_data[key], (list, dict)):
@@ -1044,15 +974,14 @@ class PersonInfoManager:
elif initial_data[key] is None:
initial_data[key] = orjson.dumps([]).decode("utf-8")
# 获取 SQLAlchemy 模odel的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data)
if was_created:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)")
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
logger.info(f"已为 {person_id} 创建新记录,初始数据: {filtered_initial_data}")
else:
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
@@ -1072,11 +1001,13 @@ class PersonInfoManager:
if not found_person_id:
def _db_find_by_name_sync(p_name_to_find: str):
with get_db_session() as session:
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
async def _db_find_by_name_async(p_name_to_find: str):
async with get_db_session() as session:
return (
await session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find))
).scalar()
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
record = await _db_find_by_name_async(person_name)
if record:
found_person_id = record.person_id
if (

View File

@@ -0,0 +1,457 @@
import time
import traceback
import orjson
import random
from typing import List, Dict, Any
from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
logger = get_logger("relationship_fetcher")
def init_real_time_info_prompts():
"""初始化实时信息提取相关的提示词"""
relationship_prompt = """
<聊天记录>
{chat_observe_info}
</聊天记录>
{name_block}
现在,你想要回复{person_name}的消息,消息内容是:{target_message}。请根据聊天记录和你要回复的消息,从你对{person_name}的了解中提取有关的信息:
1.你需要提供你想要提取的信息具体是哪方面的信息,例如:年龄,性别,你们之间的交流方式,最近发生的事等等。
2.请注意,请不要重复调取相同的信息,已经调取的信息如下:
{info_cache_block}
3.如果当前聊天记录中没有需要查询的信息,或者现有信息已经足够回复,请返回{{"none": "不需要查询"}}
请以json格式输出例如
{{
"info_type": "信息类型",
}}
请严格按照json输出格式不要输出多余内容
"""
Prompt(relationship_prompt, "real_time_info_identify_prompt")
fetch_info_prompt = """
{name_block}
以下是你在之前与{person_name}的交流中,产生的对{person_name}的了解:
{person_impression_block}
{points_text_block}
请从中提取用户"{person_name}"的有关"{info_type}"信息
请以json格式输出例如
{{
{info_json_str}
}}
请严格按照json输出格式不要输出多余内容
"""
Prompt(fetch_info_prompt, "real_time_fetch_person_info_prompt")
class RelationshipFetcher:
def __init__(self, chat_id):
self.chat_id = chat_id
# 信息获取缓存:记录正在获取的信息请求
self.info_fetching_cache: List[Dict[str, Any]] = []
# 信息结果缓存存储已获取的信息结果带TTL
self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
# LLM模型配置
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher"
)
# 小模型用于即时信息提取
self.instant_llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="relation.fetch"
)
name = get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{name}] 实时信息"
def _cleanup_expired_cache(self):
"""清理过期的信息缓存"""
for person_id in list(self.info_fetched_cache.keys()):
for info_type in list(self.info_fetched_cache[person_id].keys()):
self.info_fetched_cache[person_id][info_type]["ttl"] -= 1
if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0:
del self.info_fetched_cache[person_id][info_type]
if not self.info_fetched_cache[person_id]:
del self.info_fetched_cache[person_id]
async def build_relation_info(self, person_id, points_num=3):
# 清理过期的信息缓存
self._cleanup_expired_cache()
person_info_manager = get_person_info_manager()
person_info = await person_info_manager.get_values(
person_id, ["person_name", "short_impression", "nickname", "platform", "points"]
)
person_name = person_info.get("person_name")
short_impression = person_info.get("short_impression")
nickname_str = person_info.get("nickname")
platform = person_info.get("platform")
if person_name == nickname_str and not short_impression:
return ""
current_points = person_info.get("points") or []
# 按时间排序forgotten_points
current_points.sort(key=lambda x: x[2])
# 按权重加权随机抽取最多3个不重复的pointspoint[1]的值在1-10之间权重越高被抽到概率越大
if len(current_points) > points_num:
# point[1] 取值范围1-10直接作为权重
weights = [max(1, min(10, int(point[1]))) for point in current_points]
# 使用加权采样不放回,保证不重复
indices = list(range(len(current_points)))
points = []
for _ in range(points_num):
if not indices:
break
sub_weights = [weights[i] for i in indices]
chosen_idx = random.choices(indices, weights=sub_weights, k=1)[0]
points.append(current_points[chosen_idx])
indices.remove(chosen_idx)
else:
points = current_points
# 构建points文本
points_text = "\n".join([f"{point[2]}{point[0]}" for point in points])
nickname_str = ""
if person_name != nickname_str:
nickname_str = f"(ta在{platform}上的昵称是{nickname_str})"
relation_info = ""
if short_impression and relation_info:
if points_text:
relation_info = f"你对{person_name}的印象是{nickname_str}{short_impression}。具体来说:{relation_info}。你还记得ta最近做的事{points_text}"
else:
relation_info = (
f"你对{person_name}的印象是{nickname_str}{short_impression}。具体来说:{relation_info}"
)
elif short_impression:
if points_text:
relation_info = (
f"你对{person_name}的印象是{nickname_str}{short_impression}。你还记得ta最近做的事{points_text}"
)
else:
relation_info = f"你对{person_name}的印象是{nickname_str}{short_impression}"
elif relation_info:
if points_text:
relation_info = (
f"你对{person_name}的了解{nickname_str}{relation_info}。你还记得ta最近做的事{points_text}"
)
else:
relation_info = f"你对{person_name}的了解{nickname_str}{relation_info}"
elif points_text:
relation_info = f"你记得{person_name}{nickname_str}最近做的事:{points_text}"
else:
relation_info = ""
return relation_info
async def _build_fetch_query(self, person_id, target_message, chat_history):
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
person_info_manager = get_person_info_manager()
person_info = await person_info_manager.get_values(person_id, ["person_name"])
person_name: str = person_info.get("person_name") # type: ignore
info_cache_block = self._build_info_cache_block()
prompt = (await global_prompt_manager.get_prompt_async("real_time_info_identify_prompt")).format(
chat_observe_info=chat_history,
name_block=name_block,
info_cache_block=info_cache_block,
person_name=person_name,
target_message=target_message,
)
try:
logger.debug(f"{self.log_prefix} 信息识别prompt: \n{prompt}\n")
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if content:
content_json = orjson.loads(repair_json(content))
# 检查是否返回了不需要查询的标志
if "none" in content_json:
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息{content_json.get('none', '')}")
return None
if info_type := content_json.get("info_type"):
# 记录信息获取请求
self.info_fetching_cache.append(
{
"person_id": get_person_info_manager().get_person_id_by_person_name(person_name),
"person_name": person_name,
"info_type": info_type,
"start_time": time.time(),
"forget": False,
}
)
# 限制缓存大小
if len(self.info_fetching_cache) > 10:
self.info_fetching_cache.pop(0)
logger.info(f"{self.log_prefix} 识别到需要调取用户 {person_name} 的[{info_type}]信息")
return info_type
else:
logger.warning(f"{self.log_prefix} LLM未返回有效的info_type。响应: {content}")
except Exception as e:
logger.error(f"{self.log_prefix} 执行信息识别LLM请求时出错: {e}")
logger.error(traceback.format_exc())
return None
def _build_info_cache_block(self) -> str:
"""构建已获取信息的缓存块"""
info_cache_block = ""
if self.info_fetching_cache:
# 对于每个(person_id, info_type)组合,只保留最新的记录
latest_records = {}
for info_fetching in self.info_fetching_cache:
key = (info_fetching["person_id"], info_fetching["info_type"])
if key not in latest_records or info_fetching["start_time"] > latest_records[key]["start_time"]:
latest_records[key] = info_fetching
# 按时间排序并生成显示文本
sorted_records = sorted(latest_records.values(), key=lambda x: x["start_time"])
for info_fetching in sorted_records:
info_cache_block += (
f"你已经调取了[{info_fetching['person_name']}]的[{info_fetching['info_type']}]信息\n"
)
return info_cache_block
async def _extract_single_info(self, person_id: str, info_type: str, person_name: str):
"""提取单个信息类型
Args:
person_id: 用户ID
info_type: 信息类型
person_name: 用户名
"""
start_time = time.time()
person_info_manager = get_person_info_manager()
# 首先检查 info_list 缓存
person_info = await person_info_manager.get_values(person_id, ["info_list"])
info_list = person_info.get("info_list") or []
cached_info = None
# 查找对应的 info_type
for info_item in info_list:
if info_item.get("info_type") == info_type:
cached_info = info_item.get("info_content")
logger.debug(f"{self.log_prefix} 在info_list中找到 {person_name}{info_type} 信息: {cached_info}")
break
# 如果缓存中有信息,直接使用
if cached_info:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info": cached_info,
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
"unknown": cached_info == "none",
}
logger.info(f"{self.log_prefix} 记得 {person_name}{info_type}: {cached_info}")
return
# 如果缓存中没有,尝试从用户档案中提取
try:
person_info = await person_info_manager.get_values(person_id, ["impression", "points"])
person_impression = person_info.get("impression")
points = person_info.get("points")
# 构建印象信息块
if person_impression:
person_impression_block = (
f"<对{person_name}的总体了解>\n{person_impression}\n</对{person_name}的总体了解>"
)
else:
person_impression_block = ""
# 构建要点信息块
if points:
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
points_text_block = f"<对{person_name}的近期了解>\n{points_text}\n</对{person_name}的近期了解>"
else:
points_text_block = ""
# 如果完全没有用户信息
if not points_text_block and not person_impression_block:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info": "none",
"ttl": 2,
"start_time": start_time,
"person_name": person_name,
"unknown": True,
}
logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
await self._save_info_to_cache(person_id, info_type, "none")
return
# 使用LLM提取信息
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
prompt = (await global_prompt_manager.get_prompt_async("real_time_fetch_person_info_prompt")).format(
name_block=name_block,
info_type=info_type,
person_impression_block=person_impression_block,
person_name=person_name,
info_json_str=f'"{info_type}": "有关{info_type}的信息内容"',
points_text_block=points_text_block,
)
# 使用小模型进行即时提取
content, _ = await self.instant_llm_model.generate_response_async(prompt=prompt)
if content:
content_json = orjson.loads(repair_json(content))
if info_type in content_json:
info_content = content_json[info_type]
is_unknown = info_content == "none" or not info_content
# 保存到运行时缓存
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info": "unknown" if is_unknown else info_content,
"ttl": 3,
"start_time": start_time,
"person_name": person_name,
"unknown": is_unknown,
}
# 保存到持久化缓存 (info_list)
await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content)
if not is_unknown:
logger.info(f"{self.log_prefix} 思考得到,{person_name}{info_type}: {info_content}")
else:
logger.info(f"{self.log_prefix} 思考了也不知道{person_name}{info_type} 信息")
else:
logger.warning(f"{self.log_prefix} 小模型返回空结果,获取 {person_name}{info_type} 信息失败。")
except Exception as e:
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
logger.error(traceback.format_exc())
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
# sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中
Args:
person_id: 用户ID
info_type: 信息类型
info_content: 信息内容
"""
try:
person_info_manager = get_person_info_manager()
# 获取现有的 info_list
person_info = await person_info_manager.get_values(person_id, ["info_list"])
info_list = person_info.get("info_list") or []
# 查找是否已存在相同 info_type 的记录
found_index = -1
for i, info_item in enumerate(info_list):
if isinstance(info_item, dict) and info_item.get("info_type") == info_type:
found_index = i
break
# 创建新的信息记录
new_info_item = {
"info_type": info_type,
"info_content": info_content,
}
if found_index >= 0:
# 更新现有记录
info_list[found_index] = new_info_item
logger.info(f"{self.log_prefix} [缓存更新] 更新 {person_id}{info_type} 信息缓存")
else:
# 添加新记录
info_list.append(new_info_item)
logger.info(f"{self.log_prefix} [缓存保存] 新增 {person_id}{info_type} 信息缓存")
# 保存更新后的 info_list
await person_info_manager.update_one_field(person_id, "info_list", info_list)
except Exception as e:
logger.error(f"{self.log_prefix} [缓存保存] 保存信息到缓存失败: {e}")
logger.error(traceback.format_exc())
class RelationshipFetcherManager:
"""关系提取器管理器
管理不同 chat_id 的 RelationshipFetcher 实例
"""
def __init__(self):
self._fetchers: Dict[str, RelationshipFetcher] = {}
def get_fetcher(self, chat_id: str) -> RelationshipFetcher:
"""获取或创建指定 chat_id 的 RelationshipFetcher
Args:
chat_id: 聊天ID
Returns:
RelationshipFetcher: 关系提取器实例
"""
if chat_id not in self._fetchers:
self._fetchers[chat_id] = RelationshipFetcher(chat_id)
return self._fetchers[chat_id]
def remove_fetcher(self, chat_id: str):
"""移除指定 chat_id 的 RelationshipFetcher
Args:
chat_id: 聊天ID
"""
if chat_id in self._fetchers:
del self._fetchers[chat_id]
def clear_all(self):
"""清空所有 RelationshipFetcher"""
self._fetchers.clear()
def get_active_chat_ids(self) -> List[str]:
"""获取所有活跃的 chat_id 列表"""
return list(self._fetchers.keys())
# 全局管理器实例
relationship_fetcher_manager = RelationshipFetcherManager()
init_real_time_info_prompts()

View File

@@ -62,7 +62,7 @@ def get_messages_by_time(
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
def get_messages_by_time_in_chat(
async def get_messages_by_time_in_chat(
chat_id: str,
start_time: float,
end_time: float,
@@ -97,13 +97,13 @@ def get_messages_by_time_in_chat(
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai:
return filter_mai_messages(
get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
return await filter_mai_messages(
await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
)
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
return await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
def get_messages_by_time_in_chat_inclusive(
async def get_messages_by_time_in_chat_inclusive(
chat_id: str,
start_time: float,
end_time: float,
@@ -138,12 +138,12 @@ def get_messages_by_time_in_chat_inclusive(
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
if filter_mai:
return filter_mai_messages(
get_raw_msg_by_timestamp_with_chat_inclusive(
return await filter_mai_messages(
await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id, start_time, end_time, limit, limit_mode, filter_command
)
)
return get_raw_msg_by_timestamp_with_chat_inclusive(
return await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id, start_time, end_time, limit, limit_mode, filter_command
)
@@ -478,7 +478,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
# =============================================================================
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
从消息列表中移除麦麦的消息
Args:

View File

@@ -1,6 +1,7 @@
# mmc/src/schedule/database.py
from typing import List
from sqlalchemy import select, func, update, delete
from src.common.database.sqlalchemy_models import MonthlyPlan, get_db_session
from src.common.logger import get_logger
from src.config.config import global_config
@@ -8,21 +9,22 @@ from src.config.config import global_config
logger = get_logger("schedule_database")
def add_new_plans(plans: List[str], month: str):
async def add_new_plans(plans: List[str], month: str):
"""
批量添加新生成的月度计划到数据库,并确保不超过上限。
:param plans: 计划内容列表。
:param month: 目标月份,格式为 "YYYY-MM"
"""
with get_db_session() as session:
async with get_db_session() as session:
try:
# 1. 获取当前有效计划数量(状态为 'active'
current_plan_count = (
session.query(MonthlyPlan)
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
.count()
result = await session.execute(
select(func.count(MonthlyPlan.id)).where(
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
)
)
current_plan_count = result.scalar_one()
# 2. 从配置获取上限
max_plans = global_config.planning_system.max_plans_per_month
@@ -41,7 +43,7 @@ def add_new_plans(plans: List[str], month: str):
MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add
]
session.add_all(new_plan_objects)
session.commit()
await session.commit()
logger.info(f"成功向数据库添加了 {len(new_plan_objects)}{month} 的月度计划。")
if len(plans) > len(plans_to_add):
@@ -49,32 +51,31 @@ def add_new_plans(plans: List[str], month: str):
except Exception as e:
logger.error(f"添加月度计划时发生错误: {e}")
session.rollback()
await session.rollback()
raise
def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
"""
获取指定月份所有状态为 'active' 的计划。
:param month: 目标月份,格式为 "YYYY-MM"
:return: MonthlyPlan 对象列表。
"""
with get_db_session() as session:
async with get_db_session() as session:
try:
plans = (
session.query(MonthlyPlan)
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
result = await session.execute(
select(MonthlyPlan)
.where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
.order_by(MonthlyPlan.created_at.desc())
.all()
)
return plans
return result.scalars().all()
except Exception as e:
logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}")
return []
def mark_plans_completed(plan_ids: List[int]):
async def mark_plans_completed(plan_ids: List[int]):
"""
将指定ID的计划标记为已完成。
@@ -83,9 +84,10 @@ def mark_plans_completed(plan_ids: List[int]):
if not plan_ids:
return
with get_db_session() as session:
async with get_db_session() as session:
try:
plans_to_mark = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all()
result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
plans_to_mark = result.scalars().all()
if not plans_to_mark:
logger.info("没有需要标记为完成的月度计划。")
return
@@ -93,17 +95,17 @@ def mark_plans_completed(plan_ids: List[int]):
plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)])
logger.info(f"以下 {len(plans_to_mark)} 条月度计划将被标记为已完成:\n{plan_details}")
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update(
{"status": "completed"}, synchronize_session=False
await session.execute(
update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed")
)
session.commit()
await session.commit()
except Exception as e:
logger.error(f"标记月度计划为完成时发生错误: {e}")
session.rollback()
await session.rollback()
raise
def delete_plans_by_ids(plan_ids: List[int]):
async def delete_plans_by_ids(plan_ids: List[int]):
"""
根据ID列表从数据库中物理删除月度计划。
@@ -112,10 +114,11 @@ def delete_plans_by_ids(plan_ids: List[int]):
if not plan_ids:
return
with get_db_session() as session:
async with get_db_session() as session:
try:
# 先查询要删除的计划,用于日志记录
plans_to_delete = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all()
result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
plans_to_delete = result.scalars().all()
if not plans_to_delete:
logger.info("没有找到需要删除的月度计划。")
return
@@ -124,16 +127,16 @@ def delete_plans_by_ids(plan_ids: List[int]):
logger.info(f"检测到月度计划超额,将删除以下 {len(plans_to_delete)} 条计划:\n{plan_details}")
# 执行删除
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).delete(synchronize_session=False)
session.commit()
await session.execute(delete(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
await session.commit()
except Exception as e:
logger.error(f"删除月度计划时发生错误: {e}")
session.rollback()
await session.rollback()
raise
def update_plan_usage(plan_ids: List[int], used_date: str):
async def update_plan_usage(plan_ids: List[int], used_date: str):
"""
更新计划的使用统计信息。
@@ -143,44 +146,47 @@ def update_plan_usage(plan_ids: List[int], used_date: str):
if not plan_ids:
return
with get_db_session() as session:
async with get_db_session() as session:
try:
# 获取完成阈值配置,如果不存在则使用默认值
completion_threshold = getattr(global_config.planning_system, "completion_threshold", 3)
# 批量更新使用次数和最后使用日期
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update(
{"usage_count": MonthlyPlan.usage_count + 1, "last_used_date": used_date}, synchronize_session=False
await session.execute(
update(MonthlyPlan)
.where(MonthlyPlan.id.in_(plan_ids))
.values(usage_count=MonthlyPlan.usage_count + 1, last_used_date=used_date)
)
# 检查是否有计划达到完成阈值
plans_to_complete = (
session.query(MonthlyPlan)
.filter(
result = await session.execute(
select(MonthlyPlan).where(
MonthlyPlan.id.in_(plan_ids),
MonthlyPlan.usage_count >= completion_threshold,
MonthlyPlan.status == "active",
)
.all()
)
plans_to_complete = result.scalars().all()
if plans_to_complete:
completed_ids = [plan.id for plan in plans_to_complete]
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(completed_ids)).update(
{"status": "completed"}, synchronize_session=False
await session.execute(
update(MonthlyPlan).where(MonthlyPlan.id.in_(completed_ids)).values(status="completed")
)
logger.info(f"计划 {completed_ids} 已达到使用阈值 ({completion_threshold}),标记为已完成。")
session.commit()
await session.commit()
logger.info(f"成功更新了 {len(plan_ids)} 条月度计划的使用统计。")
except Exception as e:
logger.error(f"更新月度计划使用统计时发生错误: {e}")
session.rollback()
await session.rollback()
raise
def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]:
async def get_smart_plans_for_daily_schedule(
month: str, max_count: int = 3, avoid_days: int = 7
) -> List[MonthlyPlan]:
"""
智能抽取月度计划用于每日日程生成。
@@ -196,19 +202,24 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day
"""
from datetime import datetime, timedelta
with get_db_session() as session:
async with get_db_session() as session:
try:
# 计算避免重复的日期阈值
avoid_date = (datetime.now() - timedelta(days=avoid_days)).strftime("%Y-%m-%d")
# 查询符合条件的计划
query = session.query(MonthlyPlan).filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
query = select(MonthlyPlan).where(
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
)
# 排除最近使用过的计划
query = query.filter((MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date))
query = query.where(
(MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date)
)
# 按使用次数升序排列,优先选择使用次数少的
plans = query.order_by(MonthlyPlan.usage_count.asc()).all()
result = await session.execute(query.order_by(MonthlyPlan.usage_count.asc()))
plans = result.scalars().all()
if not plans:
logger.info(f"没有找到符合条件的 {month} 月度计划。")
@@ -228,31 +239,31 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day
return []
def archive_active_plans_for_month(month: str):
async def archive_active_plans_for_month(month: str):
"""
将指定月份所有状态为 'active' 的计划归档为 'archived'
通常在月底调用。
:param month: 目标月份,格式为 "YYYY-MM"
"""
with get_db_session() as session:
async with get_db_session() as session:
try:
updated_count = (
session.query(MonthlyPlan)
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
.update({"status": "archived"}, synchronize_session=False)
result = await session.execute(
update(MonthlyPlan)
.where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
.values(status="archived")
)
session.commit()
updated_count = result.rowcount
await session.commit()
logger.info(f"成功将 {updated_count}{month} 的活跃月度计划归档。")
return updated_count
except Exception as e:
logger.error(f"归档 {month} 的月度计划时发生错误: {e}")
session.rollback()
await session.rollback()
raise
def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
"""
获取指定月份所有状态为 'archived' 的计划。
用于生成下个月计划时的参考。
@@ -260,34 +271,34 @@ def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
:param month: 目标月份,格式为 "YYYY-MM"
:return: MonthlyPlan 对象列表。
"""
with get_db_session() as session:
async with get_db_session() as session:
try:
plans = (
session.query(MonthlyPlan)
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived")
.all()
result = await session.execute(
select(MonthlyPlan).where(
MonthlyPlan.target_month == month, MonthlyPlan.status == "archived"
)
)
return plans
return result.scalars().all()
except Exception as e:
logger.error(f"查询 {month} 的归档月度计划时发生错误: {e}")
return []
def has_active_plans(month: str) -> bool:
async def has_active_plans(month: str) -> bool:
"""
检查指定月份是否存在任何状态为 'active' 的计划。
:param month: 目标月份,格式为 "YYYY-MM"
:return: 如果存在则返回 True否则返回 False。
"""
with get_db_session() as session:
async with get_db_session() as session:
try:
count = (
session.query(MonthlyPlan)
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
.count()
result = await session.execute(
select(func.count(MonthlyPlan.id)).where(
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
)
)
return count > 0
return result.scalar_one() > 0
except Exception as e:
logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}")
return False

View File

@@ -14,6 +14,11 @@ class MonthlyPlanManager:
self.plan_manager = PlanManager()
self.monthly_task_started = False
async def initialize(self):
logger.info("正在初始化月度计划管理器...")
await self.start_monthly_plan_generation()
logger.info("月度计划管理器初始化成功")
async def start_monthly_plan_generation(self):
if not self.monthly_task_started:
logger.info(" 正在启动每月月度计划生成任务...")

View File

@@ -28,20 +28,20 @@ class PlanManager:
if target_month is None:
target_month = datetime.now().strftime("%Y-%m")
if not has_active_plans(target_month):
if not await has_active_plans(target_month):
logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。")
generation_successful = await self._generate_monthly_plans_logic(target_month)
return generation_successful
else:
logger.info(f"{target_month} 已存在有效的月度计划。")
plans = get_active_plans_for_month(target_month)
plans = await get_active_plans_for_month(target_month)
max_plans = global_config.planning_system.max_plans_per_month
if len(plans) > max_plans:
logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。")
plans_to_delete = plans[: len(plans) - max_plans]
delete_ids = [p.id for p in plans_to_delete]
delete_plans_by_ids(delete_ids) # type: ignore
plans = get_active_plans_for_month(target_month)
await delete_plans_by_ids(delete_ids) # type: ignore
plans = await get_active_plans_for_month(target_month)
if plans:
plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)])
@@ -64,11 +64,11 @@ class PlanManager:
return False
last_month = self._get_previous_month(target_month)
archived_plans = get_archived_plans_for_month(last_month)
archived_plans = await get_archived_plans_for_month(last_month)
plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans)
if plans:
add_new_plans(plans, target_month)
await add_new_plans(plans, target_month)
logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。")
return True
else:
@@ -95,11 +95,11 @@ class PlanManager:
if target_month is None:
target_month = datetime.now().strftime("%Y-%m")
logger.info(f" 开始归档 {target_month} 的活跃月度计划...")
archived_count = archive_active_plans_for_month(target_month)
archived_count = await archive_active_plans_for_month(target_month)
logger.info(f" 成功归档了 {archived_count}{target_month} 的月度计划。")
except Exception as e:
logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}")
def get_plans_for_schedule(self, month: str, max_count: int) -> List:
async def get_plans_for_schedule(self, month: str, max_count: int) -> List:
avoid_days = global_config.planning_system.avoid_repetition_days
return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)

View File

@@ -23,6 +23,13 @@ class ScheduleManager:
self.daily_task_started = False
self.schedule_generation_running = False
async def initialize(self):
if global_config.planning_system.schedule_enable:
logger.info("日程表功能已启用,正在初始化管理器...")
await self.load_or_generate_today_schedule()
await self.start_daily_schedule_generation()
logger.info("日程表管理器初始化成功。")
async def start_daily_schedule_generation(self):
if not self.daily_task_started:
logger.info("正在启动每日日程生成任务...")
@@ -40,7 +47,7 @@ class ScheduleManager:
today_str = datetime.now().strftime("%Y-%m-%d")
try:
schedule_data = self._load_schedule_from_db(today_str)
schedule_data = await self._load_schedule_from_db(today_str)
if schedule_data:
self.today_schedule = schedule_data
self._log_loaded_schedule(today_str)
@@ -54,9 +61,10 @@ class ScheduleManager:
logger.info("尝试生成日程作为备用方案...")
await self.generate_and_save_schedule()
def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]:
with get_db_session() as session:
schedule_record = session.query(Schedule).filter(Schedule.date == date_str).first()
async def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]:
async with get_db_session() as session:
result = await session.execute(select(Schedule).filter(Schedule.date == date_str))
schedule_record = result.scalars().first()
if schedule_record:
logger.info(f"从数据库加载今天的日程 ({date_str})。")
schedule_data = orjson.loads(str(schedule_record.schedule_data))
@@ -90,35 +98,35 @@ class ScheduleManager:
sampled_plans = []
if global_config.planning_system.monthly_plan_enable:
await self.plan_manager.ensure_and_generate_plans_if_needed(current_month_str)
sampled_plans = self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3)
sampled_plans = await self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3)
schedule_data = await self.llm_generator.generate_schedule_with_llm(sampled_plans)
if schedule_data:
self._save_schedule_to_db(today_str, schedule_data)
await self._save_schedule_to_db(today_str, schedule_data)
self.today_schedule = schedule_data
self._log_generated_schedule(today_str, schedule_data)
if sampled_plans:
used_plan_ids = [plan.id for plan in sampled_plans]
logger.info(f"更新使用过的月度计划 {used_plan_ids} 的统计信息。")
update_plan_usage(used_plan_ids, today_str)
await update_plan_usage(used_plan_ids, today_str)
finally:
self.schedule_generation_running = False
logger.info("日程生成任务结束")
def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]):
with get_db_session() as session:
async def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]):
async with get_db_session() as session:
schedule_json = orjson.dumps(schedule_data).decode("utf-8")
existing_schedule = session.query(Schedule).filter(Schedule.date == date_str).first()
result = await session.execute(select(Schedule).filter(Schedule.date == date_str))
existing_schedule = result.scalars().first()
if existing_schedule:
session.query(Schedule).filter(Schedule.date == date_str).update(
{Schedule.schedule_data: schedule_json, Schedule.updated_at: datetime.now()}
)
existing_schedule.schedule_data = schedule_json
existing_schedule.updated_at = datetime.now()
else:
new_schedule = Schedule(date=date_str, schedule_data=schedule_json)
session.add(new_schedule)
session.commit()
await session.commit()
def _log_generated_schedule(self, date_str: str, schedule_data: List[Dict[str, Any]]):
schedule_str = f"✅ 成功生成并保存今天的日程 ({date_str})\n"