Merge branch 'debug' into feature

This commit is contained in:
SengokuCola
2025-03-03 22:43:44 +08:00
committed by GitHub
10 changed files with 677 additions and 570 deletions

View File

@@ -97,8 +97,13 @@ class ChatBot:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 使用jieba主题: {topic1}")
print(f"\033[1;32m[主题识别]\033[0m 使用llm主题: {topic2}")
print(f"\033[1;32m[主题识别]\033[0m 使用snownlp主题: {topic3}")
topic = topic3
all_num = 0
interested_num = 0

View File

@@ -4,19 +4,18 @@ from .message import Message
import jieba
from nonebot import get_driver
from .config import global_config
from snownlp import SnowNLP
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
class TopicIdentifier:
def __init__(self):
self.client = OpenAI(
api_key=config.siliconflow_key,
base_url=config.siliconflow_base_url
)
self.llm_client = LLM_request(model=global_config.llm_normal)
def identify_topic_llm(self, text: str) -> Optional[str]:
"""识别消息主题"""
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表"""
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:
1. 主题通常2-4个字必须简短要求精准概括不要太具体。
@@ -24,33 +23,20 @@ class TopicIdentifier:
消息内容:{text}"""
response = self.client.chat.completions.create(
model=global_config.SILICONFLOW_MODEL_V3,
messages=[{"role": "user", "content": prompt}],
temperature=0.8,
max_tokens=10
)
# 使用 LLM_request 类进行请求
topic, _ = await self.llm_client.generate_response(prompt)
if not response or not response.choices:
print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空")
if not topic:
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
return None
# 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符
topic = response.choices[0].message.content.strip() if response.choices[0].message.content else None
if topic == "无主题":
return None
else:
# print(f"[主题分析结果]{text[:20]}... : {topic}")
split_topic = self.parse_topic(topic)
return split_topic
def parse_topic(self, topic: str) -> List[str]:
"""解析主题,返回主题列表"""
# 直接在这里处理主题解析
if not topic or topic == "无主题":
return []
return [t.strip() for t in topic.split(",") if t.strip()]
return None
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
return topic_list if topic_list else None
def identify_topic_jieba(self, text: str) -> Optional[str]:
"""使用jieba识别主题"""
@@ -80,9 +66,12 @@ class TopicIdentifier:
filtered_words = []
for word in words:
if word not in stop_words and not word.strip() in {
'', '', '', '', '', '', '', '"', '"', ''', ''',
'', '', '', '', '', '', '', '', '·', '', '~',
'', '+', '=', '-','[',']'
'', '', '', '', '', '', '', '"', '"', ''', ''',
'', '', '', '', '', '', '', '', '·', '', '~',
'', '+', '=', '-', '/', '\\', '|', '*', '#', '@', '$', '%',
'^', '&', '[', ']', '{', '}', '<', '>', '`', '_', '.', ',',
';', ':', '\'', '"', '(', ')', '?', '!', '±', '×', '÷', '',
'', '', '', '', '', '', '', '', '', '', ''
}:
filtered_words.append(word)
@@ -97,4 +86,25 @@ class TopicIdentifier:
return top_words if top_words else None
topic_identifier = TopicIdentifier()
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
"""使用 SnowNLP 进行主题识别
Args:
text (str): 需要识别主题的文本
Returns:
Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None
"""
if not text or len(text.strip()) == 0:
return None
try:
s = SnowNLP(text)
# 提取前3个关键词作为主题
keywords = s.keywords(3)
return keywords if keywords else None
except Exception as e:
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")
return None
topic_identifier = TopicIdentifier()

View File

@@ -2,7 +2,6 @@
import os
import sys
import jieba
from llm_module import LLMModel
import networkx as nx
import matplotlib.pyplot as plt
import math
@@ -10,10 +9,76 @@ from collections import Counter
import datetime
import random
import time
# from chat.config import global_config
from dotenv import load_dotenv
import sys
import asyncio
import aiohttp
from typing import Tuple
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
class LLMModel:
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
class Memory_graph:
def __init__(self):
@@ -158,12 +223,12 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
)
memory_graph = Memory_graph()
@@ -185,11 +250,14 @@ def main():
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
print("\n第一层记忆:")
for item in first_layer_items:
print(item)
print("\n第二层记忆:")
for item in second_layer_items:
print(item)
else:
print("未找到相关记忆。")