This commit is contained in:
minecraft1024a
2025-11-01 10:59:38 +08:00
committed by Windpicker-owo
parent 4a0bd5845e
commit 12c2b806a2
10 changed files with 30 additions and 30 deletions

View File

@@ -245,11 +245,11 @@ class ExpressionLearner:
all_expressions = await session.execute( all_expressions = await session.execute(
select(Expression).where(Expression.chat_id == self.chat_id) select(Expression).where(Expression.chat_id == self.chat_id)
) )
for expr in all_expressions.scalars(): for expr in all_expressions.scalars():
# 确保create_date存在如果不存在则使用last_active_time # 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
expr_data = { expr_data = {
"situation": expr.situation, "situation": expr.situation,
"style": expr.style, "style": expr.style,
@@ -259,13 +259,13 @@ class ExpressionLearner:
"type": expr.type, "type": expr.type,
"create_date": create_date, "create_date": create_date,
} }
# 根据类型分类 # 根据类型分类
if expr.type == "style": if expr.type == "style":
learnt_style_expressions.append(expr_data) learnt_style_expressions.append(expr_data)
elif expr.type == "grammar": elif expr.type == "grammar":
learnt_grammar_expressions.append(expr_data) learnt_grammar_expressions.append(expr_data)
return learnt_style_expressions, learnt_grammar_expressions return learnt_style_expressions, learnt_grammar_expressions
async def _apply_global_decay_to_database(self, current_time: float) -> None: async def _apply_global_decay_to_database(self, current_time: float) -> None:

View File

@@ -354,7 +354,7 @@ class MessageManager:
# 取消 stream_loop_task子任务会通过 try-catch 自动取消 # 取消 stream_loop_task子任务会通过 try-catch 自动取消
try: try:
stream_loop_task.cancel() stream_loop_task.cancel()
# 等待任务真正结束(设置超时避免死锁) # 等待任务真正结束(设置超时避免死锁)
try: try:
await asyncio.wait_for(stream_loop_task, timeout=2.0) await asyncio.wait_for(stream_loop_task, timeout=2.0)

View File

@@ -24,20 +24,20 @@ class MessageUpdateBatcher:
优化: 将多个消息ID更新操作批量处理减少数据库连接次数 优化: 将多个消息ID更新操作批量处理减少数据库连接次数
""" """
def __init__(self, batch_size: int = 20, flush_interval: float = 2.0): def __init__(self, batch_size: int = 20, flush_interval: float = 2.0):
self.batch_size = batch_size self.batch_size = batch_size
self.flush_interval = flush_interval self.flush_interval = flush_interval
self.pending_updates: deque = deque() self.pending_updates: deque = deque()
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._flush_task = None self._flush_task = None
async def start(self): async def start(self):
"""启动自动刷新任务""" """启动自动刷新任务"""
if self._flush_task is None: if self._flush_task is None:
self._flush_task = asyncio.create_task(self._auto_flush_loop()) self._flush_task = asyncio.create_task(self._auto_flush_loop())
logger.debug("消息更新批处理器已启动") logger.debug("消息更新批处理器已启动")
async def stop(self): async def stop(self):
"""停止批处理器""" """停止批处理器"""
if self._flush_task: if self._flush_task:
@@ -47,29 +47,29 @@ class MessageUpdateBatcher:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
self._flush_task = None self._flush_task = None
# 刷新剩余的更新 # 刷新剩余的更新
await self.flush() await self.flush()
logger.debug("消息更新批处理器已停止") logger.debug("消息更新批处理器已停止")
async def add_update(self, mmc_message_id: str, qq_message_id: str): async def add_update(self, mmc_message_id: str, qq_message_id: str):
"""添加消息ID更新到批处理队列""" """添加消息ID更新到批处理队列"""
async with self._lock: async with self._lock:
self.pending_updates.append((mmc_message_id, qq_message_id)) self.pending_updates.append((mmc_message_id, qq_message_id))
# 如果达到批量大小,立即刷新 # 如果达到批量大小,立即刷新
if len(self.pending_updates) >= self.batch_size: if len(self.pending_updates) >= self.batch_size:
await self.flush() await self.flush()
async def flush(self): async def flush(self):
"""执行批量更新""" """执行批量更新"""
async with self._lock: async with self._lock:
if not self.pending_updates: if not self.pending_updates:
return return
updates = list(self.pending_updates) updates = list(self.pending_updates)
self.pending_updates.clear() self.pending_updates.clear()
try: try:
async with get_db_session() as session: async with get_db_session() as session:
updated_count = 0 updated_count = 0
@@ -81,15 +81,15 @@ class MessageUpdateBatcher:
) )
if result.rowcount > 0: if result.rowcount > 0:
updated_count += 1 updated_count += 1
await session.commit() await session.commit()
if updated_count > 0: if updated_count > 0:
logger.debug(f"批量更新了 {updated_count}/{len(updates)} 条消息ID") logger.debug(f"批量更新了 {updated_count}/{len(updates)} 条消息ID")
except Exception as e: except Exception as e:
logger.error(f"批量更新消息ID失败: {e}") logger.error(f"批量更新消息ID失败: {e}")
async def _auto_flush_loop(self): async def _auto_flush_loop(self):
"""自动刷新循环""" """自动刷新循环"""
while True: while True:

View File

@@ -628,7 +628,7 @@ class ChatterActionManager:
if not first_replied: if not first_replied:
# 决定是否引用回复 # 决定是否引用回复
is_private_chat = not bool(chat_stream.group_info) is_private_chat = not bool(chat_stream.group_info)
# 如果明确指定了should_quote_reply则使用指定值 # 如果明确指定了should_quote_reply则使用指定值
if should_quote_reply is not None: if should_quote_reply is not None:
set_reply_flag = should_quote_reply and bool(message_data) set_reply_flag = should_quote_reply and bool(message_data)
@@ -641,7 +641,7 @@ class ChatterActionManager:
logger.debug( logger.debug(
f"📤 [ActionManager] 使用默认引用逻辑: 默认不引用(is_private={is_private_chat})" f"📤 [ActionManager] 使用默认引用逻辑: 默认不引用(is_private={is_private_chat})"
) )
logger.debug( logger.debug(
f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}" f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}"
) )

View File

@@ -425,7 +425,7 @@ class OpenaiClient(BaseClient):
# 清理其他事件循环的过期缓存 # 清理其他事件循环的过期缓存
keys_to_remove = [ keys_to_remove = [
key for key in self._global_client_cache.keys() key for key in self._global_client_cache.keys()
if key[0] == self._config_hash and key[1] != current_loop_id if key[0] == self._config_hash and key[1] != current_loop_id
] ]
for key in keys_to_remove: for key in keys_to_remove:
@@ -459,7 +459,7 @@ class OpenaiClient(BaseClient):
return { return {
"cached_openai_clients": len(cls._global_client_cache), "cached_openai_clients": len(cls._global_client_cache),
"cache_keys": [ "cache_keys": [
{"config_hash": k[0], "loop_id": k[1]} {"config_hash": k[0], "loop_id": k[1]}
for k in cls._global_client_cache.keys() for k in cls._global_client_cache.keys()
], ],
} }

View File

@@ -73,7 +73,7 @@ class ChatStreamImpressionTool(BaseTool):
) )
except AttributeError: except AttributeError:
# 降级处理 # 降级处理
available_models: ClassVar = [ available_models = [
attr attr
for attr in dir(model_config.model_task_config) for attr in dir(model_config.model_task_config)
if not attr.startswith("_") and attr != "model_dump" if not attr.startswith("_") and attr != "model_dump"
@@ -153,7 +153,7 @@ class ChatStreamImpressionTool(BaseTool):
await self._update_stream_impression_in_db(stream_id, final_impression) await self._update_stream_impression_in_db(stream_id, final_impression)
# 构建返回信息 # 构建返回信息
updates: ClassVar = [] updates = []
if final_impression.get("stream_impression_text"): if final_impression.get("stream_impression_text"):
updates.append(f"印象: {final_impression['stream_impression_text'][:50]}...") updates.append(f"印象: {final_impression['stream_impression_text'][:50]}...")
if final_impression.get("stream_chat_style"): if final_impression.get("stream_chat_style"):

View File

@@ -198,11 +198,11 @@ class ChatterPlanExecutor:
} }
# 构建回复动作参数 # 构建回复动作参数
action_data = action_info.action_data or {} action_data = action_info.action_data or {}
# 如果action_info中有should_quote_reply且action_data中没有则添加到action_data中 # 如果action_info中有should_quote_reply且action_data中没有则添加到action_data中
if action_info.should_quote_reply is not None and "should_quote_reply" not in action_data: if action_info.should_quote_reply is not None and "should_quote_reply" not in action_data:
action_data["should_quote_reply"] = action_info.should_quote_reply action_data["should_quote_reply"] = action_info.should_quote_reply
action_params = { action_params = {
"chat_id": plan.chat_id, "chat_id": plan.chat_id,
"target_message": action_info.action_message, "target_message": action_info.action_message,

View File

@@ -290,7 +290,7 @@ class MonthlyPlanLLMGenerator:
# 过滤掉一些明显不是计划的句子 # 过滤掉一些明显不是计划的句子
if len(line) > 5 and not line.startswith(("", "以上", "总结", "注意")): if len(line) > 5 and not line.startswith(("", "以上", "总结", "注意")):
plans.append(line) plans.append(line)
# 根据配置限制最大计划数量 # 根据配置限制最大计划数量
max_plans = global_config.planning_system.max_plans_per_month max_plans = global_config.planning_system.max_plans_per_month
if len(plans) > max_plans: if len(plans) > max_plans:

View File

@@ -90,7 +90,7 @@ class MonthlyPlanGenerationTask(AsyncTask):
next_month = datetime(now.year + 1, 1, 1) next_month = datetime(now.year + 1, 1, 1)
else: else:
next_month = datetime(now.year, now.month + 1, 1) next_month = datetime(now.year, now.month + 1, 1)
sleep_seconds = (next_month - now).total_seconds() sleep_seconds = (next_month - now).total_seconds()
logger.info( logger.info(
f" 下一次月度计划生成任务将在 {sleep_seconds:.2f} 秒后运行 (北京时间 {next_month.strftime('%Y-%m-%d %H:%M:%S')})" f" 下一次月度计划生成任务将在 {sleep_seconds:.2f} 秒后运行 (北京时间 {next_month.strftime('%Y-%m-%d %H:%M:%S')})"
@@ -100,7 +100,7 @@ class MonthlyPlanGenerationTask(AsyncTask):
# 到达月初,先归档上个月的计划 # 到达月初,先归档上个月的计划
last_month = (next_month - timedelta(days=1)).strftime("%Y-%m") last_month = (next_month - timedelta(days=1)).strftime("%Y-%m")
await self.monthly_plan_manager.plan_manager.archive_current_month_plans(last_month) await self.monthly_plan_manager.plan_manager.archive_current_month_plans(last_month)
# 为当前月生成新计划 # 为当前月生成新计划
current_month = next_month.strftime("%Y-%m") current_month = next_month.strftime("%Y-%m")
logger.info(f" 到达月初,开始生成 {current_month} 的月度计划...") logger.info(f" 到达月初,开始生成 {current_month} 的月度计划...")

View File

@@ -182,7 +182,7 @@ class UnifiedScheduler:
except ImportError: except ImportError:
pass pass
logger.info(f"统一调度器已停止") logger.info("统一调度器已停止")
self._tasks.clear() self._tasks.clear()
self._event_subscriptions.clear() self._event_subscriptions.clear()