Merge branch 'debug' into feature
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
import os
|
||||
import sys
|
||||
import jieba
|
||||
from llm_module import LLMModel
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
@@ -10,10 +9,76 @@ from collections import Counter
|
||||
import datetime
|
||||
import random
|
||||
import time
|
||||
# from chat.config import global_config
|
||||
from dotenv import load_dotenv
|
||||
import sys
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from typing import Tuple
|
||||
|
||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||
from src.common.database import Database # 使用正确的导入语法
|
||||
|
||||
# 加载.env.dev文件
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||
load_dotenv(env_path)
|
||||
|
||||
class LLMModel:
|
||||
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||
"""根据输入的提示生成模型的响应"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
**self.params
|
||||
}
|
||||
|
||||
# 发送请求到完整的chat/completions端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status == 429:
|
||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
@@ -158,12 +223,12 @@ class Memory_graph:
|
||||
def main():
|
||||
# 初始化数据库
|
||||
Database.initialize(
|
||||
host= os.getenv("MONGODB_HOST"),
|
||||
port= int(os.getenv("MONGODB_PORT")),
|
||||
db_name= os.getenv("DATABASE_NAME"),
|
||||
username= os.getenv("MONGODB_USERNAME"),
|
||||
password= os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||
username=os.getenv("MONGODB_USERNAME", ""),
|
||||
password=os.getenv("MONGODB_PASSWORD", ""),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
|
||||
)
|
||||
|
||||
memory_graph = Memory_graph()
|
||||
@@ -185,11 +250,14 @@ def main():
|
||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
items_list = memory_graph.get_related_item(query)
|
||||
if items_list:
|
||||
# print(items_list)
|
||||
for memory_item in items_list:
|
||||
print(memory_item)
|
||||
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
|
||||
if first_layer_items or second_layer_items:
|
||||
print("\n第一层记忆:")
|
||||
for item in first_layer_items:
|
||||
print(item)
|
||||
print("\n第二层记忆:")
|
||||
for item in second_layer_items:
|
||||
print(item)
|
||||
else:
|
||||
print("未找到相关记忆。")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user