typing fix
This commit is contained in:
@@ -9,51 +9,49 @@ from src.common.logger import get_logger
|
||||
import traceback
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]):
|
||||
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
|
||||
self.memory_id = memory_id
|
||||
self.chat_id = chat_id
|
||||
self.memory_text:str = memory_text
|
||||
self.keywords:list[str] = keywords
|
||||
self.create_time:float = time.time()
|
||||
self.last_view_time:float = time.time()
|
||||
|
||||
self.memory_text: str = memory_text
|
||||
self.keywords: list[str] = keywords
|
||||
self.create_time: float = time.time()
|
||||
self.last_view_time: float = time.time()
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self):
|
||||
# self.memory_items:list[MemoryItem] = []
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class InstantMemory:
|
||||
def __init__(self,chat_id):
|
||||
self.chat_id = chat_id
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
self.last_view_time = time.time()
|
||||
self.summary_model = LLMRequest(
|
||||
model=global_config.model.memory,
|
||||
temperature=0.5,
|
||||
request_type="memory.summary",
|
||||
)
|
||||
|
||||
async def if_need_build(self,text):
|
||||
|
||||
async def if_need_build(self, text):
|
||||
prompt = f"""
|
||||
请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0
|
||||
{text}
|
||||
请只输出1或0就好
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
|
||||
|
||||
|
||||
if "1" in response:
|
||||
return True
|
||||
else:
|
||||
@@ -61,8 +59,8 @@ class InstantMemory:
|
||||
except Exception as e:
|
||||
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
async def build_memory(self,text):
|
||||
|
||||
async def build_memory(self, text):
|
||||
prompt = f"""
|
||||
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
|
||||
{text}
|
||||
@@ -73,7 +71,7 @@ class InstantMemory:
|
||||
}}
|
||||
"""
|
||||
try:
|
||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
@@ -81,53 +79,53 @@ class InstantMemory:
|
||||
try:
|
||||
repaired = repair_json(response)
|
||||
result = json.loads(repaired)
|
||||
memory_text = result.get('memory_text', '')
|
||||
keywords = result.get('keywords', '')
|
||||
memory_text = result.get("memory_text", "")
|
||||
keywords = result.get("keywords", "")
|
||||
if isinstance(keywords, str):
|
||||
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
|
||||
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||
elif isinstance(keywords, list):
|
||||
keywords_list = keywords
|
||||
else:
|
||||
keywords_list = []
|
||||
return {'memory_text': memory_text, 'keywords': keywords_list}
|
||||
return {"memory_text": memory_text, "keywords": keywords_list}
|
||||
except Exception as parse_e:
|
||||
logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
|
||||
async def create_and_store_memory(self,text):
|
||||
async def create_and_store_memory(self, text):
|
||||
if_need = await self.if_need_build(text)
|
||||
if if_need:
|
||||
logger.info(f"需要记忆:{text}")
|
||||
memory = await self.build_memory(text)
|
||||
if memory and memory.get('memory_text'):
|
||||
memory = await self.build_memory(text)
|
||||
if memory and memory.get("memory_text"):
|
||||
memory_id = f"{self.chat_id}_{time.time()}"
|
||||
memory_item = MemoryItem(
|
||||
memory_id=memory_id,
|
||||
chat_id=self.chat_id,
|
||||
memory_text=memory['memory_text'],
|
||||
keywords=memory.get('keywords', [])
|
||||
memory_text=memory["memory_text"],
|
||||
keywords=memory.get("keywords", []),
|
||||
)
|
||||
await self.store_memory(memory_item)
|
||||
else:
|
||||
logger.info(f"不需要记忆:{text}")
|
||||
|
||||
async def store_memory(self,memory_item:MemoryItem):
|
||||
|
||||
async def store_memory(self, memory_item: MemoryItem):
|
||||
memory = Memory(
|
||||
memory_id=memory_item.memory_id,
|
||||
chat_id=memory_item.chat_id,
|
||||
memory_text=memory_item.memory_text,
|
||||
keywords=memory_item.keywords,
|
||||
create_time=memory_item.create_time,
|
||||
last_view_time=memory_item.last_view_time
|
||||
last_view_time=memory_item.last_view_time,
|
||||
)
|
||||
memory.save()
|
||||
|
||||
async def get_memory(self,target:str):
|
||||
|
||||
async def get_memory(self, target: str):
|
||||
from json_repair import repair_json
|
||||
|
||||
prompt = f"""
|
||||
请根据以下发言内容,判断是否需要提取记忆
|
||||
{target}
|
||||
@@ -144,7 +142,7 @@ class InstantMemory:
|
||||
请只输出json格式,不要输出其他多余内容
|
||||
"""
|
||||
try:
|
||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
@@ -153,15 +151,15 @@ class InstantMemory:
|
||||
repaired = repair_json(response)
|
||||
result = json.loads(repaired)
|
||||
# 解析keywords
|
||||
keywords = result.get('keywords', '')
|
||||
keywords = result.get("keywords", "")
|
||||
if isinstance(keywords, str):
|
||||
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
|
||||
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||
elif isinstance(keywords, list):
|
||||
keywords_list = keywords
|
||||
else:
|
||||
keywords_list = []
|
||||
# 解析time为时间段
|
||||
time_str = result.get('time', '').strip()
|
||||
time_str = result.get("time", "").strip()
|
||||
start_time, end_time = self._parse_time_range(time_str)
|
||||
logger.info(f"start_time: {start_time}, end_time: {end_time}")
|
||||
# 检索包含关键词的记忆
|
||||
@@ -170,16 +168,15 @@ class InstantMemory:
|
||||
start_ts = start_time.timestamp()
|
||||
end_ts = end_time.timestamp()
|
||||
query = Memory.select().where(
|
||||
(Memory.chat_id == self.chat_id) &
|
||||
(Memory.create_time >= start_ts) &
|
||||
(Memory.create_time < end_ts)
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts) # type: ignore
|
||||
& (Memory.create_time < end_ts) # type: ignore
|
||||
)
|
||||
else:
|
||||
query = Memory.select().where(Memory.chat_id == self.chat_id)
|
||||
|
||||
|
||||
for mem in query:
|
||||
#对每条记忆
|
||||
# 对每条记忆
|
||||
mem_keywords = mem.keywords or []
|
||||
parsed = ast.literal_eval(mem_keywords)
|
||||
if isinstance(parsed, list):
|
||||
@@ -212,6 +209,7 @@ class InstantMemory:
|
||||
- 空字符串:返回(None, None)
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
if not time_str:
|
||||
return 0, now
|
||||
@@ -251,8 +249,8 @@ class InstantMemory:
|
||||
if m:
|
||||
months = int(m.group(1))
|
||||
# 近似每月30天
|
||||
start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
# 其他无法解析
|
||||
return 0, now
|
||||
return 0, now
|
||||
|
||||
Reference in New Issue
Block a user