调整对应的调用
This commit is contained in:
@@ -5,25 +5,27 @@ import random
|
||||
import time
|
||||
import re
|
||||
import json
|
||||
from itertools import combinations
|
||||
|
||||
import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from itertools import combinations
|
||||
from typing import List, Tuple, Coroutine, Any, Dict, Set
|
||||
from collections import Counter
|
||||
from ...llm_models.utils_model import LLMRequest
|
||||
from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
from ..utils.chat_message_builder import (
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
) # 导入 build_readable_messages
|
||||
from ..utils.utils import translate_timestamp_to_human_readable
|
||||
from rich.traceback import install
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
|
||||
from ...config.config import global_config
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -198,8 +200,7 @@ class Hippocampus:
|
||||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||
# 从数据库加载记忆图
|
||||
self.entorhinal_cortex.sync_memory_from_db()
|
||||
# TODO: API-Adapter修改标记
|
||||
self.model_summary = LLMRequest(global_config.model.memory, request_type="memory.builder")
|
||||
self.model_summary = LLMRequest(model_set=model_config.model_task_config.memory, request_type="memory.builder")
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
@@ -339,9 +340,7 @@ class Hippocampus:
|
||||
else:
|
||||
topic_num = 5 # 51+字符: 5个关键词 (其余长文本)
|
||||
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
topics_response, _ = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
@@ -353,12 +352,11 @@ class Hippocampus:
|
||||
for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
if keyword.strip()
|
||||
]
|
||||
|
||||
|
||||
if keywords:
|
||||
logger.info(f"提取关键词: {keywords}")
|
||||
|
||||
return keywords
|
||||
|
||||
|
||||
return keywords
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
@@ -1245,7 +1243,7 @@ class ParahippocampalGyrus:
|
||||
|
||||
# 2. 使用LLM提取关键主题
|
||||
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
|
||||
topics_response, (reasoning_content, model_name) = await self.hippocampus.model_summary.generate_response_async(
|
||||
topics_response, _ = await self.hippocampus.model_summary.generate_response_async(
|
||||
self.hippocampus.find_topic_llm(input_text, topic_num)
|
||||
)
|
||||
|
||||
@@ -1269,7 +1267,7 @@ class ParahippocampalGyrus:
|
||||
logger.debug(f"过滤后话题: {filtered_topics}")
|
||||
|
||||
# 4. 创建所有话题的摘要生成任务
|
||||
tasks = []
|
||||
tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List[Dict[str, Any]] | None]]]]] = []
|
||||
for topic in filtered_topics:
|
||||
# 调用修改后的 topic_what,不再需要 time_info
|
||||
topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
|
||||
@@ -1281,7 +1279,7 @@ class ParahippocampalGyrus:
|
||||
continue
|
||||
|
||||
# 等待所有任务完成
|
||||
compressed_memory = set()
|
||||
compressed_memory: Set[Tuple[str, str]] = set()
|
||||
similar_topics_dict = {}
|
||||
|
||||
for topic, task in tasks:
|
||||
|
||||
@@ -3,13 +3,16 @@ import time
|
||||
import re
|
||||
import json
|
||||
import ast
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
import traceback
|
||||
|
||||
from src.config.config import global_config
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -35,8 +38,7 @@ class InstantMemory:
|
||||
self.chat_id = chat_id
|
||||
self.last_view_time = time.time()
|
||||
self.summary_model = LLMRequest(
|
||||
model=global_config.model.memory,
|
||||
temperature=0.5,
|
||||
model_set=model_config.model_task_config.memory,
|
||||
request_type="memory.summary",
|
||||
)
|
||||
|
||||
@@ -48,14 +50,11 @@ class InstantMemory:
|
||||
"""
|
||||
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
print(prompt)
|
||||
print(response)
|
||||
|
||||
if "1" in response:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return "1" in response
|
||||
except Exception as e:
|
||||
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return False
|
||||
@@ -71,9 +70,9 @@ class InstantMemory:
|
||||
}}
|
||||
"""
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
if not response:
|
||||
return None
|
||||
try:
|
||||
@@ -142,7 +141,7 @@ class InstantMemory:
|
||||
请只输出json格式,不要输出其他多余内容
|
||||
"""
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
@@ -177,7 +176,7 @@ class InstantMemory:
|
||||
|
||||
for mem in query:
|
||||
# 对每条记忆
|
||||
mem_keywords = mem.keywords or []
|
||||
mem_keywords = mem.keywords or ""
|
||||
parsed = ast.literal_eval(mem_keywords)
|
||||
if isinstance(parsed, list):
|
||||
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
|
||||
@@ -201,6 +200,7 @@ class InstantMemory:
|
||||
return None
|
||||
|
||||
def _parse_time_range(self, time_str):
|
||||
# sourcery skip: extract-duplicate-method, use-contextlib-suppress
|
||||
"""
|
||||
支持解析如下格式:
|
||||
- 具体日期时间:YYYY-MM-DD HH:MM:SS
|
||||
@@ -208,8 +208,6 @@ class InstantMemory:
|
||||
- 相对时间:今天,昨天,前天,N天前,N个月前
|
||||
- 空字符串:返回(None, None)
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
if not time_str:
|
||||
return 0, now
|
||||
@@ -239,14 +237,12 @@ class InstantMemory:
|
||||
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
m = re.match(r"(\d+)天前", time_str)
|
||||
if m:
|
||||
if m := re.match(r"(\d+)天前", time_str):
|
||||
days = int(m.group(1))
|
||||
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
m = re.match(r"(\d+)个月前", time_str)
|
||||
if m:
|
||||
if m := re.match(r"(\d+)个月前", time_str):
|
||||
months = int(m.group(1))
|
||||
# 近似每月30天
|
||||
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from datetime import datetime
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from typing import List, Dict
|
||||
import difflib
|
||||
import json
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
@@ -61,11 +63,8 @@ def init_prompt():
|
||||
|
||||
class MemoryActivator:
|
||||
def __init__(self):
|
||||
# TODO: API-Adapter修改标记
|
||||
|
||||
self.key_words_model = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
temperature=0.5,
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.activator",
|
||||
)
|
||||
|
||||
@@ -92,7 +91,9 @@ class MemoryActivator:
|
||||
|
||||
# logger.debug(f"prompt: {prompt}")
|
||||
|
||||
response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt)
|
||||
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
|
||||
prompt, temperature=0.5
|
||||
)
|
||||
|
||||
keywords = list(get_keywords_from_json(response))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user