feat: 强制注册长期记忆目标ID,支持中文描述作为ID映射

This commit is contained in:
Windpicker-owo
2025-11-19 20:19:06 +08:00
parent d9d5fe26ea
commit 2c346c3580
2 changed files with 47 additions and 16 deletions

View File

@@ -170,6 +170,18 @@ class StatisticOutputTask(AsyncTask):
logger.info("\n" + "\n".join(output))
@staticmethod
async def _yield_control(iteration: int, interval: int = 200) -> None:
"""
<20>ڴ<EFBFBD><DAB4><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʱ<EFBFBD><CAB1><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>¼<EFBFBD>ѭ<EFBFBD><D1AD><EFBFBD><EFBFBD><EFBFBD>Ӧ
Args:
iteration: <20><>ǰ<EFBFBD><C7B0><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
interval: ÿ<><C3BF><EFBFBD><EFBFBD><EFBFBD>ٴ<EFBFBD><D9B4>л<EFBFBD>һ<EFBFBD><D2BB>
"""
if iteration % interval == 0:
await asyncio.sleep(0)
async def run(self):
try:
now = datetime.now()
@@ -304,7 +316,7 @@ class StatisticOutputTask(AsyncTask):
or []
)
for record in records:
for record_idx, record in enumerate(records, 1):
if not isinstance(record, dict):
continue
@@ -315,9 +327,9 @@ class StatisticOutputTask(AsyncTask):
if not record_timestamp:
continue
for idx, (_, period_start) in enumerate(collect_period):
for period_idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
for period_key, _ in collect_period[idx:]:
for period_key, _ in collect_period[period_idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type") or "unknown"
@@ -372,10 +384,11 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
break
await StatisticOutputTask._yield_control(record_idx)
# -- 计算派生指标 --
for period_key, period_stats in stats.items():
# 计算模型相关指标
for model_name, req_count in period_stats[REQ_CNT_BY_MODEL].items():
for model_idx, (model_name, req_count) in enumerate(period_stats[REQ_CNT_BY_MODEL].items(), 1):
total_tok = period_stats[TOTAL_TOK_BY_MODEL][model_name] or 0
total_cost = period_stats[COST_BY_MODEL][model_name] or 0
time_costs = period_stats[TIME_COST_BY_MODEL][model_name] or []
@@ -390,8 +403,10 @@ class StatisticOutputTask(AsyncTask):
# Avg Tokens per Request
period_stats[AVG_TOK_BY_MODEL][model_name] = round(total_tok / req_count) if req_count > 0 else 0
await StatisticOutputTask._yield_control(model_idx, interval=100)
# 计算供应商相关指标
for provider_name, req_count in period_stats[REQ_CNT_BY_PROVIDER].items():
for provider_idx, (provider_name, req_count) in enumerate(period_stats[REQ_CNT_BY_PROVIDER].items(), 1):
total_tok = period_stats[TOTAL_TOK_BY_PROVIDER][provider_name]
total_cost = period_stats[COST_BY_PROVIDER][provider_name]
time_costs = period_stats[TIME_COST_BY_PROVIDER][provider_name]
@@ -404,6 +419,8 @@ class StatisticOutputTask(AsyncTask):
if total_tok > 0:
period_stats[COST_PER_KTOK_BY_PROVIDER][provider_name] = round((total_cost / total_tok) * 1000, 4)
await StatisticOutputTask._yield_control(provider_idx, interval=100)
# 计算平均耗时和标准差
for category_key, items in [
(REQ_CNT_BY_USER, "user"),
@@ -414,7 +431,7 @@ class StatisticOutputTask(AsyncTask):
time_cost_key = f"time_costs_by_{items.lower()}"
avg_key = f"avg_time_costs_by_{items.lower()}"
std_key = f"std_time_costs_by_{items.lower()}"
for item_name in period_stats[category_key]:
for idx, item_name in enumerate(period_stats[category_key], 1):
time_costs = period_stats[time_cost_key][item_name]
if time_costs:
avg_time = sum(time_costs) / len(time_costs)
@@ -428,6 +445,8 @@ class StatisticOutputTask(AsyncTask):
period_stats[avg_key][item_name] = 0.0
period_stats[std_key][item_name] = 0.0
await StatisticOutputTask._yield_control(idx, interval=200)
# 准备图表数据
# 按供应商花费饼图
provider_costs = period_stats[COST_BY_PROVIDER]
@@ -479,7 +498,7 @@ class StatisticOutputTask(AsyncTask):
or []
)
for record in records:
for record_idx, record in enumerate(records, 1):
if not isinstance(record, dict):
continue
@@ -494,12 +513,12 @@ class StatisticOutputTask(AsyncTask):
if not record_end_timestamp or not record_start_timestamp:
continue
for idx, (_, period_boundary_start) in enumerate(collect_period):
for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[idx:]:
for period_key, current_period_start_time in collect_period[boundary_idx:]:
# Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end
@@ -507,6 +526,8 @@ class StatisticOutputTask(AsyncTask):
if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break
await StatisticOutputTask._yield_control(record_idx)
return stats
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
@@ -538,7 +559,7 @@ class StatisticOutputTask(AsyncTask):
or []
)
for message in records:
for message_idx, message in enumerate(records, 1):
if not isinstance(message, dict):
continue
message_time_ts = message.get("time") # This is a float timestamp
@@ -572,12 +593,15 @@ class StatisticOutputTask(AsyncTask):
self.name_mapping[chat_id] = (chat_name, message_time_ts)
else:
self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start_dt) in enumerate(collect_period):
for period_idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
for period_key, _ in collect_period[period_idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break
await StatisticOutputTask._yield_control(message_idx)
return stats
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
@@ -767,7 +791,7 @@ class StatisticOutputTask(AsyncTask):
)
or []
)
for record in llm_records:
for record_idx, record in enumerate(llm_records, 1):
if not isinstance(record, dict) or not record.get("timestamp"):
continue
record_time = record["timestamp"]
@@ -791,6 +815,8 @@ class StatisticOutputTask(AsyncTask):
cost_by_module[module_name] = [0.0] * len(time_points)
cost_by_module[module_name][idx] += cost
await StatisticOutputTask._yield_control(record_idx)
# 单次查询 Messages
msg_records = (
await db_get(
@@ -800,7 +826,7 @@ class StatisticOutputTask(AsyncTask):
)
or []
)
for msg in msg_records:
for msg_idx, msg in enumerate(msg_records, 1):
if not isinstance(msg, dict) or not msg.get("time"):
continue
msg_ts = msg["time"]
@@ -819,6 +845,8 @@ class StatisticOutputTask(AsyncTask):
message_by_chat[chat_name] = [0] * len(time_points)
message_by_chat[chat_name][idx] += 1
await StatisticOutputTask._yield_control(msg_idx)
return {
"time_labels": time_labels,
"total_cost_data": total_cost_data,

View File

@@ -658,7 +658,9 @@ class LongTermMemoryManager:
memory.metadata["transfer_time"] = datetime.now().isoformat()
logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})")
self._register_temp_id(op.target_id, memory.id, temp_id_map)
# 强制注册 target_id无论它是否符合 placeholder 格式
# 这样即使 LLM 使用了中文描述作为 ID (如 "新创建的记忆"), 也能正确映射
self._register_temp_id(op.target_id, memory.id, temp_id_map, force=True)
self._register_aliases_from_params(
op.parameters,
memory.id,
@@ -766,7 +768,8 @@ class LongTermMemoryManager:
# 尝试为新节点生成 embedding (异步)
asyncio.create_task(self._generate_node_embedding(node_id, content))
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
self._register_temp_id(op.target_id, node_id, temp_id_map)
# 强制注册 target_id无论它是否符合 placeholder 格式
self._register_temp_id(op.target_id, node_id, temp_id_map, force=True)
self._register_aliases_from_params(
op.parameters,
node_id,