Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -5,7 +5,7 @@ MaiBot模块系统
|
||||
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import willing_manager
|
||||
|
||||
# 导出主要组件供外部使用
|
||||
|
||||
@@ -12,11 +12,11 @@ import re
|
||||
|
||||
# from gradio_client import file
|
||||
|
||||
from ...common.database.database_model import Emoji
|
||||
from ...common.database.database import db as peewee_db
|
||||
from ...config.config import global_config
|
||||
from ..utils.utils_image import image_path_to_base64, image_manager
|
||||
from ..models.utils_model import LLMRequest
|
||||
from src.common.database.database_model import Emoji
|
||||
from src.common.database.database import db as peewee_db
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, image_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger_manager import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from src.chat.message_receive.message import Seg # Local import needed after mo
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import random
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||
from src.chat.focus_chat.heartflow_prompt_builder import Prompt, global_prompt_manager
|
||||
|
||||
@@ -425,7 +425,10 @@ class HeartFChatting:
|
||||
self.all_observations = observations
|
||||
|
||||
with Timer("回忆", cycle_timers):
|
||||
logger.debug(f"{self.log_prefix} 开始回忆")
|
||||
running_memorys = await self.memory_activator.activate_memory(observations)
|
||||
logger.debug(f"{self.log_prefix} 回忆完成")
|
||||
print(running_memorys)
|
||||
|
||||
with Timer("执行 信息处理器", cycle_timers):
|
||||
all_plan_info = await self._process_processors(observations, running_memorys, cycle_timers)
|
||||
|
||||
@@ -11,7 +11,7 @@ from ..message_receive.chat_stream import chat_manager
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
from ..utils.timer_calculator import Timer
|
||||
from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
@@ -3,7 +3,7 @@ from src.common.logger_manager import get_logger
|
||||
from src.individuality.individuality import individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
|
||||
@@ -37,4 +37,4 @@ class SelfInfo(InfoBase):
|
||||
Returns:
|
||||
str: 处理后的信息
|
||||
"""
|
||||
return self.get_self_info()
|
||||
return self.get_self_info() or ""
|
||||
|
||||
@@ -67,3 +67,16 @@ class StructuredInfo:
|
||||
value: 要设置的属性值
|
||||
"""
|
||||
self.data[key] = value
|
||||
|
||||
def get_processed_info(self) -> str:
|
||||
"""获取处理后的信息
|
||||
|
||||
Returns:
|
||||
str: 处理后的信息字符串
|
||||
"""
|
||||
|
||||
info_str = ""
|
||||
for key, value in self.data.items():
|
||||
info_str += f"信息类型:{key},信息内容:{value}\n"
|
||||
|
||||
return info_str
|
||||
|
||||
@@ -8,7 +8,7 @@ from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservati
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
||||
from typing import Dict
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservati
|
||||
from src.chat.focus_chat.info.cycle_info import CycleInfo
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
@@ -9,7 +9,7 @@ from src.individuality.individuality import individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.json_utils import safe_json_dumps
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from .base_processor import BaseProcessor
|
||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
@@ -8,7 +8,7 @@ from src.common.logger_manager import get_logger
|
||||
from src.individuality.individuality import individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from .base_processor import BaseProcessor
|
||||
from typing import List, Optional
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
@@ -33,12 +33,13 @@ def init_prompt():
|
||||
|
||||
现在请你根据现有的信息,思考自我认同
|
||||
1. 你是一个什么样的人,你和群里的人关系如何
|
||||
2. 思考有没有人提到你,或者图片与你有关
|
||||
3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同
|
||||
4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景
|
||||
2. 你的形象是什么
|
||||
3. 思考有没有人提到你,或者图片与你有关
|
||||
4. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十几个字的简短自我认同
|
||||
5. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景
|
||||
|
||||
请思考的平淡一些,简短一些,说中文,不要浮夸,平淡一些。
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出自我认同内容。
|
||||
输出内容平淡一些,说中文,不要浮夸,平淡一些。
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出自我认同内容,记得明确说明这是你的自我认同。
|
||||
|
||||
"""
|
||||
Prompt(indentify_prompt, "indentify_prompt")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from src.common.logger_manager import get_logger
|
||||
@@ -7,7 +7,7 @@ from src.individuality.individuality import individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.tools.tool_use import ToolUser
|
||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||
from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from .base_processor import BaseProcessor
|
||||
from typing import List, Optional, Dict
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.heart_flow.observation.structure_observation import StructureObservation
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
@@ -61,6 +61,8 @@ class MemoryActivator:
|
||||
elif isinstance(observation, HFCloopObservation):
|
||||
obs_info_text += observation.get_observe_info()
|
||||
|
||||
logger.debug(f"回忆待检索内容:obs_info_text: {obs_info_text}")
|
||||
|
||||
# prompt = await global_prompt_manager.format_prompt(
|
||||
# "memory_activator_prompt",
|
||||
# obs_info_text=obs_info_text,
|
||||
@@ -81,7 +83,7 @@ class MemoryActivator:
|
||||
# valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
|
||||
# )
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=obs_info_text, max_memory_num=3, max_memory_length=2, max_depth=3, fast_retrieval=True
|
||||
text=obs_info_text, max_memory_num=5, max_memory_length=2, max_depth=3, fast_retrieval=True
|
||||
)
|
||||
|
||||
# logger.debug(f"获取到的记忆: {related_memory}")
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import traceback
|
||||
from typing import Tuple, Dict, List, Any, Optional
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action # noqa F401
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.person_info.person_info import person_info_manager
|
||||
from src.person_info.person_info import person_info_manager
|
||||
from abc import abstractmethod
|
||||
import os
|
||||
import inspect
|
||||
import toml # 导入 toml 库
|
||||
|
||||
logger = get_logger("plugin_action")
|
||||
|
||||
@@ -16,12 +19,24 @@ class PluginAction(BaseAction):
|
||||
封装了主程序内部依赖,提供简化的API接口给插件开发者
|
||||
"""
|
||||
|
||||
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs):
|
||||
action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
global_config: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化插件动作基类"""
|
||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||
|
||||
# 存储内部服务和对象引用
|
||||
self._services = {}
|
||||
self._global_config = global_config # 存储全局配置的只读引用
|
||||
self.config: Dict[str, Any] = {} # 用于存储插件自身的配置
|
||||
|
||||
# 从kwargs提取必要的内部服务
|
||||
if "observations" in kwargs:
|
||||
@@ -32,6 +47,61 @@ class PluginAction(BaseAction):
|
||||
self._services["chat_stream"] = kwargs["chat_stream"]
|
||||
|
||||
self.log_prefix = kwargs.get("log_prefix", "")
|
||||
self._load_plugin_config() # 初始化时加载插件配置
|
||||
|
||||
def _load_plugin_config(self):
|
||||
"""
|
||||
加载插件自身的配置文件。
|
||||
配置文件应与插件模块在同一目录下。
|
||||
插件可以通过覆盖 `action_config_file_name` 类属性来指定文件名。
|
||||
如果 `action_config_file_name` 未指定,则不加载配置。
|
||||
仅支持 TOML (.toml) 格式。
|
||||
"""
|
||||
if not self.action_config_file_name:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 插件 {self.__class__.__name__} 未指定 action_config_file_name,不加载插件配置。"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
plugin_module_path = inspect.getfile(self.__class__)
|
||||
plugin_dir = os.path.dirname(plugin_module_path)
|
||||
config_file_path = os.path.join(plugin_dir, self.action_config_file_name)
|
||||
|
||||
if not os.path.exists(config_file_path):
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置文件 {config_file_path} 不存在。"
|
||||
)
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(self.action_config_file_name)[1].lower()
|
||||
|
||||
if file_ext == ".toml":
|
||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||
self.config = toml.load(f) or {}
|
||||
logger.info(f"{self.log_prefix} 插件 {self.__class__.__name__} 的配置已从 {config_file_path} 加载。")
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 不支持的插件配置文件格式: {file_ext}。仅支持 .toml。插件配置未加载。"
|
||||
)
|
||||
self.config = {} # 确保未加载时为空字典
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{self.log_prefix} 加载插件 {self.__class__.__name__} 的配置文件 {self.action_config_file_name} 时出错: {e}"
|
||||
)
|
||||
self.config = {} # 出错时确保 config 是一个空字典
|
||||
|
||||
def get_global_config(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从全局配置中获取一个值。
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
"""
|
||||
if self._global_config:
|
||||
return self._global_config.get(key, default)
|
||||
logger.debug(f"{self.log_prefix} 尝试访问全局配置项 '{key}',但全局配置未提供。")
|
||||
return default
|
||||
|
||||
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
|
||||
"""根据用户名获取用户ID"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import json # <--- 确保导入 json
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Optional
|
||||
from rich.traceback import install
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
@@ -10,6 +10,7 @@ from src.chat.focus_chat.info.cycle_info import CycleInfo
|
||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
||||
from src.chat.focus_chat.info.action_info import ActionInfo
|
||||
from src.chat.focus_chat.info.structured_info import StructuredInfo
|
||||
from src.chat.focus_chat.info.self_info import SelfInfo
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.individuality.individuality import individuality
|
||||
@@ -22,7 +23,11 @@ install(extra_lines=3)
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""{extra_info_block}
|
||||
"""
|
||||
你的自我认知是:
|
||||
{self_info_block}
|
||||
|
||||
{extra_info_block}
|
||||
|
||||
你需要基于以下信息决定如何参与对话
|
||||
这些信息可能会有冲突,请你整合这些信息,并选择一个最合适的action:
|
||||
@@ -127,6 +132,8 @@ class ActionPlanner:
|
||||
current_mind = info.get_current_mind()
|
||||
elif isinstance(info, CycleInfo):
|
||||
cycle_info = info.get_observe_info()
|
||||
elif isinstance(info, SelfInfo):
|
||||
self_info = info.get_processed_info()
|
||||
elif isinstance(info, StructuredInfo):
|
||||
_structured_info = info.get_data()
|
||||
elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
|
||||
@@ -148,6 +155,7 @@ class ActionPlanner:
|
||||
|
||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||
prompt = await self.build_planner_prompt(
|
||||
self_info_block=self_info,
|
||||
is_group_chat=is_group_chat, # <-- Pass HFC state
|
||||
chat_target_info=None,
|
||||
observed_messages_str=observed_messages_str, # <-- Pass local variable
|
||||
@@ -236,6 +244,7 @@ class ActionPlanner:
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
self_info_block: str,
|
||||
is_group_chat: bool, # Now passed as argument
|
||||
chat_target_info: Optional[dict], # Now passed as argument
|
||||
observed_messages_str: str,
|
||||
@@ -301,7 +310,8 @@ class ActionPlanner:
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
self_info_block=self_info_block,
|
||||
# bot_name=global_config.bot.nickname,
|
||||
prompt_personality=personality_block,
|
||||
chat_context_description=chat_context_description,
|
||||
chat_content_block=chat_content_block,
|
||||
|
||||
@@ -3,7 +3,7 @@ import traceback
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
import json # 添加json模块导入
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import traceback
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
|
||||
@@ -88,5 +88,6 @@ class HFCloopObservation:
|
||||
for action_name, action_info in using_actions.items():
|
||||
action_description = action_info["description"]
|
||||
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
|
||||
cycle_info_block += "注意,除了上述动作选项之外,你在群聊里不能做其他任何事情,这是你能力的边界\n"
|
||||
|
||||
self.observe_info = cycle_info_block
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from typing import Optional, Tuple, Dict
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.person_info.person_info import person_info_manager
|
||||
from src.person_info.person_info import person_info_manager
|
||||
|
||||
logger = get_logger("heartflow_utils")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from ...chat.models.utils_model import LLMRequest
|
||||
from ...llm_models.utils_model import LLMRequest
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
from ..utils.chat_message_builder import (
|
||||
@@ -338,7 +338,8 @@ class Hippocampus:
|
||||
# 去重
|
||||
keywords = list(set(keywords))
|
||||
# 限制关键词数量
|
||||
keywords = keywords[:5]
|
||||
logger.debug(f"提取关键词: {keywords}")
|
||||
|
||||
else:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
@@ -361,7 +362,7 @@ class Hippocampus:
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
if not valid_keywords:
|
||||
# logger.info("没有找到有效的关键词节点")
|
||||
logger.info("没有找到有效的关键词节点")
|
||||
return []
|
||||
|
||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from ..emoji_system.emoji_manager import emoji_manager
|
||||
from ..person_info.relationship_manager import relationship_manager
|
||||
from .chat_stream import chat_manager
|
||||
from .message_sender import message_manager
|
||||
from .storage import MessageStorage
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,886 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Union, Dict, Any
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client import ClientResponse
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
import os
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||
from ...config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_module_logger("model_utils")
|
||||
|
||||
|
||||
class PayLoadTooLargeError(Exception):
|
||||
"""自定义异常类,用于处理请求体过大错误"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return "请求体过大,请尝试压缩图片或减少输入内容。"
|
||||
|
||||
|
||||
class RequestAbortException(Exception):
|
||||
"""自定义异常类,用于处理请求中断异常"""
|
||||
|
||||
def __init__(self, message: str, response: ClientResponse):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.response = response
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class PermissionDeniedException(Exception):
|
||||
"""自定义异常类,用于处理访问拒绝的异常"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
# 常见Error Code Mapping
|
||||
error_code_mapping = {
|
||||
400: "参数不正确",
|
||||
401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
|
||||
402: "账号余额不足",
|
||||
403: "需要实名,或余额不足",
|
||||
404: "Not Found",
|
||||
429: "请求过于频繁,请稍后再试",
|
||||
500: "服务器内部故障",
|
||||
503: "服务器负载过高",
|
||||
}
|
||||
|
||||
|
||||
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
|
||||
image_base64: str = request_content.get("image_base64")
|
||||
image_format: str = request_content.get("image_format")
|
||||
if (
|
||||
image_base64
|
||||
and payload
|
||||
and isinstance(payload, dict)
|
||||
and "messages" in payload
|
||||
and len(payload["messages"]) > 0
|
||||
):
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
# 定义需要转换的模型列表,作为类变量避免重复
|
||||
MODELS_NEEDING_TRANSFORMATION = [
|
||||
"o1",
|
||||
"o1-2024-12-17",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview",
|
||||
"o1-preview-2024-09-12",
|
||||
"o1-pro",
|
||||
"o1-pro-2025-03-19",
|
||||
"o3",
|
||||
"o3-2025-04-16",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31o4-mini",
|
||||
"o4-mini-2025-04-16",
|
||||
]
|
||||
|
||||
def __init__(self, model: dict, **kwargs):
|
||||
# 将大写的配置键转换为小写并从config中获取实际值
|
||||
try:
|
||||
self.api_key = os.environ[f"{model['provider']}_KEY"]
|
||||
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
|
||||
except AttributeError as e:
|
||||
logger.error(f"原始 model dict 信息:{model}")
|
||||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
|
||||
self.model_name: str = model["name"]
|
||||
self.params = kwargs
|
||||
|
||||
self.stream = model.get("stream", False)
|
||||
self.pri_in = model.get("pri_in", 0)
|
||||
self.pri_out = model.get("pri_out", 0)
|
||||
|
||||
# 获取数据库实例
|
||||
self._init_database()
|
||||
|
||||
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
||||
self.request_type = kwargs.pop("request_type", "default")
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def _record_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
user_id: str = "system",
|
||||
request_type: str = None,
|
||||
endpoint: str = "/chat/completions",
|
||||
):
|
||||
"""记录模型使用情况到数据库
|
||||
Args:
|
||||
prompt_tokens: 输入token数
|
||||
completion_tokens: 输出token数
|
||||
total_tokens: 总token数
|
||||
user_id: 用户ID,默认为system
|
||||
request_type: 请求类型
|
||||
endpoint: API端点
|
||||
"""
|
||||
# 如果 request_type 为 None,则使用实例变量中的值
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
model_name=self.model_name,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=self._calculate_cost(prompt_tokens, completion_tokens),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
logger.trace(
|
||||
f"Token使用情况 - 模型: {self.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""计算API调用成本
|
||||
使用模型的pri_in和pri_out价格计算输入和输出的成本
|
||||
|
||||
Args:
|
||||
prompt_tokens: 输入token数量
|
||||
completion_tokens: 输出token数量
|
||||
|
||||
Returns:
|
||||
float: 总成本(元)
|
||||
"""
|
||||
# 使用模型的pri_in和pri_out计算成本
|
||||
input_cost = (prompt_tokens / 1000000) * self.pri_in
|
||||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
async def _prepare_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
prompt: str = None,
|
||||
image_base64: str = None,
|
||||
image_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""配置请求参数
|
||||
Args:
|
||||
endpoint: API端点路径 (如 "chat/completions")
|
||||
prompt: prompt文本
|
||||
image_base64: 图片的base64编码
|
||||
image_format: 图片格式
|
||||
payload: 请求体数据
|
||||
retry_policy: 自定义重试策略
|
||||
request_type: 请求类型
|
||||
"""
|
||||
|
||||
# 合并重试策略
|
||||
default_retry = {
|
||||
"max_retries": 3,
|
||||
"base_wait": 10,
|
||||
"retry_codes": [429, 413, 500, 503],
|
||||
"abort_codes": [400, 401, 402, 403],
|
||||
}
|
||||
policy = {**default_retry, **(retry_policy or {})}
|
||||
|
||||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||
|
||||
stream_mode = self.stream
|
||||
|
||||
# 构建请求体
|
||||
if image_base64:
|
||||
payload = await self._build_payload(prompt, image_base64, image_format)
|
||||
elif payload is None:
|
||||
payload = await self._build_payload(prompt)
|
||||
|
||||
if stream_mode:
|
||||
payload["stream"] = stream_mode
|
||||
|
||||
return {
|
||||
"policy": policy,
|
||||
"payload": payload,
|
||||
"api_url": api_url,
|
||||
"stream_mode": stream_mode,
|
||||
"image_base64": image_base64, # 保留必要的exception处理所需的原始数据
|
||||
"image_format": image_format,
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
async def _execute_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
prompt: str = None,
|
||||
image_base64: str = None,
|
||||
image_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = None,
|
||||
):
|
||||
"""统一请求执行入口
|
||||
Args:
|
||||
endpoint: API端点路径 (如 "chat/completions")
|
||||
prompt: prompt文本
|
||||
image_base64: 图片的base64编码
|
||||
image_format: 图片格式
|
||||
payload: 请求体数据
|
||||
retry_policy: 自定义重试策略
|
||||
response_handler: 自定义响应处理器
|
||||
user_id: 用户ID
|
||||
request_type: 请求类型
|
||||
"""
|
||||
# 获取请求配置
|
||||
request_content = await self._prepare_request(
|
||||
endpoint, prompt, image_base64, image_format, payload, retry_policy
|
||||
)
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
for retry in range(request_content["policy"]["max_retries"]):
|
||||
try:
|
||||
# 使用上下文管理器处理会话
|
||||
headers = await self._build_headers()
|
||||
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||||
if request_content["stream_mode"]:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
request_content["api_url"], headers=headers, json=request_content["payload"]
|
||||
) as response:
|
||||
handled_result = await self._handle_response(
|
||||
response, request_content, retry, response_handler, user_id, request_type, endpoint
|
||||
)
|
||||
return handled_result
|
||||
except Exception as e:
|
||||
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
|
||||
retry += count_delta # 降级不计入重试次数
|
||||
if handled_payload:
|
||||
# 如果降级成功,重新构建请求体
|
||||
request_content["payload"] = handled_payload
|
||||
continue
|
||||
|
||||
logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
|
||||
raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
|
||||
|
||||
async def _handle_response(
|
||||
self,
|
||||
response: ClientResponse,
|
||||
request_content: Dict[str, Any],
|
||||
retry_count: int,
|
||||
response_handler: callable,
|
||||
user_id,
|
||||
request_type,
|
||||
endpoint,
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
policy = request_content["policy"]
|
||||
stream_mode = request_content["stream_mode"]
|
||||
if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
|
||||
await self._handle_error_response(response, retry_count, policy)
|
||||
return None
|
||||
|
||||
response.raise_for_status()
|
||||
result = {}
|
||||
if stream_mode:
|
||||
# 将流式输出转化为非流式输出
|
||||
result = await self._handle_stream_output(response)
|
||||
else:
|
||||
result = await response.json()
|
||||
return (
|
||||
response_handler(result)
|
||||
if response_handler
|
||||
else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||
)
|
||||
|
||||
async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
|
||||
flag_delta_content_finished = False
|
||||
accumulated_content = ""
|
||||
usage = None # 初始化usage变量,避免未定义错误
|
||||
reasoning_content = ""
|
||||
content = ""
|
||||
tool_calls = None # 初始化工具调用变量
|
||||
|
||||
async for line_bytes in response.content:
|
||||
try:
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
if flag_delta_content_finished:
|
||||
chunk_usage = chunk.get("usage", None)
|
||||
if chunk_usage:
|
||||
usage = chunk_usage # 获取token用量
|
||||
else:
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
delta_content = delta.get("content")
|
||||
if delta_content is None:
|
||||
delta_content = ""
|
||||
accumulated_content += delta_content
|
||||
|
||||
# 提取工具调用信息
|
||||
if "tool_calls" in delta:
|
||||
if tool_calls is None:
|
||||
tool_calls = delta["tool_calls"]
|
||||
else:
|
||||
# 合并工具调用信息
|
||||
tool_calls.extend(delta["tool_calls"])
|
||||
|
||||
# 检测流式输出文本是否结束
|
||||
finish_reason = chunk["choices"][0].get("finish_reason")
|
||||
if delta.get("reasoning_content", None):
|
||||
reasoning_content += delta["reasoning_content"]
|
||||
if finish_reason == "stop" or finish_reason == "tool_calls":
|
||||
chunk_usage = chunk.get("usage", None)
|
||||
if chunk_usage:
|
||||
usage = chunk_usage
|
||||
break
|
||||
# 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
|
||||
flag_delta_content_finished = True
|
||||
except Exception as e:
|
||||
logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
|
||||
except Exception as e:
|
||||
if isinstance(e, GeneratorExit):
|
||||
log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
|
||||
else:
|
||||
log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
|
||||
logger.warning(log_content)
|
||||
# 确保资源被正确清理
|
||||
try:
|
||||
await response.release()
|
||||
except Exception as cleanup_error:
|
||||
logger.error(f"清理资源时发生错误: {cleanup_error}")
|
||||
# 返回已经累积的内容
|
||||
content = accumulated_content
|
||||
if not content:
|
||||
content = accumulated_content
|
||||
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
||||
if think_match:
|
||||
reasoning_content = think_match.group(1).strip()
|
||||
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
||||
|
||||
# 构建消息对象
|
||||
message = {
|
||||
"content": content,
|
||||
"reasoning_content": reasoning_content,
|
||||
}
|
||||
|
||||
# 如果有工具调用,添加到消息中
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
result = {
|
||||
"choices": [{"message": message}],
|
||||
"usage": usage,
|
||||
}
|
||||
return result
|
||||
|
||||
async def _handle_error_response(
|
||||
self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]
|
||||
) -> Union[Dict[str, any]]:
|
||||
if response.status in policy["retry_codes"]:
|
||||
wait_time = policy["base_wait"] * (2**retry_count)
|
||||
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
|
||||
if response.status == 413:
|
||||
logger.warning("请求体过大,尝试压缩...")
|
||||
raise PayLoadTooLargeError("请求体过大")
|
||||
elif response.status in [500, 503]:
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||
)
|
||||
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
||||
else:
|
||||
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
|
||||
raise RuntimeError("请求限制(429)")
|
||||
elif response.status in policy["abort_codes"]:
|
||||
if response.status != 403:
|
||||
raise RequestAbortException("请求出现错误,中断处理", response)
|
||||
else:
|
||||
raise PermissionDeniedException("模型禁止访问")
|
||||
|
||||
async def _handle_exception(
|
||||
self, exception, retry_count: int, request_content: Dict[str, Any]
|
||||
) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
|
||||
policy = request_content["policy"]
|
||||
payload = request_content["payload"]
|
||||
wait_time = policy["base_wait"] * (2**retry_count)
|
||||
keep_request = False
|
||||
if retry_count < policy["max_retries"] - 1:
|
||||
keep_request = True
|
||||
if isinstance(exception, RequestAbortException):
|
||||
response = exception.response
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||
)
|
||||
# 尝试获取并记录服务器返回的详细错误信息
|
||||
try:
|
||||
error_json = await response.json()
|
||||
if error_json and isinstance(error_json, list) and len(error_json) > 0:
|
||||
# 处理多个错误的情况
|
||||
for error_item in error_json:
|
||||
if "error" in error_item and isinstance(error_item["error"], dict):
|
||||
error_obj: dict = error_item["error"]
|
||||
error_code = error_obj.get("code")
|
||||
error_message = error_obj.get("message")
|
||||
error_status = error_obj.get("status")
|
||||
logger.error(
|
||||
f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
|
||||
)
|
||||
elif isinstance(error_json, dict) and "error" in error_json:
|
||||
# 处理单个错误对象的情况
|
||||
error_obj = error_json.get("error", {})
|
||||
error_code = error_obj.get("code")
|
||||
error_message = error_obj.get("message")
|
||||
error_status = error_obj.get("status")
|
||||
logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}")
|
||||
else:
|
||||
# 记录原始错误响应内容
|
||||
logger.error(f"服务器错误响应: {error_json}")
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析服务器错误响应: {str(e)}")
|
||||
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
|
||||
|
||||
elif isinstance(exception, PermissionDeniedException):
|
||||
# 只针对硅基流动的V3和R1进行降级处理
|
||||
if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
|
||||
old_model_name = self.model_name
|
||||
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
|
||||
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
|
||||
|
||||
# 对全局配置进行更新
|
||||
if global_config.model.normal.get("name") == old_model_name:
|
||||
global_config.model.normal["name"] = self.model_name
|
||||
logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
|
||||
if global_config.model.reasoning.get("name") == old_model_name:
|
||||
global_config.model.reasoning["name"] = self.model_name
|
||||
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
|
||||
|
||||
if payload and "model" in payload:
|
||||
payload["model"] = self.model_name
|
||||
|
||||
await asyncio.sleep(wait_time)
|
||||
return payload, -1
|
||||
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}")
|
||||
|
||||
elif isinstance(exception, PayLoadTooLargeError):
|
||||
if keep_request:
|
||||
image_base64 = request_content["image_base64"]
|
||||
compressed_image_base64 = compress_base64_image_by_scale(image_base64)
|
||||
new_payload = await self._build_payload(
|
||||
request_content["prompt"], compressed_image_base64, request_content["image_format"]
|
||||
)
|
||||
return new_payload, 0
|
||||
else:
|
||||
return None, 0
|
||||
|
||||
elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError):
|
||||
if keep_request:
|
||||
logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
return None, 0
|
||||
else:
|
||||
logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
|
||||
raise RuntimeError(f"网络请求失败: {str(exception)}")
|
||||
|
||||
elif isinstance(exception, aiohttp.ClientResponseError):
|
||||
# 处理aiohttp抛出的,除了policy中的status的响应错误
|
||||
if keep_request:
|
||||
logger.error(
|
||||
f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}"
|
||||
)
|
||||
try:
|
||||
error_text = await exception.response.text()
|
||||
error_json = json.loads(error_text)
|
||||
if isinstance(error_json, list) and len(error_json) > 0:
|
||||
# 处理多个错误的情况
|
||||
for error_item in error_json:
|
||||
if "error" in error_item and isinstance(error_item["error"], dict):
|
||||
error_obj = error_item["error"]
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
|
||||
f"状态={error_obj.get('status')}, "
|
||||
f"消息={error_obj.get('message')}"
|
||||
)
|
||||
elif isinstance(error_json, dict) and "error" in error_json:
|
||||
error_obj = error_json.get("error", {})
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
|
||||
f"状态={error_obj.get('status')}, "
|
||||
f"消息={error_obj.get('message')}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
|
||||
except (json.JSONDecodeError, TypeError) as json_err:
|
||||
logger.warning(
|
||||
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
|
||||
)
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
|
||||
|
||||
await asyncio.sleep(wait_time)
|
||||
return None, 0
|
||||
else:
|
||||
logger.critical(
|
||||
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
|
||||
)
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}")
|
||||
raise RuntimeError(
|
||||
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||||
)
|
||||
|
||||
else:
|
||||
if keep_request:
|
||||
logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
return None, 0
|
||||
else:
|
||||
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload[:100]}")
|
||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
|
||||
|
||||
async def _transform_parameters(self, params: dict) -> dict:
|
||||
"""
|
||||
根据模型名称转换参数:
|
||||
- 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
|
||||
并将 'max_tokens' 重命名为 'max_completion_tokens'
|
||||
"""
|
||||
# 复制一份参数,避免直接修改原始数据
|
||||
new_params = dict(params)
|
||||
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
||||
# 删除 'temperature' 参数(如果存在)
|
||||
new_params.pop("temperature", None)
|
||||
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
||||
if "max_tokens" in new_params:
|
||||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||||
return new_params
|
||||
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||||
"""构建请求体"""
|
||||
# 复制一份参数,避免直接修改 self.params
|
||||
params_copy = await self._transform_parameters(self.params)
|
||||
if image_base64:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
**params_copy,
|
||||
}
|
||||
if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||||
payload["max_tokens"] = global_config.model.model_max_output_length
|
||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||
return payload
|
||||
|
||||
def _default_response_handler(
|
||||
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
|
||||
) -> Tuple:
|
||||
"""默认响应解析"""
|
||||
if "choices" in result and result["choices"]:
|
||||
message = result["choices"][0]["message"]
|
||||
content = message.get("content", "")
|
||||
content, reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
|
||||
if not reasoning_content:
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if not reasoning_content:
|
||||
reasoning_content = reasoning
|
||||
|
||||
# 提取工具调用信息
|
||||
tool_calls = message.get("tool_calls", None)
|
||||
|
||||
# 记录token使用情况
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", 0)
|
||||
self._record_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
user_id=user_id,
|
||||
request_type=request_type if request_type is not None else self.request_type,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
# 只有当tool_calls存在且不为空时才返回
|
||||
if tool_calls:
|
||||
logger.debug(f"检测到工具调用: {tool_calls}")
|
||||
return content, reasoning_content, tool_calls
|
||||
else:
|
||||
return content, reasoning_content
|
||||
|
||||
return "没有返回结果", ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning(content: str) -> Tuple[str, str]:
|
||||
"""CoT思维链提取"""
|
||||
match = re.search(r"(?:<think>)?(.*?)</think>", content, re.DOTALL)
|
||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||
if match:
|
||||
reasoning = match.group(1).strip()
|
||||
else:
|
||||
reasoning = ""
|
||||
return content, reasoning
|
||||
|
||||
async def _build_headers(self, no_key: bool = False) -> dict:
|
||||
"""构建请求头"""
|
||||
if no_key:
|
||||
return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
|
||||
else:
|
||||
return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
# 防止小朋友们截图自己的key
|
||||
|
||||
async def generate_response(self, prompt: str) -> Tuple:
|
||||
"""根据输入的提示生成模型的异步响应"""
|
||||
|
||||
response = await self._execute_request(endpoint="/chat/completions", prompt=prompt)
|
||||
# 根据返回值的长度决定怎么处理
|
||||
if len(response) == 3:
|
||||
content, reasoning_content, tool_calls = response
|
||||
return content, reasoning_content, self.model_name, tool_calls
|
||||
else:
|
||||
content, reasoning_content = response
|
||||
return content, reasoning_content, self.model_name
|
||||
|
||||
async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple:
|
||||
"""根据输入的提示和图片生成模型的异步响应"""
|
||||
|
||||
response = await self._execute_request(
|
||||
endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format
|
||||
)
|
||||
# 根据返回值的长度决定怎么处理
|
||||
if len(response) == 3:
|
||||
content, reasoning_content, tool_calls = response
|
||||
return content, reasoning_content, tool_calls
|
||||
else:
|
||||
content, reasoning_content = response
|
||||
return content, reasoning_content
|
||||
|
||||
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
# 构建请求体,不硬编码max_tokens
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
**self.params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
|
||||
# 原样返回响应,不做处理
|
||||
|
||||
return response
|
||||
|
||||
async def generate_response_tool_async(self, prompt: str, tools: list, **kwargs) -> tuple[str, str, list]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
# 构建请求体,不硬编码max_tokens
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
**self.params,
|
||||
**kwargs,
|
||||
"tools": tools,
|
||||
}
|
||||
|
||||
response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
|
||||
logger.debug(f"向模型 {self.model_name} 发送工具调用请求,包含 {len(tools)} 个工具,返回结果: {response}")
|
||||
# 检查响应是否包含工具调用
|
||||
if len(response) == 3:
|
||||
content, reasoning_content, tool_calls = response
|
||||
logger.debug(f"收到工具调用响应,包含 {len(tool_calls) if tool_calls else 0} 个工具调用")
|
||||
return content, reasoning_content, tool_calls
|
||||
else:
|
||||
content, reasoning_content = response
|
||||
logger.debug("收到普通响应,无工具调用")
|
||||
return content, reasoning_content, None
|
||||
|
||||
async def get_embedding(self, text: str) -> Union[list, None]:
|
||||
"""异步方法:获取文本的embedding向量
|
||||
|
||||
Args:
|
||||
text: 需要获取embedding的文本
|
||||
|
||||
Returns:
|
||||
list: embedding向量,如果失败则返回None
|
||||
"""
|
||||
|
||||
if len(text) < 1:
|
||||
logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
|
||||
return None
|
||||
|
||||
def embedding_handler(result):
|
||||
"""处理响应"""
|
||||
if "data" in result and len(result["data"]) > 0:
|
||||
# 提取 token 使用信息
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", 0)
|
||||
# 记录 token 使用情况
|
||||
self._record_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
user_id="system", # 可以根据需要修改 user_id
|
||||
# request_type="embedding", # 请求类型为 embedding
|
||||
request_type=self.request_type, # 请求类型为 text
|
||||
endpoint="/embeddings", # API 端点
|
||||
)
|
||||
return result["data"][0].get("embedding", None)
|
||||
return result["data"][0].get("embedding", None)
|
||||
return None
|
||||
|
||||
embedding = await self._execute_request(
|
||||
endpoint="/embeddings",
|
||||
prompt=text,
|
||||
payload={"model": self.model_name, "input": text, "encoding_format": "float"},
|
||||
retry_policy={"max_retries": 2, "base_wait": 6},
|
||||
response_handler=embedding_handler,
|
||||
)
|
||||
return embedding
|
||||
|
||||
|
||||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||||
"""压缩base64格式的图片到指定大小
|
||||
Args:
|
||||
base64_data: base64编码的图片数据
|
||||
target_size: 目标文件大小(字节),默认0.8MB
|
||||
Returns:
|
||||
str: 压缩后的base64图片数据
|
||||
"""
|
||||
try:
|
||||
# 将base64转换为字节数据
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# 如果已经小于目标大小,直接返回原图
|
||||
if len(image_data) <= 2 * 1024 * 1024:
|
||||
return base64_data
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 获取原始尺寸
|
||||
original_width, original_height = img.size
|
||||
|
||||
# 计算缩放比例
|
||||
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
|
||||
|
||||
# 计算新的尺寸
|
||||
new_width = int(original_width * scale)
|
||||
new_height = int(original_height * scale)
|
||||
|
||||
# 创建内存缓冲区
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
# 如果是GIF,处理所有帧
|
||||
if getattr(img, "is_animated", False):
|
||||
frames = []
|
||||
for frame_idx in range(img.n_frames):
|
||||
img.seek(frame_idx)
|
||||
new_frame = img.copy()
|
||||
new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
|
||||
frames.append(new_frame)
|
||||
|
||||
# 保存到缓冲区
|
||||
frames[0].save(
|
||||
output_buffer,
|
||||
format="GIF",
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
optimize=True,
|
||||
duration=img.info.get("duration", 100),
|
||||
loop=img.info.get("loop", 0),
|
||||
)
|
||||
else:
|
||||
# 处理静态图片
|
||||
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存到缓冲区,保持原始格式
|
||||
if img.format == "PNG" and img.mode in ("RGBA", "LA"):
|
||||
resized_img.save(output_buffer, format="PNG", optimize=True)
|
||||
else:
|
||||
resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True)
|
||||
|
||||
# 获取压缩后的数据并转换为base64
|
||||
compressed_data = output_buffer.getvalue()
|
||||
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
|
||||
logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB")
|
||||
|
||||
return base64.b64encode(compressed_data).decode("utf-8")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return base64_data
|
||||
@@ -11,7 +11,7 @@ from src.common.logger_manager import get_logger
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
||||
from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.person_info.relationship_manager import relationship_manager
|
||||
from src.chat.utils.info_catcher import info_catcher_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import random
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from ..message_receive.message import MessageThinking
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageThinking
|
||||
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
|
||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.person_info.person_info import person_info_manager, PersonInfoManager
|
||||
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||
from abc import ABC, abstractmethod
|
||||
import importlib
|
||||
from typing import Dict, Optional
|
||||
|
||||
@@ -1,639 +0,0 @@
|
||||
from src.common.logger_manager import get_logger
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import PersonInfo # 新增导入
|
||||
import copy
|
||||
import hashlib
|
||||
from typing import Any, Callable, Dict
|
||||
import datetime
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.individuality.individuality import individuality
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import json # 新增导入
|
||||
import re
|
||||
|
||||
|
||||
"""
|
||||
PersonInfoManager 类方法功能摘要:
|
||||
1. get_person_id - 根据平台和用户ID生成MD5哈希的唯一person_id
|
||||
2. create_person_info - 创建新个人信息文档(自动合并默认值)
|
||||
3. update_one_field - 更新单个字段值(若文档不存在则创建)
|
||||
4. del_one_document - 删除指定person_id的文档
|
||||
5. get_value - 获取单个字段值(返回实际值或默认值)
|
||||
6. get_values - 批量获取字段值(任一字段无效则返回空字典)
|
||||
7. del_all_undefined_field - 清理全集合中未定义的字段
|
||||
8. get_specific_value_list - 根据指定条件,返回person_id,value字典
|
||||
9. personal_habit_deduction - 定时推断个人习惯
|
||||
"""
|
||||
|
||||
|
||||
logger = get_logger("person_info")
|
||||
|
||||
person_info_default = {
|
||||
"person_id": None,
|
||||
"person_name": None, # 模型中已设为 null=True,此默认值OK
|
||||
"name_reason": None,
|
||||
"platform": "unknown", # 提供非None的默认值
|
||||
"user_id": "unknown", # 提供非None的默认值
|
||||
"nickname": "Unknown", # 提供非None的默认值
|
||||
"relationship_value": 0,
|
||||
"know_time": 0, # 修正拼写:konw_time -> know_time
|
||||
"msg_interval": 2000,
|
||||
"msg_interval_list": [], # 将作为 JSON 字符串存储在 Peewee 的 TextField
|
||||
"user_cardname": None, # 注意:此字段不在 PersonInfo Peewee 模型中
|
||||
"user_avatar": None, # 注意:此字段不在 PersonInfo Peewee 模型中
|
||||
}
|
||||
|
||||
|
||||
class PersonInfoManager:
|
||||
def __init__(self):
|
||||
self.person_name_list = {}
|
||||
# TODO: API-Adapter修改标记
|
||||
self.qv_name_llm = LLMRequest(
|
||||
model=global_config.model.normal,
|
||||
max_tokens=256,
|
||||
request_type="qv_name",
|
||||
)
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([PersonInfo], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
||||
|
||||
# 初始化时读取所有person_name
|
||||
try:
|
||||
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
PersonInfo.person_name.is_null(False)
|
||||
):
|
||||
if record.person_name:
|
||||
self.person_name_list[record.person_id] = record.person_name
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def get_person_id(platform: str, user_id: int):
|
||||
"""获取唯一id"""
|
||||
if "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
|
||||
components = [platform, str(user_id)]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def is_person_known(self, platform: str, user_id: int):
|
||||
"""判断是否认识某人"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_check_known_sync(p_id: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_id == p_id) is not None
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_check_known_sync, person_id)
|
||||
except Exception as e:
|
||||
logger.error(f"检查用户 {person_id} 是否已知时出错 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
def get_person_id_by_person_name(self, person_name: str):
|
||||
"""根据用户名获取用户ID"""
|
||||
document = db.person_info.find_one({"person_name": person_name})
|
||||
if document:
|
||||
return document["person_id"]
|
||||
else:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def create_person_info(person_id: str, data: dict = None):
|
||||
"""创建一个项"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,personid不存在")
|
||||
return
|
||||
|
||||
_person_info_default = copy.deepcopy(person_info_default)
|
||||
model_fields = PersonInfo._meta.fields.keys()
|
||||
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields and key not in final_data:
|
||||
final_data[key] = default_value
|
||||
|
||||
if "msg_interval_list" in final_data and isinstance(final_data["msg_interval_list"], list):
|
||||
final_data["msg_interval_list"] = json.dumps(final_data["msg_interval_list"])
|
||||
elif "msg_interval_list" not in final_data and "msg_interval_list" in model_fields:
|
||||
final_data["msg_interval_list"] = json.dumps([])
|
||||
|
||||
def _db_create_sync(p_data: dict):
|
||||
try:
|
||||
PersonInfo.create(**p_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_create_sync, final_data)
|
||||
|
||||
async def update_one_field(self, person_id: str, field_name: str, value, data: dict = None):
|
||||
"""更新某一个字段,会补全"""
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
if field_name in person_info_default:
|
||||
logger.debug(f"更新'{field_name}'跳过,字段存在于默认配置但不在 PersonInfo Peewee 模型中。")
|
||||
return
|
||||
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo Peewee 模型中定义的字段。")
|
||||
return
|
||||
|
||||
def _db_update_sync(p_id: str, f_name: str, val):
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
if record:
|
||||
if f_name == "msg_interval_list" and isinstance(val, list):
|
||||
setattr(record, f_name, json.dumps(val))
|
||||
else:
|
||||
setattr(record, f_name, val)
|
||||
record.save()
|
||||
return True, False
|
||||
return False, True
|
||||
|
||||
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, value)
|
||||
|
||||
if needs_creation:
|
||||
logger.debug(f"更新时 {person_id} 不存在,将新建。")
|
||||
creation_data = data if data is not None else {}
|
||||
creation_data[field_name] = value
|
||||
if "platform" not in creation_data or "user_id" not in creation_data:
|
||||
logger.warning(f"为 {person_id} 创建记录时,platform/user_id 可能缺失。")
|
||||
|
||||
await self.create_person_info(person_id, creation_data)
|
||||
|
||||
@staticmethod
|
||||
async def has_one_field(person_id: str, field_name: str):
|
||||
"""判断是否存在某一个字段"""
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo Peewee 模型中定义。")
|
||||
return False
|
||||
|
||||
def _db_has_field_sync(p_id: str, f_name: str):
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
if record:
|
||||
return True
|
||||
return False
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
try:
|
||||
parsed_json = json.loads(text)
|
||||
if isinstance(parsed_json, list):
|
||||
if parsed_json:
|
||||
parsed_json = parsed_json[0]
|
||||
else:
|
||||
parsed_json = None
|
||||
if isinstance(parsed_json, dict):
|
||||
return parsed_json
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"尝试直接解析JSON时发生意外错误: {e}")
|
||||
pass
|
||||
|
||||
try:
|
||||
json_pattern = r"\{[^{}]*\}"
|
||||
matches = re.findall(json_pattern, text)
|
||||
if matches:
|
||||
parsed_obj = json.loads(matches[0])
|
||||
if isinstance(parsed_obj, dict):
|
||||
return parsed_obj
|
||||
|
||||
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
|
||||
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
|
||||
|
||||
nickname_match = re.search(nickname_pattern, text)
|
||||
reason_match = re.search(reason_pattern, text)
|
||||
|
||||
if nickname_match:
|
||||
return {
|
||||
"nickname": nickname_match.group(1),
|
||||
"reason": reason_match.group(1) if reason_match else "未提供理由",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"后备JSON提取失败: {str(e)}")
|
||||
|
||||
logger.warning(f"无法从文本中提取有效的JSON字典: {text}")
|
||||
return {"nickname": "", "reason": ""}
|
||||
|
||||
async def qv_person_name(
|
||||
self, person_id: str, user_nickname: str, user_cardname: str, user_avatar: str, request: str = ""
|
||||
):
|
||||
"""给某个用户取名"""
|
||||
if not person_id:
|
||||
logger.debug("取名失败:person_id不能为空")
|
||||
return None
|
||||
|
||||
old_name = await self.get_value(person_id, "person_name")
|
||||
old_reason = await self.get_value(person_id, "name_reason")
|
||||
|
||||
max_retries = 5
|
||||
current_try = 0
|
||||
existing_names_str = ""
|
||||
current_name_set = set(self.person_name_list.values())
|
||||
|
||||
while current_try < max_retries:
|
||||
prompt_personality = individuality.get_prompt(x_person=2, level=1)
|
||||
bot_name = individuality.personality.bot_nickname
|
||||
|
||||
qv_name_prompt = f"你是{bot_name},{prompt_personality}"
|
||||
qv_name_prompt += f"现在你想给一个用户取一个昵称,用户是的qq昵称是{user_nickname},"
|
||||
qv_name_prompt += f"用户的qq群昵称名是{user_cardname},"
|
||||
if user_avatar:
|
||||
qv_name_prompt += f"用户的qq头像是{user_avatar},"
|
||||
if old_name:
|
||||
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason},"
|
||||
|
||||
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
|
||||
qv_name_prompt += (
|
||||
"\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改"
|
||||
)
|
||||
|
||||
if existing_names_str:
|
||||
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n"
|
||||
|
||||
if len(current_name_set) < 50 and current_name_set:
|
||||
qv_name_prompt += f"已知的其他昵称有: {', '.join(list(current_name_set)[:10])}等。\n"
|
||||
|
||||
qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:"
|
||||
qv_name_prompt += """{
|
||||
"nickname": "昵称",
|
||||
"reason": "理由"
|
||||
}"""
|
||||
response = await self.qv_name_llm.generate_response(qv_name_prompt)
|
||||
logger.trace(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
|
||||
result = self._extract_json_from_text(response[0])
|
||||
|
||||
if not result or not result.get("nickname"):
|
||||
logger.error("生成的昵称为空或结果格式不正确,重试中...")
|
||||
current_try += 1
|
||||
continue
|
||||
|
||||
generated_nickname = result["nickname"]
|
||||
|
||||
is_duplicate = False
|
||||
if generated_nickname in current_name_set:
|
||||
is_duplicate = True
|
||||
else:
|
||||
|
||||
def _db_check_name_exists_sync(name_to_check):
|
||||
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
|
||||
|
||||
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||
is_duplicate = True
|
||||
current_name_set.add(generated_nickname)
|
||||
|
||||
if not is_duplicate:
|
||||
await self.update_one_field(person_id, "person_name", generated_nickname)
|
||||
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
|
||||
|
||||
self.person_name_list[person_id] = generated_nickname
|
||||
return result
|
||||
else:
|
||||
if existing_names_str:
|
||||
existing_names_str += "、"
|
||||
existing_names_str += generated_nickname
|
||||
logger.debug(f"生成的昵称 {generated_nickname} 已存在,重试中...")
|
||||
current_try += 1
|
||||
|
||||
logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称 for {person_id}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def del_one_document(person_id: str):
|
||||
"""删除指定 person_id 的文档"""
|
||||
if not person_id:
|
||||
logger.debug("删除失败:person_id 不能为空")
|
||||
return
|
||||
|
||||
def _db_delete_sync(p_id: str):
|
||||
try:
|
||||
query = PersonInfo.delete().where(PersonInfo.person_id == p_id)
|
||||
deleted_count = query.execute()
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"删除 PersonInfo {p_id} 失败 (Peewee): {e}")
|
||||
return 0
|
||||
|
||||
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"删除成功:person_id={person_id} (Peewee)")
|
||||
else:
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
|
||||
|
||||
@staticmethod
|
||||
async def get_value(person_id: str, field_name: str):
|
||||
"""获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
if not person_id:
|
||||
logger.debug("get_value获取失败:person_id不能为空")
|
||||
return person_info_default.get(field_name)
|
||||
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
if field_name in person_info_default:
|
||||
logger.trace(f"字段'{field_name}'不在Peewee模型中,但存在于默认配置中。返回配置默认值。")
|
||||
return copy.deepcopy(person_info_default[field_name])
|
||||
logger.debug(f"get_value获取失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
|
||||
return None
|
||||
|
||||
def _db_get_value_sync(p_id: str, f_name: str):
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
if record:
|
||||
val = getattr(record, f_name)
|
||||
if f_name == "msg_interval_list" and isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析 {p_id} 的 msg_interval_list JSON: {val}")
|
||||
return copy.deepcopy(person_info_default.get(f_name, []))
|
||||
return val
|
||||
return None
|
||||
|
||||
value = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
|
||||
|
||||
if value is not None:
|
||||
return value
|
||||
else:
|
||||
default_value = copy.deepcopy(person_info_default.get(field_name))
|
||||
logger.trace(f"获取{person_id}的{field_name}失败或值为None,已返回默认值{default_value} (Peewee)")
|
||||
return default_value
|
||||
|
||||
@staticmethod
|
||||
async def get_values(person_id: str, field_names: list) -> dict:
|
||||
"""获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
if not person_id:
|
||||
logger.debug("get_values获取失败:person_id不能为空")
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
|
||||
def _db_get_record_sync(p_id: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
|
||||
record = await asyncio.to_thread(_db_get_record_sync, person_id)
|
||||
|
||||
for field_name in field_names:
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
if field_name in person_info_default:
|
||||
result[field_name] = copy.deepcopy(person_info_default[field_name])
|
||||
logger.trace(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。")
|
||||
else:
|
||||
logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。")
|
||||
result[field_name] = None
|
||||
continue
|
||||
|
||||
if record:
|
||||
value = getattr(record, field_name)
|
||||
if field_name == "msg_interval_list" and isinstance(value, str):
|
||||
try:
|
||||
result[field_name] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析 {person_id} 的 msg_interval_list JSON: {value}")
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name, []))
|
||||
elif value is not None:
|
||||
result[field_name] = value
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def del_all_undefined_field():
|
||||
"""删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
|
||||
logger.info(
|
||||
"del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。"
|
||||
)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def get_specific_value_list(
|
||||
field_name: str,
|
||||
way: Callable[[Any], bool],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取满足条件的字段值字典
|
||||
"""
|
||||
if field_name not in PersonInfo._meta.fields:
|
||||
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo Peewee 模型中定义")
|
||||
return {}
|
||||
|
||||
def _db_get_specific_sync(f_name: str):
|
||||
found_results = {}
|
||||
try:
|
||||
for record in PersonInfo.select(PersonInfo.person_id, getattr(PersonInfo, f_name)):
|
||||
value = getattr(record, f_name)
|
||||
if f_name == "msg_interval_list" and isinstance(value, str):
|
||||
try:
|
||||
processed_value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"跳过记录 {record.person_id},无法解析 msg_interval_list: {value}")
|
||||
continue
|
||||
else:
|
||||
processed_value = value
|
||||
|
||||
if way(processed_value):
|
||||
found_results[record.person_id] = processed_value
|
||||
except Exception as e_query:
|
||||
logger.error(f"数据库查询失败 (Peewee specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
|
||||
return found_results
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_get_specific_sync, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
async def personal_habit_deduction(self):
|
||||
"""启动个人信息推断,每天根据一定条件推断一次"""
|
||||
try:
|
||||
while 1:
|
||||
await asyncio.sleep(600)
|
||||
current_time_dt = datetime.datetime.now()
|
||||
logger.info(f"个人信息推断启动: {current_time_dt.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
msg_interval_map_generated = False
|
||||
msg_interval_lists_map = await self.get_specific_value_list(
|
||||
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
|
||||
)
|
||||
|
||||
for person_id, actual_msg_interval_list in msg_interval_lists_map.items():
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
time_interval = []
|
||||
for t1, t2 in zip(actual_msg_interval_list, actual_msg_interval_list[1:]):
|
||||
delta = t2 - t1
|
||||
if delta > 0:
|
||||
time_interval.append(delta)
|
||||
|
||||
time_interval = [t for t in time_interval if 200 <= t <= 8000]
|
||||
|
||||
if len(time_interval) >= 30 + 10:
|
||||
time_interval.sort()
|
||||
msg_interval_map_generated = True
|
||||
log_dir = Path("logs/person_info")
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
plt.figure(figsize=(10, 6))
|
||||
time_series_original = pd.Series(time_interval)
|
||||
plt.hist(
|
||||
time_series_original,
|
||||
bins=50,
|
||||
density=True,
|
||||
alpha=0.4,
|
||||
color="pink",
|
||||
label="Histogram (Original Filtered)",
|
||||
)
|
||||
time_series_original.plot(
|
||||
kind="kde", color="mediumpurple", linewidth=1, label="Density (Original Filtered)"
|
||||
)
|
||||
plt.grid(True, alpha=0.2)
|
||||
plt.xlim(0, 8000)
|
||||
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
|
||||
plt.xlabel("Interval (ms)")
|
||||
plt.ylabel("Density")
|
||||
plt.legend(framealpha=0.9, facecolor="white")
|
||||
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
|
||||
plt.savefig(img_path)
|
||||
plt.close()
|
||||
|
||||
trimmed_interval = time_interval[5:-5]
|
||||
if trimmed_interval:
|
||||
msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
|
||||
await self.update_one_field(person_id, "msg_interval", msg_interval_val)
|
||||
logger.trace(
|
||||
f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
|
||||
)
|
||||
else:
|
||||
logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval")
|
||||
else:
|
||||
logger.trace(
|
||||
f"用户{person_id}有效消息间隔数量 ({len(time_interval)}) 不足进行推断 (需要至少 {30 + 10} 条)"
|
||||
)
|
||||
except Exception as e_inner:
|
||||
logger.trace(f"用户{person_id}消息间隔计算失败: {type(e_inner).__name__}: {str(e_inner)}")
|
||||
continue
|
||||
|
||||
if msg_interval_map_generated:
|
||||
logger.trace("已保存分布图到: logs/person_info")
|
||||
|
||||
current_time_dt_end = datetime.datetime.now()
|
||||
logger.trace(f"个人信息推断结束: {current_time_dt_end.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
await asyncio.sleep(86400)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"个人信息推断运行时出错: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
|
||||
async def get_or_create_person(
|
||||
self, platform: str, user_id: int, nickname: str = None, user_cardname: str = None, user_avatar: str = None
|
||||
) -> str:
|
||||
"""
|
||||
根据 platform 和 user_id 获取 person_id。
|
||||
如果对应的用户不存在,则使用提供的可选信息创建新用户。
|
||||
"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_check_exists_sync(p_id: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
|
||||
record = await asyncio.to_thread(_db_check_exists_sync, person_id)
|
||||
|
||||
if record is None:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
|
||||
initial_data = {
|
||||
"platform": platform,
|
||||
"user_id": str(user_id),
|
||||
"nickname": nickname,
|
||||
"know_time": int(datetime.datetime.now().timestamp()), # 修正拼写:konw_time -> know_time
|
||||
}
|
||||
model_fields = PersonInfo._meta.fields.keys()
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
await self.create_person_info(person_id, data=filtered_initial_data)
|
||||
logger.debug(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
|
||||
|
||||
return person_id
|
||||
|
||||
async def get_person_info_by_name(self, person_name: str) -> dict | None:
|
||||
"""根据 person_name 查找用户并返回基本信息 (如果找到)"""
|
||||
if not person_name:
|
||||
logger.debug("get_person_info_by_name 获取失败:person_name 不能为空")
|
||||
return None
|
||||
|
||||
found_person_id = None
|
||||
for pid, name_in_cache in self.person_name_list.items():
|
||||
if name_in_cache == person_name:
|
||||
found_person_id = pid
|
||||
break
|
||||
|
||||
if not found_person_id:
|
||||
|
||||
def _db_find_by_name_sync(p_name_to_find: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
|
||||
|
||||
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
|
||||
if record:
|
||||
found_person_id = record.person_id
|
||||
if (
|
||||
found_person_id not in self.person_name_list
|
||||
or self.person_name_list[found_person_id] != person_name
|
||||
):
|
||||
self.person_name_list[found_person_id] = person_name
|
||||
else:
|
||||
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
|
||||
return None
|
||||
|
||||
if found_person_id:
|
||||
required_fields = [
|
||||
"person_id",
|
||||
"platform",
|
||||
"user_id",
|
||||
"nickname",
|
||||
"user_cardname",
|
||||
"user_avatar",
|
||||
"person_name",
|
||||
"name_reason",
|
||||
]
|
||||
valid_fields_to_get = [
|
||||
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
|
||||
]
|
||||
|
||||
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
||||
|
||||
if person_data:
|
||||
final_result = {key: person_data.get(key) for key in required_fields}
|
||||
return final_result
|
||||
else:
|
||||
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
|
||||
return None
|
||||
|
||||
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
|
||||
return None
|
||||
|
||||
|
||||
person_info_manager = PersonInfoManager()
|
||||
@@ -1,359 +0,0 @@
|
||||
from src.common.logger_manager import get_logger
|
||||
from ..message_receive.chat_stream import ChatStream
|
||||
import math
|
||||
from bson.decimal128 import Decimal128
|
||||
from .person_info import person_info_manager
|
||||
import time
|
||||
import random
|
||||
from maim_message import UserInfo
|
||||
|
||||
from ...manager.mood_manager import mood_manager
|
||||
|
||||
# import re
|
||||
# import traceback
|
||||
|
||||
|
||||
logger = get_logger("relation")
|
||||
|
||||
|
||||
class RelationshipManager:
|
||||
def __init__(self):
|
||||
self.positive_feedback_value = 0 # 正反馈系统
|
||||
self.gain_coefficient = [1.0, 1.0, 1.1, 1.2, 1.4, 1.7, 1.9, 2.0]
|
||||
self._mood_manager = None
|
||||
|
||||
@property
|
||||
def mood_manager(self):
|
||||
if self._mood_manager is None:
|
||||
self._mood_manager = mood_manager
|
||||
return self._mood_manager
|
||||
|
||||
def positive_feedback_sys(self, label: str, stance: str):
|
||||
"""正反馈系统,通过正反馈系数增益情绪变化,根据情绪再影响关系变更"""
|
||||
|
||||
positive_list = [
|
||||
"开心",
|
||||
"惊讶",
|
||||
"害羞",
|
||||
]
|
||||
|
||||
negative_list = [
|
||||
"愤怒",
|
||||
"悲伤",
|
||||
"恐惧",
|
||||
"厌恶",
|
||||
]
|
||||
|
||||
if label in positive_list:
|
||||
if 7 > self.positive_feedback_value >= 0:
|
||||
self.positive_feedback_value += 1
|
||||
elif self.positive_feedback_value < 0:
|
||||
self.positive_feedback_value = 0
|
||||
elif label in negative_list:
|
||||
if -7 < self.positive_feedback_value <= 0:
|
||||
self.positive_feedback_value -= 1
|
||||
elif self.positive_feedback_value > 0:
|
||||
self.positive_feedback_value = 0
|
||||
|
||||
if abs(self.positive_feedback_value) > 1:
|
||||
logger.info(f"触发mood变更增益,当前增益系数:{self.gain_coefficient[abs(self.positive_feedback_value)]}")
|
||||
|
||||
def mood_feedback(self, value):
|
||||
"""情绪反馈"""
|
||||
mood_manager = self.mood_manager
|
||||
mood_gain = mood_manager.current_mood.valence**2 * math.copysign(1, value * mood_manager.current_mood.valence)
|
||||
value += value * mood_gain
|
||||
logger.info(f"当前relationship增益系数:{mood_gain:.3f}")
|
||||
return value
|
||||
|
||||
def feedback_to_mood(self, mood_value):
|
||||
"""对情绪的反馈"""
|
||||
coefficient = self.gain_coefficient[abs(self.positive_feedback_value)]
|
||||
if mood_value > 0 and self.positive_feedback_value > 0 or mood_value < 0 and self.positive_feedback_value < 0:
|
||||
return mood_value * coefficient
|
||||
else:
|
||||
return mood_value / coefficient
|
||||
|
||||
@staticmethod
|
||||
async def is_known_some_one(platform, user_id):
|
||||
"""判断是否认识某人"""
|
||||
is_known = await person_info_manager.is_person_known(platform, user_id)
|
||||
return is_known
|
||||
|
||||
@staticmethod
|
||||
async def is_qved_name(platform, user_id):
|
||||
"""判断是否认识某人"""
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
is_qved = await person_info_manager.has_one_field(person_id, "person_name")
|
||||
old_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
# print(f"old_name: {old_name}")
|
||||
# print(f"is_qved: {is_qved}")
|
||||
if is_qved and old_name is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def first_knowing_some_one(
|
||||
platform: str, user_id: str, user_nickname: str, user_cardname: str, user_avatar: str
|
||||
):
|
||||
"""判断是否认识某人"""
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
data = {
|
||||
"platform": platform,
|
||||
"user_id": user_id,
|
||||
"nickname": user_nickname,
|
||||
"konw_time": int(time.time()),
|
||||
}
|
||||
await person_info_manager.update_one_field(
|
||||
person_id=person_id, field_name="nickname", value=user_nickname, data=data
|
||||
)
|
||||
await person_info_manager.qv_person_name(
|
||||
person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
||||
)
|
||||
|
||||
async def calculate_update_relationship_value(self, user_info: UserInfo, platform: str, label: str, stance: str):
|
||||
"""计算并变更关系值
|
||||
新的关系值变更计算方式:
|
||||
将关系值限定在-1000到1000
|
||||
对于关系值的变更,期望:
|
||||
1.向两端逼近时会逐渐减缓
|
||||
2.关系越差,改善越难,关系越好,恶化越容易
|
||||
3.人维护关系的精力往往有限,所以当高关系值用户越多,对于中高关系值用户增长越慢
|
||||
4.连续正面或负面情感会正反馈
|
||||
|
||||
返回:
|
||||
用户昵称,变更值,变更后关系等级
|
||||
|
||||
"""
|
||||
stancedict = {
|
||||
"支持": 0,
|
||||
"中立": 1,
|
||||
"反对": 2,
|
||||
}
|
||||
|
||||
valuedict = {
|
||||
"开心": 1.5,
|
||||
"愤怒": -2.0,
|
||||
"悲伤": -0.5,
|
||||
"惊讶": 0.6,
|
||||
"害羞": 2.0,
|
||||
"平静": 0.3,
|
||||
"恐惧": -1.5,
|
||||
"厌恶": -1.0,
|
||||
"困惑": 0.5,
|
||||
}
|
||||
|
||||
person_id = person_info_manager.get_person_id(platform, user_info.user_id)
|
||||
data = {
|
||||
"platform": platform,
|
||||
"user_id": user_info.user_id,
|
||||
"nickname": user_info.user_nickname,
|
||||
"konw_time": int(time.time()),
|
||||
}
|
||||
old_value = await person_info_manager.get_value(person_id, "relationship_value")
|
||||
old_value = self.ensure_float(old_value, person_id)
|
||||
|
||||
if old_value > 1000:
|
||||
old_value = 1000
|
||||
elif old_value < -1000:
|
||||
old_value = -1000
|
||||
|
||||
value = valuedict[label]
|
||||
if old_value >= 0:
|
||||
if valuedict[label] >= 0 and stancedict[stance] != 2:
|
||||
value = value * math.cos(math.pi * old_value / 2000)
|
||||
if old_value > 500:
|
||||
rdict = await person_info_manager.get_specific_value_list("relationship_value", lambda x: x > 700)
|
||||
high_value_count = len(rdict)
|
||||
if old_value > 700:
|
||||
value *= 3 / (high_value_count + 2) # 排除自己
|
||||
else:
|
||||
value *= 3 / (high_value_count + 3)
|
||||
elif valuedict[label] < 0 and stancedict[stance] != 0:
|
||||
value = value * math.exp(old_value / 2000)
|
||||
else:
|
||||
value = 0
|
||||
elif old_value < 0:
|
||||
if valuedict[label] >= 0 and stancedict[stance] != 2:
|
||||
value = value * math.exp(old_value / 2000)
|
||||
elif valuedict[label] < 0 and stancedict[stance] != 0:
|
||||
value = value * math.cos(math.pi * old_value / 2000)
|
||||
else:
|
||||
value = 0
|
||||
|
||||
self.positive_feedback_sys(label, stance)
|
||||
value = self.mood_feedback(value)
|
||||
|
||||
level_num = self.calculate_level_num(old_value + value)
|
||||
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
|
||||
logger.info(
|
||||
f"用户: {user_info.user_nickname}"
|
||||
f"当前关系: {relationship_level[level_num]}, "
|
||||
f"关系值: {old_value:.2f}, "
|
||||
f"当前立场情感: {stance}-{label}, "
|
||||
f"变更: {value:+.5f}"
|
||||
)
|
||||
|
||||
await person_info_manager.update_one_field(person_id, "relationship_value", old_value + value, data)
|
||||
|
||||
async def calculate_update_relationship_value_with_reason(
|
||||
self, chat_stream: ChatStream, label: str, stance: str, reason: str
|
||||
) -> tuple:
|
||||
"""计算并变更关系值
|
||||
新的关系值变更计算方式:
|
||||
将关系值限定在-1000到1000
|
||||
对于关系值的变更,期望:
|
||||
1.向两端逼近时会逐渐减缓
|
||||
2.关系越差,改善越难,关系越好,恶化越容易
|
||||
3.人维护关系的精力往往有限,所以当高关系值用户越多,对于中高关系值用户增长越慢
|
||||
4.连续正面或负面情感会正反馈
|
||||
|
||||
返回:
|
||||
用户昵称,变更值,变更后关系等级
|
||||
|
||||
"""
|
||||
stancedict = {
|
||||
"支持": 0,
|
||||
"中立": 1,
|
||||
"反对": 2,
|
||||
}
|
||||
|
||||
valuedict = {
|
||||
"开心": 1.5,
|
||||
"愤怒": -2.0,
|
||||
"悲伤": -0.5,
|
||||
"惊讶": 0.6,
|
||||
"害羞": 2.0,
|
||||
"平静": 0.3,
|
||||
"恐惧": -1.5,
|
||||
"厌恶": -1.0,
|
||||
"困惑": 0.5,
|
||||
}
|
||||
|
||||
person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id)
|
||||
data = {
|
||||
"platform": chat_stream.user_info.platform,
|
||||
"user_id": chat_stream.user_info.user_id,
|
||||
"nickname": chat_stream.user_info.user_nickname,
|
||||
"konw_time": int(time.time()),
|
||||
}
|
||||
old_value = await person_info_manager.get_value(person_id, "relationship_value")
|
||||
old_value = self.ensure_float(old_value, person_id)
|
||||
|
||||
if old_value > 1000:
|
||||
old_value = 1000
|
||||
elif old_value < -1000:
|
||||
old_value = -1000
|
||||
|
||||
value = valuedict[label]
|
||||
if old_value >= 0:
|
||||
if valuedict[label] >= 0 and stancedict[stance] != 2:
|
||||
value = value * math.cos(math.pi * old_value / 2000)
|
||||
if old_value > 500:
|
||||
rdict = await person_info_manager.get_specific_value_list("relationship_value", lambda x: x > 700)
|
||||
high_value_count = len(rdict)
|
||||
if old_value > 700:
|
||||
value *= 3 / (high_value_count + 2) # 排除自己
|
||||
else:
|
||||
value *= 3 / (high_value_count + 3)
|
||||
elif valuedict[label] < 0 and stancedict[stance] != 0:
|
||||
value = value * math.exp(old_value / 2000)
|
||||
else:
|
||||
value = 0
|
||||
elif old_value < 0:
|
||||
if valuedict[label] >= 0 and stancedict[stance] != 2:
|
||||
value = value * math.exp(old_value / 2000)
|
||||
elif valuedict[label] < 0 and stancedict[stance] != 0:
|
||||
value = value * math.cos(math.pi * old_value / 2000)
|
||||
else:
|
||||
value = 0
|
||||
|
||||
self.positive_feedback_sys(label, stance)
|
||||
value = self.mood_feedback(value)
|
||||
|
||||
level_num = self.calculate_level_num(old_value + value)
|
||||
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
|
||||
logger.info(
|
||||
f"用户: {chat_stream.user_info.user_nickname}"
|
||||
f"当前关系: {relationship_level[level_num]}, "
|
||||
f"关系值: {old_value:.2f}, "
|
||||
f"当前立场情感: {stance}-{label}, "
|
||||
f"变更: {value:+.5f}"
|
||||
)
|
||||
|
||||
await person_info_manager.update_one_field(person_id, "relationship_value", old_value + value, data)
|
||||
|
||||
return chat_stream.user_info.user_nickname, value, relationship_level[level_num]
|
||||
|
||||
async def build_relationship_info(self, person, is_id: bool = False) -> str:
|
||||
if is_id:
|
||||
person_id = person
|
||||
else:
|
||||
# print(f"person: {person}")
|
||||
person_id = person_info_manager.get_person_id(person[0], person[1])
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
# print(f"person_name: {person_name}")
|
||||
relationship_value = await person_info_manager.get_value(person_id, "relationship_value")
|
||||
level_num = self.calculate_level_num(relationship_value)
|
||||
|
||||
if level_num == 0 or level_num == 5:
|
||||
relationship_level = ["厌恶", "冷漠以对", "认识", "友好对待", "喜欢", "暧昧"]
|
||||
relation_prompt2_list = [
|
||||
"忽视的回应",
|
||||
"冷淡回复",
|
||||
"保持理性",
|
||||
"愿意回复",
|
||||
"积极回复",
|
||||
"友善和包容的回复",
|
||||
]
|
||||
return f"你{relationship_level[level_num]}{person_name},打算{relation_prompt2_list[level_num]}。\n"
|
||||
elif level_num == 2:
|
||||
return ""
|
||||
else:
|
||||
if random.random() < 0.6:
|
||||
relationship_level = ["厌恶", "冷漠以对", "认识", "友好对待", "喜欢", "暧昧"]
|
||||
relation_prompt2_list = [
|
||||
"忽视的回应",
|
||||
"冷淡回复",
|
||||
"保持理性",
|
||||
"愿意回复",
|
||||
"积极回复",
|
||||
"友善和包容的回复",
|
||||
]
|
||||
return f"你{relationship_level[level_num]}{person_name},打算{relation_prompt2_list[level_num]}。\n"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def calculate_level_num(relationship_value) -> int:
|
||||
"""关系等级计算"""
|
||||
if -1000 <= relationship_value < -227:
|
||||
level_num = 0
|
||||
elif -227 <= relationship_value < -73:
|
||||
level_num = 1
|
||||
elif -73 <= relationship_value < 227:
|
||||
level_num = 2
|
||||
elif 227 <= relationship_value < 587:
|
||||
level_num = 3
|
||||
elif 587 <= relationship_value < 900:
|
||||
level_num = 4
|
||||
elif 900 <= relationship_value <= 1000:
|
||||
level_num = 5
|
||||
else:
|
||||
level_num = 5 if relationship_value > 1000 else 0
|
||||
return level_num
|
||||
|
||||
@staticmethod
|
||||
def ensure_float(value, person_id):
|
||||
"""确保返回浮点数,转换失败返回0.0"""
|
||||
if isinstance(value, float):
|
||||
return value
|
||||
try:
|
||||
return float(value.to_decimal() if isinstance(value, Decimal128) else value)
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}),已重置为0")
|
||||
return 0.0
|
||||
|
||||
|
||||
relationship_manager = RelationshipManager()
|
||||
@@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间
|
||||
import random
|
||||
import re
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.chat.person_info.person_info import person_info_manager
|
||||
from src.person_info.person_info import person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from maim_message import UserInfo
|
||||
from src.common.logger import get_module_logger
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from ..message_receive.message import MessageRecv
|
||||
from ..models.utils_model import LLMRequest
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
from ...config.config import global_config
|
||||
from ...common.message_repository import find_messages, count_messages
|
||||
|
||||
@@ -8,10 +8,10 @@ import io
|
||||
import numpy as np
|
||||
|
||||
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import Images, ImageDescriptions
|
||||
from ...config.config import global_config
|
||||
from ..models.utils_model import LLMRequest
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
Reference in New Issue
Block a user