改各种小问题

This commit is contained in:
春河晴
2025-04-16 17:37:28 +09:00
parent a0b1b1f8d8
commit dc2cf843e5
36 changed files with 114 additions and 107 deletions

View File

@@ -1,4 +1,4 @@
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLMRequest
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
from src.plugins.chat.chat_stream import ChatStream from src.plugins.chat.chat_stream import ChatStream
from src.common.database import db from src.common.database import db
@@ -18,7 +18,7 @@ logger = get_module_logger("tool_use", config=tool_use_config)
class ToolUser: class ToolUser:
def __init__(self): def __init__(self):
self.llm_model_tool = LLM_request( self.llm_model_tool = LLMRequest(
model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use" model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
) )
@@ -107,7 +107,7 @@ class ToolUser:
return None return None
async def use_tool( async def use_tool(
self, message_txt: str, sender_name: str, chat_stream: ChatStream, subheartflow: SubHeartflow = None self, message_txt: str, sender_name: str, chat_stream: ChatStream, sub_heartflow: SubHeartflow = None
): ):
"""使用工具辅助思考,判断是否需要额外信息 """使用工具辅助思考,判断是否需要额外信息
@@ -115,13 +115,14 @@ class ToolUser:
message_txt: 用户消息文本 message_txt: 用户消息文本
sender_name: 发送者名称 sender_name: 发送者名称
chat_stream: 聊天流对象 chat_stream: 聊天流对象
sub_heartflow: 子心流对象(可选)
Returns: Returns:
dict: 工具使用结果,包含结构化的信息 dict: 工具使用结果,包含结构化的信息
""" """
try: try:
# 构建提示词 # 构建提示词
prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream, subheartflow) prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream, sub_heartflow)
# 定义可用工具 # 定义可用工具
tools = self._define_tools() tools = self._define_tools()

View File

@@ -1,7 +1,7 @@
from .sub_heartflow import SubHeartflow from .sub_heartflow import SubHeartflow
from .observation import ChattingObservation from .observation import ChattingObservation
from src.plugins.moods.moods import MoodManager from src.plugins.moods.moods import MoodManager
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLMRequest
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
from src.plugins.schedule.schedule_generator import bot_schedule from src.plugins.schedule.schedule_generator import bot_schedule
from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
@@ -60,7 +60,7 @@ class Heartflow:
self.current_mind = "你什么也没想" self.current_mind = "你什么也没想"
self.past_mind = [] self.past_mind = []
self.current_state: CurrentState = CurrentState() self.current_state: CurrentState = CurrentState()
self.llm_model = LLM_request( self.llm_model = LLMRequest(
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow" model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
) )

View File

@@ -1,7 +1,7 @@
# 定义了来自外部世界的信息 # 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体 # 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime from datetime import datetime
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLMRequest
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
from src.common.database import db from src.common.database import db
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
@@ -40,7 +40,7 @@ class ChattingObservation(Observation):
self.updating_old = False self.updating_old = False
self.llm_summary = LLM_request( self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation" model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
) )

View File

@@ -1,7 +1,7 @@
from .observation import Observation, ChattingObservation from .observation import Observation, ChattingObservation
import asyncio import asyncio
from src.plugins.moods.moods import MoodManager from src.plugins.moods.moods import MoodManager
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLMRequest
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
import time import time
from src.plugins.chat.message import UserInfo from src.plugins.chat.message import UserInfo
@@ -79,7 +79,7 @@ class SubHeartflow:
self.current_mind = "" self.current_mind = ""
self.past_mind = [] self.past_mind = []
self.current_state: CurrentState = CurrentState() self.current_state: CurrentState = CurrentState()
self.llm_model = LLM_request( self.llm_model = LLMRequest(
model=global_config.llm_sub_heartflow, model=global_config.llm_sub_heartflow,
temperature=global_config.llm_sub_heartflow["temp"], temperature=global_config.llm_sub_heartflow["temp"],
max_tokens=600, max_tokens=600,

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_module_logger
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")
class LLM_request_off: class LLMRequestOff:
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs

View File

@@ -19,7 +19,7 @@ with open(config_path, "r", encoding="utf-8") as f:
# 现在可以导入src模块 # 现在可以导入src模块
from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa E402 from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa E402
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402 from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
from src.individuality.offline_llm import LLM_request_off # noqa E402 from src.individuality.offline_llm import LLMRequestOff # noqa E402
# 加载环境变量 # 加载环境变量
env_path = os.path.join(root_path, ".env") env_path = os.path.join(root_path, ".env")
@@ -65,7 +65,7 @@ def adapt_scene(scene: str) -> str:
现在,请你给出改编后的场景描述 现在,请你给出改编后的场景描述
""" """
llm = LLM_request_off(model_name=config["model"]["llm_normal"]["name"]) llm = LLMRequestOff(model_name=config["model"]["llm_normal"]["name"])
adapted_scene, _ = llm.generate_response(prompt) adapted_scene, _ = llm.generate_response(prompt)
# 检查返回的场景是否为空或错误信息 # 检查返回的场景是否为空或错误信息
@@ -79,7 +79,7 @@ def adapt_scene(scene: str) -> str:
return scene return scene
class PersonalityEvaluator_direct: class PersonalityEvaluatorDirect:
def __init__(self): def __init__(self):
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.scenarios = [] self.scenarios = []
@@ -110,7 +110,7 @@ class PersonalityEvaluator_direct:
{"场景": scene["scenario"], "评估维度": [trait, secondary_trait], "场景编号": scene_key} {"场景": scene["scenario"], "评估维度": [trait, secondary_trait], "场景编号": scene_key}
) )
self.llm = LLM_request_off() self.llm = LLMRequestOff()
def evaluate_response(self, scenario: str, response: str, dimensions: List[str]) -> Dict[str, float]: def evaluate_response(self, scenario: str, response: str, dimensions: List[str]) -> Dict[str, float]:
""" """
@@ -269,7 +269,7 @@ class PersonalityEvaluator_direct:
def main(): def main():
evaluator = PersonalityEvaluator_direct() evaluator = PersonalityEvaluatorDirect()
result = evaluator.run_evaluation() result = evaluator.run_evaluation()
# 准备简化的结果数据 # 准备简化的结果数据

View File

@@ -9,7 +9,7 @@ from .plugins.willing.willing_manager import willing_manager
from .plugins.chat.chat_stream import chat_manager from .plugins.chat.chat_stream import chat_manager
from .heart_flow.heartflow import heartflow from .heart_flow.heartflow import heartflow
from .plugins.memory_system.Hippocampus import HippocampusManager from .plugins.memory_system.Hippocampus import HippocampusManager
from .plugins.chat.message_sender import message_manager from .plugins.chat.messagesender import message_manager
from .plugins.storage.storage import MessageStorage from .plugins.storage.storage import MessageStorage
from .plugins.config.config import global_config from .plugins.config.config import global_config
from .plugins.chat.bot import chat_bot from .plugins.chat.bot import chat_bot

View File

@@ -1,6 +1,6 @@
from typing import Tuple from typing import Tuple
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from ..config.config import global_config from ..config.config import global_config
from .chat_observer import ChatObserver from .chat_observer import ChatObserver
from .pfc_utils import get_items_from_json from .pfc_utils import get_items_from_json
@@ -23,7 +23,7 @@ class ActionPlanner:
"""行动规划器""" """行动规划器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLMRequest(
model=global_config.llm_normal, model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"], temperature=global_config.llm_normal["temp"],
max_tokens=1000, max_tokens=1000,

View File

@@ -4,7 +4,7 @@ from ..chat.chat_stream import ChatStream
from ..chat.message import Message from ..chat.message import Message
from ..message.message_base import Seg from ..message.message_base import Seg
from src.plugins.chat.message import MessageSending, MessageSet from src.plugins.chat.message import MessageSending, MessageSet
from src.plugins.chat.message_sender import message_manager from src.plugins.chat.messagesender import message_manager
logger = get_module_logger("message_sender") logger = get_module_logger("message_sender")

View File

@@ -120,6 +120,10 @@ class ObservationInfo:
# #spec # #spec
# meta_plan_trigger: bool = False # meta_plan_trigger: bool = False
def __init__(self):
self.last_message_id = None
self.chat_observer = None
def __post_init__(self): def __post_init__(self):
"""初始化后创建handler""" """初始化后创建handler"""
self.chat_observer = None self.chat_observer = None
@@ -129,7 +133,7 @@ class ObservationInfo:
"""绑定到指定的chat_observer """绑定到指定的chat_observer
Args: Args:
stream_id: 聊天流ID chat_observer: 要绑定的ChatObserver实例
""" """
self.chat_observer = chat_observer self.chat_observer = chat_observer
self.chat_observer.notification_manager.register_handler( self.chat_observer.notification_manager.register_handler(
@@ -171,7 +175,8 @@ class ObservationInfo:
self.last_bot_speak_time = message["time"] self.last_bot_speak_time = message["time"]
else: else:
self.last_user_speak_time = message["time"] self.last_user_speak_time = message["time"]
self.active_users.add(user_info.user_id) if user_info.user_id is not None:
self.active_users.add(str(user_info.user_id))
self.new_messages_count += 1 self.new_messages_count += 1
self.unprocessed_messages.append(message) self.unprocessed_messages.append(message)
@@ -227,7 +232,7 @@ class ObservationInfo:
"""清空未处理消息列表""" """清空未处理消息列表"""
# 将未处理消息添加到历史记录中 # 将未处理消息添加到历史记录中
for message in self.unprocessed_messages: for message in self.unprocessed_messages:
self.chat_history.append(message) self.chat_history.append(message) # TODO NEED FIX TYPE???
# 清空未处理消息列表 # 清空未处理消息列表
self.has_unread_messages = False self.has_unread_messages = False
self.unprocessed_messages.clear() self.unprocessed_messages.clear()

View File

@@ -8,7 +8,7 @@ from src.common.logger import get_module_logger
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
from ..message.message_base import UserInfo, Seg from ..message.message_base import UserInfo, Seg
from ..chat.message import Message from ..chat.message import Message
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from ..config.config import global_config from ..config.config import global_config
from src.plugins.chat.message import MessageSending from src.plugins.chat.message import MessageSending
from ..message.api import global_api from ..message.api import global_api
@@ -30,7 +30,7 @@ class GoalAnalyzer:
"""对话目标分析器""" """对话目标分析器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLMRequest(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal" model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
) )
@@ -350,7 +350,7 @@ class DirectMessageSender:
# logger.info(f"发送消息到{end_point}") # logger.info(f"发送消息到{end_point}")
# logger.info(message_json) # logger.info(message_json)
try: try:
await global_api.send_message_REST(end_point, message_json) await global_api.send_message_rest(end_point, message_json)
except Exception as e: except Exception as e:
logger.error(f"REST方式发送失败出现错误: {str(e)}") logger.error(f"REST方式发送失败出现错误: {str(e)}")
logger.info("尝试使用ws发送") logger.info("尝试使用ws发送")

View File

@@ -1,7 +1,7 @@
from typing import List, Tuple from typing import List, Tuple
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.memory_system.Hippocampus import HippocampusManager from src.plugins.memory_system.Hippocampus import HippocampusManager
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from ..config.config import global_config from ..config.config import global_config
from ..chat.message import Message from ..chat.message import Message
@@ -12,7 +12,7 @@ class KnowledgeFetcher:
"""知识调取器""" """知识调取器"""
def __init__(self): def __init__(self):
self.llm = LLM_request( self.llm = LLMRequest(
model=global_config.llm_normal, model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"], temperature=global_config.llm_normal["temp"],
max_tokens=1000, max_tokens=1000,

View File

@@ -2,7 +2,7 @@ import json
import datetime import datetime
from typing import Tuple from typing import Tuple
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from ..config.config import global_config from ..config.config import global_config
from .chat_observer import ChatObserver from .chat_observer import ChatObserver
from ..message.message_base import UserInfo from ..message.message_base import UserInfo
@@ -14,7 +14,7 @@ class ReplyChecker:
"""回复检查器""" """回复检查器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLMRequest(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check" model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check"
) )
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME

View File

@@ -1,6 +1,6 @@
from typing import Tuple from typing import Tuple
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from ..config.config import global_config from ..config.config import global_config
from .chat_observer import ChatObserver from .chat_observer import ChatObserver
from .reply_checker import ReplyChecker from .reply_checker import ReplyChecker
@@ -15,7 +15,7 @@ class ReplyGenerator:
"""回复生成器""" """回复生成器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLMRequest(
model=global_config.llm_normal, model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"], temperature=global_config.llm_normal["temp"],
max_tokens=300, max_tokens=300,

View File

@@ -1,7 +1,7 @@
from .emoji_manager import emoji_manager from .emoji_manager import emoji_manager
from ..person_info.relationship_manager import relationship_manager from ..person_info.relationship_manager import relationship_manager
from .chat_stream import chat_manager from .chat_stream import chat_manager
from .message_sender import message_manager from .messagesender import message_manager
from ..storage.storage import MessageStorage from ..storage.storage import MessageStorage

View File

@@ -42,7 +42,7 @@ class ChatBot:
self._started = True self._started = True
async def _create_PFC_chat(self, message: MessageRecv): async def _create_pfc_chat(self, message: MessageRecv):
try: try:
chat_id = str(message.chat_stream.stream_id) chat_id = str(message.chat_stream.stream_id)
@@ -112,7 +112,7 @@ class ChatBot:
) )
message.update_chat_stream(chat) message.update_chat_stream(chat)
await self.only_process_chat.process_message(message) await self.only_process_chat.process_message(message)
await self._create_PFC_chat(message) await self._create_pfc_chat(message)
else: else:
if groupinfo.group_id in global_config.talk_allowed_groups: if groupinfo.group_id in global_config.talk_allowed_groups:
# logger.debug(f"开始群聊模式{str(message_data)[:50]}...") # logger.debug(f"开始群聊模式{str(message_data)[:50]}...")

View File

@@ -13,7 +13,7 @@ from ...common.database import db
from ..config.config import global_config from ..config.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64 from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
logger = get_module_logger("emoji") logger = get_module_logger("emoji")
@@ -34,8 +34,8 @@ class EmojiManager:
def __init__(self): def __init__(self):
self._scan_task = None self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") self.vlm = LLMRequest(model=global_config.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.llm_emotion_judge = LLM_request( self.llm_emotion_judge = LLMRequest(
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji" model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度 ) # 更高的温度更少的token后续可以根据情绪来调整温度

View File

@@ -59,20 +59,20 @@ class MessageBuffer:
logger.debug(f"被新消息覆盖信息id: {cache_msg.message.message_info.message_id}") logger.debug(f"被新消息覆盖信息id: {cache_msg.message.message_info.message_id}")
# 查找最近的处理成功消息(T) # 查找最近的处理成功消息(T)
recent_F_count = 0 recent_f_count = 0
for msg_id in reversed(self.buffer_pool[person_id_]): for msg_id in reversed(self.buffer_pool[person_id_]):
msg = self.buffer_pool[person_id_][msg_id] msg = self.buffer_pool[person_id_][msg_id]
if msg.result == "T": if msg.result == "T":
break break
elif msg.result == "F": elif msg.result == "F":
recent_F_count += 1 recent_f_count += 1
# 判断条件最近T之后有超过3-5条F # 判断条件最近T之后有超过3-5条F
if recent_F_count >= random.randint(3, 5): if recent_f_count >= random.randint(3, 5):
new_msg = CacheMessages(message=message, result="T") new_msg = CacheMessages(message=message, result="T")
new_msg.cache_determination.set() new_msg.cache_determination.set()
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}") logger.debug(f"快速处理消息(已堆积{recent_f_count}条F): {message.message_info.message_id}")
return return
# 添加新消息 # 添加新消息

View File

@@ -23,7 +23,7 @@ sender_config = LogConfig(
logger = get_module_logger("msg_sender", config=sender_config) logger = get_module_logger("msg_sender", config=sender_config)
class Message_Sender: class MessageSender:
"""发送器""" """发送器"""
def __init__(self): def __init__(self):
@@ -83,7 +83,7 @@ class Message_Sender:
# logger.info(f"发送消息到{end_point}") # logger.info(f"发送消息到{end_point}")
# logger.info(message_json) # logger.info(message_json)
try: try:
await global_api.send_message_REST(end_point, message_json) await global_api.send_message_rest(end_point, message_json)
except Exception as e: except Exception as e:
logger.error(f"REST方式发送失败出现错误: {str(e)}") logger.error(f"REST方式发送失败出现错误: {str(e)}")
logger.info("尝试使用ws发送") logger.info("尝试使用ws发送")
@@ -286,4 +286,4 @@ class MessageManager:
# 创建全局消息管理器实例 # 创建全局消息管理器实例
message_manager = MessageManager() message_manager = MessageManager()
# 创建全局发送器实例 # 创建全局发送器实例
message_sender = Message_Sender() message_sender = MessageSender()

View File

@@ -8,7 +8,7 @@ import jieba
import numpy as np import numpy as np
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from ..config.config import global_config from ..config.config import global_config
from .message import MessageRecv, Message from .message import MessageRecv, Message
@@ -91,7 +91,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding"): async def get_embedding(text, request_type="embedding"):
"""获取文本的embedding向量""" """获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding, request_type=request_type) llm = LLMRequest(model=global_config.embedding, request_type=request_type)
# return llm.get_embedding_sync(text) # return llm.get_embedding_sync(text)
try: try:
embedding = await llm.get_embedding(text) embedding = await llm.get_embedding(text)
@@ -105,7 +105,7 @@ async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录
Args: Args:
group_id: 群组ID chat_id: 群组ID
limit: 获取消息数量默认12条 limit: 获取消息数量默认12条
Returns: Returns:

View File

@@ -9,7 +9,7 @@ import io
from ...common.database import db from ...common.database import db
from ..config.config import global_config from ..config.config import global_config
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
@@ -32,7 +32,7 @@ class ImageManager:
self._ensure_description_collection() self._ensure_description_collection()
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image") self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
def _ensure_image_dir(self): def _ensure_image_dir(self):
"""确保图像存储目录存在""" """确保图像存储目录存在"""

View File

@@ -8,7 +8,7 @@ from ...config.config import global_config
from ...chat.emoji_manager import emoji_manager from ...chat.emoji_manager import emoji_manager
from .reasoning_generator import ResponseGenerator from .reasoning_generator import ResponseGenerator
from ...chat.message import MessageSending, MessageRecv, MessageThinking, MessageSet from ...chat.message import MessageSending, MessageRecv, MessageThinking, MessageSet
from ...chat.message_sender import message_manager from ...chat.messagesender import message_manager
from ...storage.storage import MessageStorage from ...storage.storage import MessageStorage
from ...chat.utils import is_mentioned_bot_in_message from ...chat.utils import is_mentioned_bot_in_message
from ...chat.utils_image import image_path_to_base64 from ...chat.utils_image import image_path_to_base64

View File

@@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import random import random
from ...models.utils_model import LLM_request from ...models.utils_model import LLMRequest
from ...config.config import global_config from ...config.config import global_config
from ...chat.message import MessageThinking from ...chat.message import MessageThinking
from .reasoning_prompt_builder import prompt_builder from .reasoning_prompt_builder import prompt_builder
@@ -22,20 +22,20 @@ logger = get_module_logger("llm_generator", config=llm_config)
class ResponseGenerator: class ResponseGenerator:
def __init__(self): def __init__(self):
self.model_reasoning = LLM_request( self.model_reasoning = LLMRequest(
model=global_config.llm_reasoning, model=global_config.llm_reasoning,
temperature=0.7, temperature=0.7,
max_tokens=3000, max_tokens=3000,
request_type="response_reasoning", request_type="response_reasoning",
) )
self.model_normal = LLM_request( self.model_normal = LLMRequest(
model=global_config.llm_normal, model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"], temperature=global_config.llm_normal["temp"],
max_tokens=256, max_tokens=256,
request_type="response_reasoning", request_type="response_reasoning",
) )
self.model_sum = LLM_request( self.model_sum = LLMRequest(
model=global_config.llm_summary_by_topic, temperature=0.7, max_tokens=3000, request_type="relation" model=global_config.llm_summary_by_topic, temperature=0.7, max_tokens=3000, request_type="relation"
) )
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
@@ -68,7 +68,7 @@ class ResponseGenerator:
logger.info(f"{self.current_model_type}思考,失败") logger.info(f"{self.current_model_type}思考,失败")
return None return None
async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request, thinking_id: str): async def _generate_response_with_model(self, message: MessageThinking, model: LLMRequest, thinking_id: str):
sender_name = "" sender_name = ""
info_catcher = info_catcher_manager.get_info_catcher(thinking_id) info_catcher = info_catcher_manager.get_info_catcher(thinking_id)

View File

@@ -8,7 +8,7 @@ from ...config.config import global_config
from ...chat.emoji_manager import emoji_manager from ...chat.emoji_manager import emoji_manager
from .think_flow_generator import ResponseGenerator from .think_flow_generator import ResponseGenerator
from ...chat.message import MessageSending, MessageRecv, MessageThinking, MessageSet from ...chat.message import MessageSending, MessageRecv, MessageThinking, MessageSet
from ...chat.message_sender import message_manager from ...chat.messagesender import message_manager
from ...storage.storage import MessageStorage from ...storage.storage import MessageStorage
from ...chat.utils import is_mentioned_bot_in_message, get_recent_group_detailed_plain_text from ...chat.utils import is_mentioned_bot_in_message, get_recent_group_detailed_plain_text
from ...chat.utils_image import image_path_to_base64 from ...chat.utils_image import image_path_to_base64

View File

@@ -2,7 +2,7 @@ from typing import List, Optional
import random import random
from ...models.utils_model import LLM_request from ...models.utils_model import LLMRequest
from ...config.config import global_config from ...config.config import global_config
from ...chat.message import MessageRecv from ...chat.message import MessageRecv
from .think_flow_prompt_builder import prompt_builder from .think_flow_prompt_builder import prompt_builder
@@ -25,14 +25,14 @@ logger = get_module_logger("llm_generator", config=llm_config)
class ResponseGenerator: class ResponseGenerator:
def __init__(self): def __init__(self):
self.model_normal = LLM_request( self.model_normal = LLMRequest(
model=global_config.llm_normal, model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"], temperature=global_config.llm_normal["temp"],
max_tokens=256, max_tokens=256,
request_type="response_heartflow", request_type="response_heartflow",
) )
self.model_sum = LLM_request( self.model_sum = LLMRequest(
model=global_config.llm_summary_by_topic, temperature=0.6, max_tokens=2000, request_type="relation" model=global_config.llm_summary_by_topic, temperature=0.6, max_tokens=2000, request_type="relation"
) )
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
@@ -94,7 +94,7 @@ class ResponseGenerator:
return None return None
async def _generate_response_with_model( async def _generate_response_with_model(
self, message: MessageRecv, model: LLM_request, thinking_id: str, mode: str = "normal" self, message: MessageRecv, model: LLMRequest, thinking_id: str, mode: str = "normal"
) -> str: ) -> str:
sender_name = "" sender_name = ""

View File

@@ -62,8 +62,7 @@ def update_config():
shutil.copy2(template_path, old_config_path) shutil.copy2(template_path, old_config_path)
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回 # 如果是新创建的配置文件,直接返回
quit() return quit()
return
# 读取旧配置文件和模板文件 # 读取旧配置文件和模板文件
with open(old_config_path, "r", encoding="utf-8") as f: with open(old_config_path, "r", encoding="utf-8") as f:

View File

@@ -9,7 +9,7 @@ import networkx as nx
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from ...common.database import db from ...common.database import db
from ...plugins.models.utils_model import LLM_request from ...plugins.models.utils_model import LLMRequest
from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from .memory_config import MemoryConfig from .memory_config import MemoryConfig
@@ -91,7 +91,7 @@ memory_config = LogConfig(
logger = get_module_logger("memory_system", config=memory_config) logger = get_module_logger("memory_system", config=memory_config)
class Memory_graph: class MemoryGraph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
@@ -229,7 +229,7 @@ class Memory_graph:
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self): def __init__(self):
self.memory_graph = Memory_graph() self.memory_graph = MemoryGraph()
self.llm_topic_judge = None self.llm_topic_judge = None
self.llm_summary_by_topic = None self.llm_summary_by_topic = None
self.entorhinal_cortex = None self.entorhinal_cortex = None
@@ -243,8 +243,8 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图 # 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db() self.entorhinal_cortex.sync_memory_from_db()
self.llm_topic_judge = LLM_request(self.config.llm_topic_judge, request_type="memory") self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory")
self.llm_summary_by_topic = LLM_request(self.config.llm_summary_by_topic, request_type="memory") self.llm_summary_by_topic = LLMRequest(self.config.llm_summary_by_topic, request_type="memory")
def get_all_node_names(self) -> list: def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表""" """获取记忆图中所有节点的名字列表"""
@@ -346,7 +346,8 @@ class Hippocampus:
Args: Args:
text (str): 输入文本 text (str): 输入文本
num (int, optional): 需要返回的记忆数量。默认为5 max_memory_num (int, optional): 记忆数量限制。默认为3
max_memory_length (int, optional): 记忆长度限制。默认为2。
max_depth (int, optional): 记忆检索深度。默认为2。 max_depth (int, optional): 记忆检索深度。默认为2。
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确。 如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确。
@@ -540,7 +541,6 @@ class Hippocampus:
Args: Args:
text (str): 输入文本 text (str): 输入文本
num (int, optional): 需要返回的记忆数量。默认为5。
max_depth (int, optional): 记忆检索深度。默认为2。 max_depth (int, optional): 记忆检索深度。默认为2。
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确。 如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确。
@@ -937,7 +937,7 @@ class EntorhinalCortex:
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self): def __init__(self):
self.memory_graph = Memory_graph() self.memory_graph = MemoryGraph()
self.llm_topic_judge = None self.llm_topic_judge = None
self.llm_summary_by_topic = None self.llm_summary_by_topic = None
self.entorhinal_cortex = None self.entorhinal_cortex = None
@@ -951,8 +951,8 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图 # 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db() self.entorhinal_cortex.sync_memory_from_db()
self.llm_topic_judge = LLM_request(self.config.llm_topic_judge, request_type="memory") self.llm_topic_judge = LLMRequest(self.config.llm_topic_judge, request_type="memory")
self.llm_summary_by_topic = LLM_request(self.config.llm_summary_by_topic, request_type="memory") self.llm_summary_by_topic = LLMRequest(self.config.llm_summary_by_topic, request_type="memory")
def get_all_node_names(self) -> list: def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表""" """获取记忆图中所有节点的名字列表"""
@@ -1054,8 +1054,9 @@ class Hippocampus:
Args: Args:
text (str): 输入文本 text (str): 输入文本
num (int, optional): 需要返回的记忆数量。默认为5 max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3表示最多返回3条与输入文本相关度最高的记忆
max_depth (int, optional): 记忆检索深度。默认为2 max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2表示每个主题最多返回2条相似度最高的记忆。
max_depth (int, optional): 记忆检索深度。默认为3。值越大检索范围越广可以获取更多间接相关的记忆但速度会变慢。
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确。 如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确。
如果为False使用LLM提取关键词速度较慢但更准确。 如果为False使用LLM提取关键词速度较慢但更准确。
@@ -1248,7 +1249,6 @@ class Hippocampus:
Args: Args:
text (str): 输入文本 text (str): 输入文本
num (int, optional): 需要返回的记忆数量。默认为5。
max_depth (int, optional): 记忆检索深度。默认为2。 max_depth (int, optional): 记忆检索深度。默认为2。
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。 fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确。 如果为True使用jieba分词和TF-IDF提取关键词速度更快但可能不够准确。

View File

@@ -177,7 +177,7 @@ def remove_mem_edge(hippocampus: Hippocampus):
# 修改节点信息 # 修改节点信息
def alter_mem_node(hippocampus: Hippocampus): def alter_mem_node(hippocampus: Hippocampus):
batchEnviroment = dict() batch_environment = dict()
while True: while True:
concept = input("请输入节点概念名(输入'终止'以结束):\n") concept = input("请输入节点概念名(输入'终止'以结束):\n")
if concept.lower() == "终止": if concept.lower() == "终止":
@@ -229,7 +229,7 @@ def alter_mem_node(hippocampus: Hippocampus):
break break
try: try:
user_exec(command, node_environment, batchEnviroment) user_exec(command, node_environment, batch_environment)
except Exception as e: except Exception as e:
console.print(e) console.print(e)
console.print( console.print(
@@ -239,7 +239,7 @@ def alter_mem_node(hippocampus: Hippocampus):
# 修改边信息 # 修改边信息
def alter_mem_edge(hippocampus: Hippocampus): def alter_mem_edge(hippocampus: Hippocampus):
batchEnviroment = dict() batch_enviroment = dict()
while True: while True:
source = input("请输入 **第一个节点** 名称(输入'终止'以结束):\n") source = input("请输入 **第一个节点** 名称(输入'终止'以结束):\n")
if source.lower() == "终止": if source.lower() == "终止":
@@ -262,21 +262,21 @@ def alter_mem_edge(hippocampus: Hippocampus):
console.print("[yellow]你将获得一个执行任意代码的环境[/yellow]") console.print("[yellow]你将获得一个执行任意代码的环境[/yellow]")
console.print("[red]你已经被警告过了。[/red]\n") console.print("[red]你已经被警告过了。[/red]\n")
edgeEnviroment = {"source": "<节点名>", "target": "<节点名>", "strength": "<强度值,装在一个list里>"} edge_environment = {"source": "<节点名>", "target": "<节点名>", "strength": "<强度值,装在一个list里>"}
console.print( console.print(
"[green]环境变量中会有env与batchEnv两个dict, env在切换节点时会清空, batchEnv在操作终止时才会清空[/green]" "[green]环境变量中会有env与batchEnv两个dict, env在切换节点时会清空, batchEnv在操作终止时才会清空[/green]"
) )
console.print( console.print(
f"[green] env 会被初始化为[/green]\n{edgeEnviroment}\n[green]且会在用户代码执行完毕后被提交 [/green]" f"[green] env 会被初始化为[/green]\n{edge_environment}\n[green]且会在用户代码执行完毕后被提交 [/green]"
) )
console.print( console.print(
"[yellow]为便于书写临时脚本请手动在输入代码通过Ctrl+C等方式触发KeyboardInterrupt来结束代码执行[/yellow]" "[yellow]为便于书写临时脚本请手动在输入代码通过Ctrl+C等方式触发KeyboardInterrupt来结束代码执行[/yellow]"
) )
# 拷贝数据以防操作炸了 # 拷贝数据以防操作炸了
edgeEnviroment["strength"] = [edge["strength"]] edge_environment["strength"] = [edge["strength"]]
edgeEnviroment["source"] = source edge_environment["source"] = source
edgeEnviroment["target"] = target edge_environment["target"] = target
while True: while True:
@@ -288,8 +288,8 @@ def alter_mem_edge(hippocampus: Hippocampus):
except KeyboardInterrupt: except KeyboardInterrupt:
# 稍微防一下小天才 # 稍微防一下小天才
try: try:
if isinstance(edgeEnviroment["strength"][0], int): if isinstance(edge_environment["strength"][0], int):
edge["strength"] = edgeEnviroment["strength"][0] edge["strength"] = edge_environment["strength"][0]
else: else:
raise Exception raise Exception
@@ -301,7 +301,7 @@ def alter_mem_edge(hippocampus: Hippocampus):
break break
try: try:
user_exec(command, edgeEnviroment, batchEnviroment) user_exec(command, edge_environment, batch_enviroment)
except Exception as e: except Exception as e:
console.print(e) console.print(e)
console.print( console.print(

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_module_logger
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")
class LLM_request_off: class LLMRequestOff:
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
self.model_name = model_name self.model_name = model_name
self.params = kwargs self.params = kwargs

View File

@@ -233,7 +233,8 @@ class MessageServer(BaseMessageHandler):
async def send_message(self, message: MessageBase): async def send_message(self, message: MessageBase):
await self.broadcast_to_platform(message.message_info.platform, message.to_dict()) await self.broadcast_to_platform(message.message_info.platform, message.to_dict())
async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: @staticmethod
async def send_message_rest(url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点""" """发送消息到指定端点"""
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
try: try:

View File

@@ -16,7 +16,7 @@ from ..config.config import global_config
logger = get_module_logger("model_utils") logger = get_module_logger("model_utils")
class LLM_request: class LLMRequest:
# 定义需要转换的模型列表,作为类变量避免重复 # 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [ MODELS_NEEDING_TRANSFORMATION = [
"o3-mini", "o3-mini",

View File

@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict
import datetime import datetime
import asyncio import asyncio
import numpy as np import numpy as np
from src.plugins.models.utils_model import LLM_request from src.plugins.models.utils_model import LLMRequest
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
@@ -56,7 +56,7 @@ person_info_default = {
class PersonInfoManager: class PersonInfoManager:
def __init__(self): def __init__(self):
self.person_name_list = {} self.person_name_list = {}
self.qv_name_llm = LLM_request( self.qv_name_llm = LLMRequest(
model=global_config.llm_normal, model=global_config.llm_normal,
max_tokens=256, max_tokens=256,
request_type="qv_name", request_type="qv_name",
@@ -107,7 +107,7 @@ class PersonInfoManager:
db.person_info.insert_one(_person_info_default) db.person_info.insert_one(_person_info_default)
async def update_one_field(self, person_id: str, field_name: str, value, Data: dict = None): async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
"""更新某一个字段,会补全""" """更新某一个字段,会补全"""
if field_name not in person_info_default.keys(): if field_name not in person_info_default.keys():
logger.debug(f"更新'{field_name}'失败,未定义的字段") logger.debug(f"更新'{field_name}'失败,未定义的字段")
@@ -118,11 +118,12 @@ class PersonInfoManager:
if document: if document:
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}}) db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
else: else:
Data[field_name] = value data[field_name] = value
logger.debug(f"更新时{person_id}不存在,已新建") logger.debug(f"更新时{person_id}不存在,已新建")
await self.create_person_info(person_id, Data) await self.create_person_info(person_id, data)
async def has_one_field(self, person_id: str, field_name: str): @staticmethod
async def has_one_field(person_id: str, field_name: str):
"""判断是否存在某一个字段""" """判断是否存在某一个字段"""
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1}) document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
if document: if document:

View File

@@ -38,7 +38,7 @@ else:
print("将使用默认配置") print("将使用默认配置")
class PersonalityEvaluator_direct: class PersonalityEvaluatorDirect:
def __init__(self): def __init__(self):
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.scenarios = [] self.scenarios = []
@@ -135,7 +135,7 @@ def main():
print("\n准备好了吗?按回车键开始...") print("\n准备好了吗?按回车键开始...")
input() input()
evaluator = PersonalityEvaluator_direct() evaluator = PersonalityEvaluatorDirect()
final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
dimension_counts = {trait: 0 for trait in final_scores.keys()} dimension_counts = {trait: 0 for trait in final_scores.keys()}

View File

@@ -125,12 +125,12 @@ def main():
if global_config.remote_enable: if global_config.remote_enable:
"""主函数,启动心跳线程""" """主函数,启动心跳线程"""
# 配置 # 配置
SERVER_URL = "http://hyybuth.xyz:10058" server_url = "http://hyybuth.xyz:10058"
# SERVER_URL = "http://localhost:10058" # server_url = "http://localhost:10058"
HEARTBEAT_INTERVAL = 300 # 5分钟 heartbeat_interval = 300 # 5分钟
# 创建并启动心跳线程 # 创建并启动心跳线程
heartbeat_thread = HeartbeatThread(SERVER_URL, HEARTBEAT_INTERVAL) heartbeat_thread = HeartbeatThread(server_url, heartbeat_interval)
heartbeat_thread.start() heartbeat_thread.start()
return heartbeat_thread # 返回线程对象,便于外部控制 return heartbeat_thread # 返回线程对象,便于外部控制

View File

@@ -11,7 +11,7 @@ sys.path.append(root_path)
from src.common.database import db # noqa: E402 from src.common.database import db # noqa: E402
from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfig # noqa: E402 from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfig # noqa: E402
from src.plugins.models.utils_model import LLM_request # noqa: E402 from src.plugins.models.utils_model import LLMRequest # noqa: E402
from src.plugins.config.config import global_config # noqa: E402 from src.plugins.config.config import global_config # noqa: E402
TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区 TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
@@ -30,13 +30,13 @@ class ScheduleGenerator:
def __init__(self): def __init__(self):
# 使用离线LLM模型 # 使用离线LLM模型
self.llm_scheduler_all = LLM_request( self.llm_scheduler_all = LLMRequest(
model=global_config.llm_reasoning, model=global_config.llm_reasoning,
temperature=global_config.SCHEDULE_TEMPERATURE + 0.3, temperature=global_config.SCHEDULE_TEMPERATURE + 0.3,
max_tokens=7000, max_tokens=7000,
request_type="schedule", request_type="schedule",
) )
self.llm_scheduler_doing = LLM_request( self.llm_scheduler_doing = LLMRequest(
model=global_config.llm_normal, model=global_config.llm_normal,
temperature=global_config.SCHEDULE_TEMPERATURE, temperature=global_config.SCHEDULE_TEMPERATURE,
max_tokens=2048, max_tokens=2048,

View File

@@ -1,7 +1,7 @@
from typing import List, Optional from typing import List, Optional
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from ..config.config import global_config from ..config.config import global_config
from src.common.logger import get_module_logger, LogConfig, TOPIC_STYLE_CONFIG from src.common.logger import get_module_logger, LogConfig, TOPIC_STYLE_CONFIG
@@ -17,7 +17,7 @@ logger = get_module_logger("topic_identifier", config=topic_config)
class TopicIdentifier: class TopicIdentifier:
def __init__(self): def __init__(self):
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, request_type="topic") self.llm_topic_judge = LLMRequest(model=global_config.llm_topic_judge, request_type="topic")
async def identify_topic_llm(self, text: str) -> Optional[List[str]]: async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表""" """识别消息主题,返回主题列表"""