三次修改

This commit is contained in:
tt-P607
2025-09-20 02:21:53 +08:00
parent 6a98ae6208
commit 0cc4f5bb27
20 changed files with 478 additions and 481 deletions

View File

@@ -361,7 +361,7 @@ class HeartFChatting:
# 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收
filter_command_flag = not (is_sleeping or is_in_insomnia)
recent_messages = message_api.get_messages_by_time_in_chat(
recent_messages = await message_api.get_messages_by_time_in_chat(
chat_id=self.context.stream_id,
start_time=self.context.last_read_time,
end_time=time.time(),

View File

@@ -149,7 +149,7 @@ class MaiEmoji:
# --- 数据库操作 ---
try:
# 准备数据库记录 for emoji collection
with get_db_session() as session:
async with get_db_session() as session:
emotion_str = ",".join(self.emotion) if self.emotion else ""
emoji = Emoji(
@@ -167,7 +167,7 @@ class MaiEmoji:
last_used_time=self.last_used_time,
)
session.add(emoji)
session.commit()
await session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -203,17 +203,17 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
with get_db_session() as session:
will_delete_emoji = session.execute(
select(Emoji).where(Emoji.emoji_hash == self.hash)
async with get_db_session() as session:
will_delete_emoji = (
await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
).scalar_one_or_none()
if will_delete_emoji is None:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
result = 0
else:
session.delete(will_delete_emoji)
result = 1 # Successfully deleted one record
session.commit()
await session.delete(will_delete_emoji)
result = 1
await session.commit()
except Exception as e:
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
result = 0
@@ -424,17 +424,19 @@ class EmojiManager:
# if not self._initialized:
# raise RuntimeError("EmojiManager not initialized")
def record_usage(self, emoji_hash: str) -> None:
async def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
with get_db_session() as session:
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
async with get_db_session() as session:
emoji_update = (
await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
).scalar_one_or_none()
if emoji_update is None:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
else:
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
session.commit()
emoji_update.last_used_time = time.time()
await session.commit()
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -658,10 +660,11 @@ class EmojiManager:
async def get_all_emoji_from_db(self) -> None:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
with get_db_session() as session:
async with get_db_session() as session:
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = session.execute(select(Emoji)).scalars().all()
result = await session.execute(select(Emoji))
emoji_instances = result.scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量
@@ -687,14 +690,16 @@ class EmojiManager:
list[MaiEmoji]: 表情包对象列表
"""
try:
with get_db_session() as session:
async with get_db_session() as session:
if emoji_hash:
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
query = result.scalars().all()
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = session.execute(select(Emoji)).scalars().all()
result = await session.execute(select(Emoji))
query = result.scalars().all()
emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
@@ -771,10 +776,11 @@ class EmojiManager:
# 如果内存中没有,从数据库查找
try:
with get_db_session() as session:
emoji_record = session.execute(
async with get_db_session() as session:
result = await session.execute(
select(Emoji).where(Emoji.emoji_hash == emoji_hash)
).scalar_one_or_none()
)
emoji_record = result.scalar_one_or_none()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description
@@ -937,10 +943,13 @@ class EmojiManager:
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
existing_description = None
try:
with get_db_session() as session:
existing_image = session.query(Images).filter(
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
).one_or_none()
async with get_db_session() as session:
result = await session.execute(
select(Images).filter(
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
)
)
existing_image = result.scalar_one_or_none()
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -136,18 +136,18 @@ class ExpressionSelector:
return related_chat_ids if related_chat_ids else [chat_id]
def get_random_expressions(
async def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
with get_db_session() as session:
async with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式
style_query = session.execute(
style_query = await session.execute(
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style"))
)
grammar_query = session.execute(
grammar_query = await session.execute(
select(Expression).where((Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar"))
)
@@ -193,7 +193,7 @@ class ExpressionSelector:
return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
async def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
@@ -210,26 +210,27 @@ class ExpressionSelector:
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
with get_db_session() as session:
query = session.execute(
async with get_db_session() as session:
query = await session.execute(
select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
).scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
session.commit()
query = query.scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
await session.commit()
async def select_suitable_expressions_llm(
self,
@@ -248,7 +249,7 @@ class ExpressionSelector:
return []
# 1. 获取35个随机表达方式现在按权重抽取
style_exprs, grammar_exprs = self.get_random_expressions(chat_id, 30, 0.5, 0.5)
style_exprs, grammar_exprs = await self.get_random_expressions(chat_id, 30, 0.5, 0.5)
# 2. 构建所有表达方式的索引和情境列表
all_expressions = []
@@ -334,7 +335,7 @@ class ExpressionSelector:
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.006)
await self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions

View File

@@ -201,7 +201,7 @@ class Hippocampus:
self.entorhinal_cortex = EntorhinalCortex(self)
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
# self.entorhinal_cortex.sync_memory_from_db() # 改为异步启动
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.small")
def get_all_node_names(self) -> list:
@@ -789,7 +789,7 @@ class EntorhinalCortex:
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
def get_memory_sample(self):
async def get_memory_sample(self):
"""从数据库获取记忆样本"""
# 硬编码:每条消息最大记忆次数
max_memorized_time_per_msg = 2
@@ -812,7 +812,7 @@ class EntorhinalCortex:
logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = []
for timestamp in timestamps:
if messages := self.random_get_msg_snippet(
if messages := await self.random_get_msg_snippet(
timestamp,
global_config.memory.memory_build_sample_length,
max_memorized_time_per_msg,
@@ -826,7 +826,9 @@ class EntorhinalCortex:
return chat_samples
@staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
async def random_get_msg_snippet(
target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int
) -> list | None:
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟
@@ -864,13 +866,13 @@ class EntorhinalCortex:
for message in messages:
# 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0)
with get_db_session() as session:
session.execute(
async with get_db_session() as session:
await session.execute(
update(Messages)
.where(Messages.message_id == message["message_id"])
.values(memorized_times=current_memorized_times + 1)
)
session.commit()
await session.commit()
return messages # 直接返回原始的消息列表
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
@@ -884,8 +886,8 @@ class EntorhinalCortex:
current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点
with get_db_session() as session:
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
async with get_db_session() as session:
db_nodes = {node.concept: node for node in (await session.execute(select(GraphNodes))).scalars()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据
@@ -954,24 +956,24 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
await session.execute(insert(GraphNodes), batch)
if nodes_to_update:
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i : i + batch_size]
for node_data in batch:
session.execute(
await session.execute(
update(GraphNodes)
.where(GraphNodes.concept == node_data["concept"])
.values(**{k: v for k, v in node_data.items() if k != "concept"})
)
if nodes_to_delete:
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
await session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
# 处理边的信息
db_edges = list(session.execute(select(GraphEdges)).scalars())
db_edges = list((await session.execute(select(GraphEdges))).scalars())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -1023,14 +1025,14 @@ class EntorhinalCortex:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
await session.execute(insert(GraphEdges), batch)
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i : i + batch_size]
for edge_data in batch:
session.execute(
await session.execute(
update(GraphEdges)
.where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
@@ -1040,12 +1042,12 @@ class EntorhinalCortex:
if edges_to_delete:
for source, target in edges_to_delete:
session.execute(
await session.execute(
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
)
# 提交事务
session.commit()
await session.commit()
end_time = time.time()
logger.info(f"[同步] 总耗时: {end_time - start_time:.2f}")
@@ -1057,10 +1059,10 @@ class EntorhinalCortex:
logger.info("[数据库] 开始重新同步所有记忆数据...")
# 清空数据库
with get_db_session() as session:
async with get_db_session() as session:
clear_start = time.time()
session.execute(delete(GraphNodes))
session.execute(delete(GraphEdges))
await session.execute(delete(GraphNodes))
await session.execute(delete(GraphEdges))
clear_end = time.time()
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}")
@@ -1119,7 +1121,7 @@ class EntorhinalCortex:
batch_size = 500 # 增加批量大小
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
session.execute(insert(GraphNodes), batch)
await session.execute(insert(GraphNodes), batch)
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
@@ -1130,8 +1132,8 @@ class EntorhinalCortex:
batch_size = 500 # 增加批量大小
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
session.execute(insert(GraphEdges), batch)
session.commit()
await session.execute(insert(GraphEdges), batch)
await session.commit()
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")
@@ -1140,7 +1142,7 @@ class EntorhinalCortex:
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}")
logger.info(f"[数据库] 同步了 {len(nodes_data)} 个节点和 {len(edges_data)} 条边")
def sync_memory_from_db(self):
async def sync_memory_from_db(self):
"""从数据库同步数据到内存中的图结构"""
current_time = datetime.datetime.now().timestamp()
need_update = False
@@ -1149,8 +1151,8 @@ class EntorhinalCortex:
self.memory_graph.G.clear()
# 从数据库加载所有节点
with get_db_session() as session:
nodes = list(session.execute(select(GraphNodes)).scalars())
async with get_db_session() as session:
nodes = list((await session.execute(select(GraphNodes))).scalars())
for node in nodes:
concept = node.concept
try:
@@ -1168,7 +1170,9 @@ class EntorhinalCortex:
if not node.last_modified:
update_data["last_modified"] = current_time
session.execute(update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data))
await session.execute(
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
)
# 获取时间信息(如果不存在则使用当前时间)
created_time = node.created_time or current_time
@@ -1183,7 +1187,7 @@ class EntorhinalCortex:
continue
# 从数据库加载所有边
edges = list(session.execute(select(GraphEdges)).scalars())
edges = list((await session.execute(select(GraphEdges))).scalars())
for edge in edges:
source = edge.source
target = edge.target
@@ -1199,7 +1203,7 @@ class EntorhinalCortex:
if not edge.last_modified:
update_data["last_modified"] = current_time
session.execute(
await session.execute(
update(GraphEdges)
.where((GraphEdges.source == source) & (GraphEdges.target == target))
.values(**update_data)
@@ -1214,7 +1218,7 @@ class EntorhinalCortex:
self.memory_graph.G.add_edge(
source, target, strength=strength, created_time=created_time, last_modified=last_modified
)
session.commit()
await session.commit()
if need_update:
logger.info("[数据库] 已为缺失的时间字段进行补充")
@@ -1254,7 +1258,7 @@ class ParahippocampalGyrus:
# 1. 使用 build_readable_messages 生成格式化文本
# build_readable_messages 只返回一个字符串,不需要解包
input_text = build_readable_messages(
input_text = await build_readable_messages(
messages,
merge_messages=True, # 合并连续消息
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
@@ -1342,7 +1346,7 @@ class ParahippocampalGyrus:
# sourcery skip: merge-list-appends-into-extend
logger.info("------------------------------------开始构建记忆--------------------------------------")
start_time = time.time()
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
memory_samples = await self.hippocampus.entorhinal_cortex.get_memory_sample()
all_added_nodes = []
all_connected_nodes = []
all_added_edges = []
@@ -1620,7 +1624,7 @@ class HippocampusManager:
return self._hippocampus
self._hippocampus = Hippocampus()
self._hippocampus.initialize()
# self._hippocampus.initialize() # 改为异步启动
self._initialized = True
# 输出记忆图统计信息
@@ -1639,6 +1643,13 @@ class HippocampusManager:
return self._hippocampus
async def initialize_async(self):
"""异步初始化海马体实例"""
if not self._initialized:
self.initialize() # 先进行同步部分的初始化
self._hippocampus.initialize()
await self._hippocampus.entorhinal_cortex.sync_memory_from_db()
def get_hippocampus(self):
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")

View File

@@ -246,11 +246,11 @@ class ChatManager:
return stream
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
with get_db_session() as session:
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
async def _db_find_stream_async(s_id: str):
async with get_db_session() as session:
return (await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id))).scalar()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
model_instance = await _db_find_stream_async(stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
@@ -344,11 +344,10 @@ class ChatManager:
return
stream_data_dict = stream.to_dict()
def _db_save_stream_sync(s_data_dict: dict):
with get_db_session() as session:
async def _db_save_stream_async(s_data_dict: dict):
async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
@@ -364,8 +363,6 @@ class ChatManager:
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
"focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value),
}
# 根据数据库类型选择插入语句
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
@@ -375,15 +372,13 @@ class ChatManager:
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
else:
# 默认使用通用插入尝试SQLite语法
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
session.execute(stmt)
session.commit()
await session.execute(stmt)
await session.commit()
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
await _db_save_stream_async(stream_data_dict)
stream.saved = True
except Exception as e:
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
@@ -397,10 +392,10 @@ class ChatManager:
"""从数据库加载所有聊天流"""
logger.info("正在从数据库加载所有聊天流")
def _db_load_all_streams_sync():
async def _db_load_all_streams_async():
loaded_streams_data = []
with get_db_session() as session:
for model_instance in session.execute(select(ChatStreams)).scalars():
async with get_db_session() as session:
for model_instance in (await session.execute(select(ChatStreams))).scalars():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
@@ -414,7 +409,6 @@ class ChatManager:
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
@@ -427,11 +421,11 @@ class ChatManager:
"focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value),
}
loaded_streams_data.append(data_for_from_dict)
session.commit()
await session.commit()
return loaded_streams_data
try:
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
all_streams_data_list = await _db_load_all_streams_async()
self.streams.clear()
for data in all_streams_data_list:
stream = ChatStream.from_dict(data)

View File

@@ -41,7 +41,7 @@ class MessageStorage:
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
@@ -128,9 +128,9 @@ class MessageStorage:
key_words=key_words,
key_words_lite=key_words_lite,
)
with get_db_session() as session:
async with get_db_session() as session:
session.add(new_message)
session.commit()
await session.commit()
except Exception:
logger.exception("存储消息失败")
@@ -173,16 +173,18 @@ class MessageStorage:
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
async with get_db_session() as session:
matched_message = (
await session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
)
).scalar()
if matched_message:
session.execute(
await session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
session.commit()
await session.commit()
# 会在上下文管理器中自动调用
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
@@ -196,28 +198,36 @@ class MessageStorage:
)
@staticmethod
def replace_image_descriptions(text: str) -> str:
async def replace_image_descriptions(text: str) -> str:
"""将[图片:描述]替换为[picid:image_id]"""
# 先检查文本中是否有图片标记
pattern = r"\[图片:([^\]]+)\]"
matches = re.findall(pattern, text)
matches = list(re.finditer(pattern, text))
if not matches:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match):
new_text = ""
last_end = 0
for match in matches:
new_text += text[last_end : match.start()]
description = match.group(1).strip()
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
async with get_db_session() as session:
image_record = (
await session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
)
).scalar()
session.commit()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
if image_record:
new_text += f"[picid:{image_record.image_id}]"
else:
new_text += match.group(0)
except Exception:
return match.group(0)
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)
new_text += match.group(0)
last_end = match.end()
new_text += text[last_end:]
return new_text

View File

@@ -97,12 +97,12 @@ class ActionModifier:
for action_name, reason in chat_type_removals:
logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}")
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
chat_id=self.chat_stream.stream_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
)
chat_content = build_readable_messages(
chat_content = await build_readable_messages(
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,

View File

@@ -152,7 +152,7 @@ class PlanFilter:
)
return prompt, message_id_list
chat_content_block, message_id_list = build_readable_messages_with_id(
chat_content_block, message_id_list = await build_readable_messages_with_id(
messages=[msg.flatten() for msg in plan.chat_history],
timestamp_mode="normal",
read_mark=self.last_obs_time_mark,
@@ -167,7 +167,7 @@ class PlanFilter:
limit=5,
)
actions_before_now_block = build_readable_actions(actions=actions_before_now)
actions_before_now_block = build_readable_actions(actions=await actions_before_now)
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
self.last_obs_time_mark = time.time()

View File

@@ -63,7 +63,7 @@ class PlanGenerator:
timestamp=time.time(),
limit=int(global_config.chat.max_context_size),
)
chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw]
chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw]
plan = Plan(

View File

@@ -828,7 +828,8 @@ class DefaultReplyer:
platform, # type: ignore
reply_message.get("user_id"), # type: ignore
)
person_name = await person_info_manager.get_value(person_id, "person_name")
person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"])
person_name = person_info.get("person_name")
# 如果person_name为None使用fallback值
if person_name is None:
@@ -839,7 +840,7 @@ class DefaultReplyer:
# 检查是否是bot自己的名字如果是则替换为"(你)"
bot_user_id = str(global_config.bot.qq_account)
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
current_user_id = person_info.get("user_id")
current_platform = reply_message.get("chat_info_platform")
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
@@ -872,18 +873,18 @@ class DefaultReplyer:
action_descriptions += f"- {action_name}: {action_description}\n"
action_descriptions += "\n"
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size * 2,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
chat_talking_prompt_short = build_readable_messages(
chat_talking_prompt_short = await build_readable_messages(
message_list_before_short,
replace_bot_name=True,
merge_messages=False,
@@ -895,7 +896,7 @@ class DefaultReplyer:
# 获取目标用户信息用于s4u模式
target_user_info = None
if sender:
target_user_info = await person_info_manager.get_person_info_by_name(sender)
target_user_info = person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务
@@ -1122,12 +1123,12 @@ class DefaultReplyer:
else:
mood_prompt = ""
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
)
chat_talking_prompt_half = build_readable_messages(
chat_talking_prompt_half = await build_readable_messages(
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,

View File

@@ -121,7 +121,8 @@ async def replace_user_references_async(
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id)
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
person_info = await person_info_manager.get_values(person_id, ["person_name"])
return person_info.get("person_name") or user_id
name_resolver = default_resolver
@@ -169,7 +170,7 @@ async def replace_user_references_async(
return content
def get_raw_msg_by_timestamp(
async def get_raw_msg_by_timestamp(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
@@ -180,10 +181,10 @@ def get_raw_msg_by_timestamp(
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_by_timestamp_with_chat(
async def get_raw_msg_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -200,7 +201,7 @@ def get_raw_msg_by_timestamp_with_chat(
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(
return await find_messages(
message_filter=filter_query,
sort=sort_order,
limit=limit,
@@ -210,7 +211,7 @@ def get_raw_msg_by_timestamp_with_chat(
)
def get_raw_msg_by_timestamp_with_chat_inclusive(
async def get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -227,12 +228,12 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(
return await find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
)
def get_raw_msg_by_timestamp_with_chat_users(
async def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -251,10 +252,10 @@ def get_raw_msg_by_timestamp_with_chat_users(
}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_actions_by_timestamp_with_chat(
async def get_actions_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float = 0,
timestamp_end: float = time.time(),
@@ -273,10 +274,10 @@ def get_actions_by_timestamp_with_chat(
f"limit={limit}, limit_mode={limit_mode}"
)
with get_db_session() as session:
async with get_db_session() as session:
if limit > 0:
if limit_mode == "latest":
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -306,7 +307,7 @@ def get_actions_by_timestamp_with_chat(
}
actions_result.append(action_dict)
else: # earliest
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -336,7 +337,7 @@ def get_actions_by_timestamp_with_chat(
}
actions_result.append(action_dict)
else:
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -367,14 +368,14 @@ def get_actions_by_timestamp_with_chat(
return actions_result
def get_actions_by_timestamp_with_chat_inclusive(
async def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
with get_db_session() as session:
async with get_db_session() as session:
if limit > 0:
if limit_mode == "latest":
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -389,7 +390,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
actions = list(query.scalars())
return [action.__dict__ for action in reversed(actions)]
else: # earliest
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -402,7 +403,7 @@ def get_actions_by_timestamp_with_chat_inclusive(
.limit(limit)
)
else:
query = session.execute(
query = await session.execute(
select(ActionRecords)
.where(
and_(
@@ -418,14 +419,14 @@ def get_actions_by_timestamp_with_chat_inclusive(
return [action.__dict__ for action in actions]
def get_raw_msg_by_timestamp_random(
async def get_raw_msg_by_timestamp_random(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
"""
# 获取所有消息只取chat_id字段
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
if not all_msgs:
return []
# 随机选一条
@@ -433,10 +434,10 @@ def get_raw_msg_by_timestamp_random(
chat_id = msg["chat_id"]
timestamp_start = msg["time"]
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
def get_raw_msg_by_timestamp_with_users(
async def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
@@ -446,37 +447,39 @@ def get_raw_msg_by_timestamp_with_users(
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -490,10 +493,10 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
return 0 # 起始时间大于等于结束时间,没有新消息
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
return count_messages(message_filter=filter_query)
return await count_messages(message_filter=filter_query)
def num_new_messages_since_with_users(
async def num_new_messages_since_with_users(
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
) -> int:
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
@@ -504,10 +507,10 @@ def num_new_messages_since_with_users(
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
"user_id": {"$in": person_ids},
}
return count_messages(message_filter=filter_query)
return await count_messages(message_filter=filter_query)
def _build_readable_messages_internal(
async def _build_readable_messages_internal(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -627,7 +630,8 @@ def _build_readable_messages_internal(
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
person_info = await person_info_manager.get_values(person_id, ["person_name"])
person_name = person_info.get("person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
@@ -796,7 +800,7 @@ def _build_readable_messages_internal(
)
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
"""
构建图片映射信息字符串,显示图片的具体描述内容
@@ -819,9 +823,9 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "[图片内容未知]" # 默认描述
try:
with get_db_session() as session:
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
if image and image.description: # type: ignore
async with get_db_session() as session:
image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none()
if image and image.description: # type: ignore
description = image.description
except Exception:
# 如果查询失败,保持默认描述
@@ -917,17 +921,17 @@ async def build_readable_messages_with_list(
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, details_list
def build_readable_messages_with_id(
async def build_readable_messages_with_id(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -943,7 +947,7 @@ def build_readable_messages_with_id(
"""
message_id_list = assign_message_ids(messages)
formatted_string = build_readable_messages(
formatted_string = await build_readable_messages(
messages=messages,
replace_bot_name=replace_bot_name,
merge_messages=merge_messages,
@@ -958,7 +962,7 @@ def build_readable_messages_with_id(
return formatted_string, message_id_list
def build_readable_messages(
async def build_readable_messages(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
@@ -999,24 +1003,28 @@ def build_readable_messages(
from src.common.database.sqlalchemy_database_api import get_db_session
with get_db_session() as session:
async with get_db_session() as session:
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
actions_in_range = (
await session.execute(
select(ActionRecords)
.where(
and_(
ActionRecords.time >= min_time, ActionRecords.time <= max_time, ActionRecords.chat_id == chat_id
)
)
.order_by(ActionRecords.time)
)
.order_by(ActionRecords.time)
).scalars()
# 获取最新消息之后的第一个动作记录
action_after_latest = session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
action_after_latest = (
await session.execute(
select(ActionRecords)
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
).scalars()
# 合并两部分动作记录,并转为 dict避免 DetachedInstanceError
@@ -1048,7 +1056,7 @@ def build_readable_messages(
if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
copy_messages,
replace_bot_name,
merge_messages,
@@ -1059,7 +1067,7 @@ def build_readable_messages(
)
# 生成图片映射信息并添加到最前面
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info:
return f"{pic_mapping_info}\n\n{formatted_string}"
else:
@@ -1074,7 +1082,7 @@ def build_readable_messages(
pic_counter = 1
# 分别格式化,但使用共享的图片映射
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
messages_before_mark,
replace_bot_name,
merge_messages,
@@ -1085,7 +1093,7 @@ def build_readable_messages(
show_pic=show_pic,
message_id_list=message_id_list,
)
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
messages_after_mark,
replace_bot_name,
merge_messages,
@@ -1101,7 +1109,7 @@ def build_readable_messages(
# 生成图片映射信息
if pic_id_mapping:
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
else:
pic_mapping_info = "聊天记录信息:\n"