Merge remote-tracking branch 'upstream/debug' into debug
This commit is contained in:
@@ -235,10 +235,10 @@ class ChatBot:
|
||||
is_head=not mark_head,
|
||||
is_emoji=False,
|
||||
)
|
||||
print(f"bot_message: {bot_message}")
|
||||
logger.debug(f"bot_message: {bot_message}")
|
||||
if not mark_head:
|
||||
mark_head = True
|
||||
print(f"添加消息到message_set: {bot_message}")
|
||||
logger.debug(f"添加消息到message_set: {bot_message}")
|
||||
message_set.add_message(bot_message)
|
||||
|
||||
# message_set 可以直接加入 message_manager
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
# 解析各种CQ码
|
||||
|
||||
@@ -246,7 +246,7 @@ class EmojiManager:
|
||||
# 即使表情包已存在,也检查是否需要同步到images集合
|
||||
description = existing_emoji.get('discription')
|
||||
# 检查是否在images集合中存在
|
||||
existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
|
||||
existing_image = image_manager.db.images.find_one({'hash': image_hash})
|
||||
if not existing_image:
|
||||
# 同步到images集合
|
||||
image_doc = {
|
||||
@@ -256,7 +256,7 @@ class EmojiManager:
|
||||
'description': description,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
image_manager.db.db.images.update_one(
|
||||
image_manager.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
@@ -318,7 +318,7 @@ class EmojiManager:
|
||||
'description': description,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
image_manager.db.db.images.update_one(
|
||||
image_manager.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
|
||||
@@ -88,13 +88,13 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
list: 消息记录列表,每个记录包含时间和文本信息
|
||||
"""
|
||||
chat_records = []
|
||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
chat_id = closest_record['chat_id'] # 获取chat_id
|
||||
# 获取该时间戳之后的length条消息,保持相同的chat_id
|
||||
chat_records = list(db.db.messages.find(
|
||||
chat_records = list(db.messages.find(
|
||||
{
|
||||
"time": {"$gt": closest_time},
|
||||
"chat_id": chat_id # 添加chat_id过滤
|
||||
@@ -128,7 +128,7 @@ async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
|
||||
"""
|
||||
|
||||
# 从数据库获取最近消息
|
||||
recent_messages = list(db.db.messages.find(
|
||||
recent_messages = list(db.messages.find(
|
||||
{"chat_id": chat_id},
|
||||
).sort("time", -1).limit(limit))
|
||||
|
||||
@@ -162,7 +162,7 @@ async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
|
||||
|
||||
|
||||
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
|
||||
recent_messages = list(db.db.messages.find(
|
||||
recent_messages = list(db.messages.find(
|
||||
{"chat_id": chat_stream_id},
|
||||
{
|
||||
"time": 1, # 返回时间字段
|
||||
|
||||
@@ -289,6 +289,7 @@ class ImageManager:
|
||||
async def get_image_description(self, image_base64: str) -> str:
|
||||
"""获取普通图片描述,带查重和保存功能"""
|
||||
try:
|
||||
print("处理图片中")
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
@@ -296,12 +297,15 @@ class ImageManager:
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, 'image')
|
||||
if cached_description:
|
||||
print("图片描述缓存中")
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
|
||||
print(f"描述是{description}")
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片]"
|
||||
|
||||
@@ -5,101 +5,98 @@ from typing import Dict
|
||||
from .config import global_config
|
||||
from .chat_stream import ChatStream
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class WillingManager:
|
||||
def __init__(self):
|
||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||
self._decay_task = None
|
||||
self._started = False
|
||||
|
||||
|
||||
async def _decay_reply_willing(self):
|
||||
"""定期衰减回复意愿"""
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
for chat_id in self.chat_reply_willing:
|
||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||
for chat_id in self.chat_reply_willing:
|
||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||
|
||||
def get_willing(self,chat_stream:ChatStream) -> float:
|
||||
|
||||
def get_willing(self, chat_stream: ChatStream) -> float:
|
||||
"""获取指定聊天流的回复意愿"""
|
||||
stream = chat_stream
|
||||
if stream:
|
||||
return self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
return 0
|
||||
|
||||
|
||||
def set_willing(self, chat_id: str, willing: float):
|
||||
"""设置指定聊天流的回复意愿"""
|
||||
self.chat_reply_willing[chat_id] = willing
|
||||
def set_willing(self, chat_id: str, willing: float):
|
||||
"""设置指定聊天流的回复意愿"""
|
||||
self.chat_reply_willing[chat_id] = willing
|
||||
|
||||
async def change_reply_willing_received(self,
|
||||
chat_stream:ChatStream,
|
||||
topic: str = None,
|
||||
is_mentioned_bot: bool = False,
|
||||
config = None,
|
||||
is_emoji: bool = False,
|
||||
interested_rate: float = 0) -> float:
|
||||
|
||||
async def change_reply_willing_received(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
topic: str = None,
|
||||
is_mentioned_bot: bool = False,
|
||||
config=None,
|
||||
is_emoji: bool = False,
|
||||
interested_rate: float = 0,
|
||||
) -> float:
|
||||
"""改变指定聊天流的回复意愿并返回回复概率"""
|
||||
# 获取或创建聊天流
|
||||
stream = chat_stream
|
||||
chat_id = stream.stream_id
|
||||
|
||||
|
||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||
|
||||
# print(f"初始意愿: {current_willing}")
|
||||
|
||||
if is_mentioned_bot and current_willing < 1.0:
|
||||
current_willing += 0.9
|
||||
print(f"被提及, 当前意愿: {current_willing}")
|
||||
logger.debug(f"被提及, 当前意愿: {current_willing}")
|
||||
elif is_mentioned_bot:
|
||||
current_willing += 0.05
|
||||
print(f"被重复提及, 当前意愿: {current_willing}")
|
||||
|
||||
logger.debug(f"被重复提及, 当前意愿: {current_willing}")
|
||||
|
||||
if is_emoji:
|
||||
current_willing *= 0.1
|
||||
print(f"表情包, 当前意愿: {current_willing}")
|
||||
|
||||
print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
|
||||
interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
|
||||
logger.debug(f"表情包, 当前意愿: {current_willing}")
|
||||
|
||||
logger.debug(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
|
||||
interested_rate *= global_config.response_interested_rate_amplifier # 放大回复兴趣度
|
||||
if interested_rate > 0.4:
|
||||
# print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
||||
current_willing += interested_rate-0.4
|
||||
|
||||
current_willing *= global_config.response_willing_amplifier #放大回复意愿
|
||||
current_willing += interested_rate - 0.4
|
||||
|
||||
current_willing *= global_config.response_willing_amplifier # 放大回复意愿
|
||||
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
|
||||
|
||||
|
||||
reply_probability = max((current_willing - 0.45) * 2, 0)
|
||||
|
||||
|
||||
# 检查群组权限(如果是群聊)
|
||||
if chat_stream.group_info:
|
||||
if chat_stream.group_info:
|
||||
if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
|
||||
reply_probability = reply_probability / global_config.down_frequency_rate
|
||||
|
||||
reply_probability = min(reply_probability, 1)
|
||||
if reply_probability < 0:
|
||||
reply_probability = 0
|
||||
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
return reply_probability
|
||||
|
||||
def change_reply_willing_sent(self, chat_stream:ChatStream):
|
||||
|
||||
def change_reply_willing_sent(self, chat_stream: ChatStream):
|
||||
"""开始思考后降低聊天流的回复意愿"""
|
||||
stream = chat_stream
|
||||
if stream:
|
||||
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
|
||||
|
||||
def change_reply_willing_after_sent(self,chat_stream:ChatStream):
|
||||
|
||||
def change_reply_willing_after_sent(self, chat_stream: ChatStream):
|
||||
"""发送消息后提高聊天流的回复意愿"""
|
||||
stream = chat_stream
|
||||
if stream:
|
||||
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
if current_willing < 1:
|
||||
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
|
||||
|
||||
|
||||
async def ensure_started(self):
|
||||
"""确保衰减任务已启动"""
|
||||
if not self._started:
|
||||
@@ -107,5 +104,6 @@ class WillingManager:
|
||||
self._decay_task = asyncio.create_task(self._decay_reply_willing())
|
||||
self._started = True
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
willing_manager = WillingManager()
|
||||
willing_manager = WillingManager()
|
||||
|
||||
@@ -349,7 +349,7 @@ class Hippocampus:
|
||||
def sync_memory_to_db(self):
|
||||
"""检查并同步内存中的图结构与数据库"""
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
|
||||
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
@@ -377,7 +377,7 @@ class Hippocampus:
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
|
||||
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
@@ -385,7 +385,7 @@ class Hippocampus:
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
@@ -396,7 +396,7 @@ class Hippocampus:
|
||||
)
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
|
||||
db_edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -428,11 +428,11 @@ class Hippocampus:
|
||||
'created_time': created_time,
|
||||
'last_modified': last_modified
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
|
||||
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'hash': edge_hash,
|
||||
@@ -451,7 +451,7 @@ class Hippocampus:
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
|
||||
nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
@@ -468,7 +468,7 @@ class Hippocampus:
|
||||
if 'last_modified' not in node:
|
||||
update_data['last_modified'] = current_time
|
||||
|
||||
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': update_data}
|
||||
)
|
||||
@@ -485,7 +485,7 @@ class Hippocampus:
|
||||
last_modified=last_modified)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(self.memory_graph.db.db.graph_data.edges.find())
|
||||
edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
@@ -501,7 +501,7 @@ class Hippocampus:
|
||||
if 'last_modified' not in edge:
|
||||
update_data['last_modified'] = current_time
|
||||
|
||||
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': update_data}
|
||||
)
|
||||
|
||||
@@ -56,13 +56,13 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
list: 消息记录字典列表,每个字典包含消息内容和时间信息
|
||||
"""
|
||||
chat_records = []
|
||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
|
||||
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id']
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
records = list(db.db.messages.find(
|
||||
records = list(db.messages.find(
|
||||
{"time": {"$gt": closest_time}, "group_id": group_id}
|
||||
).sort('time', 1).limit(length))
|
||||
|
||||
@@ -74,7 +74,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
return ''
|
||||
|
||||
# 更新memorized值
|
||||
db.db.messages.update_one(
|
||||
db.messages.update_one(
|
||||
{"_id": record["_id"]},
|
||||
{"$set": {"memorized": current_memorized + 1}}
|
||||
)
|
||||
@@ -323,7 +323,7 @@ class Hippocampus:
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = self.memory_graph.db.db.graph_data.nodes.find()
|
||||
nodes = self.memory_graph.db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
@@ -334,7 +334,7 @@ class Hippocampus:
|
||||
self.memory_graph.G.add_node(concept, memory_items=memory_items)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = self.memory_graph.db.db.graph_data.edges.find()
|
||||
edges = self.memory_graph.db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
@@ -371,7 +371,7 @@ class Hippocampus:
|
||||
使用特征值(哈希值)快速判断是否需要更新
|
||||
"""
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
|
||||
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
@@ -394,7 +394,7 @@ class Hippocampus:
|
||||
'memory_items': memory_items,
|
||||
'hash': memory_hash
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
|
||||
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
@@ -403,7 +403,7 @@ class Hippocampus:
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
# logger.info(f"更新节点内容: {concept}")
|
||||
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
@@ -416,10 +416,10 @@ class Hippocampus:
|
||||
for db_node in db_nodes:
|
||||
if db_node['concept'] not in memory_concepts:
|
||||
# logger.info(f"删除多余节点: {db_node['concept']}")
|
||||
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
|
||||
db_edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges())
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -445,12 +445,12 @@ class Hippocampus:
|
||||
'num': 1,
|
||||
'hash': edge_hash
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
|
||||
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
logger.info(f"更新边: {source} - {target}")
|
||||
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {'hash': edge_hash}}
|
||||
)
|
||||
@@ -461,7 +461,7 @@ class Hippocampus:
|
||||
if edge_key not in memory_edge_set:
|
||||
source, target = edge_key
|
||||
logger.info(f"删除多余边: {source} - {target}")
|
||||
self.memory_graph.db.db.graph_data.edges.delete_one({
|
||||
self.memory_graph.db.graph_data.edges.delete_one({
|
||||
'source': source,
|
||||
'target': target
|
||||
})
|
||||
@@ -487,9 +487,9 @@ class Hippocampus:
|
||||
topic: 要删除的节点概念
|
||||
"""
|
||||
# 删除节点
|
||||
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': topic})
|
||||
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic})
|
||||
# 删除所有涉及该节点的边
|
||||
self.memory_graph.db.db.graph_data.edges.delete_many({
|
||||
self.memory_graph.db.graph_data.edges.delete_many({
|
||||
'$or': [
|
||||
{'source': topic},
|
||||
{'target': topic}
|
||||
|
||||
@@ -115,13 +115,13 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
list: 消息记录字典列表,每个字典包含消息内容和时间信息
|
||||
"""
|
||||
chat_records = []
|
||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
|
||||
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id']
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
records = list(db.db.messages.find(
|
||||
records = list(db.messages.find(
|
||||
{"time": {"$gt": closest_time}, "group_id": group_id}
|
||||
).sort('time', 1).limit(length))
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
return ''
|
||||
|
||||
# 更新memorized值
|
||||
db.db.messages.update_one(
|
||||
db.messages.update_one(
|
||||
{"_id": record["_id"]},
|
||||
{"$set": {"memorized": current_memorized + 1}}
|
||||
)
|
||||
@@ -163,7 +163,7 @@ class Memory_cortex:
|
||||
default_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = self.memory_graph.db.db.graph_data.nodes.find()
|
||||
nodes = self.memory_graph.db.graph_data.nodes.find()
|
||||
for node in nodes:
|
||||
concept = node['concept']
|
||||
memory_items = node.get('memory_items', [])
|
||||
@@ -180,7 +180,7 @@ class Memory_cortex:
|
||||
created_time = default_time
|
||||
last_modified = default_time
|
||||
# 更新数据库中的节点
|
||||
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'created_time': created_time,
|
||||
@@ -196,7 +196,7 @@ class Memory_cortex:
|
||||
last_modified=last_modified)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = self.memory_graph.db.db.graph_data.edges.find()
|
||||
edges = self.memory_graph.db.graph_data.edges.find()
|
||||
for edge in edges:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
@@ -212,7 +212,7 @@ class Memory_cortex:
|
||||
created_time = default_time
|
||||
last_modified = default_time
|
||||
# 更新数据库中的边
|
||||
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'created_time': created_time,
|
||||
@@ -256,7 +256,7 @@ class Memory_cortex:
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
|
||||
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 转换数据库节点为字典格式,方便查找
|
||||
@@ -280,7 +280,7 @@ class Memory_cortex:
|
||||
'created_time': data.get('created_time', current_time),
|
||||
'last_modified': data.get('last_modified', current_time)
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.nodes.insert_one(node_data)
|
||||
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
|
||||
else:
|
||||
# 获取数据库中节点的特征值
|
||||
db_node = db_nodes_dict[concept]
|
||||
@@ -288,7 +288,7 @@ class Memory_cortex:
|
||||
|
||||
# 如果特征值不同,则更新节点
|
||||
if db_hash != memory_hash:
|
||||
self.memory_graph.db.db.graph_data.nodes.update_one(
|
||||
self.memory_graph.db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': {
|
||||
'memory_items': memory_items,
|
||||
@@ -301,10 +301,10 @@ class Memory_cortex:
|
||||
memory_concepts = set(node[0] for node in memory_nodes)
|
||||
for db_node in db_nodes:
|
||||
if db_node['concept'] not in memory_concepts:
|
||||
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
|
||||
db_edges = list(self.memory_graph.db.graph_data.edges.find())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -332,11 +332,11 @@ class Memory_cortex:
|
||||
'created_time': data.get('created_time', current_time),
|
||||
'last_modified': data.get('last_modified', current_time)
|
||||
}
|
||||
self.memory_graph.db.db.graph_data.edges.insert_one(edge_data)
|
||||
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
|
||||
else:
|
||||
# 检查边的特征值是否变化
|
||||
if db_edge_dict[edge_key]['hash'] != edge_hash:
|
||||
self.memory_graph.db.db.graph_data.edges.update_one(
|
||||
self.memory_graph.db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': {
|
||||
'hash': edge_hash,
|
||||
@@ -350,7 +350,7 @@ class Memory_cortex:
|
||||
for edge_key in db_edge_dict:
|
||||
if edge_key not in memory_edge_set:
|
||||
source, target = edge_key
|
||||
self.memory_graph.db.db.graph_data.edges.delete_one({
|
||||
self.memory_graph.db.graph_data.edges.delete_one({
|
||||
'source': source,
|
||||
'target': target
|
||||
})
|
||||
@@ -365,9 +365,9 @@ class Memory_cortex:
|
||||
topic: 要删除的节点概念
|
||||
"""
|
||||
# 删除节点
|
||||
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': topic})
|
||||
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic})
|
||||
# 删除所有涉及该节点的边
|
||||
self.memory_graph.db.db.graph_data.edges.delete_many({
|
||||
self.memory_graph.db.graph_data.edges.delete_many({
|
||||
'$or': [
|
||||
{'source': topic},
|
||||
{'target': topic}
|
||||
|
||||
@@ -235,7 +235,7 @@ class LLM_request:
|
||||
delta_content = ""
|
||||
accumulated_content += delta_content
|
||||
# 检测流式输出文本是否结束
|
||||
finish_reason = chunk["choices"][0]["finish_reason"]
|
||||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||
if finish_reason == "stop":
|
||||
usage = chunk.get("usage", None)
|
||||
if usage:
|
||||
|
||||
@@ -13,6 +13,8 @@ from pathlib import Path
|
||||
import jieba
|
||||
from pypinyin import Style, pinyin
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
def __init__(self,
|
||||
@@ -38,7 +40,9 @@ class ChineseTypoGenerator:
|
||||
self.max_freq_diff = max_freq_diff
|
||||
|
||||
# 加载数据
|
||||
print("正在加载汉字数据库,请稍候...")
|
||||
# print("正在加载汉字数据库,请稍候...")
|
||||
logger.info("正在加载汉字数据库,请稍候...")
|
||||
|
||||
self.pinyin_dict = self._create_pinyin_dict()
|
||||
self.char_frequency = self._load_or_create_char_frequency()
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class KnowledgeLibrary:
|
||||
|
||||
try:
|
||||
current_hash = self.calculate_file_hash(file_path)
|
||||
processed_record = self.db.db.processed_files.find_one({"file_path": file_path})
|
||||
processed_record = self.db.processed_files.find_one({"file_path": file_path})
|
||||
|
||||
if processed_record:
|
||||
if processed_record.get("hash") == current_hash:
|
||||
@@ -197,14 +197,14 @@ class KnowledgeLibrary:
|
||||
"split_length": knowledge_length,
|
||||
"created_at": datetime.now()
|
||||
}
|
||||
self.db.db.knowledges.insert_one(knowledge)
|
||||
self.db.knowledges.insert_one(knowledge)
|
||||
result["chunks_processed"] += 1
|
||||
|
||||
split_by = processed_record.get("split_by", []) if processed_record else []
|
||||
if knowledge_length not in split_by:
|
||||
split_by.append(knowledge_length)
|
||||
|
||||
self.db.db.processed_files.update_one(
|
||||
self.db.knowledges.processed_files.update_one(
|
||||
{"file_path": file_path},
|
||||
{
|
||||
"$set": {
|
||||
@@ -322,7 +322,7 @@ class KnowledgeLibrary:
|
||||
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
|
||||
]
|
||||
|
||||
results = list(self.db.db.knowledges.aggregate(pipeline))
|
||||
results = list(self.db.knowledges.aggregate(pipeline))
|
||||
return results
|
||||
|
||||
# 创建单例实例
|
||||
@@ -346,7 +346,7 @@ if __name__ == "__main__":
|
||||
elif choice == '2':
|
||||
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
|
||||
if confirm == 'y':
|
||||
knowledge_library.db.db.knowledges.delete_many({})
|
||||
knowledge_library.db.knowledges.delete_many({})
|
||||
console.print("[green]已清空所有知识![/green]")
|
||||
continue
|
||||
elif choice == '1':
|
||||
|
||||
Reference in New Issue
Block a user