三次修改
This commit is contained in:
@@ -361,7 +361,7 @@ class HeartFChatting:
|
||||
# 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收
|
||||
filter_command_flag = not (is_sleeping or is_in_insomnia)
|
||||
|
||||
recent_messages = message_api.get_messages_by_time_in_chat(
|
||||
recent_messages = await message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.context.stream_id,
|
||||
start_time=self.context.last_read_time,
|
||||
end_time=time.time(),
|
||||
|
||||
@@ -149,7 +149,7 @@ class MaiEmoji:
|
||||
# --- 数据库操作 ---
|
||||
try:
|
||||
# 准备数据库记录 for emoji collection
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||
|
||||
emoji = Emoji(
|
||||
@@ -167,7 +167,7 @@ class MaiEmoji:
|
||||
last_used_time=self.last_used_time,
|
||||
)
|
||||
session.add(emoji)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
|
||||
@@ -203,17 +203,17 @@ class MaiEmoji:
|
||||
|
||||
# 2. 删除数据库记录
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
will_delete_emoji = session.execute(
|
||||
select(Emoji).where(Emoji.emoji_hash == self.hash)
|
||||
async with get_db_session() as session:
|
||||
will_delete_emoji = (
|
||||
await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
|
||||
).scalar_one_or_none()
|
||||
if will_delete_emoji is None:
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
result = 0 # Indicate no DB record was deleted
|
||||
result = 0
|
||||
else:
|
||||
session.delete(will_delete_emoji)
|
||||
result = 1 # Successfully deleted one record
|
||||
session.commit()
|
||||
await session.delete(will_delete_emoji)
|
||||
result = 1
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
|
||||
result = 0
|
||||
@@ -424,17 +424,19 @@ class EmojiManager:
|
||||
# if not self._initialized:
|
||||
# raise RuntimeError("EmojiManager not initialized")
|
||||
|
||||
def record_usage(self, emoji_hash: str) -> None:
|
||||
async def record_usage(self, emoji_hash: str) -> None:
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
async with get_db_session() as session:
|
||||
emoji_update = (
|
||||
await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
|
||||
).scalar_one_or_none()
|
||||
if emoji_update is None:
|
||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||
else:
|
||||
emoji_update.usage_count += 1
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
session.commit()
|
||||
emoji_update.last_used_time = time.time()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
@@ -658,10 +660,11 @@ class EmojiManager:
|
||||
async def get_all_emoji_from_db(self) -> None:
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 ...")
|
||||
|
||||
emoji_instances = session.execute(select(Emoji)).scalars().all()
|
||||
result = await session.execute(select(Emoji))
|
||||
emoji_instances = result.scalars().all()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
# 更新内存中的列表和数量
|
||||
@@ -687,14 +690,16 @@ class EmojiManager:
|
||||
list[MaiEmoji]: 表情包对象列表
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if emoji_hash:
|
||||
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
|
||||
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
|
||||
query = result.scalars().all()
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
query = session.execute(select(Emoji)).scalars().all()
|
||||
result = await session.execute(select(Emoji))
|
||||
query = result.scalars().all()
|
||||
|
||||
emoji_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
@@ -771,10 +776,11 @@ class EmojiManager:
|
||||
|
||||
# 如果内存中没有,从数据库查找
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
emoji_record = session.execute(
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(
|
||||
select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
||||
).scalar_one_or_none()
|
||||
)
|
||||
emoji_record = result.scalar_one_or_none()
|
||||
if emoji_record and emoji_record.description:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||
return emoji_record.description
|
||||
@@ -937,10 +943,13 @@ class EmojiManager:
|
||||
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
||||
existing_description = None
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
existing_image = session.query(Images).filter(
|
||||
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
|
||||
).one_or_none()
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(
|
||||
select(Images).filter(
|
||||
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
|
||||
)
|
||||
)
|
||||
existing_image = result.scalar_one_or_none()
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
|
||||
@@ -136,18 +136,18 @@ class ExpressionSelector:
|
||||
|
||||
return related_chat_ids if related_chat_ids else [chat_id]
|
||||
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
async def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = session.execute(
|
||||
style_query = await session.execute(
|
||||
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style"))
|
||||
)
|
||||
grammar_query = session.execute(
|
||||
grammar_query = await session.execute(
|
||||
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar"))
|
||||
)
|
||||
|
||||
@@ -193,7 +193,7 @@ class ExpressionSelector:
|
||||
|
||||
return selected_style, selected_grammar
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
async def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
@@ -210,26 +210,27 @@ class ExpressionSelector:
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
with get_db_session() as session:
|
||||
query = session.execute(
|
||||
async with get_db_session() as session:
|
||||
query = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
session.commit()
|
||||
query = query.scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
@@ -246,12 +247,8 @@ class ExpressionSelector:
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||
style_exprs = self.get_random_expressions(chat_id, 10)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
# 1. 获取35个随机表达方式(现在按权重抽取)
|
||||
style_exprs, grammar_exprs = await self.get_random_expressions(chat_id, 30, 0.5, 0.5)
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
@@ -326,7 +323,7 @@ class ExpressionSelector:
|
||||
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||
await self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
@@ -267,8 +267,8 @@ class Hippocampus:
|
||||
self.entorhinal_cortex = EntorhinalCortex(self)
|
||||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||
# 从数据库加载记忆图
|
||||
self.entorhinal_cortex.sync_memory_from_db()
|
||||
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify")
|
||||
# self.entorhinal_cortex.sync_memory_from_db() # 改为异步启动
|
||||
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.small")
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
@@ -877,7 +877,7 @@ class EntorhinalCortex:
|
||||
self.hippocampus = hippocampus
|
||||
self.memory_graph = hippocampus.memory_graph
|
||||
|
||||
def get_memory_sample(self):
|
||||
async def get_memory_sample(self):
|
||||
"""从数据库获取记忆样本"""
|
||||
# 硬编码:每条消息最大记忆次数
|
||||
max_memorized_time_per_msg = 2
|
||||
@@ -900,7 +900,7 @@ class EntorhinalCortex:
|
||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||
chat_samples = []
|
||||
for timestamp in timestamps:
|
||||
if messages := self.random_get_msg_snippet(
|
||||
if messages := await self.random_get_msg_snippet(
|
||||
timestamp,
|
||||
global_config.memory.memory_build_sample_length,
|
||||
max_memorized_time_per_msg,
|
||||
@@ -914,7 +914,9 @@ class EntorhinalCortex:
|
||||
return chat_samples
|
||||
|
||||
@staticmethod
|
||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
||||
async def random_get_msg_snippet(
|
||||
target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int
|
||||
) -> list | None:
|
||||
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||
@@ -952,13 +954,13 @@ class EntorhinalCortex:
|
||||
for message in messages:
|
||||
# 确保在更新前获取最新的 memorized_times
|
||||
current_memorized_times = message.get("memorized_times", 0)
|
||||
with get_db_session() as session:
|
||||
session.execute(
|
||||
async with get_db_session() as session:
|
||||
await session.execute(
|
||||
update(Messages)
|
||||
.where(Messages.message_id == message["message_id"])
|
||||
.values(memorized_times=current_memorized_times + 1)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return messages # 直接返回原始的消息列表
|
||||
|
||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||
@@ -972,8 +974,8 @@ class EntorhinalCortex:
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
with get_db_session() as session:
|
||||
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
|
||||
async with get_db_session() as session:
|
||||
db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 批量准备节点数据
|
||||
@@ -1043,24 +1045,24 @@ class EntorhinalCortex:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
await session.execute(insert(GraphNodes), batch)
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_update), batch_size):
|
||||
batch = nodes_to_update[i : i + batch_size]
|
||||
for node_data in batch:
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(GraphNodes)
|
||||
.where(GraphNodes.concept == node_data["concept"])
|
||||
.values(**{k: v for k, v in node_data.items() if k != "concept"})
|
||||
)
|
||||
|
||||
if nodes_to_delete:
|
||||
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
|
||||
await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
db_edges = list((await session.execute(select(GraphEdges))).scalars())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -1112,14 +1114,14 @@ class EntorhinalCortex:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
await session.execute(insert(GraphEdges), batch)
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_update), batch_size):
|
||||
batch = edges_to_update[i : i + batch_size]
|
||||
for edge_data in batch:
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(GraphEdges)
|
||||
.where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
@@ -1129,12 +1131,12 @@ class EntorhinalCortex:
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
session.execute(
|
||||
await session.execute(
|
||||
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
)
|
||||
|
||||
# 提交事务
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||
@@ -1146,10 +1148,10 @@ class EntorhinalCortex:
|
||||
logger.info("[数据库] 开始重新同步所有记忆数据...")
|
||||
|
||||
# 清空数据库
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
clear_start = time.time()
|
||||
session.execute(delete(GraphNodes))
|
||||
session.execute(delete(GraphEdges))
|
||||
await session.execute(delete(GraphNodes))
|
||||
await session.execute(delete(GraphEdges))
|
||||
|
||||
clear_end = time.time()
|
||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||
@@ -1208,7 +1210,7 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
await session.execute(insert(GraphNodes), batch)
|
||||
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
@@ -1219,8 +1221,8 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i : i + batch_size]
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
session.commit()
|
||||
await session.execute(insert(GraphEdges), batch)
|
||||
await session.commit()
|
||||
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
@@ -1229,7 +1231,7 @@ class EntorhinalCortex:
|
||||
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边")
|
||||
|
||||
def sync_memory_from_db(self):
|
||||
async def sync_memory_from_db(self):
|
||||
"""从数据库同步数据到内存中的图结构"""
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
need_update = False
|
||||
@@ -1243,8 +1245,8 @@ class EntorhinalCortex:
|
||||
skipped_nodes = 0
|
||||
|
||||
# 从数据库加载所有节点
|
||||
with get_db_session() as session:
|
||||
nodes = list(session.execute(select(GraphNodes)).scalars())
|
||||
async with get_db_session() as session:
|
||||
nodes = list((await session.execute(select(GraphNodes))).scalars())
|
||||
for node in nodes:
|
||||
concept = node.concept
|
||||
try:
|
||||
@@ -1262,7 +1264,9 @@ class EntorhinalCortex:
|
||||
if not node.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data))
|
||||
await session.execute(
|
||||
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
|
||||
)
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.created_time or current_time
|
||||
@@ -1277,7 +1281,7 @@ class EntorhinalCortex:
|
||||
continue
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
edges = list((await session.execute(select(GraphEdges))).scalars())
|
||||
for edge in edges:
|
||||
source = edge.source
|
||||
target = edge.target
|
||||
@@ -1293,7 +1297,7 @@ class EntorhinalCortex:
|
||||
if not edge.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(GraphEdges)
|
||||
.where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
.values(**update_data)
|
||||
@@ -1308,7 +1312,7 @@ class EntorhinalCortex:
|
||||
self.memory_graph.G.add_edge(
|
||||
source, target, strength=strength, created_time=created_time, last_modified=last_modified
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
if need_update:
|
||||
logger.info("[数据库] 已为缺失的时间字段进行补充")
|
||||
@@ -1348,7 +1352,7 @@ class ParahippocampalGyrus:
|
||||
|
||||
# 1. 使用 build_readable_messages 生成格式化文本
|
||||
# build_readable_messages 只返回一个字符串,不需要解包
|
||||
input_text = build_readable_messages(
|
||||
input_text = await build_readable_messages(
|
||||
messages,
|
||||
merge_messages=True, # 合并连续消息
|
||||
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
|
||||
@@ -1436,7 +1440,7 @@ class ParahippocampalGyrus:
|
||||
# sourcery skip: merge-list-appends-into-extend
|
||||
logger.info("------------------------------------开始构建记忆--------------------------------------")
|
||||
start_time = time.time()
|
||||
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||
memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||
all_added_nodes = []
|
||||
all_connected_nodes = []
|
||||
all_added_edges = []
|
||||
@@ -1697,7 +1701,7 @@ class HippocampusManager:
|
||||
return self._hippocampus
|
||||
|
||||
self._hippocampus = Hippocampus()
|
||||
self._hippocampus.initialize()
|
||||
# self._hippocampus.initialize() # 改为异步启动
|
||||
self._initialized = True
|
||||
|
||||
# 输出记忆图统计信息
|
||||
@@ -1716,6 +1720,13 @@ class HippocampusManager:
|
||||
|
||||
return self._hippocampus
|
||||
|
||||
async def initialize_async(self):
|
||||
"""异步初始化海马体实例"""
|
||||
if not self._initialized:
|
||||
self.initialize() # 先进行同步部分的初始化
|
||||
self._hippocampus.initialize()
|
||||
await self._hippocampus.entorhinal_cortex.sync_memory_from_db()
|
||||
|
||||
def get_hippocampus(self):
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
|
||||
@@ -246,11 +246,11 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
with get_db_session() as session:
|
||||
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
|
||||
async def _db_find_stream_async(s_id: str):
|
||||
async with get_db_session() as session:
|
||||
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalar()
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
model_instance = await _db_find_stream_async(stream_id)
|
||||
|
||||
if model_instance:
|
||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||
@@ -344,11 +344,10 @@ class ChatManager:
|
||||
return
|
||||
stream_data_dict = stream.to_dict()
|
||||
|
||||
def _db_save_stream_sync(s_data_dict: dict):
|
||||
with get_db_session() as session:
|
||||
async def _db_save_stream_async(s_data_dict: dict):
|
||||
async with get_db_session() as session:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict["platform"],
|
||||
"create_time": s_data_dict["create_time"],
|
||||
@@ -364,8 +363,6 @@ class ChatManager:
|
||||
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
||||
"focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value),
|
||||
}
|
||||
|
||||
# 根据数据库类型选择插入语句
|
||||
if global_config.database.database_type == "sqlite":
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||||
@@ -375,15 +372,13 @@ class ChatManager:
|
||||
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用通用插入,尝试SQLite语法
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||||
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
await _db_save_stream_async(stream_data_dict)
|
||||
stream.saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||||
@@ -397,10 +392,10 @@ class ChatManager:
|
||||
"""从数据库加载所有聊天流"""
|
||||
logger.info("正在从数据库加载所有聊天流")
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
async def _db_load_all_streams_async():
|
||||
loaded_streams_data = []
|
||||
with get_db_session() as session:
|
||||
for model_instance in session.execute(select(ChatStreams)).scalars():
|
||||
async with get_db_session() as session:
|
||||
for model_instance in (await session.execute(select(ChatStreams))).scalars():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
@@ -414,7 +409,6 @@ class ChatManager:
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
@@ -427,11 +421,11 @@ class ChatManager:
|
||||
"focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value),
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return loaded_streams_data
|
||||
|
||||
try:
|
||||
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
|
||||
all_streams_data_list = await _db_load_all_streams_async()
|
||||
self.streams.clear()
|
||||
for data in all_streams_data_list:
|
||||
stream = ChatStream.from_dict(data)
|
||||
|
||||
@@ -42,7 +42,7 @@ class MessageStorage:
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
@@ -129,9 +129,9 @@ class MessageStorage:
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
)
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
session.add(new_message)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
@@ -174,16 +174,18 @@ class MessageStorage:
|
||||
# 使用上下文管理器确保session正确管理
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
matched_message = session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
async with get_db_session() as session:
|
||||
matched_message = (
|
||||
await session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
)
|
||||
).scalar()
|
||||
|
||||
if matched_message:
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
# 会在上下文管理器中自动调用
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
@@ -197,28 +199,36 @@ class MessageStorage:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_image_descriptions(text: str) -> str:
|
||||
async def replace_image_descriptions(text: str) -> str:
|
||||
"""将[图片:描述]替换为[picid:image_id]"""
|
||||
# 先检查文本中是否有图片标记
|
||||
pattern = r"\[图片:([^\]]+)\]"
|
||||
matches = re.findall(pattern, text)
|
||||
matches = list(re.finditer(pattern, text))
|
||||
|
||||
if not matches:
|
||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
||||
return text
|
||||
|
||||
def replace_match(match):
|
||||
new_text = ""
|
||||
last_end = 0
|
||||
for match in matches:
|
||||
new_text += text[last_end : match.start()]
|
||||
description = match.group(1).strip()
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
image_record = session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
async with get_db_session() as session:
|
||||
image_record = (
|
||||
await session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
)
|
||||
).scalar()
|
||||
session.commit()
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
if image_record:
|
||||
new_text += f"[picid:{image_record.image_id}]"
|
||||
else:
|
||||
new_text += match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
|
||||
new_text += match.group(0)
|
||||
last_end = match.end()
|
||||
new_text += text[last_end:]
|
||||
return new_text
|
||||
|
||||
@@ -97,12 +97,12 @@ class ActionModifier:
|
||||
for action_name, reason in chat_type_removals:
|
||||
logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}")
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
|
||||
)
|
||||
chat_content = build_readable_messages(
|
||||
chat_content = await build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
|
||||
@@ -152,7 +152,7 @@ class PlanFilter:
|
||||
)
|
||||
return prompt, message_id_list
|
||||
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
@@ -167,7 +167,7 @@ class PlanFilter:
|
||||
limit=5,
|
||||
)
|
||||
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = build_readable_actions(actions=await actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
@@ -63,7 +63,7 @@ class PlanGenerator:
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size),
|
||||
)
|
||||
chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw]
|
||||
chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw]
|
||||
|
||||
|
||||
plan = Plan(
|
||||
|
||||
@@ -873,7 +873,8 @@ class DefaultReplyer:
|
||||
platform, # type: ignore
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"])
|
||||
person_name = person_info.get("person_name")
|
||||
|
||||
# 如果person_name为None,使用fallback值
|
||||
if person_name is None:
|
||||
@@ -884,7 +885,7 @@ class DefaultReplyer:
|
||||
|
||||
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
current_user_id = person_info.get("user_id")
|
||||
current_platform = reply_message.get("chat_info_platform")
|
||||
|
||||
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
|
||||
@@ -909,18 +910,18 @@ class DefaultReplyer:
|
||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
|
||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size * 1,
|
||||
)
|
||||
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=int(global_config.chat.max_context_size * 0.33),
|
||||
)
|
||||
chat_talking_prompt_short = build_readable_messages(
|
||||
chat_talking_prompt_short = await build_readable_messages(
|
||||
message_list_before_short,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
@@ -932,7 +933,7 @@ class DefaultReplyer:
|
||||
# 获取目标用户信息,用于s4u模式
|
||||
target_user_info = None
|
||||
if sender:
|
||||
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
||||
target_user_info = person_info_manager.get_person_info_by_name(sender)
|
||||
|
||||
from src.chat.utils.prompt import Prompt
|
||||
# 并行执行六个构建任务
|
||||
@@ -1150,12 +1151,12 @@ class DefaultReplyer:
|
||||
else:
|
||||
mood_prompt = ""
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||
)
|
||||
chat_talking_prompt_half = build_readable_messages(
|
||||
chat_talking_prompt_half = await build_readable_messages(
|
||||
message_list_before_now_half,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
|
||||
@@ -116,8 +116,9 @@ async def replace_user_references_async(
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_id # type: ignore
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
return person_info.get("person_name") or user_id
|
||||
|
||||
name_resolver = default_resolver
|
||||
|
||||
@@ -165,7 +166,7 @@ async def replace_user_references_async(
|
||||
return content
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp(
|
||||
async def get_raw_msg_by_timestamp(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -176,10 +177,10 @@ def get_raw_msg_by_timestamp(
|
||||
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat(
|
||||
async def get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
@@ -196,7 +197,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
return find_messages(
|
||||
return await find_messages(
|
||||
message_filter=filter_query,
|
||||
sort=sort_order,
|
||||
limit=limit,
|
||||
@@ -206,7 +207,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
||||
)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
async def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
@@ -223,12 +224,12 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
# 直接将 limit_mode 传递给 find_messages
|
||||
|
||||
return find_messages(
|
||||
return await find_messages(
|
||||
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
||||
)
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_chat_users(
|
||||
async def get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
@@ -247,10 +248,10 @@ def get_raw_msg_by_timestamp_with_chat_users(
|
||||
}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat(
|
||||
async def get_actions_by_timestamp_with_chat(
|
||||
chat_id: str,
|
||||
timestamp_start: float = 0,
|
||||
timestamp_end: float = time.time(),
|
||||
@@ -269,10 +270,10 @@ def get_actions_by_timestamp_with_chat(
|
||||
f"limit={limit}, limit_mode={limit_mode}"
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = session.execute(
|
||||
query = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -302,7 +303,7 @@ def get_actions_by_timestamp_with_chat(
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else: # earliest
|
||||
query = session.execute(
|
||||
query = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -332,7 +333,7 @@ def get_actions_by_timestamp_with_chat(
|
||||
}
|
||||
actions_result.append(action_dict)
|
||||
else:
|
||||
query = session.execute(
|
||||
query = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -363,14 +364,14 @@ def get_actions_by_timestamp_with_chat(
|
||||
return actions_result
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat_inclusive(
|
||||
async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = session.execute(
|
||||
query = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -385,7 +386,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = session.execute(
|
||||
query = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -398,7 +399,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
query = session.execute(
|
||||
query = await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
@@ -414,14 +415,14 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
return [action.__dict__ for action in actions]
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
async def get_raw_msg_by_timestamp_random(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
||||
"""
|
||||
# 获取所有消息,只取chat_id字段
|
||||
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
|
||||
all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
|
||||
if not all_msgs:
|
||||
return []
|
||||
# 随机选一条
|
||||
@@ -429,10 +430,10 @@ def get_raw_msg_by_timestamp_random(
|
||||
chat_id = msg["chat_id"]
|
||||
timestamp_start = msg["time"]
|
||||
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
|
||||
return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_users(
|
||||
async def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
@@ -442,37 +443,39 @@ def get_raw_msg_by_timestamp_with_users(
|
||||
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
|
||||
# 只有当 limit 为 0 时才应用外部 sort
|
||||
sort_order = [("time", 1)] if limit == 0 else None
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
async def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
|
||||
sort_order = [("time", 1)]
|
||||
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||
"""
|
||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||
@@ -486,10 +489,10 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
|
||||
return 0 # 起始时间大于等于结束时间,没有新消息
|
||||
|
||||
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
|
||||
return count_messages(message_filter=filter_query)
|
||||
return await count_messages(message_filter=filter_query)
|
||||
|
||||
|
||||
def num_new_messages_since_with_users(
|
||||
async def num_new_messages_since_with_users(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
|
||||
) -> int:
|
||||
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
||||
@@ -500,10 +503,10 @@ def num_new_messages_since_with_users(
|
||||
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
|
||||
"user_id": {"$in": person_ids},
|
||||
}
|
||||
return count_messages(message_filter=filter_query)
|
||||
return await count_messages(message_filter=filter_query)
|
||||
|
||||
|
||||
def _build_readable_messages_internal(
|
||||
async def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -622,7 +625,8 @@ def _build_readable_messages_internal(
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
person_name = person.person_name or user_id # type: ignore
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
person_name = person_info.get("person_name") # type: ignore
|
||||
|
||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||
if not person_name:
|
||||
@@ -791,7 +795,7 @@ def _build_readable_messages_internal(
|
||||
)
|
||||
|
||||
|
||||
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
@@ -814,9 +818,9 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "[图片内容未知]" # 默认描述
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
|
||||
if image and image.description: # type: ignore
|
||||
async with get_db_session() as session:
|
||||
image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none()
|
||||
if image and image.description: # type: ignore
|
||||
description = image.description
|
||||
except Exception:
|
||||
# 如果查询失败,保持默认描述
|
||||
@@ -912,17 +916,17 @@ async def build_readable_messages_with_list(
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
|
||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
|
||||
return formatted_string, details_list
|
||||
|
||||
|
||||
def build_readable_messages_with_id(
|
||||
async def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -938,7 +942,7 @@ def build_readable_messages_with_id(
|
||||
"""
|
||||
message_id_list = assign_message_ids(messages)
|
||||
|
||||
formatted_string = build_readable_messages(
|
||||
formatted_string = await build_readable_messages(
|
||||
messages=messages,
|
||||
replace_bot_name=replace_bot_name,
|
||||
merge_messages=merge_messages,
|
||||
@@ -953,7 +957,7 @@ def build_readable_messages_with_id(
|
||||
return formatted_string, message_id_list
|
||||
|
||||
|
||||
def build_readable_messages(
|
||||
async def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
@@ -994,24 +998,28 @@ def build_readable_messages(
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
|
||||
actions_in_range = (
|
||||
await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
|
||||
)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
action_after_latest = (
|
||||
await session.execute(
|
||||
select(ActionRecords)
|
||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
).scalars()
|
||||
|
||||
# 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError
|
||||
@@ -1043,7 +1051,7 @@ def build_readable_messages(
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
copy_messages,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1054,7 +1062,7 @@ def build_readable_messages(
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
return f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
else:
|
||||
@@ -1069,7 +1077,7 @@ def build_readable_messages(
|
||||
pic_counter = 1
|
||||
|
||||
# 分别格式化,但使用共享的图片映射
|
||||
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
|
||||
formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
|
||||
messages_before_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1080,7 +1088,7 @@ def build_readable_messages(
|
||||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
@@ -1096,7 +1104,7 @@ def build_readable_messages(
|
||||
|
||||
# 生成图片映射信息
|
||||
if pic_id_mapping:
|
||||
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
else:
|
||||
pic_mapping_info = "聊天记录信息:\n"
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
def find_messages(
|
||||
async def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
limit: int = 0,
|
||||
@@ -46,7 +46,7 @@ def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
@@ -96,7 +96,7 @@ def find_messages(
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
results = (await session.execute(query)).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
@@ -104,7 +104,7 @@ def find_messages(
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
latest_results = session.execute(query).scalars().all()
|
||||
latest_results = (await session.execute(query)).scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
@@ -128,12 +128,12 @@ def find_messages(
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
results = (await session.execute(query)).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
|
||||
return [_model_to_dict(msg) for msg in results]
|
||||
return [_model_to_dict(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
@@ -143,7 +143,7 @@ def find_messages(
|
||||
return []
|
||||
|
||||
|
||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
async def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
@@ -154,7 +154,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
@@ -192,7 +192,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = session.execute(query).scalar()
|
||||
count = (await session.execute(query)).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
|
||||
10
src/main.py
10
src/main.py
@@ -40,6 +40,9 @@ if not global_config.memory.enable_memory:
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
async def initialize_async(self):
|
||||
pass
|
||||
|
||||
def get_hippocampus(self):
|
||||
return None
|
||||
|
||||
@@ -248,7 +251,7 @@ MoFox_Bot(第三方修改版)
|
||||
logger.info("聊天管理器初始化成功")
|
||||
|
||||
# 初始化记忆系统
|
||||
self.hippocampus_manager.initialize()
|
||||
await self.hippocampus_manager.initialize_async()
|
||||
logger.info("记忆系统初始化成功")
|
||||
|
||||
# 初始化LPMM知识库
|
||||
@@ -283,7 +286,7 @@ MoFox_Bot(第三方修改版)
|
||||
if global_config.planning_system.monthly_plan_enable:
|
||||
logger.info("正在初始化月度计划管理器...")
|
||||
try:
|
||||
await monthly_plan_manager.start_monthly_plan_generation()
|
||||
await monthly_plan_manager.initialize()
|
||||
logger.info("月度计划管理器初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"月度计划管理器初始化失败: {e}")
|
||||
@@ -291,8 +294,7 @@ MoFox_Bot(第三方修改版)
|
||||
# 初始化日程管理器
|
||||
if global_config.planning_system.schedule_enable:
|
||||
logger.info("日程表功能已启用,正在初始化管理器...")
|
||||
await schedule_manager.load_or_generate_today_schedule()
|
||||
await schedule_manager.start_daily_schedule_generation()
|
||||
await schedule_manager.initialize()
|
||||
logger.info("日程表管理器初始化成功。")
|
||||
|
||||
try:
|
||||
|
||||
@@ -401,14 +401,15 @@ class PersonInfoManager:
|
||||
|
||||
# # 初始化时读取所有person_name
|
||||
try:
|
||||
pass
|
||||
# 在这里获取会话
|
||||
with get_db_session() as session:
|
||||
for record in session.execute(
|
||||
select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
|
||||
).fetchall():
|
||||
if record.person_name:
|
||||
self.person_name_list[record.person_id] = record.person_name
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
|
||||
# with get_db_session() as session:
|
||||
# for record in session.execute(
|
||||
# select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
|
||||
# ).fetchall():
|
||||
# if record.person_name:
|
||||
# self.person_name_list[record.person_id] = record.person_name
|
||||
# logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
|
||||
except Exception as e:
|
||||
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
|
||||
|
||||
@@ -430,23 +431,25 @@ class PersonInfoManager:
|
||||
"""判断是否认识某人"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_check_known_sync(p_id: str):
|
||||
async def _db_check_known_async(p_id: str):
|
||||
# 在需要时获取会话
|
||||
with get_db_session() as session:
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
|
||||
async with get_db_session() as session:
|
||||
return (
|
||||
await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))
|
||||
).scalar() is not None
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_check_known_sync, person_id)
|
||||
return await _db_check_known_async(person_id)
|
||||
except Exception as e:
|
||||
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
def get_person_id_by_person_name(self, person_name: str) -> str:
|
||||
async def get_person_id_by_person_name(self, person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
try:
|
||||
# 在需要时获取会话
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
|
||||
async with get_db_session() as session:
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name))).scalar()
|
||||
return record.person_id if record else ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
||||
@@ -500,19 +503,18 @@ class PersonInfoManager:
|
||||
final_data[key] = orjson.dumps([]).decode("utf-8")
|
||||
# If it's already a string, assume it's valid JSON or a non-JSON string field
|
||||
|
||||
def _db_create_sync(p_data: dict):
|
||||
with get_db_session() as session:
|
||||
async def _db_create_async(p_data: dict):
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
new_person = PersonInfo(**p_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_create_sync, final_data)
|
||||
await _db_create_async(final_data)
|
||||
|
||||
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
|
||||
"""安全地创建用户信息,处理竞态条件"""
|
||||
@@ -557,11 +559,11 @@ class PersonInfoManager:
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = orjson.dumps([]).decode("utf-8")
|
||||
|
||||
def _db_safe_create_sync(p_data: dict):
|
||||
with get_db_session() as session:
|
||||
async def _db_safe_create_async(p_data: dict):
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
existing = session.execute(
|
||||
select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])
|
||||
existing = (
|
||||
await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"]))
|
||||
).scalar()
|
||||
if existing:
|
||||
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
|
||||
@@ -570,18 +572,17 @@ class PersonInfoManager:
|
||||
# 尝试创建
|
||||
new_person = PersonInfo(**p_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
|
||||
return True # 其他协程已创建,视为成功
|
||||
return True
|
||||
else:
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_safe_create_sync, final_data)
|
||||
await _db_safe_create_async(final_data)
|
||||
|
||||
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
|
||||
"""更新某一个字段,会补全"""
|
||||
@@ -598,37 +599,33 @@ class PersonInfoManager:
|
||||
elif value is None: # Store None as "[]" for JSON list fields
|
||||
processed_value = orjson.dumps([]).decode("utf-8")
|
||||
|
||||
def _db_update_sync(p_id: str, f_name: str, val_to_set):
|
||||
async def _db_update_async(p_id: str, f_name: str, val_to_set):
|
||||
start_time = time.time()
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
query_time = time.time()
|
||||
|
||||
if record:
|
||||
setattr(record, f_name, val_to_set)
|
||||
|
||||
save_time = time.time()
|
||||
|
||||
total_time = save_time - start_time
|
||||
if total_time > 0.5: # 如果超过500ms就记录日志
|
||||
if total_time > 0.5:
|
||||
logger.warning(
|
||||
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return True, False # Found and updated, no creation needed
|
||||
await session.commit()
|
||||
return True, False
|
||||
else:
|
||||
total_time = time.time() - start_time
|
||||
if total_time > 0.5:
|
||||
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
|
||||
return False, True # Not found, needs creation
|
||||
return False, True
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
|
||||
raise
|
||||
|
||||
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
|
||||
found, needs_creation = await _db_update_async(person_id, field_name, processed_value)
|
||||
|
||||
if needs_creation:
|
||||
logger.info(f"{person_id} 不存在,将新建。")
|
||||
@@ -666,13 +663,13 @@ class PersonInfoManager:
|
||||
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。")
|
||||
return False
|
||||
|
||||
def _db_has_field_sync(p_id: str, f_name: str):
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
async def _db_has_field_async(p_id: str, f_name: str):
|
||||
async with get_db_session() as session:
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
return bool(record)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
|
||||
return await _db_has_field_async(person_id, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}")
|
||||
return False
|
||||
@@ -778,14 +775,14 @@ class PersonInfoManager:
|
||||
logger.info(f"尝试给用户{user_nickname} {person_id} 取名,但是 {generated_nickname} 已存在,重试中...")
|
||||
else:
|
||||
|
||||
def _db_check_name_exists_sync(name_to_check):
|
||||
with get_db_session() as session:
|
||||
async def _db_check_name_exists_async(name_to_check):
|
||||
async with get_db_session() as session:
|
||||
return (
|
||||
session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar()
|
||||
(await session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check))).scalar()
|
||||
is not None
|
||||
)
|
||||
|
||||
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||
if await _db_check_name_exists_async(generated_nickname):
|
||||
is_duplicate = True
|
||||
current_name_set.add(generated_nickname)
|
||||
|
||||
@@ -824,91 +821,26 @@ class PersonInfoManager:
|
||||
logger.debug("删除失败:person_id 不能为空")
|
||||
return
|
||||
|
||||
def _db_delete_sync(p_id: str):
|
||||
async def _db_delete_async(p_id: str):
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
async with get_db_session() as session:
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
if record:
|
||||
session.delete(record)
|
||||
session.commit()
|
||||
return 1
|
||||
await session.delete(record)
|
||||
await session.commit()
|
||||
return 1
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
|
||||
return 0
|
||||
|
||||
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
|
||||
deleted_count = await _db_delete_async(person_id)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"删除成功:person_id={person_id} (Peewee)")
|
||||
logger.debug(f"删除成功:person_id={person_id}")
|
||||
else:
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行")
|
||||
|
||||
@staticmethod
|
||||
async def get_value(person_id: str, field_name: str):
|
||||
"""获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
|
||||
|
||||
def _db_get_value_sync(p_id: str, f_name: str):
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
val = getattr(record, f_name, None)
|
||||
if f_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return orjson.loads(val)
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return [] # Default for JSON fields on error
|
||||
elif val is None: # Field exists in DB but is None
|
||||
return [] # Default for JSON fields
|
||||
# If val is already a list/dict (e.g. if somehow set without serialization)
|
||||
return val # Should ideally not happen if update_one_field is always used
|
||||
return val
|
||||
return None # Record not found
|
||||
|
||||
try:
|
||||
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
|
||||
if value_from_db is not None:
|
||||
return value_from_db
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
|
||||
return None # Ultimate fallback
|
||||
except Exception as e:
|
||||
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
||||
# Fallback to default in case of any error during DB access
|
||||
return default_value_for_field if field_name in person_info_default else None
|
||||
|
||||
@staticmethod
|
||||
def get_value_sync(person_id: str, field_name: str):
|
||||
"""同步获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
with get_db_session() as session:
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = []
|
||||
|
||||
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
|
||||
val = getattr(record, field_name, None)
|
||||
if field_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return orjson.loads(val)
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return []
|
||||
elif val is None:
|
||||
return []
|
||||
return val
|
||||
return val
|
||||
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_values(person_id: str, field_names: list) -> dict:
|
||||
@@ -919,11 +851,11 @@ class PersonInfoManager:
|
||||
|
||||
result = {}
|
||||
|
||||
def _db_get_record_sync(p_id: str):
|
||||
with get_db_session() as session:
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
async def _db_get_record_async(p_id: str):
|
||||
async with get_db_session() as session:
|
||||
return (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
|
||||
record = await asyncio.to_thread(_db_get_record_sync, person_id)
|
||||
record = await _db_get_record_async(person_id)
|
||||
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
@@ -960,14 +892,15 @@ class PersonInfoManager:
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
if field_name not in model_fields:
|
||||
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义")
|
||||
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模型中定义")
|
||||
return {}
|
||||
|
||||
def _db_get_specific_sync(f_name: str):
|
||||
async def _db_get_specific_async(f_name: str):
|
||||
found_results = {}
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name)))
|
||||
for record in result.fetchall():
|
||||
value = getattr(record, f_name)
|
||||
if way(value):
|
||||
found_results[record.person_id] = value
|
||||
@@ -978,9 +911,9 @@ class PersonInfoManager:
|
||||
return found_results
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_get_specific_sync, field_name)
|
||||
return await _db_get_specific_async(field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
|
||||
logger.error(f"执行 get_specific_value_list 时出错: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
async def get_or_create_person(
|
||||
@@ -993,40 +926,38 @@ class PersonInfoManager:
|
||||
"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_get_or_create_sync(p_id: str, init_data: dict):
|
||||
async def _db_get_or_create_async(p_id: str, init_data: dict):
|
||||
"""原子性的获取或创建操作"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 首先尝试获取现有记录
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
if record:
|
||||
return record, False # 记录存在,未创建
|
||||
|
||||
# 记录不存在,尝试创建
|
||||
try:
|
||||
new_person = PersonInfo(**init_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
|
||||
return session.execute(
|
||||
select(PersonInfo).where(PersonInfo.person_id == p_id)
|
||||
).scalar(), True # 创建成功
|
||||
except Exception as e:
|
||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
return record, False # 其他协程已创建,返回现有记录
|
||||
# 如果仍然失败,重新抛出异常
|
||||
raise e
|
||||
|
||||
# 记录不存在,尝试创建
|
||||
try:
|
||||
new_person = PersonInfo(**init_data)
|
||||
session.add(new_person)
|
||||
await session.commit()
|
||||
await session.refresh(new_person)
|
||||
return new_person, True # 创建成功
|
||||
except Exception as e:
|
||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
||||
record = (await session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id))).scalar()
|
||||
if record:
|
||||
return record, False # 其他协程已创建,返回现有记录
|
||||
# 如果仍然失败,重新抛出异常
|
||||
raise e
|
||||
|
||||
unique_nickname = await self._generate_unique_person_name(nickname)
|
||||
initial_data = {
|
||||
"person_id": person_id,
|
||||
"platform": platform,
|
||||
"user_id": str(user_id),
|
||||
"nickname": nickname,
|
||||
"person_name": unique_nickname, # 使用群昵称作为person_name
|
||||
"person_name": unique_nickname,
|
||||
"name_reason": "从群昵称获取",
|
||||
"know_times": 0,
|
||||
"know_since": int(datetime.datetime.now().timestamp()),
|
||||
@@ -1036,7 +967,6 @@ class PersonInfoManager:
|
||||
"forgotten_points": [],
|
||||
}
|
||||
|
||||
# 序列化JSON字段
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in initial_data:
|
||||
if isinstance(initial_data[key], (list, dict)):
|
||||
@@ -1044,15 +974,14 @@ class PersonInfoManager:
|
||||
elif initial_data[key] is None:
|
||||
initial_data[key] = orjson.dumps([]).decode("utf-8")
|
||||
|
||||
# 获取 SQLAlchemy 模odel的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
|
||||
record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data)
|
||||
|
||||
if was_created:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
|
||||
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。")
|
||||
logger.info(f"已为 {person_id} 创建新记录,初始数据: {filtered_initial_data}")
|
||||
else:
|
||||
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
|
||||
|
||||
@@ -1072,11 +1001,13 @@ class PersonInfoManager:
|
||||
|
||||
if not found_person_id:
|
||||
|
||||
def _db_find_by_name_sync(p_name_to_find: str):
|
||||
with get_db_session() as session:
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
|
||||
async def _db_find_by_name_async(p_name_to_find: str):
|
||||
async with get_db_session() as session:
|
||||
return (
|
||||
await session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find))
|
||||
).scalar()
|
||||
|
||||
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
|
||||
record = await _db_find_by_name_async(person_name)
|
||||
if record:
|
||||
found_person_id = record.person_id
|
||||
if (
|
||||
|
||||
457
src/person_info/relationship_fetcher.py
Normal file
457
src/person_info/relationship_fetcher.py
Normal file
@@ -0,0 +1,457 @@
|
||||
import time
|
||||
import traceback
|
||||
import orjson
|
||||
import random
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
|
||||
logger = get_logger("relationship_fetcher")
|
||||
|
||||
|
||||
def init_real_time_info_prompts():
|
||||
"""初始化实时信息提取相关的提示词"""
|
||||
relationship_prompt = """
|
||||
<聊天记录>
|
||||
{chat_observe_info}
|
||||
</聊天记录>
|
||||
|
||||
{name_block}
|
||||
现在,你想要回复{person_name}的消息,消息内容是:{target_message}。请根据聊天记录和你要回复的消息,从你对{person_name}的了解中提取有关的信息:
|
||||
1.你需要提供你想要提取的信息具体是哪方面的信息,例如:年龄,性别,你们之间的交流方式,最近发生的事等等。
|
||||
2.请注意,请不要重复调取相同的信息,已经调取的信息如下:
|
||||
{info_cache_block}
|
||||
3.如果当前聊天记录中没有需要查询的信息,或者现有信息已经足够回复,请返回{{"none": "不需要查询"}}
|
||||
|
||||
请以json格式输出,例如:
|
||||
|
||||
{{
|
||||
"info_type": "信息类型",
|
||||
}}
|
||||
|
||||
请严格按照json输出格式,不要输出多余内容:
|
||||
"""
|
||||
Prompt(relationship_prompt, "real_time_info_identify_prompt")
|
||||
|
||||
fetch_info_prompt = """
|
||||
|
||||
{name_block}
|
||||
以下是你在之前与{person_name}的交流中,产生的对{person_name}的了解:
|
||||
{person_impression_block}
|
||||
{points_text_block}
|
||||
|
||||
请从中提取用户"{person_name}"的有关"{info_type}"信息
|
||||
请以json格式输出,例如:
|
||||
|
||||
{{
|
||||
{info_json_str}
|
||||
}}
|
||||
|
||||
请严格按照json输出格式,不要输出多余内容:
|
||||
"""
|
||||
Prompt(fetch_info_prompt, "real_time_fetch_person_info_prompt")
|
||||
|
||||
|
||||
class RelationshipFetcher:
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
|
||||
# 信息获取缓存:记录正在获取的信息请求
|
||||
self.info_fetching_cache: List[Dict[str, Any]] = []
|
||||
|
||||
# 信息结果缓存:存储已获取的信息结果,带TTL
|
||||
self.info_fetched_cache: Dict[str, Dict[str, Any]] = {}
|
||||
# 结构:{person_id: {info_type: {"info": str, "ttl": int, "start_time": float, "person_name": str, "unknown": bool}}}
|
||||
|
||||
# LLM模型配置
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher"
|
||||
)
|
||||
|
||||
# 小模型用于即时信息提取
|
||||
self.instant_llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="relation.fetch"
|
||||
)
|
||||
|
||||
name = get_chat_manager().get_stream_name(self.chat_id)
|
||||
self.log_prefix = f"[{name}] 实时信息"
|
||||
|
||||
def _cleanup_expired_cache(self):
|
||||
"""清理过期的信息缓存"""
|
||||
for person_id in list(self.info_fetched_cache.keys()):
|
||||
for info_type in list(self.info_fetched_cache[person_id].keys()):
|
||||
self.info_fetched_cache[person_id][info_type]["ttl"] -= 1
|
||||
if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0:
|
||||
del self.info_fetched_cache[person_id][info_type]
|
||||
if not self.info_fetched_cache[person_id]:
|
||||
del self.info_fetched_cache[person_id]
|
||||
|
||||
async def build_relation_info(self, person_id, points_num=3):
|
||||
# 清理过期的信息缓存
|
||||
self._cleanup_expired_cache()
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_info = await person_info_manager.get_values(
|
||||
person_id, ["person_name", "short_impression", "nickname", "platform", "points"]
|
||||
)
|
||||
person_name = person_info.get("person_name")
|
||||
short_impression = person_info.get("short_impression")
|
||||
nickname_str = person_info.get("nickname")
|
||||
platform = person_info.get("platform")
|
||||
|
||||
if person_name == nickname_str and not short_impression:
|
||||
return ""
|
||||
|
||||
current_points = person_info.get("points") or []
|
||||
|
||||
# 按时间排序forgotten_points
|
||||
current_points.sort(key=lambda x: x[2])
|
||||
# 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
||||
if len(current_points) > points_num:
|
||||
# point[1] 取值范围1-10,直接作为权重
|
||||
weights = [max(1, min(10, int(point[1]))) for point in current_points]
|
||||
# 使用加权采样不放回,保证不重复
|
||||
indices = list(range(len(current_points)))
|
||||
points = []
|
||||
for _ in range(points_num):
|
||||
if not indices:
|
||||
break
|
||||
sub_weights = [weights[i] for i in indices]
|
||||
chosen_idx = random.choices(indices, weights=sub_weights, k=1)[0]
|
||||
points.append(current_points[chosen_idx])
|
||||
indices.remove(chosen_idx)
|
||||
else:
|
||||
points = current_points
|
||||
|
||||
# 构建points文本
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
|
||||
nickname_str = ""
|
||||
if person_name != nickname_str:
|
||||
nickname_str = f"(ta在{platform}上的昵称是{nickname_str})"
|
||||
|
||||
relation_info = ""
|
||||
|
||||
if short_impression and relation_info:
|
||||
if points_text:
|
||||
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}。你还记得ta最近做的事:{points_text}"
|
||||
else:
|
||||
relation_info = (
|
||||
f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}"
|
||||
)
|
||||
elif short_impression:
|
||||
if points_text:
|
||||
relation_info = (
|
||||
f"你对{person_name}的印象是{nickname_str}:{short_impression}。你还记得ta最近做的事:{points_text}"
|
||||
)
|
||||
else:
|
||||
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}"
|
||||
elif relation_info:
|
||||
if points_text:
|
||||
relation_info = (
|
||||
f"你对{person_name}的了解{nickname_str}:{relation_info}。你还记得ta最近做的事:{points_text}"
|
||||
)
|
||||
else:
|
||||
relation_info = f"你对{person_name}的了解{nickname_str}:{relation_info}"
|
||||
elif points_text:
|
||||
relation_info = f"你记得{person_name}{nickname_str}最近做的事:{points_text}"
|
||||
else:
|
||||
relation_info = ""
|
||||
|
||||
return relation_info
|
||||
|
||||
async def _build_fetch_query(self, person_id, target_message, chat_history):
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
||||
person_name: str = person_info.get("person_name") # type: ignore
|
||||
|
||||
info_cache_block = self._build_info_cache_block()
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("real_time_info_identify_prompt")).format(
|
||||
chat_observe_info=chat_history,
|
||||
name_block=name_block,
|
||||
info_cache_block=info_cache_block,
|
||||
person_name=person_name,
|
||||
target_message=target_message,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(f"{self.log_prefix} 信息识别prompt: \n{prompt}\n")
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
if content:
|
||||
content_json = orjson.loads(repair_json(content))
|
||||
|
||||
# 检查是否返回了不需要查询的标志
|
||||
if "none" in content_json:
|
||||
logger.debug(f"{self.log_prefix} LLM判断当前不需要查询任何信息:{content_json.get('none', '')}")
|
||||
return None
|
||||
|
||||
if info_type := content_json.get("info_type"):
|
||||
# 记录信息获取请求
|
||||
self.info_fetching_cache.append(
|
||||
{
|
||||
"person_id": get_person_info_manager().get_person_id_by_person_name(person_name),
|
||||
"person_name": person_name,
|
||||
"info_type": info_type,
|
||||
"start_time": time.time(),
|
||||
"forget": False,
|
||||
}
|
||||
)
|
||||
|
||||
# 限制缓存大小
|
||||
if len(self.info_fetching_cache) > 10:
|
||||
self.info_fetching_cache.pop(0)
|
||||
|
||||
logger.info(f"{self.log_prefix} 识别到需要调取用户 {person_name} 的[{info_type}]信息")
|
||||
return info_type
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} LLM未返回有效的info_type。响应: {content}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行信息识别LLM请求时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
return None
|
||||
|
||||
def _build_info_cache_block(self) -> str:
|
||||
"""构建已获取信息的缓存块"""
|
||||
info_cache_block = ""
|
||||
if self.info_fetching_cache:
|
||||
# 对于每个(person_id, info_type)组合,只保留最新的记录
|
||||
latest_records = {}
|
||||
for info_fetching in self.info_fetching_cache:
|
||||
key = (info_fetching["person_id"], info_fetching["info_type"])
|
||||
if key not in latest_records or info_fetching["start_time"] > latest_records[key]["start_time"]:
|
||||
latest_records[key] = info_fetching
|
||||
|
||||
# 按时间排序并生成显示文本
|
||||
sorted_records = sorted(latest_records.values(), key=lambda x: x["start_time"])
|
||||
for info_fetching in sorted_records:
|
||||
info_cache_block += (
|
||||
f"你已经调取了[{info_fetching['person_name']}]的[{info_fetching['info_type']}]信息\n"
|
||||
)
|
||||
return info_cache_block
|
||||
|
||||
async def _extract_single_info(self, person_id: str, info_type: str, person_name: str):
|
||||
"""提取单个信息类型
|
||||
|
||||
Args:
|
||||
person_id: 用户ID
|
||||
info_type: 信息类型
|
||||
person_name: 用户名
|
||||
"""
|
||||
start_time = time.time()
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 首先检查 info_list 缓存
|
||||
person_info = await person_info_manager.get_values(person_id, ["info_list"])
|
||||
info_list = person_info.get("info_list") or []
|
||||
cached_info = None
|
||||
|
||||
# 查找对应的 info_type
|
||||
for info_item in info_list:
|
||||
if info_item.get("info_type") == info_type:
|
||||
cached_info = info_item.get("info_content")
|
||||
logger.debug(f"{self.log_prefix} 在info_list中找到 {person_name} 的 {info_type} 信息: {cached_info}")
|
||||
break
|
||||
|
||||
# 如果缓存中有信息,直接使用
|
||||
if cached_info:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info": cached_info,
|
||||
"ttl": 2,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
"unknown": cached_info == "none",
|
||||
}
|
||||
logger.info(f"{self.log_prefix} 记得 {person_name} 的 {info_type}: {cached_info}")
|
||||
return
|
||||
|
||||
# 如果缓存中没有,尝试从用户档案中提取
|
||||
try:
|
||||
person_info = await person_info_manager.get_values(person_id, ["impression", "points"])
|
||||
person_impression = person_info.get("impression")
|
||||
points = person_info.get("points")
|
||||
|
||||
# 构建印象信息块
|
||||
if person_impression:
|
||||
person_impression_block = (
|
||||
f"<对{person_name}的总体了解>\n{person_impression}\n</对{person_name}的总体了解>"
|
||||
)
|
||||
else:
|
||||
person_impression_block = ""
|
||||
|
||||
# 构建要点信息块
|
||||
if points:
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
points_text_block = f"<对{person_name}的近期了解>\n{points_text}\n</对{person_name}的近期了解>"
|
||||
else:
|
||||
points_text_block = ""
|
||||
|
||||
# 如果完全没有用户信息
|
||||
if not points_text_block and not person_impression_block:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info": "none",
|
||||
"ttl": 2,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
"unknown": True,
|
||||
}
|
||||
logger.info(f"{self.log_prefix} 完全不认识 {person_name}")
|
||||
await self._save_info_to_cache(person_id, info_type, "none")
|
||||
return
|
||||
|
||||
# 使用LLM提取信息
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("real_time_fetch_person_info_prompt")).format(
|
||||
name_block=name_block,
|
||||
info_type=info_type,
|
||||
person_impression_block=person_impression_block,
|
||||
person_name=person_name,
|
||||
info_json_str=f'"{info_type}": "有关{info_type}的信息内容"',
|
||||
points_text_block=points_text_block,
|
||||
)
|
||||
|
||||
# 使用小模型进行即时提取
|
||||
content, _ = await self.instant_llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
if content:
|
||||
content_json = orjson.loads(repair_json(content))
|
||||
if info_type in content_json:
|
||||
info_content = content_json[info_type]
|
||||
is_unknown = info_content == "none" or not info_content
|
||||
|
||||
# 保存到运行时缓存
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info": "unknown" if is_unknown else info_content,
|
||||
"ttl": 3,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
"unknown": is_unknown,
|
||||
}
|
||||
|
||||
# 保存到持久化缓存 (info_list)
|
||||
await self._save_info_to_cache(person_id, info_type, "none" if is_unknown else info_content)
|
||||
|
||||
if not is_unknown:
|
||||
logger.info(f"{self.log_prefix} 思考得到,{person_name} 的 {info_type}: {info_content}")
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 思考了也不知道{person_name} 的 {info_type} 信息")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
|
||||
# sourcery skip: use-next
|
||||
"""将提取到的信息保存到 person_info 的 info_list 字段中
|
||||
|
||||
Args:
|
||||
person_id: 用户ID
|
||||
info_type: 信息类型
|
||||
info_content: 信息内容
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 获取现有的 info_list
|
||||
person_info = await person_info_manager.get_values(person_id, ["info_list"])
|
||||
info_list = person_info.get("info_list") or []
|
||||
|
||||
# 查找是否已存在相同 info_type 的记录
|
||||
found_index = -1
|
||||
for i, info_item in enumerate(info_list):
|
||||
if isinstance(info_item, dict) and info_item.get("info_type") == info_type:
|
||||
found_index = i
|
||||
break
|
||||
|
||||
# 创建新的信息记录
|
||||
new_info_item = {
|
||||
"info_type": info_type,
|
||||
"info_content": info_content,
|
||||
}
|
||||
|
||||
if found_index >= 0:
|
||||
# 更新现有记录
|
||||
info_list[found_index] = new_info_item
|
||||
logger.info(f"{self.log_prefix} [缓存更新] 更新 {person_id} 的 {info_type} 信息缓存")
|
||||
else:
|
||||
# 添加新记录
|
||||
info_list.append(new_info_item)
|
||||
logger.info(f"{self.log_prefix} [缓存保存] 新增 {person_id} 的 {info_type} 信息缓存")
|
||||
|
||||
# 保存更新后的 info_list
|
||||
await person_info_manager.update_one_field(person_id, "info_list", info_list)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} [缓存保存] 保存信息到缓存失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
class RelationshipFetcherManager:
|
||||
"""关系提取器管理器
|
||||
|
||||
管理不同 chat_id 的 RelationshipFetcher 实例
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._fetchers: Dict[str, RelationshipFetcher] = {}
|
||||
|
||||
def get_fetcher(self, chat_id: str) -> RelationshipFetcher:
|
||||
"""获取或创建指定 chat_id 的 RelationshipFetcher
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
RelationshipFetcher: 关系提取器实例
|
||||
"""
|
||||
if chat_id not in self._fetchers:
|
||||
self._fetchers[chat_id] = RelationshipFetcher(chat_id)
|
||||
return self._fetchers[chat_id]
|
||||
|
||||
def remove_fetcher(self, chat_id: str):
|
||||
"""移除指定 chat_id 的 RelationshipFetcher
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
"""
|
||||
if chat_id in self._fetchers:
|
||||
del self._fetchers[chat_id]
|
||||
|
||||
def clear_all(self):
|
||||
"""清空所有 RelationshipFetcher"""
|
||||
self._fetchers.clear()
|
||||
|
||||
def get_active_chat_ids(self) -> List[str]:
|
||||
"""获取所有活跃的 chat_id 列表"""
|
||||
return list(self._fetchers.keys())
|
||||
|
||||
|
||||
# 全局管理器实例
|
||||
relationship_fetcher_manager = RelationshipFetcherManager()
|
||||
|
||||
|
||||
init_real_time_info_prompts()
|
||||
@@ -62,7 +62,7 @@ def get_messages_by_time(
|
||||
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat(
|
||||
async def get_messages_by_time_in_chat(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
@@ -97,13 +97,13 @@ def get_messages_by_time_in_chat(
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(
|
||||
get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
return await filter_mai_messages(
|
||||
await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
)
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
return await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_inclusive(
|
||||
async def get_messages_by_time_in_chat_inclusive(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
@@ -138,12 +138,12 @@ def get_messages_by_time_in_chat_inclusive(
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
return await filter_mai_messages(
|
||||
await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id, start_time, end_time, limit, limit_mode, filter_command
|
||||
)
|
||||
)
|
||||
return get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
return await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id, start_time, end_time, limit, limit_mode, filter_command
|
||||
)
|
||||
|
||||
@@ -478,7 +478,7 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
async def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从消息列表中移除麦麦的消息
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# mmc/src/schedule/database.py
|
||||
|
||||
from typing import List
|
||||
from sqlalchemy import select, func, update, delete
|
||||
from src.common.database.sqlalchemy_models import MonthlyPlan, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -8,21 +9,22 @@ from src.config.config import global_config
|
||||
logger = get_logger("schedule_database")
|
||||
|
||||
|
||||
def add_new_plans(plans: List[str], month: str):
|
||||
async def add_new_plans(plans: List[str], month: str):
|
||||
"""
|
||||
批量添加新生成的月度计划到数据库,并确保不超过上限。
|
||||
|
||||
:param plans: 计划内容列表。
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 1. 获取当前有效计划数量(状态为 'active')
|
||||
current_plan_count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.count()
|
||||
result = await session.execute(
|
||||
select(func.count(MonthlyPlan.id)).where(
|
||||
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
|
||||
)
|
||||
)
|
||||
current_plan_count = result.scalar_one()
|
||||
|
||||
# 2. 从配置获取上限
|
||||
max_plans = global_config.planning_system.max_plans_per_month
|
||||
@@ -41,7 +43,7 @@ def add_new_plans(plans: List[str], month: str):
|
||||
MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add
|
||||
]
|
||||
session.add_all(new_plan_objects)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。")
|
||||
if len(plans) > len(plans_to_add):
|
||||
@@ -49,32 +51,31 @@ def add_new_plans(plans: List[str], month: str):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加月度计划时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
"""
|
||||
获取指定月份所有状态为 'active' 的计划。
|
||||
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
:return: MonthlyPlan 对象列表。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
plans = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
result = await session.execute(
|
||||
select(MonthlyPlan)
|
||||
.where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.order_by(MonthlyPlan.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return plans
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def mark_plans_completed(plan_ids: List[int]):
|
||||
async def mark_plans_completed(plan_ids: List[int]):
|
||||
"""
|
||||
将指定ID的计划标记为已完成。
|
||||
|
||||
@@ -83,9 +84,10 @@ def mark_plans_completed(plan_ids: List[int]):
|
||||
if not plan_ids:
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
plans_to_mark = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all()
|
||||
result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
|
||||
plans_to_mark = result.scalars().all()
|
||||
if not plans_to_mark:
|
||||
logger.info("没有需要标记为完成的月度计划。")
|
||||
return
|
||||
@@ -93,17 +95,17 @@ def mark_plans_completed(plan_ids: List[int]):
|
||||
plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)])
|
||||
logger.info(f"以下 {len(plans_to_mark)} 条月度计划将被标记为已完成:\n{plan_details}")
|
||||
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update(
|
||||
{"status": "completed"}, synchronize_session=False
|
||||
await session.execute(
|
||||
update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed")
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"标记月度计划为完成时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def delete_plans_by_ids(plan_ids: List[int]):
|
||||
async def delete_plans_by_ids(plan_ids: List[int]):
|
||||
"""
|
||||
根据ID列表从数据库中物理删除月度计划。
|
||||
|
||||
@@ -112,10 +114,11 @@ def delete_plans_by_ids(plan_ids: List[int]):
|
||||
if not plan_ids:
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 先查询要删除的计划,用于日志记录
|
||||
plans_to_delete = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all()
|
||||
result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
|
||||
plans_to_delete = result.scalars().all()
|
||||
if not plans_to_delete:
|
||||
logger.info("没有找到需要删除的月度计划。")
|
||||
return
|
||||
@@ -124,16 +127,16 @@ def delete_plans_by_ids(plan_ids: List[int]):
|
||||
logger.info(f"检测到月度计划超额,将删除以下 {len(plans_to_delete)} 条计划:\n{plan_details}")
|
||||
|
||||
# 执行删除
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
await session.execute(delete(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除月度计划时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
async def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
"""
|
||||
更新计划的使用统计信息。
|
||||
|
||||
@@ -143,44 +146,47 @@ def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
if not plan_ids:
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 获取完成阈值配置,如果不存在则使用默认值
|
||||
completion_threshold = getattr(global_config.planning_system, "completion_threshold", 3)
|
||||
|
||||
# 批量更新使用次数和最后使用日期
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update(
|
||||
{"usage_count": MonthlyPlan.usage_count + 1, "last_used_date": used_date}, synchronize_session=False
|
||||
await session.execute(
|
||||
update(MonthlyPlan)
|
||||
.where(MonthlyPlan.id.in_(plan_ids))
|
||||
.values(usage_count=MonthlyPlan.usage_count + 1, last_used_date=used_date)
|
||||
)
|
||||
|
||||
# 检查是否有计划达到完成阈值
|
||||
plans_to_complete = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(MonthlyPlan).where(
|
||||
MonthlyPlan.id.in_(plan_ids),
|
||||
MonthlyPlan.usage_count >= completion_threshold,
|
||||
MonthlyPlan.status == "active",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
plans_to_complete = result.scalars().all()
|
||||
|
||||
if plans_to_complete:
|
||||
completed_ids = [plan.id for plan in plans_to_complete]
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(completed_ids)).update(
|
||||
{"status": "completed"}, synchronize_session=False
|
||||
await session.execute(
|
||||
update(MonthlyPlan).where(MonthlyPlan.id.in_(completed_ids)).values(status="completed")
|
||||
)
|
||||
|
||||
logger.info(f"计划 {completed_ids} 已达到使用阈值 ({completion_threshold}),标记为已完成。")
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"成功更新了 {len(plan_ids)} 条月度计划的使用统计。")
|
||||
except Exception as e:
|
||||
logger.error(f"更新月度计划使用统计时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]:
|
||||
async def get_smart_plans_for_daily_schedule(
|
||||
month: str, max_count: int = 3, avoid_days: int = 7
|
||||
) -> List[MonthlyPlan]:
|
||||
"""
|
||||
智能抽取月度计划用于每日日程生成。
|
||||
|
||||
@@ -196,19 +202,24 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 计算避免重复的日期阈值
|
||||
avoid_date = (datetime.now() - timedelta(days=avoid_days)).strftime("%Y-%m-%d")
|
||||
|
||||
# 查询符合条件的计划
|
||||
query = session.query(MonthlyPlan).filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
query = select(MonthlyPlan).where(
|
||||
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
|
||||
)
|
||||
|
||||
# 排除最近使用过的计划
|
||||
query = query.filter((MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date))
|
||||
query = query.where(
|
||||
(MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date)
|
||||
)
|
||||
|
||||
# 按使用次数升序排列,优先选择使用次数少的
|
||||
plans = query.order_by(MonthlyPlan.usage_count.asc()).all()
|
||||
result = await session.execute(query.order_by(MonthlyPlan.usage_count.asc()))
|
||||
plans = result.scalars().all()
|
||||
|
||||
if not plans:
|
||||
logger.info(f"没有找到符合条件的 {month} 月度计划。")
|
||||
@@ -228,31 +239,31 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day
|
||||
return []
|
||||
|
||||
|
||||
def archive_active_plans_for_month(month: str):
|
||||
async def archive_active_plans_for_month(month: str):
|
||||
"""
|
||||
将指定月份所有状态为 'active' 的计划归档为 'archived'。
|
||||
通常在月底调用。
|
||||
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
updated_count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.update({"status": "archived"}, synchronize_session=False)
|
||||
result = await session.execute(
|
||||
update(MonthlyPlan)
|
||||
.where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.values(status="archived")
|
||||
)
|
||||
|
||||
session.commit()
|
||||
updated_count = result.rowcount
|
||||
await session.commit()
|
||||
logger.info(f"成功将 {updated_count} 条 {month} 的活跃月度计划归档。")
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
logger.error(f"归档 {month} 的月度计划时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
"""
|
||||
获取指定月份所有状态为 'archived' 的计划。
|
||||
用于生成下个月计划时的参考。
|
||||
@@ -260,34 +271,34 @@ def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
:return: MonthlyPlan 对象列表。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
plans = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived")
|
||||
.all()
|
||||
result = await session.execute(
|
||||
select(MonthlyPlan).where(
|
||||
MonthlyPlan.target_month == month, MonthlyPlan.status == "archived"
|
||||
)
|
||||
)
|
||||
return plans
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询 {month} 的归档月度计划时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def has_active_plans(month: str) -> bool:
|
||||
async def has_active_plans(month: str) -> bool:
|
||||
"""
|
||||
检查指定月份是否存在任何状态为 'active' 的计划。
|
||||
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
:return: 如果存在则返回 True,否则返回 False。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.count()
|
||||
result = await session.execute(
|
||||
select(func.count(MonthlyPlan.id)).where(
|
||||
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
|
||||
)
|
||||
)
|
||||
return count > 0
|
||||
return result.scalar_one() > 0
|
||||
except Exception as e:
|
||||
logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}")
|
||||
return False
|
||||
@@ -14,6 +14,11 @@ class MonthlyPlanManager:
|
||||
self.plan_manager = PlanManager()
|
||||
self.monthly_task_started = False
|
||||
|
||||
async def initialize(self):
|
||||
logger.info("正在初始化月度计划管理器...")
|
||||
await self.start_monthly_plan_generation()
|
||||
logger.info("月度计划管理器初始化成功")
|
||||
|
||||
async def start_monthly_plan_generation(self):
|
||||
if not self.monthly_task_started:
|
||||
logger.info(" 正在启动每月月度计划生成任务...")
|
||||
|
||||
@@ -28,20 +28,20 @@ class PlanManager:
|
||||
if target_month is None:
|
||||
target_month = datetime.now().strftime("%Y-%m")
|
||||
|
||||
if not has_active_plans(target_month):
|
||||
if not await has_active_plans(target_month):
|
||||
logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。")
|
||||
generation_successful = await self._generate_monthly_plans_logic(target_month)
|
||||
return generation_successful
|
||||
else:
|
||||
logger.info(f"{target_month} 已存在有效的月度计划。")
|
||||
plans = get_active_plans_for_month(target_month)
|
||||
plans = await get_active_plans_for_month(target_month)
|
||||
max_plans = global_config.planning_system.max_plans_per_month
|
||||
if len(plans) > max_plans:
|
||||
logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。")
|
||||
plans_to_delete = plans[: len(plans) - max_plans]
|
||||
delete_ids = [p.id for p in plans_to_delete]
|
||||
delete_plans_by_ids(delete_ids) # type: ignore
|
||||
plans = get_active_plans_for_month(target_month)
|
||||
await delete_plans_by_ids(delete_ids) # type: ignore
|
||||
plans = await get_active_plans_for_month(target_month)
|
||||
|
||||
if plans:
|
||||
plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)])
|
||||
@@ -64,11 +64,11 @@ class PlanManager:
|
||||
return False
|
||||
|
||||
last_month = self._get_previous_month(target_month)
|
||||
archived_plans = get_archived_plans_for_month(last_month)
|
||||
archived_plans = await get_archived_plans_for_month(last_month)
|
||||
plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans)
|
||||
|
||||
if plans:
|
||||
add_new_plans(plans, target_month)
|
||||
await add_new_plans(plans, target_month)
|
||||
logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。")
|
||||
return True
|
||||
else:
|
||||
@@ -95,11 +95,11 @@ class PlanManager:
|
||||
if target_month is None:
|
||||
target_month = datetime.now().strftime("%Y-%m")
|
||||
logger.info(f" 开始归档 {target_month} 的活跃月度计划...")
|
||||
archived_count = archive_active_plans_for_month(target_month)
|
||||
archived_count = await archive_active_plans_for_month(target_month)
|
||||
logger.info(f" 成功归档了 {archived_count} 条 {target_month} 的月度计划。")
|
||||
except Exception as e:
|
||||
logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}")
|
||||
|
||||
def get_plans_for_schedule(self, month: str, max_count: int) -> List:
|
||||
async def get_plans_for_schedule(self, month: str, max_count: int) -> List:
|
||||
avoid_days = global_config.planning_system.avoid_repetition_days
|
||||
return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
|
||||
return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
|
||||
@@ -23,6 +23,13 @@ class ScheduleManager:
|
||||
self.daily_task_started = False
|
||||
self.schedule_generation_running = False
|
||||
|
||||
async def initialize(self):
|
||||
if global_config.planning_system.schedule_enable:
|
||||
logger.info("日程表功能已启用,正在初始化管理器...")
|
||||
await self.load_or_generate_today_schedule()
|
||||
await self.start_daily_schedule_generation()
|
||||
logger.info("日程表管理器初始化成功。")
|
||||
|
||||
async def start_daily_schedule_generation(self):
|
||||
if not self.daily_task_started:
|
||||
logger.info("正在启动每日日程生成任务...")
|
||||
@@ -40,7 +47,7 @@ class ScheduleManager:
|
||||
|
||||
today_str = datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
schedule_data = self._load_schedule_from_db(today_str)
|
||||
schedule_data = await self._load_schedule_from_db(today_str)
|
||||
if schedule_data:
|
||||
self.today_schedule = schedule_data
|
||||
self._log_loaded_schedule(today_str)
|
||||
@@ -54,9 +61,10 @@ class ScheduleManager:
|
||||
logger.info("尝试生成日程作为备用方案...")
|
||||
await self.generate_and_save_schedule()
|
||||
|
||||
def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]:
|
||||
with get_db_session() as session:
|
||||
schedule_record = session.query(Schedule).filter(Schedule.date == date_str).first()
|
||||
async def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(Schedule).filter(Schedule.date == date_str))
|
||||
schedule_record = result.scalars().first()
|
||||
if schedule_record:
|
||||
logger.info(f"从数据库加载今天的日程 ({date_str})。")
|
||||
schedule_data = orjson.loads(str(schedule_record.schedule_data))
|
||||
@@ -90,35 +98,35 @@ class ScheduleManager:
|
||||
sampled_plans = []
|
||||
if global_config.planning_system.monthly_plan_enable:
|
||||
await self.plan_manager.ensure_and_generate_plans_if_needed(current_month_str)
|
||||
sampled_plans = self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3)
|
||||
sampled_plans = await self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3)
|
||||
|
||||
schedule_data = await self.llm_generator.generate_schedule_with_llm(sampled_plans)
|
||||
|
||||
if schedule_data:
|
||||
self._save_schedule_to_db(today_str, schedule_data)
|
||||
await self._save_schedule_to_db(today_str, schedule_data)
|
||||
self.today_schedule = schedule_data
|
||||
self._log_generated_schedule(today_str, schedule_data)
|
||||
|
||||
if sampled_plans:
|
||||
used_plan_ids = [plan.id for plan in sampled_plans]
|
||||
logger.info(f"更新使用过的月度计划 {used_plan_ids} 的统计信息。")
|
||||
update_plan_usage(used_plan_ids, today_str)
|
||||
await update_plan_usage(used_plan_ids, today_str)
|
||||
finally:
|
||||
self.schedule_generation_running = False
|
||||
logger.info("日程生成任务结束")
|
||||
|
||||
def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]):
|
||||
with get_db_session() as session:
|
||||
async def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]):
|
||||
async with get_db_session() as session:
|
||||
schedule_json = orjson.dumps(schedule_data).decode("utf-8")
|
||||
existing_schedule = session.query(Schedule).filter(Schedule.date == date_str).first()
|
||||
result = await session.execute(select(Schedule).filter(Schedule.date == date_str))
|
||||
existing_schedule = result.scalars().first()
|
||||
if existing_schedule:
|
||||
session.query(Schedule).filter(Schedule.date == date_str).update(
|
||||
{Schedule.schedule_data: schedule_json, Schedule.updated_at: datetime.now()}
|
||||
)
|
||||
existing_schedule.schedule_data = schedule_json
|
||||
existing_schedule.updated_at = datetime.now()
|
||||
else:
|
||||
new_schedule = Schedule(date=date_str, schedule_data=schedule_json)
|
||||
session.add(new_schedule)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
def _log_generated_schedule(self, date_str: str, schedule_data: List[Dict[str, Any]]):
|
||||
schedule_str = f"✅ 成功生成并保存今天的日程 ({date_str}):\n"
|
||||
|
||||
Reference in New Issue
Block a user