feat: 强制注册长期记忆目标ID,支持中文描述作为ID映射
This commit is contained in:
@@ -170,6 +170,18 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
|
|
||||||
logger.info("\n" + "\n".join(output))
|
logger.info("\n" + "\n".join(output))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _yield_control(iteration: int, interval: int = 200) -> None:
|
||||||
|
"""
|
||||||
|
<20>ڴ<EFBFBD><DAB4><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʱ<EFBFBD><CAB1><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>첽<EFBFBD>¼<EFBFBD>ѭ<EFBFBD><D1AD><EFBFBD><EFBFBD><EFBFBD>Ӧ
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iteration: <20><>ǰ<EFBFBD><C7B0><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||||||
|
interval: ÿ<><C3BF><EFBFBD><EFBFBD><EFBFBD>ٴ<EFBFBD><D9B4>л<EFBFBD>һ<EFBFBD><D2BB>
|
||||||
|
"""
|
||||||
|
if iteration % interval == 0:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
@@ -304,7 +316,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
or []
|
or []
|
||||||
)
|
)
|
||||||
|
|
||||||
for record in records:
|
for record_idx, record in enumerate(records, 1):
|
||||||
if not isinstance(record, dict):
|
if not isinstance(record, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -315,9 +327,9 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
if not record_timestamp:
|
if not record_timestamp:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for idx, (_, period_start) in enumerate(collect_period):
|
for period_idx, (_, period_start) in enumerate(collect_period):
|
||||||
if record_timestamp >= period_start:
|
if record_timestamp >= period_start:
|
||||||
for period_key, _ in collect_period[idx:]:
|
for period_key, _ in collect_period[period_idx:]:
|
||||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||||
|
|
||||||
request_type = record.get("request_type") or "unknown"
|
request_type = record.get("request_type") or "unknown"
|
||||||
@@ -372,10 +384,11 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
|
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
await StatisticOutputTask._yield_control(record_idx)
|
||||||
# -- 计算派生指标 --
|
# -- 计算派生指标 --
|
||||||
for period_key, period_stats in stats.items():
|
for period_key, period_stats in stats.items():
|
||||||
# 计算模型相关指标
|
# 计算模型相关指标
|
||||||
for model_name, req_count in period_stats[REQ_CNT_BY_MODEL].items():
|
for model_idx, (model_name, req_count) in enumerate(period_stats[REQ_CNT_BY_MODEL].items(), 1):
|
||||||
total_tok = period_stats[TOTAL_TOK_BY_MODEL][model_name] or 0
|
total_tok = period_stats[TOTAL_TOK_BY_MODEL][model_name] or 0
|
||||||
total_cost = period_stats[COST_BY_MODEL][model_name] or 0
|
total_cost = period_stats[COST_BY_MODEL][model_name] or 0
|
||||||
time_costs = period_stats[TIME_COST_BY_MODEL][model_name] or []
|
time_costs = period_stats[TIME_COST_BY_MODEL][model_name] or []
|
||||||
@@ -390,8 +403,10 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
# Avg Tokens per Request
|
# Avg Tokens per Request
|
||||||
period_stats[AVG_TOK_BY_MODEL][model_name] = round(total_tok / req_count) if req_count > 0 else 0
|
period_stats[AVG_TOK_BY_MODEL][model_name] = round(total_tok / req_count) if req_count > 0 else 0
|
||||||
|
|
||||||
|
await StatisticOutputTask._yield_control(model_idx, interval=100)
|
||||||
|
|
||||||
# 计算供应商相关指标
|
# 计算供应商相关指标
|
||||||
for provider_name, req_count in period_stats[REQ_CNT_BY_PROVIDER].items():
|
for provider_idx, (provider_name, req_count) in enumerate(period_stats[REQ_CNT_BY_PROVIDER].items(), 1):
|
||||||
total_tok = period_stats[TOTAL_TOK_BY_PROVIDER][provider_name]
|
total_tok = period_stats[TOTAL_TOK_BY_PROVIDER][provider_name]
|
||||||
total_cost = period_stats[COST_BY_PROVIDER][provider_name]
|
total_cost = period_stats[COST_BY_PROVIDER][provider_name]
|
||||||
time_costs = period_stats[TIME_COST_BY_PROVIDER][provider_name]
|
time_costs = period_stats[TIME_COST_BY_PROVIDER][provider_name]
|
||||||
@@ -404,6 +419,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
if total_tok > 0:
|
if total_tok > 0:
|
||||||
period_stats[COST_PER_KTOK_BY_PROVIDER][provider_name] = round((total_cost / total_tok) * 1000, 4)
|
period_stats[COST_PER_KTOK_BY_PROVIDER][provider_name] = round((total_cost / total_tok) * 1000, 4)
|
||||||
|
|
||||||
|
await StatisticOutputTask._yield_control(provider_idx, interval=100)
|
||||||
|
|
||||||
# 计算平均耗时和标准差
|
# 计算平均耗时和标准差
|
||||||
for category_key, items in [
|
for category_key, items in [
|
||||||
(REQ_CNT_BY_USER, "user"),
|
(REQ_CNT_BY_USER, "user"),
|
||||||
@@ -414,7 +431,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
time_cost_key = f"time_costs_by_{items.lower()}"
|
time_cost_key = f"time_costs_by_{items.lower()}"
|
||||||
avg_key = f"avg_time_costs_by_{items.lower()}"
|
avg_key = f"avg_time_costs_by_{items.lower()}"
|
||||||
std_key = f"std_time_costs_by_{items.lower()}"
|
std_key = f"std_time_costs_by_{items.lower()}"
|
||||||
for item_name in period_stats[category_key]:
|
for idx, item_name in enumerate(period_stats[category_key], 1):
|
||||||
time_costs = period_stats[time_cost_key][item_name]
|
time_costs = period_stats[time_cost_key][item_name]
|
||||||
if time_costs:
|
if time_costs:
|
||||||
avg_time = sum(time_costs) / len(time_costs)
|
avg_time = sum(time_costs) / len(time_costs)
|
||||||
@@ -428,6 +445,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
period_stats[avg_key][item_name] = 0.0
|
period_stats[avg_key][item_name] = 0.0
|
||||||
period_stats[std_key][item_name] = 0.0
|
period_stats[std_key][item_name] = 0.0
|
||||||
|
|
||||||
|
await StatisticOutputTask._yield_control(idx, interval=200)
|
||||||
|
|
||||||
# 准备图表数据
|
# 准备图表数据
|
||||||
# 按供应商花费饼图
|
# 按供应商花费饼图
|
||||||
provider_costs = period_stats[COST_BY_PROVIDER]
|
provider_costs = period_stats[COST_BY_PROVIDER]
|
||||||
@@ -479,7 +498,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
or []
|
or []
|
||||||
)
|
)
|
||||||
|
|
||||||
for record in records:
|
for record_idx, record in enumerate(records, 1):
|
||||||
if not isinstance(record, dict):
|
if not isinstance(record, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -494,12 +513,12 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
if not record_end_timestamp or not record_start_timestamp:
|
if not record_end_timestamp or not record_start_timestamp:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||||
if record_end_timestamp >= period_boundary_start:
|
if record_end_timestamp >= period_boundary_start:
|
||||||
# Calculate effective end time for this record in relation to 'now'
|
# Calculate effective end time for this record in relation to 'now'
|
||||||
effective_end_time = min(record_end_timestamp, now)
|
effective_end_time = min(record_end_timestamp, now)
|
||||||
|
|
||||||
for period_key, current_period_start_time in collect_period[idx:]:
|
for period_key, current_period_start_time in collect_period[boundary_idx:]:
|
||||||
# Determine the portion of the record that falls within this specific statistical period
|
# Determine the portion of the record that falls within this specific statistical period
|
||||||
overlap_start = max(record_start_timestamp, current_period_start_time)
|
overlap_start = max(record_start_timestamp, current_period_start_time)
|
||||||
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
overlap_end = effective_end_time # Already capped by 'now' and record's own end
|
||||||
@@ -507,6 +526,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
if overlap_end > overlap_start:
|
if overlap_end > overlap_start:
|
||||||
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
await StatisticOutputTask._yield_control(record_idx)
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||||
@@ -538,7 +559,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
or []
|
or []
|
||||||
)
|
)
|
||||||
|
|
||||||
for message in records:
|
for message_idx, message in enumerate(records, 1):
|
||||||
if not isinstance(message, dict):
|
if not isinstance(message, dict):
|
||||||
continue
|
continue
|
||||||
message_time_ts = message.get("time") # This is a float timestamp
|
message_time_ts = message.get("time") # This is a float timestamp
|
||||||
@@ -572,12 +593,15 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||||
else:
|
else:
|
||||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
for period_idx, (_, period_start_dt) in enumerate(collect_period):
|
||||||
if message_time_ts >= period_start_dt.timestamp():
|
if message_time_ts >= period_start_dt.timestamp():
|
||||||
for period_key, _ in collect_period[idx:]:
|
for period_key, _ in collect_period[period_idx:]:
|
||||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
|
await StatisticOutputTask._yield_control(message_idx)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
|
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
|
||||||
@@ -767,7 +791,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
)
|
)
|
||||||
or []
|
or []
|
||||||
)
|
)
|
||||||
for record in llm_records:
|
for record_idx, record in enumerate(llm_records, 1):
|
||||||
if not isinstance(record, dict) or not record.get("timestamp"):
|
if not isinstance(record, dict) or not record.get("timestamp"):
|
||||||
continue
|
continue
|
||||||
record_time = record["timestamp"]
|
record_time = record["timestamp"]
|
||||||
@@ -791,6 +815,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
cost_by_module[module_name] = [0.0] * len(time_points)
|
cost_by_module[module_name] = [0.0] * len(time_points)
|
||||||
cost_by_module[module_name][idx] += cost
|
cost_by_module[module_name][idx] += cost
|
||||||
|
|
||||||
|
await StatisticOutputTask._yield_control(record_idx)
|
||||||
|
|
||||||
# 单次查询 Messages
|
# 单次查询 Messages
|
||||||
msg_records = (
|
msg_records = (
|
||||||
await db_get(
|
await db_get(
|
||||||
@@ -800,7 +826,7 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
)
|
)
|
||||||
or []
|
or []
|
||||||
)
|
)
|
||||||
for msg in msg_records:
|
for msg_idx, msg in enumerate(msg_records, 1):
|
||||||
if not isinstance(msg, dict) or not msg.get("time"):
|
if not isinstance(msg, dict) or not msg.get("time"):
|
||||||
continue
|
continue
|
||||||
msg_ts = msg["time"]
|
msg_ts = msg["time"]
|
||||||
@@ -819,6 +845,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
message_by_chat[chat_name] = [0] * len(time_points)
|
message_by_chat[chat_name] = [0] * len(time_points)
|
||||||
message_by_chat[chat_name][idx] += 1
|
message_by_chat[chat_name][idx] += 1
|
||||||
|
|
||||||
|
await StatisticOutputTask._yield_control(msg_idx)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"time_labels": time_labels,
|
"time_labels": time_labels,
|
||||||
"total_cost_data": total_cost_data,
|
"total_cost_data": total_cost_data,
|
||||||
|
|||||||
@@ -658,7 +658,9 @@ class LongTermMemoryManager:
|
|||||||
memory.metadata["transfer_time"] = datetime.now().isoformat()
|
memory.metadata["transfer_time"] = datetime.now().isoformat()
|
||||||
|
|
||||||
logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})")
|
logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})")
|
||||||
self._register_temp_id(op.target_id, memory.id, temp_id_map)
|
# 强制注册 target_id,无论它是否符合 placeholder 格式
|
||||||
|
# 这样即使 LLM 使用了中文描述作为 ID (如 "新创建的记忆"), 也能正确映射
|
||||||
|
self._register_temp_id(op.target_id, memory.id, temp_id_map, force=True)
|
||||||
self._register_aliases_from_params(
|
self._register_aliases_from_params(
|
||||||
op.parameters,
|
op.parameters,
|
||||||
memory.id,
|
memory.id,
|
||||||
@@ -766,7 +768,8 @@ class LongTermMemoryManager:
|
|||||||
# 尝试为新节点生成 embedding (异步)
|
# 尝试为新节点生成 embedding (异步)
|
||||||
asyncio.create_task(self._generate_node_embedding(node_id, content))
|
asyncio.create_task(self._generate_node_embedding(node_id, content))
|
||||||
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
|
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
|
||||||
self._register_temp_id(op.target_id, node_id, temp_id_map)
|
# 强制注册 target_id,无论它是否符合 placeholder 格式
|
||||||
|
self._register_temp_id(op.target_id, node_id, temp_id_map, force=True)
|
||||||
self._register_aliases_from_params(
|
self._register_aliases_from_params(
|
||||||
op.parameters,
|
op.parameters,
|
||||||
node_id,
|
node_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user