Merge remote-tracking branch 'upstream/debug' into feature

This commit is contained in:
tcmofashi
2025-03-04 08:18:22 +08:00
34 changed files with 13732 additions and 413 deletions

View File

@@ -17,12 +17,12 @@ driver = get_driver()
config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source= config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source= config.MONGODB_AUTH_SOURCE
)
print("\033[1;32m[初始化数据库完成]\033[0m")

View File

@@ -97,8 +97,13 @@ class ChatBot:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 使用jieba主题: {topic1}")
print(f"\033[1;32m[主题识别]\033[0m 使用llm主题: {topic2}")
print(f"\033[1;32m[主题识别]\033[0m 使用snownlp主题: {topic3}")
topic = topic3
all_num = 0
interested_num = 0
@@ -166,7 +171,6 @@ class ChatBot:
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=think_id,
message_based_id=event.message_id,
raw_message=msg,
plain_text=msg,
processed_plain_text=msg,

View File

@@ -116,6 +116,9 @@ class BotConfig:
if "vlm" in model_config:
config.vlm = model_config["vlm"]
if "embedding" in model_config:
config.embedding = model_config["embedding"]
# 消息配置
if "message" in toml_dict:
@@ -138,7 +141,7 @@ class BotConfig:
if "others" in toml_dict:
others_config = toml_dict["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
logger.success(f"成功加载配置文件: {config_path}")
@@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
if not os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
logger.info("使用默认配置文件")
logger.info("使用bot配置文件")
else:
logger.info("已找到开发环境配置文件")
logger.info("已找到开发bot配置文件")
global_config = BotConfig.load_config(config_path=bot_config_path)
@dataclass
class LLMConfig:
"""机器人配置类"""
# 基础配置
SILICONFLOW_API_KEY: str = None
SILICONFLOW_BASE_URL: str = None
DEEP_SEEK_API_KEY: str = None
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
config = get_driver().config
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output:
# logger.remove()
pass

View File

@@ -8,7 +8,7 @@ from ...common.database import Database
from PIL import Image
from .config import global_config
import urllib3
from .utils_user import get_user_nickname,get_user_cardname
from .utils_user import get_user_nickname,get_user_cardname,get_groupname
from .utils_cq import parse_cq_code
from .cq_code import cq_code_tool,CQCode
@@ -21,50 +21,47 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message:
"""消息数据类"""
message_id: int = None
time: float = None
group_id: int = None
group_name: str = None # 群名称
user_id: int = None
user_nickname: str = None # 用户昵称
user_cardname: str=None # 用户群昵称
group_name: str = None # 群名称
message_id: int = None
raw_message: str = None
plain_text: str = None
message_based_id: int = None
reply_message: Dict = None # 存储回复消息
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
time: float = None
reply_message: Dict = None # 存储 回复的 源消息
is_emoji: bool = False # 是否是表情包
has_emoji: bool = False # 是否包含表情包
translate_cq: bool = True # 是否翻译cq码
reply_benefits: float = 0.0
type: str = 'received' # 消息类型可以是received或者send
def __post_init__(self):
if self.time is None:
self.time = int(time.time())
if not self.group_name:
self.group_name = get_groupname(self.group_id)
if not self.user_nickname:
self.user_nickname = get_user_nickname(self.user_id)
if not self.user_cardname:
self.user_cardname=get_user_cardname(self.user_id)
if not self.group_name:
self.group_name = self.get_groupname(self.group_id)
if not self.processed_plain_text:
if self.raw_message:
self.message_segments = self.parse_message_segments(str(self.raw_message))
@@ -244,6 +241,38 @@ class MessageSet:
return len(self.messages)
@dataclass
class Message_Sending(Message):
"""发送消息数据类继承自Message类"""
priority: int = 0 # 发送优先级,数字越大优先级越高
wait_until: float = None # 等待发送的时间戳
continue_thinking: bool = False # 是否继续思考
def __post_init__(self):
super().__post_init__()
if self.wait_until is None:
self.wait_until = self.time
@property
def can_send(self) -> bool:
"""检查是否可以发送消息"""
return time.time() >= self.wait_until
def set_wait_time(self, seconds: float) -> None:
"""设置等待发送时间"""
self.wait_until = time.time() + seconds
def set_priority(self, priority: int) -> None:
"""设置发送优先级"""
self.priority = priority
def __lt__(self, other):
"""重写小于比较,用于优先级排序"""
if not isinstance(other, Message_Sending):
return NotImplemented
return (self.priority, -self.wait_until) < (other.priority, -other.wait_until)

View File

@@ -201,7 +201,7 @@ class MessageSendControl:
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
cost_time = round(time.time(), 2) - message.time
if cost_time > 40:
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_based_id) + message.processed_plain_text
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_id) + message.processed_plain_text
cur_time = time.time()
await self._current_bot.send_group_msg(
group_id=group_id,

View File

View File

@@ -127,15 +127,15 @@ class MessageStream:
# 从数据库中查询最近的消息
recent_messages = list(db.db.messages.find(
{"group_id": self.group_id},
{
"time": 1,
"user_id": 1,
"user_nickname": 1,
# "user_cardname": 1,
"message_id": 1,
"raw_message": 1,
"processed_text": 1
}
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# # "user_cardname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(count))
if not recent_messages:
@@ -145,17 +145,21 @@ class MessageStream:
from .message import Message
messages = []
for msg_data in recent_messages:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
user_cardname=msg_data.get("user_cardname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=self.group_id
)
messages.append(msg)
try:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
user_cardname=msg_data.get("user_cardname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=self.group_id
)
messages.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
continue
return list(reversed(messages)) # 返回按时间正序的消息

View File

@@ -118,7 +118,7 @@ class PromptBuilder:
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
promt_info_prompt = '你有一些[知识],在上面可以参考。'
# promt_info_prompt = '你有一些[知识],在上面可以参考。'
end_time = time.time()
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}")

View File

@@ -0,0 +1,14 @@
#Broca's Area
# 功能:语言产生、语法处理和言语运动控制。
# 损伤后果:布洛卡失语症(表达困难,但理解保留)。
import time
class Thinking_Idea:
def __init__(self, message_id: str):
self.messages = [] # 消息列表集合
self.current_thoughts = [] # 当前思考内容列表
self.time = time.time() # 创建时间
self.id = str(int(time.time() * 1000)) # 使用时间戳生成唯一标识ID

View File

@@ -4,6 +4,8 @@ from .message import Message
import jieba
from nonebot import get_driver
from .config import global_config
from snownlp import SnowNLP
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
@@ -11,12 +13,10 @@ config = driver.config
class TopicIdentifier:
def __init__(self):
self.client = OpenAI(
api_key=config.siliconflow_key, base_url=config.siliconflow_base_url
)
def identify_topic_llm(self, text: str) -> Optional[str]:
"""识别消息主题"""
self.llm_client = LLM_request(model=global_config.llm_normal)
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表"""
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:\
1. 主题通常2-4个字必须简短要求精准概括不要太具体。\
@@ -24,36 +24,20 @@ class TopicIdentifier:
3. 这里是
消息内容:{text}"""
response = self.client.chat.completions.create(
model=global_config.SILICONFLOW_MODEL_V3,
messages=[{"role": "user", "content": prompt}],
temperature=0.8,
max_tokens=10,
)
if not response or not response.choices:
print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空")
# 使用 LLM_request 类进行请求
topic, _ = await self.llm_client.generate_response(prompt)
if not topic:
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
return None
# 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符
topic = (
response.choices[0].message.content.strip()
if response.choices[0].message.content
else None
)
if topic == "无主题":
return None
else:
# print(f"[主题分析结果]{text[:20]}... : {topic}")
split_topic = self.parse_topic(topic)
return split_topic
def parse_topic(self, topic: str) -> List[str]:
"""解析主题,返回主题列表"""
# 直接在这里处理主题解析
if not topic or topic == "无主题":
return []
return [t.strip() for t in topic.split(",") if t.strip()]
return None
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
return topic_list if topic_list else None
def identify_topic_jieba(self, text: str) -> Optional[str]:
"""使用jieba识别主题"""
@@ -239,33 +223,12 @@ class TopicIdentifier:
filtered_words = []
for word in words:
if word not in stop_words and not word.strip() in {
"",
"",
"",
"",
"",
"",
"",
'"',
'"',
""", """,
"",
"",
"",
"",
"",
"",
"",
"",
"·",
"",
"~",
"",
"+",
"=",
"-",
"[",
"]",
'', '', '', '', '', '', '', '"', '"', ''', ''',
'', '', '', '', '', '', '', '', '·', '', '~',
'', '+', '=', '-', '/', '\\', '|', '*', '#', '@', '$', '%',
'^', '&', '[', ']', '{', '}', '<', '>', '`', '_', '.', ',',
';', ':', '\'', '"', '(', ')', '?', '!', '±', '×', '÷', '',
'', '', '', '', '', '', '', '', '', '', ''
}:
filtered_words.append(word)
@@ -280,5 +243,25 @@ class TopicIdentifier:
return top_words if top_words else None
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
"""使用 SnowNLP 进行主题识别
Args:
text (str): 需要识别主题的文本
Returns:
Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None
"""
if not text or len(text.strip()) == 0:
return None
try:
s = SnowNLP(text)
# 提取前3个关键词作为主题
keywords = s.keywords(3)
return keywords if keywords else None
except Exception as e:
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")
return None
topic_identifier = TopicIdentifier()

View File

@@ -10,6 +10,7 @@ from typing import Dict
from collections import Counter
import math
from nonebot import get_driver
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
@@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
return False
def get_embedding(text):
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json"
}
response = requests.request("POST", url, json=payload, headers=headers)
if response.status_code != 200:
print(f"API请求失败: {response.status_code}")
print(f"错误信息: {response.text}")
return None
return response.json()['data'][0]['embedding']
"""获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding)
return llm.get_embedding_sync(text)
def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
@@ -142,14 +127,14 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{
"time": 1,
"user_id": 1,
"user_nickname": 1,
"message_id": 1,
"raw_message": 1,
"processed_text": 1
}
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(limit))
if not recent_messages:
@@ -159,16 +144,20 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
from .message import Message
message_objects = []
for msg_data in recent_messages:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
message_objects.append(msg)
try:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
continue
# 按时间正序排列
message_objects.reverse()
@@ -181,7 +170,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
"time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段
"user_nickname": 1, # 返回用户昵称字段
"user_cardname": 1, #返回用户群昵称
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段
}
@@ -193,6 +181,8 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
message_detailed_plain_text = ''
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
for msg_db_data in recent_messages:

View File

@@ -6,8 +6,12 @@ def get_user_nickname(user_id: int) -> str:
return global_config.BOT_NICKNAME
# print(user_id)
return relationship_manager.get_name(user_id)
def get_user_cardname(user_id: int) -> str:
if int(user_id) == int(global_config.BOT_QQ):
return global_config.BOT_NICKNAME
# print(user_id)
return ''
return ''
def get_groupname(group_id: int) -> str:
return f"{group_id}"

View File

@@ -17,12 +17,12 @@ from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class KnowledgeLibrary:

View File

@@ -2,7 +2,6 @@
import os
import sys
import jieba
from llm_module import LLMModel
import networkx as nx
import matplotlib.pyplot as plt
import math
@@ -10,10 +9,76 @@ from collections import Counter
import datetime
import random
import time
# from chat.config import global_config
from dotenv import load_dotenv
import sys
import asyncio
import aiohttp
from typing import Tuple
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
class LLMModel:
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
class Memory_graph:
def __init__(self):
@@ -158,12 +223,12 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
)
memory_graph = Memory_graph()
@@ -185,11 +250,14 @@ def main():
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
print("\n第一层记忆:")
for item in first_layer_items:
print(item)
print("\n第二层记忆:")
for item in second_layer_items:
print(item)
else:
print("未找到相关记忆。")

View File

@@ -66,7 +66,7 @@ class LLMModel:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""

View File

@@ -259,12 +259,12 @@ config = driver.config
start_time = time.time()
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= config.MONGODB_PORT,
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
#创建记忆图
memory_graph = Memory_graph()

View File

@@ -9,7 +9,7 @@ driver = get_driver()
config = driver.config
class LLM_request:
def __init__(self, model = global_config.llm_normal,**kwargs):
def __init__(self, model ,**kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = getattr(config, model["key"])
@@ -61,7 +61,7 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
@@ -126,7 +126,7 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
@@ -166,8 +166,8 @@ class LLM_request:
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
@@ -191,9 +191,119 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""同步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry)
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status()
result = response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
print(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
print(f"embedding请求失败: {str(e)}")
return None
print("达到最大重试次数embedding请求仍然失败")
return None
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""异步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry)
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status()
result = await response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
print(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
print(f"embedding请求失败: {str(e)}")
return None
print("达到最大重试次数embedding请求仍然失败")
return None

View File

@@ -11,12 +11,12 @@ config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class ScheduleGenerator:
@@ -128,6 +128,10 @@ class ScheduleGenerator:
def _time_diff(self, time1: str, time2: str) -> int:
"""计算两个时间字符串之间的分钟差"""
if time1=="24:00":
time1="23:59"
if time2=="24:00":
time2="23:59"
t1 = datetime.datetime.strptime(time1, "%H:%M")
t2 = datetime.datetime.strptime(time2, "%H:%M")
diff = int((t2 - t1).total_seconds() / 60)
@@ -165,4 +169,4 @@ class ScheduleGenerator:
# if __name__ == "__main__":
# main()
bot_schedule = ScheduleGenerator()
bot_schedule = ScheduleGenerator()