From 4c30388f186d46e73131ce48f4dd1d9a9887ad7d Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Wed, 19 Nov 2025 20:19:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=BC=BA=E5=88=B6=E6=B3=A8=E5=86=8C?= =?UTF-8?q?=E9=95=BF=E6=9C=9F=E8=AE=B0=E5=BF=86=E7=9B=AE=E6=A0=87ID?= =?UTF-8?q?=EF=BC=8C=E6=94=AF=E6=8C=81=E4=B8=AD=E6=96=87=E6=8F=8F=E8=BF=B0?= =?UTF-8?q?=E4=BD=9C=E4=B8=BAID=E6=98=A0=E5=B0=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/utils/statistic.py | 56 ++++++++++++++++++++------- src/memory_graph/long_term_manager.py | 7 +++- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 5b4b811c0..a94278b04 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -170,6 +170,18 @@ class StatisticOutputTask(AsyncTask): logger.info("\n" + "\n".join(output)) + @staticmethod + async def _yield_control(iteration: int, interval: int = 200) -> None: + """ + �ڴ����������ʱ������������첽�¼�ѭ�����Ӧ + + Args: + iteration: ��ǰ������� + interval: ÿ�����ٴ��л�һ�� + """ + if iteration % interval == 0: + await asyncio.sleep(0) + async def run(self): try: now = datetime.now() @@ -304,7 +316,7 @@ class StatisticOutputTask(AsyncTask): or [] ) - for record in records: + for record_idx, record in enumerate(records, 1): if not isinstance(record, dict): continue @@ -315,9 +327,9 @@ class StatisticOutputTask(AsyncTask): if not record_timestamp: continue - for idx, (_, period_start) in enumerate(collect_period): + for period_idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: - for period_key, _ in collect_period[idx:]: + for period_key, _ in collect_period[period_idx:]: stats[period_key][TOTAL_REQ_CNT] += 1 request_type = record.get("request_type") or "unknown" @@ -372,10 +384,11 @@ class StatisticOutputTask(AsyncTask): stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost) break + await StatisticOutputTask._yield_control(record_idx) # -- 计算派生指标 -- for period_key, period_stats in stats.items(): # 计算模型相关指标 - for model_name, req_count in period_stats[REQ_CNT_BY_MODEL].items(): + for model_idx, (model_name, req_count) in enumerate(period_stats[REQ_CNT_BY_MODEL].items(), 1): total_tok = period_stats[TOTAL_TOK_BY_MODEL][model_name] or 0 total_cost = period_stats[COST_BY_MODEL][model_name] or 0 time_costs = period_stats[TIME_COST_BY_MODEL][model_name] or [] @@ -390,8 +403,10 @@ class StatisticOutputTask(AsyncTask): # Avg Tokens per Request period_stats[AVG_TOK_BY_MODEL][model_name] = round(total_tok / req_count) if req_count > 0 else 0 + await StatisticOutputTask._yield_control(model_idx, interval=100) + # 计算供应商相关指标 - for provider_name, req_count in period_stats[REQ_CNT_BY_PROVIDER].items(): + for provider_idx, (provider_name, req_count) in enumerate(period_stats[REQ_CNT_BY_PROVIDER].items(), 1): total_tok = period_stats[TOTAL_TOK_BY_PROVIDER][provider_name] total_cost = period_stats[COST_BY_PROVIDER][provider_name] time_costs = period_stats[TIME_COST_BY_PROVIDER][provider_name] @@ -404,6 +419,8 @@ class StatisticOutputTask(AsyncTask): if total_tok > 0: period_stats[COST_PER_KTOK_BY_PROVIDER][provider_name] = round((total_cost / total_tok) * 1000, 4) + await StatisticOutputTask._yield_control(provider_idx, interval=100) + # 计算平均耗时和标准差 for category_key, items in [ (REQ_CNT_BY_USER, "user"), @@ -414,7 +431,7 @@ class StatisticOutputTask(AsyncTask): time_cost_key = f"time_costs_by_{items.lower()}" avg_key = f"avg_time_costs_by_{items.lower()}" std_key = f"std_time_costs_by_{items.lower()}" - for item_name in period_stats[category_key]: + for idx, item_name in enumerate(period_stats[category_key], 1): time_costs = period_stats[time_cost_key][item_name] if time_costs: avg_time = sum(time_costs) / len(time_costs) @@ -428,6 +445,8 @@ class StatisticOutputTask(AsyncTask): period_stats[avg_key][item_name] = 0.0 period_stats[std_key][item_name] = 0.0 + await StatisticOutputTask._yield_control(idx, interval=200) + # 准备图表数据 # 按供应商花费饼图 provider_costs = period_stats[COST_BY_PROVIDER] @@ -479,7 +498,7 @@ class StatisticOutputTask(AsyncTask): or [] ) - for record in records: + for record_idx, record in enumerate(records, 1): if not isinstance(record, dict): continue @@ -494,12 +513,12 @@ class StatisticOutputTask(AsyncTask): if not record_end_timestamp or not record_start_timestamp: continue - for idx, (_, period_boundary_start) in enumerate(collect_period): + for boundary_idx, (_, period_boundary_start) in enumerate(collect_period): if record_end_timestamp >= period_boundary_start: # Calculate effective end time for this record in relation to 'now' effective_end_time = min(record_end_timestamp, now) - for period_key, current_period_start_time in collect_period[idx:]: + for period_key, current_period_start_time in collect_period[boundary_idx:]: # Determine the portion of the record that falls within this specific statistical period overlap_start = max(record_start_timestamp, current_period_start_time) overlap_end = effective_end_time # Already capped by 'now' and record's own end @@ -507,6 +526,8 @@ class StatisticOutputTask(AsyncTask): if overlap_end > overlap_start: stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() break + + await StatisticOutputTask._yield_control(record_idx) return stats async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: @@ -538,7 +559,7 @@ class StatisticOutputTask(AsyncTask): or [] ) - for message in records: + for message_idx, message in enumerate(records, 1): if not isinstance(message, dict): continue message_time_ts = message.get("time") # This is a float timestamp @@ -572,12 +593,15 @@ class StatisticOutputTask(AsyncTask): self.name_mapping[chat_id] = (chat_name, message_time_ts) else: self.name_mapping[chat_id] = (chat_name, message_time_ts) - for idx, (_, period_start_dt) in enumerate(collect_period): + for period_idx, (_, period_start_dt) in enumerate(collect_period): if message_time_ts >= period_start_dt.timestamp(): - for period_key, _ in collect_period[idx:]: + for period_key, _ in collect_period[period_idx:]: stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 break + + await StatisticOutputTask._yield_control(message_idx) + return stats async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]: @@ -767,7 +791,7 @@ class StatisticOutputTask(AsyncTask): ) or [] ) - for record in llm_records: + for record_idx, record in enumerate(llm_records, 1): if not isinstance(record, dict) or not record.get("timestamp"): continue record_time = record["timestamp"] @@ -791,6 +815,8 @@ class StatisticOutputTask(AsyncTask): cost_by_module[module_name] = [0.0] * len(time_points) cost_by_module[module_name][idx] += cost + await StatisticOutputTask._yield_control(record_idx) + # 单次查询 Messages msg_records = ( await db_get( @@ -800,7 +826,7 @@ class StatisticOutputTask(AsyncTask): ) or [] ) - for msg in msg_records: + for msg_idx, msg in enumerate(msg_records, 1): if not isinstance(msg, dict) or not msg.get("time"): continue msg_ts = msg["time"] @@ -819,6 +845,8 @@ class StatisticOutputTask(AsyncTask): message_by_chat[chat_name] = [0] * len(time_points) message_by_chat[chat_name][idx] += 1 + await StatisticOutputTask._yield_control(msg_idx) + return { "time_labels": time_labels, "total_cost_data": total_cost_data, diff --git a/src/memory_graph/long_term_manager.py b/src/memory_graph/long_term_manager.py index 9e6c91393..8dba0fc2d 100644 --- a/src/memory_graph/long_term_manager.py +++ b/src/memory_graph/long_term_manager.py @@ -658,7 +658,9 @@ class LongTermMemoryManager: memory.metadata["transfer_time"] = datetime.now().isoformat() logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})") - self._register_temp_id(op.target_id, memory.id, temp_id_map) + # 强制注册 target_id,无论它是否符合 placeholder 格式 + # 这样即使 LLM 使用了中文描述作为 ID (如 "新创建的记忆"), 也能正确映射 + self._register_temp_id(op.target_id, memory.id, temp_id_map, force=True) self._register_aliases_from_params( op.parameters, memory.id, @@ -766,7 +768,8 @@ class LongTermMemoryManager: # 尝试为新节点生成 embedding (异步) asyncio.create_task(self._generate_node_embedding(node_id, content)) logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}") - self._register_temp_id(op.target_id, node_id, temp_id_map) + # 强制注册 target_id,无论它是否符合 placeholder 格式 + self._register_temp_id(op.target_id, node_id, temp_id_map, force=True) self._register_aliases_from_params( op.parameters, node_id,