feat: 强制注册长期记忆目标ID,支持中文描述作为ID映射
This commit is contained in:
@@ -170,6 +170,18 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
logger.info("\n" + "\n".join(output))
|
||||
|
||||
@staticmethod
|
||||
async def _yield_control(iteration: int, interval: int = 200) -> None:
|
||||
"""
|
||||
<20>ڴ<EFBFBD><DAB4><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʱ<EFBFBD><CAB1><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>첽<EFBFBD>¼<EFBFBD>ѭ<EFBFBD><D1AD><EFBFBD><EFBFBD><EFBFBD>Ӧ
|
||||
|
||||
Args:
|
||||
iteration: <20><>ǰ<EFBFBD><C7B0><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||||
interval: ÿ<><C3BF><EFBFBD><EFBFBD><EFBFBD>ٴ<EFBFBD><D9B4>л<EFBFBD>һ<EFBFBD><D2BB>
|
||||
"""
|
||||
if iteration % interval == 0:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
now = datetime.now()
|
||||
@@ -304,7 +316,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
or []
|
||||
)
|
||||
|
||||
for record in records:
|
||||
for record_idx, record in enumerate(records, 1):
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
@@ -315,9 +327,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
if not record_timestamp:
|
||||
continue
|
||||
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
for period_idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
for period_key, _ in collect_period[period_idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.get("request_type") or "unknown"
|
||||
@@ -372,10 +384,11 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
|
||||
break
|
||||
|
||||
await StatisticOutputTask._yield_control(record_idx)
|
||||
# -- 计算派生指标 --
|
||||
for period_key, period_stats in stats.items():
|
||||
# 计算模型相关指标
|
||||
for model_name, req_count in period_stats[REQ_CNT_BY_MODEL].items():
|
||||
for model_idx, (model_name, req_count) in enumerate(period_stats[REQ_CNT_BY_MODEL].items(), 1):
|
||||
total_tok = period_stats[TOTAL_TOK_BY_MODEL][model_name] or 0
|
||||
total_cost = period_stats[COST_BY_MODEL][model_name] or 0
|
||||
time_costs = period_stats[TIME_COST_BY_MODEL][model_name] or []
|
||||
@@ -390,8 +403,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
# Avg Tokens per Request
|
||||
period_stats[AVG_TOK_BY_MODEL][model_name] = round(total_tok / req_count) if req_count > 0 else 0
|
||||
|
||||
await StatisticOutputTask._yield_control(model_idx, interval=100)
|
||||
|
||||
# 计算供应商相关指标
|
||||
for provider_name, req_count in period_stats[REQ_CNT_BY_PROVIDER].items():
|
||||
for provider_idx, (provider_name, req_count) in enumerate(period_stats[REQ_CNT_BY_PROVIDER].items(), 1):
|
||||
total_tok = period_stats[TOTAL_TOK_BY_PROVIDER][provider_name]
|
||||
total_cost = period_stats[COST_BY_PROVIDER][provider_name]
|
||||
time_costs = period_stats[TIME_COST_BY_PROVIDER][provider_name]
|
||||
@@ -404,6 +419,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
if total_tok > 0:
|
||||
period_stats[COST_PER_KTOK_BY_PROVIDER][provider_name] = round((total_cost / total_tok) * 1000, 4)
|
||||
|
||||
await StatisticOutputTask._yield_control(provider_idx, interval=100)
|
||||
|
||||
# 计算平均耗时和标准差
|
||||
for category_key, items in [
|
||||
(REQ_CNT_BY_USER, "user"),
|
||||
@@ -414,7 +431,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
time_cost_key = f"time_costs_by_{items.lower()}"
|
||||
avg_key = f"avg_time_costs_by_{items.lower()}"
|
||||
std_key = f"std_time_costs_by_{items.lower()}"
|
||||
for item_name in period_stats[category_key]:
|
||||
for idx, item_name in enumerate(period_stats[category_key], 1):
|
||||
time_costs = period_stats[time_cost_key][item_name]
|
||||
if time_costs:
|
||||
avg_time = sum(time_costs) / len(time_costs)
|
||||
@@ -428,6 +445,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
period_stats[avg_key][item_name] = 0.0
|
||||
period_stats[std_key][item_name] = 0.0
|
||||
|
||||
await StatisticOutputTask._yield_control(idx, interval=200)
|
||||
|
||||
# 准备图表数据
|
||||
# 按供应商花费饼图
|
||||
provider_costs = period_stats[COST_BY_PROVIDER]
|
||||
@@ -479,7 +498,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
or []
|
||||
)
|
||||
|
||||
for record in records:
|
||||
for record_idx, record in enumerate(records, 1):
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
@@ -494,12 +513,12 @@ class StatisticOutputTask(AsyncTask):
|
||||
if not record_end_timestamp or not record_start_timestamp:
|
||||
continue
|
||||
|
||||
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
# Calculate effective end time for this record in relation to 'now'
|
||||
effective_end_time = min(record_end_timestamp, now)
|
||||
|
||||
for period_key, current_period_start_time in collect_period[idx:]:
|
||||
for period_key, current_period_start_time in collect_period[boundary_idx:]:
|
||||
# Determine the portion of the record that falls within this specific statistical period
|
||||
overlap_start = max(record_start_timestamp, current_period_start_time)
|
||||
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
||||
@@ -507,6 +526,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
if overlap_end > overlap_start:
|
||||
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
||||
break
|
||||
|
||||
await StatisticOutputTask._yield_control(record_idx)
|
||||
return stats
|
||||
|
||||
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||
@@ -538,7 +559,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
or []
|
||||
)
|
||||
|
||||
for message in records:
|
||||
for message_idx, message in enumerate(records, 1):
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
message_time_ts = message.get("time") # This is a float timestamp
|
||||
@@ -572,12 +593,15 @@ class StatisticOutputTask(AsyncTask):
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
else:
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
for period_idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if message_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
for period_key, _ in collect_period[period_idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
break
|
||||
|
||||
await StatisticOutputTask._yield_control(message_idx)
|
||||
|
||||
return stats
|
||||
|
||||
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
|
||||
@@ -767,7 +791,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
)
|
||||
or []
|
||||
)
|
||||
for record in llm_records:
|
||||
for record_idx, record in enumerate(llm_records, 1):
|
||||
if not isinstance(record, dict) or not record.get("timestamp"):
|
||||
continue
|
||||
record_time = record["timestamp"]
|
||||
@@ -791,6 +815,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost_by_module[module_name] = [0.0] * len(time_points)
|
||||
cost_by_module[module_name][idx] += cost
|
||||
|
||||
await StatisticOutputTask._yield_control(record_idx)
|
||||
|
||||
# 单次查询 Messages
|
||||
msg_records = (
|
||||
await db_get(
|
||||
@@ -800,7 +826,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
)
|
||||
or []
|
||||
)
|
||||
for msg in msg_records:
|
||||
for msg_idx, msg in enumerate(msg_records, 1):
|
||||
if not isinstance(msg, dict) or not msg.get("time"):
|
||||
continue
|
||||
msg_ts = msg["time"]
|
||||
@@ -819,6 +845,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
message_by_chat[chat_name] = [0] * len(time_points)
|
||||
message_by_chat[chat_name][idx] += 1
|
||||
|
||||
await StatisticOutputTask._yield_control(msg_idx)
|
||||
|
||||
return {
|
||||
"time_labels": time_labels,
|
||||
"total_cost_data": total_cost_data,
|
||||
|
||||
@@ -658,7 +658,9 @@ class LongTermMemoryManager:
|
||||
memory.metadata["transfer_time"] = datetime.now().isoformat()
|
||||
|
||||
logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})")
|
||||
self._register_temp_id(op.target_id, memory.id, temp_id_map)
|
||||
# 强制注册 target_id,无论它是否符合 placeholder 格式
|
||||
# 这样即使 LLM 使用了中文描述作为 ID (如 "新创建的记忆"), 也能正确映射
|
||||
self._register_temp_id(op.target_id, memory.id, temp_id_map, force=True)
|
||||
self._register_aliases_from_params(
|
||||
op.parameters,
|
||||
memory.id,
|
||||
@@ -766,7 +768,8 @@ class LongTermMemoryManager:
|
||||
# 尝试为新节点生成 embedding (异步)
|
||||
asyncio.create_task(self._generate_node_embedding(node_id, content))
|
||||
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
|
||||
self._register_temp_id(op.target_id, node_id, temp_id_map)
|
||||
# 强制注册 target_id,无论它是否符合 placeholder 格式
|
||||
self._register_temp_id(op.target_id, node_id, temp_id_map, force=True)
|
||||
self._register_aliases_from_params(
|
||||
op.parameters,
|
||||
node_id,
|
||||
|
||||
Reference in New Issue
Block a user