feat: 强制注册长期记忆目标ID,支持中文描述作为ID映射

This commit is contained in:
Windpicker-owo
2025-11-19 20:19:06 +08:00
parent ff3d2f5ef3
commit 4c30388f18
2 changed files with 47 additions and 16 deletions

View File

@@ -170,6 +170,18 @@ class StatisticOutputTask(AsyncTask):
logger.info("\n" + "\n".join(output)) logger.info("\n" + "\n".join(output))
@staticmethod
async def _yield_control(iteration: int, interval: int = 200) -> None:
"""
<20>ڴ<EFBFBD><DAB4><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʱ<EFBFBD><CAB1><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>¼<EFBFBD>ѭ<EFBFBD><D1AD><EFBFBD><EFBFBD><EFBFBD>Ӧ
Args:
iteration: <20><>ǰ<EFBFBD><C7B0><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
interval: ÿ<><C3BF><EFBFBD><EFBFBD><EFBFBD>ٴ<EFBFBD><D9B4>л<EFBFBD>һ<EFBFBD><D2BB>
"""
if iteration % interval == 0:
await asyncio.sleep(0)
async def run(self): async def run(self):
try: try:
now = datetime.now() now = datetime.now()
@@ -304,7 +316,7 @@ class StatisticOutputTask(AsyncTask):
or [] or []
) )
for record in records: for record_idx, record in enumerate(records, 1):
if not isinstance(record, dict): if not isinstance(record, dict):
continue continue
@@ -315,9 +327,9 @@ class StatisticOutputTask(AsyncTask):
if not record_timestamp: if not record_timestamp:
continue continue
for idx, (_, period_start) in enumerate(collect_period): for period_idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start: if record_timestamp >= period_start:
for period_key, _ in collect_period[idx:]: for period_key, _ in collect_period[period_idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1 stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.get("request_type") or "unknown" request_type = record.get("request_type") or "unknown"
@@ -372,10 +384,11 @@ class StatisticOutputTask(AsyncTask):
stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost) stats[period_key][TIME_COST_BY_PROVIDER][provider_name].append(time_cost)
break break
await StatisticOutputTask._yield_control(record_idx)
# -- 计算派生指标 -- # -- 计算派生指标 --
for period_key, period_stats in stats.items(): for period_key, period_stats in stats.items():
# 计算模型相关指标 # 计算模型相关指标
for model_name, req_count in period_stats[REQ_CNT_BY_MODEL].items(): for model_idx, (model_name, req_count) in enumerate(period_stats[REQ_CNT_BY_MODEL].items(), 1):
total_tok = period_stats[TOTAL_TOK_BY_MODEL][model_name] or 0 total_tok = period_stats[TOTAL_TOK_BY_MODEL][model_name] or 0
total_cost = period_stats[COST_BY_MODEL][model_name] or 0 total_cost = period_stats[COST_BY_MODEL][model_name] or 0
time_costs = period_stats[TIME_COST_BY_MODEL][model_name] or [] time_costs = period_stats[TIME_COST_BY_MODEL][model_name] or []
@@ -390,8 +403,10 @@ class StatisticOutputTask(AsyncTask):
# Avg Tokens per Request # Avg Tokens per Request
period_stats[AVG_TOK_BY_MODEL][model_name] = round(total_tok / req_count) if req_count > 0 else 0 period_stats[AVG_TOK_BY_MODEL][model_name] = round(total_tok / req_count) if req_count > 0 else 0
await StatisticOutputTask._yield_control(model_idx, interval=100)
# 计算供应商相关指标 # 计算供应商相关指标
for provider_name, req_count in period_stats[REQ_CNT_BY_PROVIDER].items(): for provider_idx, (provider_name, req_count) in enumerate(period_stats[REQ_CNT_BY_PROVIDER].items(), 1):
total_tok = period_stats[TOTAL_TOK_BY_PROVIDER][provider_name] total_tok = period_stats[TOTAL_TOK_BY_PROVIDER][provider_name]
total_cost = period_stats[COST_BY_PROVIDER][provider_name] total_cost = period_stats[COST_BY_PROVIDER][provider_name]
time_costs = period_stats[TIME_COST_BY_PROVIDER][provider_name] time_costs = period_stats[TIME_COST_BY_PROVIDER][provider_name]
@@ -404,6 +419,8 @@ class StatisticOutputTask(AsyncTask):
if total_tok > 0: if total_tok > 0:
period_stats[COST_PER_KTOK_BY_PROVIDER][provider_name] = round((total_cost / total_tok) * 1000, 4) period_stats[COST_PER_KTOK_BY_PROVIDER][provider_name] = round((total_cost / total_tok) * 1000, 4)
await StatisticOutputTask._yield_control(provider_idx, interval=100)
# 计算平均耗时和标准差 # 计算平均耗时和标准差
for category_key, items in [ for category_key, items in [
(REQ_CNT_BY_USER, "user"), (REQ_CNT_BY_USER, "user"),
@@ -414,7 +431,7 @@ class StatisticOutputTask(AsyncTask):
time_cost_key = f"time_costs_by_{items.lower()}" time_cost_key = f"time_costs_by_{items.lower()}"
avg_key = f"avg_time_costs_by_{items.lower()}" avg_key = f"avg_time_costs_by_{items.lower()}"
std_key = f"std_time_costs_by_{items.lower()}" std_key = f"std_time_costs_by_{items.lower()}"
for item_name in period_stats[category_key]: for idx, item_name in enumerate(period_stats[category_key], 1):
time_costs = period_stats[time_cost_key][item_name] time_costs = period_stats[time_cost_key][item_name]
if time_costs: if time_costs:
avg_time = sum(time_costs) / len(time_costs) avg_time = sum(time_costs) / len(time_costs)
@@ -428,6 +445,8 @@ class StatisticOutputTask(AsyncTask):
period_stats[avg_key][item_name] = 0.0 period_stats[avg_key][item_name] = 0.0
period_stats[std_key][item_name] = 0.0 period_stats[std_key][item_name] = 0.0
await StatisticOutputTask._yield_control(idx, interval=200)
# 准备图表数据 # 准备图表数据
# 按供应商花费饼图 # 按供应商花费饼图
provider_costs = period_stats[COST_BY_PROVIDER] provider_costs = period_stats[COST_BY_PROVIDER]
@@ -479,7 +498,7 @@ class StatisticOutputTask(AsyncTask):
or [] or []
) )
for record in records: for record_idx, record in enumerate(records, 1):
if not isinstance(record, dict): if not isinstance(record, dict):
continue continue
@@ -494,12 +513,12 @@ class StatisticOutputTask(AsyncTask):
if not record_end_timestamp or not record_start_timestamp: if not record_end_timestamp or not record_start_timestamp:
continue continue
for idx, (_, period_boundary_start) in enumerate(collect_period): for boundary_idx, (_, period_boundary_start) in enumerate(collect_period):
if record_end_timestamp >= period_boundary_start: if record_end_timestamp >= period_boundary_start:
# Calculate effective end time for this record in relation to 'now' # Calculate effective end time for this record in relation to 'now'
effective_end_time = min(record_end_timestamp, now) effective_end_time = min(record_end_timestamp, now)
for period_key, current_period_start_time in collect_period[idx:]: for period_key, current_period_start_time in collect_period[boundary_idx:]:
# Determine the portion of the record that falls within this specific statistical period # Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time) overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end overlap_end = effective_end_time # Already capped by 'now' and record's own end
@@ -507,6 +526,8 @@ class StatisticOutputTask(AsyncTask):
if overlap_end > overlap_start: if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds() stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
break break
await StatisticOutputTask._yield_control(record_idx)
return stats return stats
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]: async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
@@ -538,7 +559,7 @@ class StatisticOutputTask(AsyncTask):
or [] or []
) )
for message in records: for message_idx, message in enumerate(records, 1):
if not isinstance(message, dict): if not isinstance(message, dict):
continue continue
message_time_ts = message.get("time") # This is a float timestamp message_time_ts = message.get("time") # This is a float timestamp
@@ -572,12 +593,15 @@ class StatisticOutputTask(AsyncTask):
self.name_mapping[chat_id] = (chat_name, message_time_ts) self.name_mapping[chat_id] = (chat_name, message_time_ts)
else: else:
self.name_mapping[chat_id] = (chat_name, message_time_ts) self.name_mapping[chat_id] = (chat_name, message_time_ts)
for idx, (_, period_start_dt) in enumerate(collect_period): for period_idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp(): if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]: for period_key, _ in collect_period[period_idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
break break
await StatisticOutputTask._yield_control(message_idx)
return stats return stats
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]: async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
@@ -767,7 +791,7 @@ class StatisticOutputTask(AsyncTask):
) )
or [] or []
) )
for record in llm_records: for record_idx, record in enumerate(llm_records, 1):
if not isinstance(record, dict) or not record.get("timestamp"): if not isinstance(record, dict) or not record.get("timestamp"):
continue continue
record_time = record["timestamp"] record_time = record["timestamp"]
@@ -791,6 +815,8 @@ class StatisticOutputTask(AsyncTask):
cost_by_module[module_name] = [0.0] * len(time_points) cost_by_module[module_name] = [0.0] * len(time_points)
cost_by_module[module_name][idx] += cost cost_by_module[module_name][idx] += cost
await StatisticOutputTask._yield_control(record_idx)
# 单次查询 Messages # 单次查询 Messages
msg_records = ( msg_records = (
await db_get( await db_get(
@@ -800,7 +826,7 @@ class StatisticOutputTask(AsyncTask):
) )
or [] or []
) )
for msg in msg_records: for msg_idx, msg in enumerate(msg_records, 1):
if not isinstance(msg, dict) or not msg.get("time"): if not isinstance(msg, dict) or not msg.get("time"):
continue continue
msg_ts = msg["time"] msg_ts = msg["time"]
@@ -819,6 +845,8 @@ class StatisticOutputTask(AsyncTask):
message_by_chat[chat_name] = [0] * len(time_points) message_by_chat[chat_name] = [0] * len(time_points)
message_by_chat[chat_name][idx] += 1 message_by_chat[chat_name][idx] += 1
await StatisticOutputTask._yield_control(msg_idx)
return { return {
"time_labels": time_labels, "time_labels": time_labels,
"total_cost_data": total_cost_data, "total_cost_data": total_cost_data,

View File

@@ -658,7 +658,9 @@ class LongTermMemoryManager:
memory.metadata["transfer_time"] = datetime.now().isoformat() memory.metadata["transfer_time"] = datetime.now().isoformat()
logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})") logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})")
self._register_temp_id(op.target_id, memory.id, temp_id_map) # 强制注册 target_id无论它是否符合 placeholder 格式
# 这样即使 LLM 使用了中文描述作为 ID (如 "新创建的记忆"), 也能正确映射
self._register_temp_id(op.target_id, memory.id, temp_id_map, force=True)
self._register_aliases_from_params( self._register_aliases_from_params(
op.parameters, op.parameters,
memory.id, memory.id,
@@ -766,7 +768,8 @@ class LongTermMemoryManager:
# 尝试为新节点生成 embedding (异步) # 尝试为新节点生成 embedding (异步)
asyncio.create_task(self._generate_node_embedding(node_id, content)) asyncio.create_task(self._generate_node_embedding(node_id, content))
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}") logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
self._register_temp_id(op.target_id, node_id, temp_id_map) # 强制注册 target_id无论它是否符合 placeholder 格式
self._register_temp_id(op.target_id, node_id, temp_id_map, force=True)
self._register_aliases_from_params( self._register_aliases_from_params(
op.parameters, op.parameters,
node_id, node_id,