feat:测试性的新辅助记忆系统

This commit is contained in:
SengokuCola
2025-07-16 16:11:56 +08:00
parent e2ce6a14f4
commit 4aff3c8005
8 changed files with 306 additions and 8 deletions

View File

@@ -0,0 +1,258 @@
# -*- coding: utf-8 -*-
import time
import re
import json
import ast
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
import traceback
from src.config.config import global_config
from src.common.database.database_model import Memory # Peewee Models导入
logger = get_logger(__name__)
class MemoryItem:
def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]):
self.memory_id = memory_id
self.chat_id = chat_id
self.memory_text:str = memory_text
self.keywords:list[str] = keywords
self.create_time:float = time.time()
self.last_view_time:float = time.time()
class MemoryManager:
def __init__(self):
# self.memory_items:list[MemoryItem] = []
pass
class InstantMemory:
def __init__(self,chat_id):
self.chat_id = chat_id
self.last_view_time = time.time()
self.summary_model = LLMRequest(
model=global_config.model.memory,
temperature=0.5,
request_type="memory.summary",
)
async def if_need_build(self,text):
prompt = f"""
请判断以下内容中是否有值得记忆的信息如果有请输出1否则输出0
{text}
请只输出1或0就好
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if "1" in response:
return True
else:
return False
except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False
async def build_memory(self,text):
prompt = f"""
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
{text}
请以json格式输出一段概括的记忆内容和关键词
{{
"memory_text": "记忆内容",
"keywords": "关键词,用/划分"
}}
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if not response:
return None
try:
repaired = repair_json(response)
result = json.loads(repaired)
memory_text = result.get('memory_text', '')
keywords = result.get('keywords', '')
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
return {'memory_text': memory_text, 'keywords': keywords_list}
except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
async def create_and_store_memory(self,text):
if_need = await self.if_need_build(text)
if if_need:
logger.info(f"需要记忆:{text}")
memory = await self.build_memory(text)
if memory and memory.get('memory_text'):
memory_id = f"{self.chat_id}_{time.time()}"
memory_item = MemoryItem(
memory_id=memory_id,
chat_id=self.chat_id,
memory_text=memory['memory_text'],
keywords=memory.get('keywords', [])
)
await self.store_memory(memory_item)
else:
logger.info(f"不需要记忆:{text}")
async def store_memory(self,memory_item:MemoryItem):
memory = Memory(
memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text,
keywords=memory_item.keywords,
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time
)
memory.save()
async def get_memory(self,target:str):
from json_repair import repair_json
prompt = f"""
请根据以下发言内容,判断是否需要提取记忆
{target}
请用json格式输出包含以下字段
其中time的要求是
可以选择具体日期时间格式为YYYY-MM-DD HH:MM:SS或者大致时间格式为YYYY-MM-DD
可以选择相对时间例如今天昨天前天5天前1个月前
可以选择留空进行模糊搜索
{{
"need_memory": 1,
"keywords": "希望获取的记忆关键词,用/划分",
"time": "希望获取的记忆大致时间"
}}
请只输出json格式不要输出其他多余内容
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if not response:
return None
try:
repaired = repair_json(response)
result = json.loads(repaired)
# 解析keywords
keywords = result.get('keywords', '')
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
# 解析time为时间段
time_str = result.get('time', '').strip()
start_time, end_time = self._parse_time_range(time_str)
logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆
memories_set = set()
if start_time and end_time:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = Memory.select().where(
(Memory.chat_id == self.chat_id) &
(Memory.create_time >= start_ts) &
(Memory.create_time < end_ts)
)
else:
query = Memory.select().where(Memory.chat_id == self.chat_id)
for mem in query:
#对每条记忆
mem_keywords = mem.keywords or []
parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list):
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
else:
mem_keywords = []
# logger.info(f"mem_keywords: {mem_keywords}")
# logger.info(f"keywords_list: {keywords_list}")
for kw in keywords_list:
# logger.info(f"kw: {kw}")
# logger.info(f"kw in mem_keywords: {kw in mem_keywords}")
if kw in mem_keywords:
# logger.info(f"mem.memory_text: {mem.memory_text}")
memories_set.add(mem.memory_text)
break
return list(memories_set)
except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"获取记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
def _parse_time_range(self, time_str):
"""
支持解析如下格式:
- 具体日期时间YYYY-MM-DD HH:MM:SS
- 具体日期YYYY-MM-DD
- 相对时间今天昨天前天N天前N个月前
- 空字符串:返回(None, None)
"""
from datetime import datetime, timedelta
now = datetime.now()
if not time_str:
return 0, now
time_str = time_str.strip()
# 具体日期时间
try:
dt = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
return dt, dt + timedelta(hours=1)
except Exception:
pass
# 具体日期
try:
dt = datetime.strptime(time_str, "%Y-%m-%d")
return dt, dt + timedelta(days=1)
except Exception:
pass
# 相对时间
if time_str == "今天":
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
if time_str == "昨天":
start = (now - timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
if time_str == "前天":
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
m = re.match(r"(\d+)天前", time_str)
if m:
days = int(m.group(1))
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
m = re.match(r"(\d+)个月前", time_str)
if m:
months = int(m.group(1))
# 近似每月30天
start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
# 其他无法解析
return 0, now