Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
@@ -4,7 +4,7 @@ import traceback
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_module_logger
|
||||
from ..message.message_base import UserInfo
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||
from .message_storage import MongoDBMessageStorage
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from ..chat.chat_stream import ChatStream
|
||||
from ..message.message_base import UserInfo, Seg
|
||||
from ..chat.message import Message
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from src.plugins.chat.message import MessageSending
|
||||
from ..message.api import global_api
|
||||
from ..storage.storage import MessageStorage
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from ..chat.message import Message
|
||||
|
||||
logger = get_module_logger("knowledge_fetcher")
|
||||
|
||||
@@ -3,7 +3,7 @@ import datetime
|
||||
from typing import Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from ..message.message_base import UserInfo
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
@@ -2,7 +2,7 @@ from src.common.logger import get_module_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.individuality.individuality import Individuality
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ..moods.moods import MoodManager # 导入情绪管理器
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from .message import MessageRecv
|
||||
from ..PFC.pfc_manager import PFCManager
|
||||
from .chat_stream import chat_manager
|
||||
|
||||
@@ -10,7 +10,7 @@ from PIL import Image
|
||||
import io
|
||||
|
||||
from ...common.database import db
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from ..chat.utils import get_embedding
|
||||
from ..chat.utils_image import ImageManager, image_path_to_base64
|
||||
from ..models.utils_model import LLMRequest
|
||||
|
||||
@@ -3,13 +3,13 @@ from src.common.logger import get_module_logger
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from .message import MessageRecv
|
||||
from ..message.message_base import BaseMessageInfo, GroupInfo
|
||||
from ..message.message_base import BaseMessageInfo, GroupInfo, Seg
|
||||
import hashlib
|
||||
from typing import Dict
|
||||
from collections import OrderedDict
|
||||
import random
|
||||
import time
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
|
||||
logger = get_module_logger("message_buffer")
|
||||
|
||||
@@ -130,22 +130,40 @@ class MessageBuffer:
|
||||
keep_msgs = OrderedDict()
|
||||
combined_text = []
|
||||
found = False
|
||||
type = "text"
|
||||
type = "seglist"
|
||||
is_update = True
|
||||
for msg_id, msg in self.buffer_pool[person_id_].items():
|
||||
if msg_id == message.message_info.message_id:
|
||||
found = True
|
||||
type = msg.message.message_segment.type
|
||||
if msg.message.message_segment.type != "seglist":
|
||||
type = msg.message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(msg.message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
|
||||
and len(msg.message.message_segment.data) == 1
|
||||
):
|
||||
type = msg.message.message_segment.data[0].type
|
||||
combined_text.append(msg.message.processed_plain_text)
|
||||
continue
|
||||
if found:
|
||||
keep_msgs[msg_id] = msg
|
||||
elif msg.result == "F":
|
||||
# 收集F消息的文本内容
|
||||
F_type = "seglist"
|
||||
if msg.message.message_segment.type != "seglist":
|
||||
F_type = msg.message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(msg.message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
|
||||
and len(msg.message.message_segment.data) == 1
|
||||
):
|
||||
F_type = msg.message.message_segment.data[0].type
|
||||
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
|
||||
if msg.message.message_segment.type == "text":
|
||||
if F_type == "text":
|
||||
combined_text.append(msg.message.processed_plain_text)
|
||||
elif msg.message.message_segment.type != "text":
|
||||
elif F_type != "text":
|
||||
is_update = False
|
||||
elif msg.result == "U":
|
||||
logger.debug(f"异常未处理信息id: {msg.message.message_info.message_id}")
|
||||
|
||||
@@ -8,7 +8,7 @@ from ..message.api import global_api
|
||||
from .message import MessageSending, MessageThinking, MessageSet
|
||||
|
||||
from ..storage.storage import MessageStorage
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from .utils import truncate_message, calculate_typing_time, count_messages_between
|
||||
|
||||
from src.common.logger import LogConfig, SENDER_STYLE_CONFIG
|
||||
|
||||
@@ -10,7 +10,7 @@ from src.common.logger import get_module_logger
|
||||
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from .message import MessageRecv, Message
|
||||
from ..message.message_base import UserInfo
|
||||
from .chat_stream import ChatStream
|
||||
@@ -338,11 +338,21 @@ def random_remove_punctuation(text: str) -> str:
|
||||
|
||||
|
||||
def process_llm_response(text: str) -> List[str]:
|
||||
# 先保护颜文字
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text)
|
||||
logger.debug(f"保护颜文字后的文本: {protected_text}")
|
||||
# 提取被 () 或 [] 包裹的内容
|
||||
pattern = re.compile(r"[(\[].*?[\)\]")
|
||||
_extracted_contents = pattern.findall(text)
|
||||
pattern = re.compile(r"[\(\[\(].*?[\)\]\)]")
|
||||
# _extracted_contents = pattern.findall(text)
|
||||
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
|
||||
|
||||
# 去除 () 和 [] 及其包裹的内容
|
||||
cleaned_text = pattern.sub("", text)
|
||||
# cleaned_text = pattern.sub("", text)
|
||||
cleaned_text = pattern.sub("", protected_text)
|
||||
|
||||
if cleaned_text == "":
|
||||
return ["呃呃"]
|
||||
|
||||
logger.debug(f"{text}去除括号处理后的文本: {cleaned_text}")
|
||||
|
||||
# 对清理后的文本进行进一步处理
|
||||
@@ -382,6 +392,8 @@ def process_llm_response(text: str) -> List[str]:
|
||||
return [f"{global_config.BOT_NICKNAME}不知道哦"]
|
||||
|
||||
# sentences.extend(extracted_contents)
|
||||
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
|
||||
sentences = recover_kaomoji(sentences, kaomoji_mapping)
|
||||
|
||||
return sentences
|
||||
|
||||
@@ -508,8 +520,7 @@ def protect_kaomoji(sentence):
|
||||
r"]"
|
||||
r")"
|
||||
r"|"
|
||||
r"([▼▽・ᴥω・﹏^><≧≦ ̄`´∀ヮДд︿﹀へ。゚╥╯╰︶︹•⁄]{2,15"
|
||||
r"}"
|
||||
r"([▼▽・ᴥω・﹏^><≧≦ ̄`´∀ヮДд︿﹀へ。゚╥╯╰︶︹•⁄]{2,15})"
|
||||
)
|
||||
|
||||
kaomoji_matches = kaomoji_pattern.findall(sentence)
|
||||
@@ -706,12 +717,30 @@ def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
||||
# normal模式: 直接转换所有时间戳
|
||||
if mode == "normal":
|
||||
result_text = text
|
||||
|
||||
# 将时间戳转换为可读格式并记录相同格式的时间戳
|
||||
timestamp_readable_map = {}
|
||||
readable_time_used = set()
|
||||
|
||||
for match in matches:
|
||||
timestamp = float(match.group(1))
|
||||
readable_time = translate_timestamp_to_human_readable(timestamp, "normal")
|
||||
# 由于替换会改变文本长度,需要使用正则替换而非直接替换
|
||||
pattern_instance = re.escape(match.group(0))
|
||||
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
|
||||
timestamp_readable_map[match.group(0)] = (timestamp, readable_time)
|
||||
|
||||
# 按时间戳排序
|
||||
sorted_timestamps = sorted(timestamp_readable_map.items(), key=lambda x: x[1][0])
|
||||
|
||||
# 执行替换,相同格式的只保留最早的
|
||||
for ts_str, (_, readable) in sorted_timestamps:
|
||||
pattern_instance = re.escape(ts_str)
|
||||
if readable in readable_time_used:
|
||||
# 如果这个可读时间已经使用过,替换为空字符串
|
||||
result_text = re.sub(pattern_instance, "", result_text, count=1)
|
||||
else:
|
||||
# 否则替换为可读时间并记录
|
||||
result_text = re.sub(pattern_instance, readable, result_text, count=1)
|
||||
readable_time_used.add(readable)
|
||||
|
||||
return result_text
|
||||
else:
|
||||
# lite模式: 按5秒间隔划分并选择性转换
|
||||
@@ -770,15 +799,30 @@ def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
||||
pattern_instance = re.escape(match.group(0))
|
||||
result_text = re.sub(pattern_instance, "", result_text, count=1)
|
||||
|
||||
# 按照时间戳原始顺序排序,避免替换时位置错误
|
||||
to_convert.sort(key=lambda x: x[1].start())
|
||||
# 按照时间戳升序排序
|
||||
to_convert.sort(key=lambda x: x[0])
|
||||
|
||||
# 将时间戳转换为可读时间并记录哪些可读时间已经使用过
|
||||
converted_timestamps = []
|
||||
readable_time_used = set()
|
||||
|
||||
# 执行替换
|
||||
# 由于替换会改变文本长度,从后向前替换
|
||||
to_convert.reverse()
|
||||
for ts, match in to_convert:
|
||||
readable_time = translate_timestamp_to_human_readable(ts, "relative")
|
||||
converted_timestamps.append((ts, match, readable_time))
|
||||
|
||||
# 按照时间戳原始顺序排序,避免替换时位置错误
|
||||
converted_timestamps.sort(key=lambda x: x[1].start())
|
||||
|
||||
# 从后向前替换,避免位置改变
|
||||
converted_timestamps.reverse()
|
||||
for match, readable_time in converted_timestamps:
|
||||
pattern_instance = re.escape(match.group(0))
|
||||
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
|
||||
if readable_time in readable_time_used:
|
||||
# 如果相同格式的时间已存在,替换为空字符串
|
||||
result_text = re.sub(pattern_instance, "", result_text, count=1)
|
||||
else:
|
||||
# 否则替换为可读时间并记录
|
||||
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
|
||||
readable_time_used.add(readable_time)
|
||||
|
||||
return result_text
|
||||
|
||||
@@ -8,7 +8,7 @@ import io
|
||||
|
||||
|
||||
from ...common.database import db
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from ..models.utils_model import LLMRequest
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.chat.message import MessageRecv
|
||||
from src.plugins.storage.storage import MessageStorage
|
||||
from src.plugins.config.config import global_config
|
||||
from src.config.config import global_config
|
||||
from datetime import datetime
|
||||
|
||||
logger = get_module_logger("pfc_message_processor")
|
||||
|
||||
@@ -4,7 +4,7 @@ import traceback
|
||||
from typing import List
|
||||
from ...memory_system.Hippocampus import HippocampusManager
|
||||
from ...moods.moods import MoodManager
|
||||
from ...config.config import global_config
|
||||
from ....config.config import global_config
|
||||
from ...chat.emoji_manager import emoji_manager
|
||||
from .reasoning_generator import ResponseGenerator
|
||||
from ...chat.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
@@ -192,11 +192,21 @@ class ReasoningChat:
|
||||
if not buffer_result:
|
||||
await willing_manager.bombing_buffer_message_handle(message.message_info.message_id)
|
||||
willing_manager.delete(message.message_info.message_id)
|
||||
if message.message_segment.type == "text":
|
||||
F_type = "seglist"
|
||||
if message.message_segment.type != "seglist":
|
||||
F_type = message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
||||
and len(message.message_segment.data) == 1
|
||||
):
|
||||
F_type = message.message_segment.data[0].type
|
||||
if F_type == "text":
|
||||
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
|
||||
elif message.message_segment.type == "image":
|
||||
elif F_type == "image":
|
||||
logger.info("触发缓冲,已炸飞表情包/图片")
|
||||
elif message.message_segment.type == "seglist":
|
||||
elif F_type == "seglist":
|
||||
logger.info("触发缓冲,已炸飞消息列")
|
||||
return
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import random
|
||||
|
||||
from ...models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from ....config.config import global_config
|
||||
from ...chat.message import MessageThinking
|
||||
from .reasoning_prompt_builder import prompt_builder
|
||||
from ...chat.utils import process_llm_response
|
||||
|
||||
@@ -9,7 +9,7 @@ from ...moods.moods import MoodManager
|
||||
from ....individuality.individuality import Individuality
|
||||
from ...memory_system.Hippocampus import HippocampusManager
|
||||
from ...schedule.schedule_generator import bot_schedule
|
||||
from ...config.config import global_config
|
||||
from ....config.config import global_config
|
||||
from ...person_info.relationship_manager import relationship_manager
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
||||
@@ -4,7 +4,7 @@ import traceback
|
||||
from typing import List
|
||||
from ...memory_system.Hippocampus import HippocampusManager
|
||||
from ...moods.moods import MoodManager
|
||||
from ...config.config import global_config
|
||||
from ....config.config import global_config
|
||||
from ...chat.emoji_manager import emoji_manager
|
||||
from .think_flow_generator import ResponseGenerator
|
||||
from ...chat.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
@@ -204,11 +204,21 @@ class ThinkFlowChat:
|
||||
if not buffer_result:
|
||||
await willing_manager.bombing_buffer_message_handle(message.message_info.message_id)
|
||||
willing_manager.delete(message.message_info.message_id)
|
||||
if message.message_segment.type == "text":
|
||||
F_type = "seglist"
|
||||
if message.message_segment.type != "seglist":
|
||||
F_type = message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
||||
and len(message.message_segment.data) == 1
|
||||
):
|
||||
F_type = message.message_segment.data[0].type
|
||||
if F_type == "text":
|
||||
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
|
||||
elif message.message_segment.type == "image":
|
||||
elif F_type == "image":
|
||||
logger.info("触发缓冲,已炸飞表情包/图片")
|
||||
elif message.message_segment.type == "seglist":
|
||||
elif F_type == "seglist":
|
||||
logger.info("触发缓冲,已炸飞消息列")
|
||||
return
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
|
||||
|
||||
from ...models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
from ....config.config import global_config
|
||||
from ...chat.message import MessageRecv
|
||||
from .think_flow_prompt_builder import prompt_builder
|
||||
from ...chat.utils import process_llm_response
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from ...config.config import global_config
|
||||
from ....config.config import global_config
|
||||
from ...chat.utils import get_recent_group_detailed_plain_text
|
||||
from ...chat.chat_stream import chat_manager
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
import shutil
|
||||
import tomlkit
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def update_config():
|
||||
print("开始更新配置文件...")
|
||||
# 获取根目录路径
|
||||
root_dir = Path(__file__).parent.parent.parent.parent
|
||||
template_dir = root_dir / "template"
|
||||
config_dir = root_dir / "config"
|
||||
old_config_dir = config_dir / "old"
|
||||
|
||||
# 创建old目录(如果不存在)
|
||||
old_config_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 定义文件路径
|
||||
template_path = template_dir / "bot_config_template.toml"
|
||||
old_config_path = config_dir / "bot_config.toml"
|
||||
new_config_path = config_dir / "bot_config.toml"
|
||||
|
||||
# 读取旧配置文件
|
||||
old_config = {}
|
||||
if old_config_path.exists():
|
||||
print(f"发现旧配置文件: {old_config_path}")
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
|
||||
# 生成带时间戳的新文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
|
||||
|
||||
# 移动旧配置文件到old目录
|
||||
shutil.move(old_config_path, old_backup_path)
|
||||
print(f"已备份旧配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
print(f"从模板文件创建新配置: {template_path}")
|
||||
shutil.copy2(template_path, new_config_path)
|
||||
|
||||
# 读取新配置文件
|
||||
with open(new_config_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
if old_version and new_version and old_version == new_version:
|
||||
print(f"检测到版本号相同 (v{old_version}),跳过更新")
|
||||
# 如果version相同,恢复旧配置文件并返回
|
||||
shutil.move(old_backup_path, old_config_path)
|
||||
return
|
||||
else:
|
||||
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
|
||||
# 递归更新配置
|
||||
def update_dict(target, source):
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
|
||||
update_dict(target[key], value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
if not value:
|
||||
target[key] = tomlkit.array()
|
||||
else:
|
||||
target[key] = tomlkit.array(value)
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
print("开始合并新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
print("配置文件更新完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update_config()
|
||||
@@ -1,773 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from dateutil import tz
|
||||
|
||||
import tomli
|
||||
import tomlkit
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from packaging import version
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||
|
||||
from src.common.logger import get_module_logger, CONFIG_STYLE_CONFIG, LogConfig
|
||||
|
||||
# 定义日志配置
|
||||
config_config = LogConfig(
|
||||
# 使用消息发送专用样式
|
||||
console_format=CONFIG_STYLE_CONFIG["console_format"],
|
||||
file_format=CONFIG_STYLE_CONFIG["file_format"],
|
||||
)
|
||||
|
||||
# 配置主程序日志格式
|
||||
logger = get_module_logger("config", config=config_config)
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
is_test = True
|
||||
mai_version_main = "0.6.3"
|
||||
mai_version_fix = "snapshot-1"
|
||||
|
||||
if mai_version_fix:
|
||||
if is_test:
|
||||
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
|
||||
else:
|
||||
mai_version = f"{mai_version_main}-{mai_version_fix}"
|
||||
else:
|
||||
if is_test:
|
||||
mai_version = f"test-{mai_version_main}"
|
||||
else:
|
||||
mai_version = mai_version_main
|
||||
|
||||
|
||||
def update_config():
|
||||
# 获取根目录路径
|
||||
root_dir = Path(__file__).parent.parent.parent.parent
|
||||
template_dir = root_dir / "template"
|
||||
config_dir = root_dir / "config"
|
||||
old_config_dir = config_dir / "old"
|
||||
|
||||
# 定义文件路径
|
||||
template_path = template_dir / "bot_config_template.toml"
|
||||
old_config_path = config_dir / "bot_config.toml"
|
||||
new_config_path = config_dir / "bot_config.toml"
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not old_config_path.exists():
|
||||
logger.info("配置文件不存在,从模板创建新配置")
|
||||
# 创建文件夹
|
||||
old_config_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(template_path, old_config_path)
|
||||
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
|
||||
# 如果是新创建的配置文件,直接返回
|
||||
return quit()
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version")
|
||||
new_version = new_config["inner"].get("version")
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
|
||||
# 创建old目录(如果不存在)
|
||||
old_config_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 生成带时间戳的新文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
|
||||
|
||||
# 移动旧配置文件到old目录
|
||||
shutil.move(old_config_path, old_backup_path)
|
||||
logger.info(f"已备份旧配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(template_path, new_config_path)
|
||||
logger.info(f"已创建新配置文件: {new_config_path}")
|
||||
|
||||
# 递归更新配置
|
||||
def update_dict(target, source):
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
if isinstance(value, dict) and isinstance(target[key], (dict, tomlkit.items.Table)):
|
||||
update_dict(target[key], value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
# 如果是空数组,确保它保持为空数组
|
||||
if not value:
|
||||
target[key] = tomlkit.array()
|
||||
else:
|
||||
target[key] = tomlkit.array(value)
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
logger.info("配置文件更新完成")
|
||||
|
||||
|
||||
logger = get_module_logger("config")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotConfig:
|
||||
"""机器人配置类"""
|
||||
|
||||
INNER_VERSION: Version = None
|
||||
MAI_VERSION: str = mai_version # 硬编码的版本信息
|
||||
|
||||
# bot
|
||||
BOT_QQ: Optional[int] = 114514
|
||||
BOT_NICKNAME: Optional[str] = None
|
||||
BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它
|
||||
|
||||
# group
|
||||
talk_allowed_groups = set()
|
||||
talk_frequency_down_groups = set()
|
||||
ban_user_id = set()
|
||||
|
||||
# personality
|
||||
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
|
||||
personality_sides: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"用一句话或几句话描述人格的一些侧面",
|
||||
"用一句话或几句话描述人格的一些侧面",
|
||||
"用一句话或几句话描述人格的一些侧面",
|
||||
]
|
||||
)
|
||||
# identity
|
||||
identity_detail: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"身份特点",
|
||||
"身份特点",
|
||||
]
|
||||
)
|
||||
height: int = 170 # 身高 单位厘米
|
||||
weight: int = 50 # 体重 单位千克
|
||||
age: int = 20 # 年龄 单位岁
|
||||
gender: str = "男" # 性别
|
||||
appearance: str = "用几句话描述外貌特征" # 外貌特征
|
||||
|
||||
# schedule
|
||||
ENABLE_SCHEDULE_GEN: bool = False # 是否启用日程生成
|
||||
PROMPT_SCHEDULE_GEN = "无日程"
|
||||
SCHEDULE_DOING_UPDATE_INTERVAL: int = 300 # 日程表更新间隔 单位秒
|
||||
SCHEDULE_TEMPERATURE: float = 0.5 # 日程表温度,建议0.5-1.0
|
||||
TIME_ZONE: str = "Asia/Shanghai" # 时区
|
||||
|
||||
# message
|
||||
MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数
|
||||
emoji_chance: float = 0.2 # 发送表情包的基础概率
|
||||
thinking_timeout: int = 120 # 思考时间
|
||||
max_response_length: int = 1024 # 最大回复长度
|
||||
message_buffer: bool = True # 消息缓冲器
|
||||
|
||||
ban_words = set()
|
||||
ban_msgs_regex = set()
|
||||
|
||||
# heartflow
|
||||
# enable_heartflow: bool = False # 是否启用心流
|
||||
sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒
|
||||
sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒
|
||||
sub_heart_flow_stop_time: int = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒
|
||||
heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒
|
||||
observation_context_size: int = 20 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩
|
||||
compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5
|
||||
compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除
|
||||
|
||||
# willing
|
||||
willing_mode: str = "classical" # 意愿模式
|
||||
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
|
||||
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
|
||||
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
|
||||
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
|
||||
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
|
||||
at_bot_inevitable_reply: bool = False # @bot 必然回复
|
||||
|
||||
# response
|
||||
response_mode: str = "heart_flow" # 回复策略
|
||||
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
||||
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
||||
# MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
||||
|
||||
# emoji
|
||||
max_emoji_num: int = 200 # 表情包最大数量
|
||||
max_reach_deletion: bool = True # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包
|
||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||
EMOJI_SAVE: bool = True # 偷表情包
|
||||
EMOJI_CHECK: bool = False # 是否开启过滤
|
||||
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
|
||||
|
||||
# memory
|
||||
build_memory_interval: int = 600 # 记忆构建间隔(秒)
|
||||
memory_build_distribution: list = field(
|
||||
default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
|
||||
) # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
||||
build_memory_sample_num: int = 10 # 记忆构建采样数量
|
||||
build_memory_sample_length: int = 20 # 记忆构建采样长度
|
||||
memory_compress_rate: float = 0.1 # 记忆压缩率
|
||||
|
||||
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
|
||||
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
|
||||
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
|
||||
|
||||
memory_ban_words: list = field(
|
||||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
|
||||
) # 添加新的配置项默认值
|
||||
|
||||
# mood
|
||||
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
|
||||
mood_decay_rate: float = 0.95 # 情绪衰减率
|
||||
mood_intensity_factor: float = 0.7 # 情绪强度因子
|
||||
|
||||
# keywords
|
||||
keywords_reaction_rules = [] # 关键词回复规则
|
||||
|
||||
# chinese_typo
|
||||
chinese_typo_enable = True # 是否启用中文错别字生成器
|
||||
chinese_typo_error_rate = 0.03 # 单字替换概率
|
||||
chinese_typo_min_freq = 7 # 最小字频阈值
|
||||
chinese_typo_tone_error_rate = 0.2 # 声调错误概率
|
||||
chinese_typo_word_replace_rate = 0.02 # 整词替换概率
|
||||
|
||||
# response_splitter
|
||||
enable_response_splitter = True # 是否启用回复分割器
|
||||
response_max_length = 100 # 回复允许的最大长度
|
||||
response_max_sentence_num = 3 # 回复允许的最大句子数
|
||||
|
||||
# remote
|
||||
remote_enable: bool = True # 是否启用远程控制
|
||||
|
||||
# experimental
|
||||
enable_friend_chat: bool = False # 是否启用好友聊天
|
||||
# enable_think_flow: bool = False # 是否启用思考流程
|
||||
enable_pfc_chatting: bool = False # 是否启用PFC聊天
|
||||
|
||||
# 模型配置
|
||||
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
|
||||
# llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
||||
llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {})
|
||||
llm_summary_by_topic: Dict[str, str] = field(default_factory=lambda: {})
|
||||
llm_emotion_judge: Dict[str, str] = field(default_factory=lambda: {})
|
||||
embedding: Dict[str, str] = field(default_factory=lambda: {})
|
||||
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
||||
moderation: Dict[str, str] = field(default_factory=lambda: {})
|
||||
|
||||
# 实验性
|
||||
llm_observation: Dict[str, str] = field(default_factory=lambda: {})
|
||||
llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {})
|
||||
llm_heartflow: Dict[str, str] = field(default_factory=lambda: {})
|
||||
|
||||
build_memory_interval: int = 600 # 记忆构建间隔(秒)
|
||||
|
||||
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
|
||||
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
|
||||
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
|
||||
memory_compress_rate: float = 0.1 # 记忆压缩率
|
||||
build_memory_sample_num: int = 10 # 记忆构建采样数量
|
||||
build_memory_sample_length: int = 20 # 记忆构建采样长度
|
||||
memory_build_distribution: list = field(
|
||||
default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
|
||||
) # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
||||
memory_ban_words: list = field(
|
||||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
|
||||
) # 添加新的配置项默认值
|
||||
|
||||
api_urls: Dict[str, str] = field(default_factory=lambda: {})
|
||||
|
||||
@staticmethod
|
||||
def get_config_dir() -> str:
|
||||
"""获取配置文件目录"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
||||
config_dir = os.path.join(root_dir, "config")
|
||||
if not os.path.exists(config_dir):
|
||||
os.makedirs(config_dir)
|
||||
return config_dir
|
||||
|
||||
@classmethod
|
||||
def convert_to_specifierset(cls, value: str) -> SpecifierSet:
|
||||
"""将 字符串 版本表达式转换成 SpecifierSet
|
||||
Args:
|
||||
value[str]: 版本表达式(字符串)
|
||||
Returns:
|
||||
SpecifierSet
|
||||
"""
|
||||
|
||||
try:
|
||||
converted = SpecifierSet(value)
|
||||
except InvalidSpecifier:
|
||||
logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
|
||||
exit(1)
|
||||
|
||||
return converted
|
||||
|
||||
@classmethod
|
||||
def get_config_version(cls, toml: dict) -> Version:
|
||||
"""提取配置文件的 SpecifierSet 版本数据
|
||||
Args:
|
||||
toml[dict]: 输入的配置文件字典
|
||||
Returns:
|
||||
Version
|
||||
"""
|
||||
|
||||
if "inner" in toml:
|
||||
try:
|
||||
config_version: str = toml["inner"]["version"]
|
||||
except KeyError as e:
|
||||
logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
|
||||
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
|
||||
else:
|
||||
toml["inner"] = {"version": "0.0.0"}
|
||||
config_version = toml["inner"]["version"]
|
||||
|
||||
try:
|
||||
ver = version.parse(config_version)
|
||||
except InvalidVersion as e:
|
||||
logger.error(
|
||||
"配置文件中 inner段 的 version 键是错误的版本描述\n"
|
||||
"请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
|
||||
"本项目在不同的版本下有不同的模板,请注意识别"
|
||||
)
|
||||
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
|
||||
|
||||
return ver
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, config_path: str = None) -> "BotConfig":
|
||||
"""从TOML配置文件加载配置"""
|
||||
config = cls()
|
||||
|
||||
def personality(parent: dict):
|
||||
personality_config = parent["personality"]
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
|
||||
config.personality_core = personality_config.get("personality_core", config.personality_core)
|
||||
config.personality_sides = personality_config.get("personality_sides", config.personality_sides)
|
||||
|
||||
def identity(parent: dict):
|
||||
identity_config = parent["identity"]
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
|
||||
config.identity_detail = identity_config.get("identity_detail", config.identity_detail)
|
||||
config.height = identity_config.get("height", config.height)
|
||||
config.weight = identity_config.get("weight", config.weight)
|
||||
config.age = identity_config.get("age", config.age)
|
||||
config.gender = identity_config.get("gender", config.gender)
|
||||
config.appearance = identity_config.get("appearance", config.appearance)
|
||||
|
||||
def schedule(parent: dict):
|
||||
schedule_config = parent["schedule"]
|
||||
config.ENABLE_SCHEDULE_GEN = schedule_config.get("enable_schedule_gen", config.ENABLE_SCHEDULE_GEN)
|
||||
config.PROMPT_SCHEDULE_GEN = schedule_config.get("prompt_schedule_gen", config.PROMPT_SCHEDULE_GEN)
|
||||
config.SCHEDULE_DOING_UPDATE_INTERVAL = schedule_config.get(
|
||||
"schedule_doing_update_interval", config.SCHEDULE_DOING_UPDATE_INTERVAL
|
||||
)
|
||||
logger.info(
|
||||
f"载入自定义日程prompt:{schedule_config.get('prompt_schedule_gen', config.PROMPT_SCHEDULE_GEN)}"
|
||||
)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.0.2"):
|
||||
config.SCHEDULE_TEMPERATURE = schedule_config.get("schedule_temperature", config.SCHEDULE_TEMPERATURE)
|
||||
time_zone = schedule_config.get("time_zone", config.TIME_ZONE)
|
||||
if tz.gettz(time_zone) is None:
|
||||
logger.error(f"无效的时区: {time_zone},使用默认值: {config.TIME_ZONE}")
|
||||
else:
|
||||
config.TIME_ZONE = time_zone
|
||||
|
||||
def emoji(parent: dict):
|
||||
emoji_config = parent["emoji"]
|
||||
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
|
||||
config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL)
|
||||
config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
|
||||
config.EMOJI_SAVE = emoji_config.get("auto_save", config.EMOJI_SAVE)
|
||||
config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.1.1"):
|
||||
config.max_emoji_num = emoji_config.get("max_emoji_num", config.max_emoji_num)
|
||||
config.max_reach_deletion = emoji_config.get("max_reach_deletion", config.max_reach_deletion)
|
||||
|
||||
def bot(parent: dict):
|
||||
# 机器人基础配置
|
||||
bot_config = parent["bot"]
|
||||
bot_qq = bot_config.get("qq")
|
||||
config.BOT_QQ = int(bot_qq)
|
||||
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
|
||||
config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
|
||||
|
||||
def response(parent: dict):
|
||||
response_config = parent["response"]
|
||||
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
|
||||
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
||||
# config.MODEL_R1_DISTILL_PROBABILITY = response_config.get(
|
||||
# "model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY
|
||||
# )
|
||||
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.0.4"):
|
||||
config.response_mode = response_config.get("response_mode", config.response_mode)
|
||||
|
||||
def heartflow(parent: dict):
|
||||
heartflow_config = parent["heartflow"]
|
||||
config.sub_heart_flow_update_interval = heartflow_config.get(
|
||||
"sub_heart_flow_update_interval", config.sub_heart_flow_update_interval
|
||||
)
|
||||
config.sub_heart_flow_freeze_time = heartflow_config.get(
|
||||
"sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time
|
||||
)
|
||||
config.sub_heart_flow_stop_time = heartflow_config.get(
|
||||
"sub_heart_flow_stop_time", config.sub_heart_flow_stop_time
|
||||
)
|
||||
config.heart_flow_update_interval = heartflow_config.get(
|
||||
"heart_flow_update_interval", config.heart_flow_update_interval
|
||||
)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.3.0"):
|
||||
config.observation_context_size = heartflow_config.get(
|
||||
"observation_context_size", config.observation_context_size
|
||||
)
|
||||
config.compressed_length = heartflow_config.get("compressed_length", config.compressed_length)
|
||||
config.compress_length_limit = heartflow_config.get(
|
||||
"compress_length_limit", config.compress_length_limit
|
||||
)
|
||||
|
||||
def willing(parent: dict):
|
||||
willing_config = parent["willing"]
|
||||
config.willing_mode = willing_config.get("willing_mode", config.willing_mode)
|
||||
|
||||
if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
|
||||
config.response_willing_amplifier = willing_config.get(
|
||||
"response_willing_amplifier", config.response_willing_amplifier
|
||||
)
|
||||
config.response_interested_rate_amplifier = willing_config.get(
|
||||
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
|
||||
)
|
||||
config.down_frequency_rate = willing_config.get("down_frequency_rate", config.down_frequency_rate)
|
||||
config.emoji_response_penalty = willing_config.get(
|
||||
"emoji_response_penalty", config.emoji_response_penalty
|
||||
)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.2.5"):
|
||||
config.mentioned_bot_inevitable_reply = willing_config.get(
|
||||
"mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
|
||||
)
|
||||
config.at_bot_inevitable_reply = willing_config.get(
|
||||
"at_bot_inevitable_reply", config.at_bot_inevitable_reply
|
||||
)
|
||||
|
||||
def model(parent: dict):
|
||||
# 加载模型配置
|
||||
model_config: dict = parent["model"]
|
||||
|
||||
config_list = [
|
||||
"llm_reasoning",
|
||||
# "llm_reasoning_minor",
|
||||
"llm_normal",
|
||||
"llm_topic_judge",
|
||||
"llm_summary_by_topic",
|
||||
"llm_emotion_judge",
|
||||
"vlm",
|
||||
"embedding",
|
||||
"llm_tool_use",
|
||||
"llm_observation",
|
||||
"llm_sub_heartflow",
|
||||
"llm_heartflow",
|
||||
]
|
||||
|
||||
for item in config_list:
|
||||
if item in model_config:
|
||||
cfg_item: dict = model_config[item]
|
||||
|
||||
# base_url 的例子: SILICONFLOW_BASE_URL
|
||||
# key 的例子: SILICONFLOW_KEY
|
||||
cfg_target = {
|
||||
"name": "",
|
||||
"base_url": "",
|
||||
"key": "",
|
||||
"stream": False,
|
||||
"pri_in": 0,
|
||||
"pri_out": 0,
|
||||
"temp": 0.7,
|
||||
}
|
||||
|
||||
if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
|
||||
cfg_target = cfg_item
|
||||
|
||||
elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
|
||||
stable_item = ["name", "pri_in", "pri_out"]
|
||||
|
||||
stream_item = ["stream"]
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.0.1"):
|
||||
stable_item.append("stream")
|
||||
|
||||
pricing_item = ["pri_in", "pri_out"]
|
||||
|
||||
# 从配置中原始拷贝稳定字段
|
||||
for i in stable_item:
|
||||
# 如果 字段 属于计费项 且获取不到,那默认值是 0
|
||||
if i in pricing_item and i not in cfg_item:
|
||||
cfg_target[i] = 0
|
||||
|
||||
if i in stream_item and i not in cfg_item:
|
||||
cfg_target[i] = False
|
||||
|
||||
else:
|
||||
# 没有特殊情况则原样复制
|
||||
try:
|
||||
cfg_target[i] = cfg_item[i]
|
||||
except KeyError as e:
|
||||
logger.error(f"{item} 中的必要字段不存在,请检查")
|
||||
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
|
||||
|
||||
# 如果配置中有temp参数,就使用配置中的值
|
||||
if "temp" in cfg_item:
|
||||
cfg_target["temp"] = cfg_item["temp"]
|
||||
else:
|
||||
# 如果没有temp参数,就删除默认值
|
||||
cfg_target.pop("temp", None)
|
||||
|
||||
provider = cfg_item.get("provider")
|
||||
if provider is None:
|
||||
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
||||
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
||||
|
||||
cfg_target["base_url"] = f"{provider}_BASE_URL"
|
||||
cfg_target["key"] = f"{provider}_KEY"
|
||||
|
||||
# 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
|
||||
setattr(config, item, cfg_target)
|
||||
else:
|
||||
logger.error(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件")
|
||||
raise KeyError(f"模型 {item} 在config中不存在,请检查,或尝试更新配置文件")
|
||||
|
||||
def message(parent: dict):
|
||||
msg_config = parent["message"]
|
||||
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
|
||||
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
|
||||
config.ban_words = msg_config.get("ban_words", config.ban_words)
|
||||
config.thinking_timeout = msg_config.get("thinking_timeout", config.thinking_timeout)
|
||||
config.response_willing_amplifier = msg_config.get(
|
||||
"response_willing_amplifier", config.response_willing_amplifier
|
||||
)
|
||||
config.response_interested_rate_amplifier = msg_config.get(
|
||||
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
|
||||
)
|
||||
config.down_frequency_rate = msg_config.get("down_frequency_rate", config.down_frequency_rate)
|
||||
for r in msg_config.get("ban_msgs_regex", config.ban_msgs_regex):
|
||||
config.ban_msgs_regex.add(re.compile(r))
|
||||
if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
|
||||
config.max_response_length = msg_config.get("max_response_length", config.max_response_length)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.1.4"):
|
||||
config.message_buffer = msg_config.get("message_buffer", config.message_buffer)
|
||||
|
||||
def memory(parent: dict):
|
||||
memory_config = parent["memory"]
|
||||
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
|
||||
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
|
||||
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
|
||||
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
|
||||
config.memory_forget_percentage = memory_config.get(
|
||||
"memory_forget_percentage", config.memory_forget_percentage
|
||||
)
|
||||
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
|
||||
if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
|
||||
config.memory_build_distribution = memory_config.get(
|
||||
"memory_build_distribution", config.memory_build_distribution
|
||||
)
|
||||
config.build_memory_sample_num = memory_config.get(
|
||||
"build_memory_sample_num", config.build_memory_sample_num
|
||||
)
|
||||
config.build_memory_sample_length = memory_config.get(
|
||||
"build_memory_sample_length", config.build_memory_sample_length
|
||||
)
|
||||
|
||||
def remote(parent: dict):
|
||||
remote_config = parent["remote"]
|
||||
config.remote_enable = remote_config.get("enable", config.remote_enable)
|
||||
|
||||
def mood(parent: dict):
|
||||
mood_config = parent["mood"]
|
||||
config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
|
||||
config.mood_decay_rate = mood_config.get("mood_decay_rate", config.mood_decay_rate)
|
||||
config.mood_intensity_factor = mood_config.get("mood_intensity_factor", config.mood_intensity_factor)
|
||||
|
||||
def keywords_reaction(parent: dict):
|
||||
keywords_reaction_config = parent["keywords_reaction"]
|
||||
if keywords_reaction_config.get("enable", False):
|
||||
config.keywords_reaction_rules = keywords_reaction_config.get("rules", config.keywords_reaction_rules)
|
||||
for rule in config.keywords_reaction_rules:
|
||||
if rule.get("enable", False) and "regex" in rule:
|
||||
rule["regex"] = [re.compile(r) for r in rule.get("regex", [])]
|
||||
|
||||
def chinese_typo(parent: dict):
|
||||
chinese_typo_config = parent["chinese_typo"]
|
||||
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
|
||||
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
|
||||
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
|
||||
config.chinese_typo_tone_error_rate = chinese_typo_config.get(
|
||||
"tone_error_rate", config.chinese_typo_tone_error_rate
|
||||
)
|
||||
config.chinese_typo_word_replace_rate = chinese_typo_config.get(
|
||||
"word_replace_rate", config.chinese_typo_word_replace_rate
|
||||
)
|
||||
|
||||
def response_splitter(parent: dict):
|
||||
response_splitter_config = parent["response_splitter"]
|
||||
config.enable_response_splitter = response_splitter_config.get(
|
||||
"enable_response_splitter", config.enable_response_splitter
|
||||
)
|
||||
config.response_max_length = response_splitter_config.get("response_max_length", config.response_max_length)
|
||||
config.response_max_sentence_num = response_splitter_config.get(
|
||||
"response_max_sentence_num", config.response_max_sentence_num
|
||||
)
|
||||
|
||||
def groups(parent: dict):
|
||||
groups_config = parent["groups"]
|
||||
config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
|
||||
config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
|
||||
config.ban_user_id = set(groups_config.get("ban_user_id", []))
|
||||
|
||||
def platforms(parent: dict):
|
||||
platforms_config = parent["platforms"]
|
||||
if platforms_config and isinstance(platforms_config, dict):
|
||||
for k in platforms_config.keys():
|
||||
config.api_urls[k] = platforms_config[k]
|
||||
|
||||
def experimental(parent: dict):
|
||||
experimental_config = parent["experimental"]
|
||||
config.enable_friend_chat = experimental_config.get("enable_friend_chat", config.enable_friend_chat)
|
||||
# config.enable_think_flow = experimental_config.get("enable_think_flow", config.enable_think_flow)
|
||||
if config.INNER_VERSION in SpecifierSet(">=1.1.0"):
|
||||
config.enable_pfc_chatting = experimental_config.get("pfc_chatting", config.enable_pfc_chatting)
|
||||
|
||||
# 版本表达式:>=1.0.0,<2.0.0
|
||||
# 允许字段:func: method, support: str, notice: str, necessary: bool
|
||||
# 如果使用 notice 字段,在该组配置加载时,会展示该字段对用户的警示
|
||||
# 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
|
||||
# 正常执行程序,但是会看到这条自定义提示
|
||||
|
||||
# 版本格式:主版本号.次版本号.修订号,版本号递增规则如下:
|
||||
# 主版本号:当你做了不兼容的 API 修改,
|
||||
# 次版本号:当你做了向下兼容的功能性新增,
|
||||
# 修订号:当你做了向下兼容的问题修正。
|
||||
# 先行版本号及版本编译信息可以加到"主版本号.次版本号.修订号"的后面,作为延伸。
|
||||
|
||||
# 如果你做了break的修改,就应该改动主版本号
|
||||
# 如果做了一个兼容修改,就不应该要求这个选项是必须的!
|
||||
include_configs = {
|
||||
"bot": {"func": bot, "support": ">=0.0.0"},
|
||||
"groups": {"func": groups, "support": ">=0.0.0"},
|
||||
"personality": {"func": personality, "support": ">=0.0.0"},
|
||||
"identity": {"func": identity, "support": ">=1.2.4"},
|
||||
"schedule": {"func": schedule, "support": ">=0.0.11", "necessary": False},
|
||||
"message": {"func": message, "support": ">=0.0.0"},
|
||||
"willing": {"func": willing, "support": ">=0.0.9", "necessary": False},
|
||||
"emoji": {"func": emoji, "support": ">=0.0.0"},
|
||||
"response": {"func": response, "support": ">=0.0.0"},
|
||||
"model": {"func": model, "support": ">=0.0.0"},
|
||||
"memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
|
||||
"mood": {"func": mood, "support": ">=0.0.0"},
|
||||
"remote": {"func": remote, "support": ">=0.0.10", "necessary": False},
|
||||
"keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
|
||||
"chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
|
||||
"platforms": {"func": platforms, "support": ">=1.0.0"},
|
||||
"response_splitter": {"func": response_splitter, "support": ">=0.0.11", "necessary": False},
|
||||
"experimental": {"func": experimental, "support": ">=0.0.11", "necessary": False},
|
||||
"heartflow": {"func": heartflow, "support": ">=1.0.2", "necessary": False},
|
||||
}
|
||||
|
||||
# 原地修改,将 字符串版本表达式 转换成 版本对象
|
||||
for key in include_configs:
|
||||
item_support = include_configs[key]["support"]
|
||||
include_configs[key]["support"] = cls.convert_to_specifierset(item_support)
|
||||
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "rb") as f:
|
||||
try:
|
||||
toml_dict = tomli.load(f)
|
||||
except tomli.TOMLDecodeError as e:
|
||||
logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}")
|
||||
exit(1)
|
||||
|
||||
# 获取配置文件版本
|
||||
config.INNER_VERSION = cls.get_config_version(toml_dict)
|
||||
|
||||
# 如果在配置中找到了需要的项,调用对应项的闭包函数处理
|
||||
for key in include_configs:
|
||||
if key in toml_dict:
|
||||
group_specifierset: SpecifierSet = include_configs[key]["support"]
|
||||
|
||||
# 检查配置文件版本是否在支持范围内
|
||||
if config.INNER_VERSION in group_specifierset:
|
||||
# 如果版本在支持范围内,检查是否存在通知
|
||||
if "notice" in include_configs[key]:
|
||||
logger.warning(include_configs[key]["notice"])
|
||||
|
||||
include_configs[key]["func"](toml_dict)
|
||||
|
||||
else:
|
||||
# 如果版本不在支持范围内,崩溃并提示用户
|
||||
logger.error(
|
||||
f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
|
||||
f"当前程序仅支持以下版本范围: {group_specifierset}"
|
||||
)
|
||||
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
|
||||
|
||||
# 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理
|
||||
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
|
||||
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
|
||||
if key == "keywords_reaction":
|
||||
pass
|
||||
|
||||
else:
|
||||
# 如果用户根本没有需要的配置项,提示缺少配置
|
||||
logger.error(f"配置文件中缺少必需的字段: '{key}'")
|
||||
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
|
||||
|
||||
# identity_detail字段非空检查
|
||||
if not config.identity_detail:
|
||||
logger.error("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
|
||||
raise ValueError("配置文件错误:[identity] 部分的 identity_detail 不能为空字符串")
|
||||
|
||||
logger.success(f"成功加载配置文件: {config_path}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# 获取配置文件路径
|
||||
logger.info(f"MaiCore当前版本: {mai_version}")
|
||||
update_config()
|
||||
|
||||
bot_config_floder_path = BotConfig.get_config_dir()
|
||||
logger.info(f"正在品鉴配置文件目录: {bot_config_floder_path}")
|
||||
|
||||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||||
|
||||
if os.path.exists(bot_config_path):
|
||||
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||||
logger.info(f"异常的新鲜,异常的美味: {bot_config_path}")
|
||||
else:
|
||||
# 配置文件不存在
|
||||
logger.error("配置文件不存在,请检查路径: {bot_config_path}")
|
||||
raise FileNotFoundError(f"配置文件不存在: {bot_config_path}")
|
||||
|
||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||
@@ -1,59 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
class EnvConfig:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(EnvConfig, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self.ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
||||
self.load_env()
|
||||
|
||||
def load_env(self):
|
||||
env_file = self.ROOT_DIR / ".env"
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file)
|
||||
|
||||
# 根据ENVIRONMENT变量加载对应的环境文件
|
||||
env_type = os.getenv("ENVIRONMENT", "prod")
|
||||
if env_type == "dev":
|
||||
env_file = self.ROOT_DIR / ".env.dev"
|
||||
elif env_type == "prod":
|
||||
env_file = self.ROOT_DIR / ".env"
|
||||
|
||||
if env_file.exists():
|
||||
load_dotenv(env_file, override=True)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return os.getenv(key, default)
|
||||
|
||||
def get_all(self):
|
||||
return dict(os.environ)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.get(name)
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
env_config = EnvConfig()
|
||||
|
||||
|
||||
# 导出环境变量
|
||||
def get_env(key, default=None):
|
||||
return os.getenv(key, default)
|
||||
|
||||
|
||||
# 导出所有环境变量
|
||||
def get_all_env():
|
||||
return dict(os.environ)
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
# 添加项目根目录到系统路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.plugins.config.config import global_config
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
async def test_memory_system():
|
||||
|
||||
@@ -11,7 +11,7 @@ from PIL import Image
|
||||
import io
|
||||
import os
|
||||
from ...common.database import db
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
|
||||
logger = get_module_logger("model_utils")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from src.common.logger import get_module_logger, LogConfig, MOOD_STYLE_CONFIG
|
||||
from ..person_info.relationship_manager import relationship_manager
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
@@ -7,7 +7,7 @@ import datetime
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from src.plugins.models.utils_model import LLMRequest
|
||||
from src.plugins.config.config import global_config
|
||||
from src.config.config import global_config
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
import matplotlib
|
||||
@@ -354,7 +354,7 @@ class PersonInfoManager:
|
||||
"""启动个人信息推断,每天根据一定条件推断一次"""
|
||||
try:
|
||||
while 1:
|
||||
await asyncio.sleep(60)
|
||||
await asyncio.sleep(600)
|
||||
current_time = datetime.datetime.now()
|
||||
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# from .questionnaire import PERSONALITY_QUESTIONS, FACTOR_DESCRIPTIONS
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
project_root = current_dir.parent.parent.parent
|
||||
env_path = project_root / ".env"
|
||||
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.plugins.personality.questionnaire import PERSONALITY_QUESTIONS, FACTOR_DESCRIPTIONS # noqa: E402
|
||||
|
||||
|
||||
class BigFiveTest:
|
||||
def __init__(self):
|
||||
self.questions = PERSONALITY_QUESTIONS
|
||||
self.factors = FACTOR_DESCRIPTIONS
|
||||
|
||||
def run_test(self):
|
||||
"""运行测试并收集答案"""
|
||||
print("\n欢迎参加中国大五人格测试!")
|
||||
print("\n本测试采用六级评分,请根据每个描述与您的符合程度进行打分:")
|
||||
print("1 = 完全不符合")
|
||||
print("2 = 比较不符合")
|
||||
print("3 = 有点不符合")
|
||||
print("4 = 有点符合")
|
||||
print("5 = 比较符合")
|
||||
print("6 = 完全符合")
|
||||
print("\n请认真阅读每个描述,选择最符合您实际情况的选项。\n")
|
||||
|
||||
# 创建题目序号到题目的映射
|
||||
questions_map = {q["id"]: q for q in self.questions}
|
||||
|
||||
# 获取所有题目ID并随机打乱顺序
|
||||
question_ids = list(questions_map.keys())
|
||||
random.shuffle(question_ids)
|
||||
|
||||
answers = {}
|
||||
total_questions = len(question_ids)
|
||||
|
||||
for i, question_id in enumerate(question_ids, 1):
|
||||
question = questions_map[question_id]
|
||||
while True:
|
||||
try:
|
||||
print(f"\n[{i}/{total_questions}] {question['content']}")
|
||||
score = int(input("您的评分(1-6): "))
|
||||
if 1 <= score <= 6:
|
||||
answers[question_id] = score
|
||||
break
|
||||
else:
|
||||
print("请输入1-6之间的数字!")
|
||||
except ValueError:
|
||||
print("请输入有效的数字!")
|
||||
|
||||
return self.calculate_scores(answers)
|
||||
|
||||
def calculate_scores(self, answers):
|
||||
"""计算各维度得分"""
|
||||
results = {}
|
||||
factor_questions = {"外向性": [], "神经质": [], "严谨性": [], "开放性": [], "宜人性": []}
|
||||
|
||||
# 将题目按因子分类
|
||||
for q in self.questions:
|
||||
factor_questions[q["factor"]].append(q)
|
||||
|
||||
# 计算每个维度的得分
|
||||
for factor, questions in factor_questions.items():
|
||||
total_score = 0
|
||||
for q in questions:
|
||||
score = answers[q["id"]]
|
||||
# 处理反向计分题目
|
||||
if q["reverse_scoring"]:
|
||||
score = 7 - score # 6分量表反向计分为7减原始分
|
||||
total_score += score
|
||||
|
||||
# 计算平均分
|
||||
avg_score = round(total_score / len(questions), 2)
|
||||
results[factor] = {"得分": avg_score, "题目数": len(questions), "总分": total_score}
|
||||
|
||||
return results
|
||||
|
||||
def get_factor_description(self, factor):
|
||||
"""获取因子的详细描述"""
|
||||
return self.factors[factor]
|
||||
|
||||
|
||||
def main():
|
||||
test = BigFiveTest()
|
||||
results = test.run_test()
|
||||
|
||||
print("\n测试结果:")
|
||||
print("=" * 50)
|
||||
for factor, data in results.items():
|
||||
print(f"\n{factor}:")
|
||||
print(f"平均分: {data['得分']} (总分: {data['总分']}, 题目数: {data['题目数']})")
|
||||
print("-" * 30)
|
||||
description = test.get_factor_description(factor)
|
||||
print("维度说明:", description["description"][:100] + "...")
|
||||
print("\n特征词:", ", ".join(description["trait_words"]))
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,353 +0,0 @@
|
||||
"""
|
||||
基于聊天记录的人格特征分析系统
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
import sys
|
||||
import random
|
||||
from collections import defaultdict
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
import matplotlib.font_manager as fm
|
||||
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
project_root = current_dir.parent.parent.parent
|
||||
env_path = project_root / ".env"
|
||||
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.plugins.personality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa: E402
|
||||
from src.plugins.personality.questionnaire import FACTOR_DESCRIPTIONS # noqa: E402
|
||||
from src.plugins.personality.offline_llm import LLMModel # noqa: E402
|
||||
from src.plugins.personality.who_r_u import MessageAnalyzer # noqa: E402
|
||||
|
||||
# 加载环境变量
|
||||
if env_path.exists():
|
||||
print(f"从 {env_path} 加载环境变量")
|
||||
load_dotenv(env_path)
|
||||
else:
|
||||
print(f"未找到环境变量文件: {env_path}")
|
||||
print("将使用默认配置")
|
||||
|
||||
|
||||
class ChatBasedPersonalityEvaluator:
|
||||
def __init__(self):
|
||||
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||
self.scenarios = []
|
||||
self.message_analyzer = MessageAnalyzer()
|
||||
self.llm = LLMModel()
|
||||
self.trait_scores_history = defaultdict(list) # 记录每个特质的得分历史
|
||||
|
||||
# 为每个人格特质获取对应的场景
|
||||
for trait in PERSONALITY_SCENES:
|
||||
scenes = get_scene_by_factor(trait)
|
||||
if not scenes:
|
||||
continue
|
||||
scene_keys = list(scenes.keys())
|
||||
selected_scenes = random.sample(scene_keys, min(3, len(scene_keys)))
|
||||
|
||||
for scene_key in selected_scenes:
|
||||
scene = scenes[scene_key]
|
||||
other_traits = [t for t in PERSONALITY_SCENES if t != trait]
|
||||
secondary_trait = random.choice(other_traits)
|
||||
self.scenarios.append(
|
||||
{"场景": scene["scenario"], "评估维度": [trait, secondary_trait], "场景编号": scene_key}
|
||||
)
|
||||
|
||||
def analyze_chat_context(self, messages: List[Dict]) -> str:
|
||||
"""
|
||||
分析一组消息的上下文,生成场景描述
|
||||
"""
|
||||
context = ""
|
||||
for msg in messages:
|
||||
nickname = msg.get("user_info", {}).get("user_nickname", "未知用户")
|
||||
content = msg.get("processed_plain_text", msg.get("detailed_plain_text", ""))
|
||||
if content:
|
||||
context += f"{nickname}: {content}\n"
|
||||
return context
|
||||
|
||||
def evaluate_chat_response(
|
||||
self, user_nickname: str, chat_context: str, dimensions: List[str] = None
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
评估聊天内容在各个人格维度上的得分
|
||||
"""
|
||||
# 使用所有维度进行评估
|
||||
dimensions = list(self.personality_traits.keys())
|
||||
|
||||
dimension_descriptions = []
|
||||
for dim in dimensions:
|
||||
desc = FACTOR_DESCRIPTIONS.get(dim, "")
|
||||
if desc:
|
||||
dimension_descriptions.append(f"- {dim}:{desc}")
|
||||
|
||||
dimensions_text = "\n".join(dimension_descriptions)
|
||||
|
||||
prompt = f"""请根据以下聊天记录,评估"{user_nickname}"在大五人格模型中的维度得分(1-6分)。
|
||||
|
||||
聊天记录:
|
||||
{chat_context}
|
||||
|
||||
需要评估的维度说明:
|
||||
{dimensions_text}
|
||||
|
||||
请按照以下格式输出评估结果,注意,你的评价对象是"{user_nickname}"(仅输出JSON格式):
|
||||
{{
|
||||
"开放性": 分数,
|
||||
"严谨性": 分数,
|
||||
"外向性": 分数,
|
||||
"宜人性": 分数,
|
||||
"神经质": 分数
|
||||
}}
|
||||
|
||||
评分标准:
|
||||
1 = 非常不符合该维度特征
|
||||
2 = 比较不符合该维度特征
|
||||
3 = 有点不符合该维度特征
|
||||
4 = 有点符合该维度特征
|
||||
5 = 比较符合该维度特征
|
||||
6 = 非常符合该维度特征
|
||||
|
||||
如果你觉得某个维度没有相关信息或者无法判断,请输出0分
|
||||
|
||||
请根据聊天记录的内容和语气,结合维度说明进行评分。如果维度可以评分,确保分数在1-6之间。如果没有体现,请输出0分"""
|
||||
|
||||
try:
|
||||
ai_response, _ = self.llm.generate_response(prompt)
|
||||
start_idx = ai_response.find("{")
|
||||
end_idx = ai_response.rfind("}") + 1
|
||||
if start_idx != -1 and end_idx != 0:
|
||||
json_str = ai_response[start_idx:end_idx]
|
||||
scores = json.loads(json_str)
|
||||
return {k: max(0, min(6, float(v))) for k, v in scores.items()}
|
||||
else:
|
||||
print("AI响应格式不正确,使用默认评分")
|
||||
return {dim: 0 for dim in dimensions}
|
||||
except Exception as e:
|
||||
print(f"评估过程出错:{str(e)}")
|
||||
return {dim: 0 for dim in dimensions}
|
||||
|
||||
def evaluate_user_personality(self, qq_id: str, num_samples: int = 10, context_length: int = 5) -> Dict:
|
||||
"""
|
||||
基于用户的聊天记录评估人格特征
|
||||
|
||||
Args:
|
||||
qq_id (str): 用户QQ号
|
||||
num_samples (int): 要分析的聊天片段数量
|
||||
context_length (int): 每个聊天片段的上下文长度
|
||||
|
||||
Returns:
|
||||
Dict: 评估结果
|
||||
"""
|
||||
# 获取用户的随机消息及其上下文
|
||||
chat_contexts, user_nickname = self.message_analyzer.get_user_random_contexts(
|
||||
qq_id, num_messages=num_samples, context_length=context_length
|
||||
)
|
||||
if not chat_contexts:
|
||||
return {"error": f"没有找到QQ号 {qq_id} 的消息记录"}
|
||||
|
||||
# 初始化评分
|
||||
final_scores = defaultdict(float)
|
||||
dimension_counts = defaultdict(int)
|
||||
chat_samples = []
|
||||
|
||||
# 清空历史记录
|
||||
self.trait_scores_history.clear()
|
||||
|
||||
# 分析每个聊天上下文
|
||||
for chat_context in chat_contexts:
|
||||
# 评估这段聊天内容的所有维度
|
||||
scores = self.evaluate_chat_response(user_nickname, chat_context)
|
||||
|
||||
# 记录样本
|
||||
chat_samples.append(
|
||||
{"聊天内容": chat_context, "评估维度": list(self.personality_traits.keys()), "评分": scores}
|
||||
)
|
||||
|
||||
# 更新总分和历史记录
|
||||
for dimension, score in scores.items():
|
||||
if score > 0: # 只统计大于0的有效分数
|
||||
final_scores[dimension] += score
|
||||
dimension_counts[dimension] += 1
|
||||
self.trait_scores_history[dimension].append(score)
|
||||
|
||||
# 计算平均分
|
||||
average_scores = {}
|
||||
for dimension in self.personality_traits:
|
||||
if dimension_counts[dimension] > 0:
|
||||
average_scores[dimension] = round(final_scores[dimension] / dimension_counts[dimension], 2)
|
||||
else:
|
||||
average_scores[dimension] = 0 # 如果没有有效分数,返回0
|
||||
|
||||
# 生成趋势图
|
||||
self._generate_trend_plot(qq_id, user_nickname)
|
||||
|
||||
result = {
|
||||
"用户QQ": qq_id,
|
||||
"用户昵称": user_nickname,
|
||||
"样本数量": len(chat_samples),
|
||||
"人格特征评分": average_scores,
|
||||
"维度评估次数": dict(dimension_counts),
|
||||
"详细样本": chat_samples,
|
||||
"特质得分历史": {k: v for k, v in self.trait_scores_history.items()},
|
||||
}
|
||||
|
||||
# 保存结果
|
||||
os.makedirs("results", exist_ok=True)
|
||||
result_file = f"results/personality_result_{qq_id}.json"
|
||||
with open(result_file, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return result
|
||||
|
||||
def _generate_trend_plot(self, qq_id: str, user_nickname: str):
|
||||
"""
|
||||
生成人格特质累计平均分变化趋势图
|
||||
"""
|
||||
# 查找系统中可用的中文字体
|
||||
chinese_fonts = []
|
||||
for f in fm.fontManager.ttflist:
|
||||
try:
|
||||
if "简" in f.name or "SC" in f.name or "黑" in f.name or "宋" in f.name or "微软" in f.name:
|
||||
chinese_fonts.append(f.name)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if chinese_fonts:
|
||||
plt.rcParams["font.sans-serif"] = chinese_fonts + ["SimHei", "Microsoft YaHei", "Arial Unicode MS"]
|
||||
else:
|
||||
# 如果没有找到中文字体,使用默认字体,并将中文昵称转换为拼音或英文
|
||||
try:
|
||||
from pypinyin import lazy_pinyin
|
||||
|
||||
user_nickname = "".join(lazy_pinyin(user_nickname))
|
||||
except ImportError:
|
||||
user_nickname = "User" # 如果无法转换为拼音,使用默认英文
|
||||
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
|
||||
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.style.use("bmh") # 使用内置的bmh样式,它有类似seaborn的美观效果
|
||||
|
||||
colors = {
|
||||
"开放性": "#FF9999",
|
||||
"严谨性": "#66B2FF",
|
||||
"外向性": "#99FF99",
|
||||
"宜人性": "#FFCC99",
|
||||
"神经质": "#FF99CC",
|
||||
}
|
||||
|
||||
# 计算每个维度在每个时间点的累计平均分
|
||||
cumulative_averages = {}
|
||||
for trait, scores in self.trait_scores_history.items():
|
||||
if not scores:
|
||||
continue
|
||||
|
||||
averages = []
|
||||
total = 0
|
||||
valid_count = 0
|
||||
for score in scores:
|
||||
if score > 0: # 只计算大于0的有效分数
|
||||
total += score
|
||||
valid_count += 1
|
||||
if valid_count > 0:
|
||||
averages.append(total / valid_count)
|
||||
else:
|
||||
# 如果当前分数无效,使用前一个有效的平均分
|
||||
if averages:
|
||||
averages.append(averages[-1])
|
||||
else:
|
||||
continue # 跳过无效分数
|
||||
|
||||
if averages: # 只有在有有效分数的情况下才添加到累计平均中
|
||||
cumulative_averages[trait] = averages
|
||||
|
||||
# 绘制每个维度的累计平均分变化趋势
|
||||
for trait, averages in cumulative_averages.items():
|
||||
x = range(1, len(averages) + 1)
|
||||
plt.plot(x, averages, "o-", label=trait, color=colors.get(trait), linewidth=2, markersize=8)
|
||||
|
||||
# 添加趋势线
|
||||
z = np.polyfit(x, averages, 1)
|
||||
p = np.poly1d(z)
|
||||
plt.plot(x, p(x), "--", color=colors.get(trait), alpha=0.5)
|
||||
|
||||
plt.title(f"{user_nickname} 的人格特质累计平均分变化趋势", fontsize=14, pad=20)
|
||||
plt.xlabel("评估次数", fontsize=12)
|
||||
plt.ylabel("累计平均分", fontsize=12)
|
||||
plt.grid(True, linestyle="--", alpha=0.7)
|
||||
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||
plt.ylim(0, 7)
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图表
|
||||
os.makedirs("results/plots", exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
plot_file = f"results/plots/personality_trend_{qq_id}_{timestamp}.png"
|
||||
plt.savefig(plot_file, dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length: int = 5) -> str:
|
||||
"""
|
||||
分析用户人格特征的便捷函数
|
||||
|
||||
Args:
|
||||
qq_id (str): 用户QQ号
|
||||
num_samples (int): 要分析的聊天片段数量
|
||||
context_length (int): 每个聊天片段的上下文长度
|
||||
|
||||
Returns:
|
||||
str: 格式化的分析结果
|
||||
"""
|
||||
evaluator = ChatBasedPersonalityEvaluator()
|
||||
result = evaluator.evaluate_user_personality(qq_id, num_samples, context_length)
|
||||
|
||||
if "error" in result:
|
||||
return result["error"]
|
||||
|
||||
# 格式化输出
|
||||
output = f"QQ号 {qq_id} ({result['用户昵称']}) 的人格特征分析结果:\n"
|
||||
output += "=" * 50 + "\n\n"
|
||||
|
||||
output += "人格特征评分:\n"
|
||||
for trait, score in result["人格特征评分"].items():
|
||||
if score == 0:
|
||||
output += f"{trait}: 数据不足,无法判断 (评估次数: {result['维度评估次数'].get(trait, 0)})\n"
|
||||
else:
|
||||
output += f"{trait}: {score}/6 (评估次数: {result['维度评估次数'].get(trait, 0)})\n"
|
||||
|
||||
# 添加变化趋势描述
|
||||
if trait in result["特质得分历史"] and len(result["特质得分历史"][trait]) > 1:
|
||||
scores = [s for s in result["特质得分历史"][trait] if s != 0] # 过滤掉无效分数
|
||||
if len(scores) > 1: # 确保有足够的有效分数计算趋势
|
||||
trend = np.polyfit(range(len(scores)), scores, 1)[0]
|
||||
if abs(trend) < 0.1:
|
||||
trend_desc = "保持稳定"
|
||||
elif trend > 0:
|
||||
trend_desc = "呈上升趋势"
|
||||
else:
|
||||
trend_desc = "呈下降趋势"
|
||||
output += f" 变化趋势: {trend_desc} (斜率: {trend:.2f})\n"
|
||||
|
||||
output += f"\n分析样本数量:{result['样本数量']}\n"
|
||||
output += f"结果已保存至:results/personality_result_{qq_id}.json\n"
|
||||
output += "变化趋势图已保存至:results/plots/目录\n"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
# test_qq = "" # 替换为要测试的QQ号
|
||||
# print(analyze_user_personality(test_qq, num_samples=30, context_length=20))
|
||||
# test_qq = ""
|
||||
# print(analyze_user_personality(test_qq, num_samples=30, context_length=20))
|
||||
test_qq = "1026294844"
|
||||
print(analyze_user_personality(test_qq, num_samples=30, context_length=30))
|
||||
@@ -1,349 +0,0 @@
|
||||
from typing import Dict
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import random
|
||||
from scipy import stats # 添加scipy导入用于t检验
|
||||
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
project_root = current_dir.parent.parent.parent
|
||||
env_path = project_root / ".env"
|
||||
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.plugins.personality.big5_test import BigFiveTest # noqa: E402
|
||||
from src.plugins.personality.renqingziji import PersonalityEvaluator_direct # noqa: E402
|
||||
from src.plugins.personality.questionnaire import FACTOR_DESCRIPTIONS, PERSONALITY_QUESTIONS # noqa: E402
|
||||
|
||||
|
||||
class CombinedPersonalityTest:
|
||||
def __init__(self):
|
||||
self.big5_test = BigFiveTest()
|
||||
self.scenario_test = PersonalityEvaluator_direct()
|
||||
self.dimensions = ["开放性", "严谨性", "外向性", "宜人性", "神经质"]
|
||||
|
||||
def run_combined_test(self):
|
||||
"""运行组合测试"""
|
||||
print("\n=== 人格特征综合评估系统 ===")
|
||||
print("\n本测试将通过两种方式评估人格特征:")
|
||||
print("1. 传统问卷测评(约40题)")
|
||||
print("2. 情景反应测评(15个场景)")
|
||||
print("\n两种测评完成后,将对比分析结果的异同。")
|
||||
input("\n准备好开始第一部分(问卷测评)了吗?按回车继续...")
|
||||
|
||||
# 运行问卷测试
|
||||
print("\n=== 第一部分:问卷测评 ===")
|
||||
print("本部分采用六级评分,请根据每个描述与您的符合程度进行打分:")
|
||||
print("1 = 完全不符合")
|
||||
print("2 = 比较不符合")
|
||||
print("3 = 有点不符合")
|
||||
print("4 = 有点符合")
|
||||
print("5 = 比较符合")
|
||||
print("6 = 完全符合")
|
||||
print("\n重要提示:您可以选择以下两种方式之一来回答问题:")
|
||||
print("1. 根据您自身的真实情况来回答")
|
||||
print("2. 根据您想要扮演的角色特征来回答")
|
||||
print("\n无论选择哪种方式,请保持一致并认真回答每个问题。")
|
||||
input("\n按回车开始答题...")
|
||||
|
||||
questionnaire_results = self.run_questionnaire()
|
||||
|
||||
# 转换问卷结果格式以便比较
|
||||
questionnaire_scores = {factor: data["得分"] for factor, data in questionnaire_results.items()}
|
||||
|
||||
# 运行情景测试
|
||||
print("\n=== 第二部分:情景反应测评 ===")
|
||||
print("接下来,您将面对一系列具体场景,请描述您在每个场景中可能的反应。")
|
||||
print("每个场景都会评估不同的人格维度,共15个场景。")
|
||||
print("您可以选择提供自己的真实反应,也可以选择扮演一个您创作的角色来回答。")
|
||||
input("\n准备好开始了吗?按回车继续...")
|
||||
|
||||
scenario_results = self.run_scenario_test()
|
||||
|
||||
# 比较和展示结果
|
||||
self.compare_and_display_results(questionnaire_scores, scenario_results)
|
||||
|
||||
# 保存结果
|
||||
self.save_results(questionnaire_scores, scenario_results)
|
||||
|
||||
def run_questionnaire(self):
|
||||
"""运行问卷测试部分"""
|
||||
# 创建题目序号到题目的映射
|
||||
questions_map = {q["id"]: q for q in PERSONALITY_QUESTIONS}
|
||||
|
||||
# 获取所有题目ID并随机打乱顺序
|
||||
question_ids = list(questions_map.keys())
|
||||
random.shuffle(question_ids)
|
||||
|
||||
answers = {}
|
||||
total_questions = len(question_ids)
|
||||
|
||||
for i, question_id in enumerate(question_ids, 1):
|
||||
question = questions_map[question_id]
|
||||
while True:
|
||||
try:
|
||||
print(f"\n问题 [{i}/{total_questions}]")
|
||||
print(f"{question['content']}")
|
||||
score = int(input("您的评分(1-6): "))
|
||||
if 1 <= score <= 6:
|
||||
answers[question_id] = score
|
||||
break
|
||||
else:
|
||||
print("请输入1-6之间的数字!")
|
||||
except ValueError:
|
||||
print("请输入有效的数字!")
|
||||
|
||||
# 每10题显示一次进度
|
||||
if i % 10 == 0:
|
||||
print(f"\n已完成 {i}/{total_questions} 题 ({int(i / total_questions * 100)}%)")
|
||||
|
||||
return self.calculate_questionnaire_scores(answers)
|
||||
|
||||
def calculate_questionnaire_scores(self, answers):
|
||||
"""计算问卷测试的维度得分"""
|
||||
results = {}
|
||||
factor_questions = {"外向性": [], "神经质": [], "严谨性": [], "开放性": [], "宜人性": []}
|
||||
|
||||
# 将题目按因子分类
|
||||
for q in PERSONALITY_QUESTIONS:
|
||||
factor_questions[q["factor"]].append(q)
|
||||
|
||||
# 计算每个维度的得分
|
||||
for factor, questions in factor_questions.items():
|
||||
total_score = 0
|
||||
for q in questions:
|
||||
score = answers[q["id"]]
|
||||
# 处理反向计分题目
|
||||
if q["reverse_scoring"]:
|
||||
score = 7 - score # 6分量表反向计分为7减原始分
|
||||
total_score += score
|
||||
|
||||
# 计算平均分
|
||||
avg_score = round(total_score / len(questions), 2)
|
||||
results[factor] = {"得分": avg_score, "题目数": len(questions), "总分": total_score}
|
||||
|
||||
return results
|
||||
|
||||
def run_scenario_test(self):
|
||||
"""运行情景测试部分"""
|
||||
final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||
dimension_counts = {trait: 0 for trait in final_scores.keys()}
|
||||
|
||||
# 随机打乱场景顺序
|
||||
scenarios = self.scenario_test.scenarios.copy()
|
||||
random.shuffle(scenarios)
|
||||
|
||||
for i, scenario_data in enumerate(scenarios, 1):
|
||||
print(f"\n场景 [{i}/{len(scenarios)}] - {scenario_data['场景编号']}")
|
||||
print("-" * 50)
|
||||
print(scenario_data["场景"])
|
||||
print("\n请描述您在这种情况下会如何反应:")
|
||||
response = input().strip()
|
||||
|
||||
if not response:
|
||||
print("反应描述不能为空!")
|
||||
continue
|
||||
|
||||
print("\n正在评估您的描述...")
|
||||
scores = self.scenario_test.evaluate_response(scenario_data["场景"], response, scenario_data["评估维度"])
|
||||
|
||||
# 更新分数
|
||||
for dimension, score in scores.items():
|
||||
final_scores[dimension] += score
|
||||
dimension_counts[dimension] += 1
|
||||
|
||||
# print("\n当前场景评估结果:")
|
||||
# print("-" * 30)
|
||||
# for dimension, score in scores.items():
|
||||
# print(f"{dimension}: {score}/6")
|
||||
|
||||
# 每5个场景显示一次总进度
|
||||
if i % 5 == 0:
|
||||
print(f"\n已完成 {i}/{len(scenarios)} 个场景 ({int(i / len(scenarios) * 100)}%)")
|
||||
|
||||
if i < len(scenarios):
|
||||
input("\n按回车继续下一个场景...")
|
||||
|
||||
# 计算平均分
|
||||
for dimension in final_scores:
|
||||
if dimension_counts[dimension] > 0:
|
||||
final_scores[dimension] = round(final_scores[dimension] / dimension_counts[dimension], 2)
|
||||
|
||||
return final_scores
|
||||
|
||||
def compare_and_display_results(self, questionnaire_scores: Dict, scenario_scores: Dict):
|
||||
"""比较和展示两种测试的结果"""
|
||||
print("\n=== 测评结果对比分析 ===")
|
||||
print("\n" + "=" * 60)
|
||||
print(f"{'维度':<8} {'问卷得分':>10} {'情景得分':>10} {'差异':>10} {'差异程度':>10}")
|
||||
print("-" * 60)
|
||||
|
||||
# 收集每个维度的得分用于统计分析
|
||||
questionnaire_values = []
|
||||
scenario_values = []
|
||||
diffs = []
|
||||
|
||||
for dimension in self.dimensions:
|
||||
q_score = questionnaire_scores[dimension]
|
||||
s_score = scenario_scores[dimension]
|
||||
diff = round(abs(q_score - s_score), 2)
|
||||
|
||||
questionnaire_values.append(q_score)
|
||||
scenario_values.append(s_score)
|
||||
diffs.append(diff)
|
||||
|
||||
# 计算差异程度
|
||||
diff_level = "低" if diff < 0.5 else "中" if diff < 1.0 else "高"
|
||||
print(f"{dimension:<8} {q_score:>10.2f} {s_score:>10.2f} {diff:>10.2f} {diff_level:>10}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# 计算整体统计指标
|
||||
mean_diff = sum(diffs) / len(diffs)
|
||||
std_diff = (sum((x - mean_diff) ** 2 for x in diffs) / (len(diffs) - 1)) ** 0.5
|
||||
|
||||
# 计算效应量 (Cohen's d)
|
||||
pooled_std = (
|
||||
(
|
||||
sum((x - sum(questionnaire_values) / len(questionnaire_values)) ** 2 for x in questionnaire_values)
|
||||
+ sum((x - sum(scenario_values) / len(scenario_values)) ** 2 for x in scenario_values)
|
||||
)
|
||||
/ (2 * len(self.dimensions) - 2)
|
||||
) ** 0.5
|
||||
|
||||
if pooled_std != 0:
|
||||
cohens_d = abs(mean_diff / pooled_std)
|
||||
|
||||
# 解释效应量
|
||||
if cohens_d < 0.2:
|
||||
effect_size = "微小"
|
||||
elif cohens_d < 0.5:
|
||||
effect_size = "小"
|
||||
elif cohens_d < 0.8:
|
||||
effect_size = "中等"
|
||||
else:
|
||||
effect_size = "大"
|
||||
|
||||
# 对所有维度进行整体t检验
|
||||
t_stat, p_value = stats.ttest_rel(questionnaire_values, scenario_values)
|
||||
print("\n整体统计分析:")
|
||||
print(f"平均差异: {mean_diff:.3f}")
|
||||
print(f"差异标准差: {std_diff:.3f}")
|
||||
print(f"效应量(Cohen's d): {cohens_d:.3f}")
|
||||
print(f"效应量大小: {effect_size}")
|
||||
print(f"t统计量: {t_stat:.3f}")
|
||||
print(f"p值: {p_value:.3f}")
|
||||
|
||||
if p_value < 0.05:
|
||||
print("结论: 两种测评方法的结果存在显著差异 (p < 0.05)")
|
||||
else:
|
||||
print("结论: 两种测评方法的结果无显著差异 (p >= 0.05)")
|
||||
|
||||
print("\n维度说明:")
|
||||
for dimension in self.dimensions:
|
||||
print(f"\n{dimension}:")
|
||||
desc = FACTOR_DESCRIPTIONS[dimension]
|
||||
print(f"定义:{desc['description']}")
|
||||
print(f"特征词:{', '.join(desc['trait_words'])}")
|
||||
|
||||
# 分析显著差异
|
||||
significant_diffs = []
|
||||
for dimension in self.dimensions:
|
||||
diff = abs(questionnaire_scores[dimension] - scenario_scores[dimension])
|
||||
if diff >= 1.0: # 差异大于等于1分视为显著
|
||||
significant_diffs.append(
|
||||
{
|
||||
"dimension": dimension,
|
||||
"diff": diff,
|
||||
"questionnaire": questionnaire_scores[dimension],
|
||||
"scenario": scenario_scores[dimension],
|
||||
}
|
||||
)
|
||||
|
||||
if significant_diffs:
|
||||
print("\n\n显著差异分析:")
|
||||
print("-" * 40)
|
||||
for diff in significant_diffs:
|
||||
print(f"\n{diff['dimension']}维度的测评结果存在显著差异:")
|
||||
print(f"问卷得分:{diff['questionnaire']:.2f}")
|
||||
print(f"情景得分:{diff['scenario']:.2f}")
|
||||
print(f"差异值:{diff['diff']:.2f}")
|
||||
|
||||
# 分析可能的原因
|
||||
if diff["questionnaire"] > diff["scenario"]:
|
||||
print("可能原因:在问卷中的自我评价较高,但在具体情景中的表现较为保守。")
|
||||
else:
|
||||
print("可能原因:在具体情景中表现出更多该维度特征,而在问卷自评时较为保守。")
|
||||
|
||||
def save_results(self, questionnaire_scores: Dict, scenario_scores: Dict):
|
||||
"""保存测试结果"""
|
||||
results = {
|
||||
"测试时间": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"问卷测评结果": questionnaire_scores,
|
||||
"情景测评结果": scenario_scores,
|
||||
"维度说明": FACTOR_DESCRIPTIONS,
|
||||
}
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
# 生成带时间戳的文件名
|
||||
filename = f"results/personality_combined_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
# 保存到文件
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n完整的测评结果已保存到:{filename}")
|
||||
|
||||
|
||||
def load_existing_results():
|
||||
"""检查并加载已有的测试结果"""
|
||||
results_dir = "results"
|
||||
if not os.path.exists(results_dir):
|
||||
return None
|
||||
|
||||
# 获取所有personality_combined开头的文件
|
||||
result_files = [f for f in os.listdir(results_dir) if f.startswith("personality_combined_") and f.endswith(".json")]
|
||||
|
||||
if not result_files:
|
||||
return None
|
||||
|
||||
# 按文件修改时间排序,获取最新的结果文件
|
||||
latest_file = max(result_files, key=lambda f: os.path.getmtime(os.path.join(results_dir, f)))
|
||||
|
||||
print(f"\n发现已有的测试结果:{latest_file}")
|
||||
try:
|
||||
with open(os.path.join(results_dir, latest_file), "r", encoding="utf-8") as f:
|
||||
results = json.load(f)
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"读取结果文件时出错:{str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
test = CombinedPersonalityTest()
|
||||
|
||||
# 检查是否存在已有结果
|
||||
existing_results = load_existing_results()
|
||||
|
||||
if existing_results:
|
||||
print("\n=== 使用已有测试结果进行分析 ===")
|
||||
print(f"测试时间:{existing_results['测试时间']}")
|
||||
|
||||
questionnaire_scores = existing_results["问卷测评结果"]
|
||||
scenario_scores = existing_results["情景测评结果"]
|
||||
|
||||
# 直接进行结果对比分析
|
||||
test.compare_and_display_results(questionnaire_scores, scenario_scores)
|
||||
else:
|
||||
print("\n未找到已有的测试结果,开始新的测试...")
|
||||
test.run_combined_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,123 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("offline_llm")
|
||||
|
||||
|
||||
class LLMModel:
|
||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-V3", **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
if not self.api_key or not self.base_url:
|
||||
raise ValueError("环境变量未正确加载:SILICONFLOW_KEY 或 SILICONFLOW_BASE_URL 未设置")
|
||||
|
||||
logger.info(f"API URL: {self.base_url}") # 使用 logger 记录 base_url
|
||||
|
||||
def generate_response(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||
"""根据输入的提示生成模型的响应"""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
**self.params,
|
||||
}
|
||||
|
||||
# 发送请求到完整的 chat/completions 端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15 # 基础等待时间(秒)
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
|
||||
if response.status_code == 429:
|
||||
wait_time = base_wait_time * (2**retry) # 指数退避
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
result = response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2**retry)
|
||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"请求失败: {str(e)}")
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
**self.params,
|
||||
}
|
||||
|
||||
# 发送请求到完整的 chat/completions 端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status == 429:
|
||||
wait_time = base_wait_time * (2**retry) # 指数退避
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2**retry)
|
||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"请求失败: {str(e)}")
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
@@ -1,142 +0,0 @@
|
||||
# 人格测试问卷题目
|
||||
# 王孟成, 戴晓阳, & 姚树桥. (2011).
|
||||
# 中国大五人格问卷的初步编制Ⅲ:简式版的制定及信效度检验. 中国临床心理学杂志, 19(04), Article 04.
|
||||
|
||||
# 王孟成, 戴晓阳, & 姚树桥. (2010).
|
||||
# 中国大五人格问卷的初步编制Ⅰ:理论框架与信度分析. 中国临床心理学杂志, 18(05), Article 05.
|
||||
|
||||
PERSONALITY_QUESTIONS = [
|
||||
# 神经质维度 (F1)
|
||||
{"id": 1, "content": "我常担心有什么不好的事情要发生", "factor": "神经质", "reverse_scoring": False},
|
||||
{"id": 2, "content": "我常感到害怕", "factor": "神经质", "reverse_scoring": False},
|
||||
{"id": 3, "content": "有时我觉得自己一无是处", "factor": "神经质", "reverse_scoring": False},
|
||||
{"id": 4, "content": "我很少感到忧郁或沮丧", "factor": "神经质", "reverse_scoring": True},
|
||||
{"id": 5, "content": "别人一句漫不经心的话,我常会联系在自己身上", "factor": "神经质", "reverse_scoring": False},
|
||||
{"id": 6, "content": "在面对压力时,我有种快要崩溃的感觉", "factor": "神经质", "reverse_scoring": False},
|
||||
{"id": 7, "content": "我常担忧一些无关紧要的事情", "factor": "神经质", "reverse_scoring": False},
|
||||
{"id": 8, "content": "我常常感到内心不踏实", "factor": "神经质", "reverse_scoring": False},
|
||||
# 严谨性维度 (F2)
|
||||
{"id": 9, "content": "在工作上,我常只求能应付过去便可", "factor": "严谨性", "reverse_scoring": True},
|
||||
{"id": 10, "content": "一旦确定了目标,我会坚持努力地实现它", "factor": "严谨性", "reverse_scoring": False},
|
||||
{"id": 11, "content": "我常常是仔细考虑之后才做出决定", "factor": "严谨性", "reverse_scoring": False},
|
||||
{"id": 12, "content": "别人认为我是个慎重的人", "factor": "严谨性", "reverse_scoring": False},
|
||||
{"id": 13, "content": "做事讲究逻辑和条理是我的一个特点", "factor": "严谨性", "reverse_scoring": False},
|
||||
{"id": 14, "content": "我喜欢一开头就把事情计划好", "factor": "严谨性", "reverse_scoring": False},
|
||||
{"id": 15, "content": "我工作或学习很勤奋", "factor": "严谨性", "reverse_scoring": False},
|
||||
{"id": 16, "content": "我是个倾尽全力做事的人", "factor": "严谨性", "reverse_scoring": False},
|
||||
# 宜人性维度 (F3)
|
||||
{
|
||||
"id": 17,
|
||||
"content": "尽管人类社会存在着一些阴暗的东西(如战争、罪恶、欺诈),我仍然相信人性总的来说是善良的",
|
||||
"factor": "宜人性",
|
||||
"reverse_scoring": False,
|
||||
},
|
||||
{"id": 18, "content": "我觉得大部分人基本上是心怀善意的", "factor": "宜人性", "reverse_scoring": False},
|
||||
{"id": 19, "content": "虽然社会上有骗子,但我觉得大部分人还是可信的", "factor": "宜人性", "reverse_scoring": False},
|
||||
{"id": 20, "content": "我不太关心别人是否受到不公正的待遇", "factor": "宜人性", "reverse_scoring": True},
|
||||
{"id": 21, "content": "我时常觉得别人的痛苦与我无关", "factor": "宜人性", "reverse_scoring": True},
|
||||
{"id": 22, "content": "我常为那些遭遇不幸的人感到难过", "factor": "宜人性", "reverse_scoring": False},
|
||||
{"id": 23, "content": "我是那种只照顾好自己,不替别人担忧的人", "factor": "宜人性", "reverse_scoring": True},
|
||||
{"id": 24, "content": "当别人向我诉说不幸时,我常感到难过", "factor": "宜人性", "reverse_scoring": False},
|
||||
# 开放性维度 (F4)
|
||||
{"id": 25, "content": "我的想象力相当丰富", "factor": "开放性", "reverse_scoring": False},
|
||||
{"id": 26, "content": "我头脑中经常充满生动的画面", "factor": "开放性", "reverse_scoring": False},
|
||||
{"id": 27, "content": "我对许多事情有着很强的好奇心", "factor": "开放性", "reverse_scoring": False},
|
||||
{"id": 28, "content": "我喜欢冒险", "factor": "开放性", "reverse_scoring": False},
|
||||
{"id": 29, "content": "我是个勇于冒险,突破常规的人", "factor": "开放性", "reverse_scoring": False},
|
||||
{"id": 30, "content": "我身上具有别人没有的冒险精神", "factor": "开放性", "reverse_scoring": False},
|
||||
{
|
||||
"id": 31,
|
||||
"content": "我渴望学习一些新东西,即使它们与我的日常生活无关",
|
||||
"factor": "开放性",
|
||||
"reverse_scoring": False,
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"content": "我很愿意也很容易接受那些新事物、新观点、新想法",
|
||||
"factor": "开放性",
|
||||
"reverse_scoring": False,
|
||||
},
|
||||
# 外向性维度 (F5)
|
||||
{"id": 33, "content": "我喜欢参加社交与娱乐聚会", "factor": "外向性", "reverse_scoring": False},
|
||||
{"id": 34, "content": "我对人多的聚会感到乏味", "factor": "外向性", "reverse_scoring": True},
|
||||
{"id": 35, "content": "我尽量避免参加人多的聚会和嘈杂的环境", "factor": "外向性", "reverse_scoring": True},
|
||||
{"id": 36, "content": "在热闹的聚会上,我常常表现主动并尽情玩耍", "factor": "外向性", "reverse_scoring": False},
|
||||
{"id": 37, "content": "有我在的场合一般不会冷场", "factor": "外向性", "reverse_scoring": False},
|
||||
{"id": 38, "content": "我希望成为领导者而不是被领导者", "factor": "外向性", "reverse_scoring": False},
|
||||
{"id": 39, "content": "在一个团体中,我希望处于领导地位", "factor": "外向性", "reverse_scoring": False},
|
||||
{"id": 40, "content": "别人多认为我是一个热情和友好的人", "factor": "外向性", "reverse_scoring": False},
|
||||
]
|
||||
|
||||
# 因子维度说明
|
||||
FACTOR_DESCRIPTIONS = {
|
||||
"外向性": {
|
||||
"description": "反映个体神经系统的强弱和动力特征。外向性主要表现为个体在人际交往和社交活动中的倾向性,"
|
||||
"包括对社交活动的兴趣、"
|
||||
"对人群的态度、社交互动中的主动程度以及在群体中的影响力。高分者倾向于积极参与社交活动,乐于与人交往,善于表达自我,"
|
||||
"并往往在群体中发挥领导作用;低分者则倾向于独处,不喜欢热闹的社交场合,表现出内向、安静的特征。",
|
||||
"trait_words": ["热情", "活力", "社交", "主动"],
|
||||
"subfactors": {
|
||||
"合群性": "个体愿意与他人聚在一起,即接近人群的倾向;高分表现乐群、好交际,低分表现封闭、独处",
|
||||
"热情": "个体对待别人时所表现出的态度;高分表现热情好客,低分表现冷淡",
|
||||
"支配性": "个体喜欢指使、操纵他人,倾向于领导别人的特点;高分表现好强、发号施令,低分表现顺从、低调",
|
||||
"活跃": "个体精力充沛,活跃、主动性等特点;高分表现活跃,低分表现安静",
|
||||
},
|
||||
},
|
||||
"神经质": {
|
||||
"description": "反映个体情绪的状态和体验内心苦恼的倾向性。这个维度主要关注个体在面对压力、"
|
||||
"挫折和日常生活挑战时的情绪稳定性和适应能力。它包含了对焦虑、抑郁、愤怒等负面情绪的敏感程度,"
|
||||
"以及个体对这些情绪的调节和控制能力。高分者容易体验负面情绪,对压力较为敏感,情绪波动较大;"
|
||||
"低分者则表现出较强的情绪稳定性,能够较好地应对压力和挫折。",
|
||||
"trait_words": ["稳定", "沉着", "从容", "坚韧"],
|
||||
"subfactors": {
|
||||
"焦虑": "个体体验焦虑感的个体差异;高分表现坐立不安,低分表现平静",
|
||||
"抑郁": "个体体验抑郁情感的个体差异;高分表现郁郁寡欢,低分表现平静",
|
||||
"敏感多疑": "个体常常关注自己的内心活动,行为和过于意识人对自己的看法、评价;高分表现敏感多疑,"
|
||||
"低分表现淡定、自信",
|
||||
"脆弱性": "个体在危机或困难面前无力、脆弱的特点;高分表现无能、易受伤、逃避,低分表现坚强",
|
||||
"愤怒-敌意": "个体准备体验愤怒,及相关情绪的状态;高分表现暴躁易怒,低分表现平静",
|
||||
},
|
||||
},
|
||||
"严谨性": {
|
||||
"description": "反映个体在目标导向行为上的组织、坚持和动机特征。这个维度体现了个体在工作、"
|
||||
"学习等目标性活动中的自我约束和行为管理能力。它涉及到个体的责任感、自律性、计划性、条理性以及完成任务的态度。"
|
||||
"高分者往往表现出强烈的责任心、良好的组织能力、谨慎的决策风格和持续的努力精神;低分者则可能表现出随意性强、"
|
||||
"缺乏规划、做事马虎或易放弃的特点。",
|
||||
"trait_words": ["负责", "自律", "条理", "勤奋"],
|
||||
"subfactors": {
|
||||
"责任心": "个体对待任务和他人认真负责,以及对自己承诺的信守;高分表现有责任心、负责任,"
|
||||
"低分表现推卸责任、逃避处罚",
|
||||
"自我控制": "个体约束自己的能力,及自始至终的坚持性;高分表现自制、有毅力,低分表现冲动、无毅力",
|
||||
"审慎性": "个体在采取具体行动前的心理状态;高分表现谨慎、小心,低分表现鲁莽、草率",
|
||||
"条理性": "个体处理事务和工作的秩序,条理和逻辑性;高分表现整洁、有秩序,低分表现混乱、遗漏",
|
||||
"勤奋": "个体工作和学习的努力程度及为达到目标而表现出的进取精神;高分表现勤奋、刻苦,低分表现懒散",
|
||||
},
|
||||
},
|
||||
"开放性": {
|
||||
"description": "反映个体对新异事物、新观念和新经验的接受程度,以及在思维和行为方面的创新倾向。"
|
||||
"这个维度体现了个体在认知和体验方面的广度、深度和灵活性。它包括对艺术的欣赏能力、对知识的求知欲、想象力的丰富程度,"
|
||||
"以及对冒险和创新的态度。高分者往往具有丰富的想象力、广泛的兴趣、开放的思维方式和创新的倾向;低分者则倾向于保守、"
|
||||
"传统,喜欢熟悉和常规的事物。",
|
||||
"trait_words": ["创新", "好奇", "艺术", "冒险"],
|
||||
"subfactors": {
|
||||
"幻想": "个体富于幻想和想象的水平;高分表现想象力丰富,低分表现想象力匮乏",
|
||||
"审美": "个体对于艺术和美的敏感与热爱程度;高分表现富有艺术气息,低分表现一般对艺术不敏感",
|
||||
"好奇心": "个体对未知事物的态度;高分表现兴趣广泛、好奇心浓,低分表现兴趣少、无好奇心",
|
||||
"冒险精神": "个体愿意尝试有风险活动的个体差异;高分表现好冒险,低分表现保守",
|
||||
"价值观念": "个体对新事物、新观念、怪异想法的态度;高分表现开放、坦然接受新事物,低分则相反",
|
||||
},
|
||||
},
|
||||
"宜人性": {
|
||||
"description": "反映个体在人际关系中的亲和倾向,体现了对他人的关心、同情和合作意愿。"
|
||||
"这个维度主要关注个体与他人互动时的态度和行为特征,包括对他人的信任程度、同理心水平、"
|
||||
"助人意愿以及在人际冲突中的处理方式。高分者通常表现出友善、富有同情心、乐于助人的特质,善于与他人建立和谐关系;"
|
||||
"低分者则可能表现出较少的人际关注,在社交互动中更注重自身利益,较少考虑他人感受。",
|
||||
"trait_words": ["友善", "同理", "信任", "合作"],
|
||||
"subfactors": {
|
||||
"信任": "个体对他人和/或他人言论的相信程度;高分表现信任他人,低分表现怀疑",
|
||||
"体贴": "个体对别人的兴趣和需要的关注程度;高分表现体贴、温存,低分表现冷漠、不在乎",
|
||||
"同情": "个体对处于不利地位的人或物的态度;高分表现富有同情心,低分表现冷漠",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
"""
|
||||
The definition of artificial personality in this paper follows the dispositional para-digm and adapts a definition of
|
||||
personality developed for humans [17]:
|
||||
Personality for a human is the "whole and organisation of relatively stable tendencies and patterns of experience and
|
||||
behaviour within one person (distinguishing it from other persons)". This definition is modified for artificial
|
||||
personality:
|
||||
Artificial personality describes the relatively stable tendencies and patterns of behav-iour of an AI-based machine that
|
||||
can be designed by developers and designers via different modalities, such as language, creating the impression
|
||||
of individuality of a humanized social agent when users interact with the machine."""
|
||||
|
||||
from typing import Dict, List
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
import sys
|
||||
|
||||
"""
|
||||
第一种方案:基于情景评估的人格测定
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
project_root = current_dir.parent.parent.parent
|
||||
env_path = project_root / ".env"
|
||||
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.plugins.personality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa: E402
|
||||
from src.plugins.personality.questionnaire import FACTOR_DESCRIPTIONS # noqa: E402
|
||||
from src.plugins.personality.offline_llm import LLMModel # noqa: E402
|
||||
|
||||
# 加载环境变量
|
||||
if env_path.exists():
|
||||
print(f"从 {env_path} 加载环境变量")
|
||||
load_dotenv(env_path)
|
||||
else:
|
||||
print(f"未找到环境变量文件: {env_path}")
|
||||
print("将使用默认配置")
|
||||
|
||||
|
||||
class PersonalityEvaluatorDirect:
|
||||
def __init__(self):
|
||||
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||
self.scenarios = []
|
||||
|
||||
# 为每个人格特质获取对应的场景
|
||||
for trait in PERSONALITY_SCENES:
|
||||
scenes = get_scene_by_factor(trait)
|
||||
if not scenes:
|
||||
continue
|
||||
|
||||
# 从每个维度选择3个场景
|
||||
import random
|
||||
|
||||
scene_keys = list(scenes.keys())
|
||||
selected_scenes = random.sample(scene_keys, min(3, len(scene_keys)))
|
||||
|
||||
for scene_key in selected_scenes:
|
||||
scene = scenes[scene_key]
|
||||
|
||||
# 为每个场景添加评估维度
|
||||
# 主维度是当前特质,次维度随机选择一个其他特质
|
||||
other_traits = [t for t in PERSONALITY_SCENES if t != trait]
|
||||
secondary_trait = random.choice(other_traits)
|
||||
|
||||
self.scenarios.append(
|
||||
{"场景": scene["scenario"], "评估维度": [trait, secondary_trait], "场景编号": scene_key}
|
||||
)
|
||||
|
||||
self.llm = LLMModel()
|
||||
|
||||
def evaluate_response(self, scenario: str, response: str, dimensions: List[str]) -> Dict[str, float]:
|
||||
"""
|
||||
使用 DeepSeek AI 评估用户对特定场景的反应
|
||||
"""
|
||||
# 构建维度描述
|
||||
dimension_descriptions = []
|
||||
for dim in dimensions:
|
||||
desc = FACTOR_DESCRIPTIONS.get(dim, "")
|
||||
if desc:
|
||||
dimension_descriptions.append(f"- {dim}:{desc}")
|
||||
|
||||
dimensions_text = "\n".join(dimension_descriptions)
|
||||
|
||||
prompt = f"""请根据以下场景和用户描述,评估用户在大五人格模型中的相关维度得分(1-6分)。
|
||||
|
||||
场景描述:
|
||||
{scenario}
|
||||
|
||||
用户回应:
|
||||
{response}
|
||||
|
||||
需要评估的维度说明:
|
||||
{dimensions_text}
|
||||
|
||||
请按照以下格式输出评估结果(仅输出JSON格式):
|
||||
{{
|
||||
"{dimensions[0]}": 分数,
|
||||
"{dimensions[1]}": 分数
|
||||
}}
|
||||
|
||||
评分标准:
|
||||
1 = 非常不符合该维度特征
|
||||
2 = 比较不符合该维度特征
|
||||
3 = 有点不符合该维度特征
|
||||
4 = 有点符合该维度特征
|
||||
5 = 比较符合该维度特征
|
||||
6 = 非常符合该维度特征
|
||||
|
||||
请根据用户的回应,结合场景和维度说明进行评分。确保分数在1-6之间,并给出合理的评估。"""
|
||||
|
||||
try:
|
||||
ai_response, _ = self.llm.generate_response(prompt)
|
||||
# 尝试从AI响应中提取JSON部分
|
||||
start_idx = ai_response.find("{")
|
||||
end_idx = ai_response.rfind("}") + 1
|
||||
if start_idx != -1 and end_idx != 0:
|
||||
json_str = ai_response[start_idx:end_idx]
|
||||
scores = json.loads(json_str)
|
||||
# 确保所有分数在1-6之间
|
||||
return {k: max(1, min(6, float(v))) for k, v in scores.items()}
|
||||
else:
|
||||
print("AI响应格式不正确,使用默认评分")
|
||||
return {dim: 3.5 for dim in dimensions}
|
||||
except Exception as e:
|
||||
print(f"评估过程出错:{str(e)}")
|
||||
return {dim: 3.5 for dim in dimensions}
|
||||
|
||||
|
||||
def main():
|
||||
print("欢迎使用人格形象创建程序!")
|
||||
print("接下来,您将面对一系列场景(共15个)。请根据您想要创建的角色形象,描述在该场景下可能的反应。")
|
||||
print("每个场景都会评估不同的人格维度,最终得出完整的人格特征评估。")
|
||||
print("评分标准:1=非常不符合,2=比较不符合,3=有点不符合,4=有点符合,5=比较符合,6=非常符合")
|
||||
print("\n准备好了吗?按回车键开始...")
|
||||
input()
|
||||
|
||||
evaluator = PersonalityEvaluatorDirect()
|
||||
final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||
dimension_counts = {trait: 0 for trait in final_scores.keys()}
|
||||
|
||||
for i, scenario_data in enumerate(evaluator.scenarios, 1):
|
||||
print(f"\n场景 {i}/{len(evaluator.scenarios)} - {scenario_data['场景编号']}:")
|
||||
print("-" * 50)
|
||||
print(scenario_data["场景"])
|
||||
print("\n请描述您的角色在这种情况下会如何反应:")
|
||||
response = input().strip()
|
||||
|
||||
if not response:
|
||||
print("反应描述不能为空!")
|
||||
continue
|
||||
|
||||
print("\n正在评估您的描述...")
|
||||
scores = evaluator.evaluate_response(scenario_data["场景"], response, scenario_data["评估维度"])
|
||||
|
||||
# 更新最终分数
|
||||
for dimension, score in scores.items():
|
||||
final_scores[dimension] += score
|
||||
dimension_counts[dimension] += 1
|
||||
|
||||
print("\n当前评估结果:")
|
||||
print("-" * 30)
|
||||
for dimension, score in scores.items():
|
||||
print(f"{dimension}: {score}/6")
|
||||
|
||||
if i < len(evaluator.scenarios):
|
||||
print("\n按回车键继续下一个场景...")
|
||||
input()
|
||||
|
||||
# 计算平均分
|
||||
for dimension in final_scores:
|
||||
if dimension_counts[dimension] > 0:
|
||||
final_scores[dimension] = round(final_scores[dimension] / dimension_counts[dimension], 2)
|
||||
|
||||
print("\n最终人格特征评估结果:")
|
||||
print("-" * 30)
|
||||
for trait, score in final_scores.items():
|
||||
print(f"{trait}: {score}/6")
|
||||
print(f"测试场景数:{dimension_counts[trait]}")
|
||||
|
||||
# 保存结果
|
||||
result = {"final_scores": final_scores, "dimension_counts": dimension_counts, "scenarios": evaluator.scenarios}
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
# 保存到文件
|
||||
with open("results/personality_result.json", "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print("\n结果已保存到 results/personality_result.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,156 +0,0 @@
|
||||
import random
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
project_root = current_dir.parent.parent.parent
|
||||
env_path = project_root / ".env"
|
||||
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.common.database import db # noqa: E402
|
||||
|
||||
|
||||
class MessageAnalyzer:
|
||||
def __init__(self):
|
||||
self.messages_collection = db["messages"]
|
||||
|
||||
def get_message_context(self, message_id: int, context_length: int = 5) -> Optional[List[Dict]]:
|
||||
"""
|
||||
获取指定消息ID的上下文消息列表
|
||||
|
||||
Args:
|
||||
message_id (int): 消息ID
|
||||
context_length (int): 上下文长度(单侧,总长度为 2*context_length + 1)
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict]]: 消息列表,如果未找到则返回None
|
||||
"""
|
||||
# 从数据库获取指定消息
|
||||
target_message = self.messages_collection.find_one({"message_id": message_id})
|
||||
if not target_message:
|
||||
return None
|
||||
|
||||
# 获取该消息的stream_id
|
||||
stream_id = target_message.get("chat_info", {}).get("stream_id")
|
||||
if not stream_id:
|
||||
return None
|
||||
|
||||
# 获取同一stream_id的所有消息
|
||||
stream_messages = list(self.messages_collection.find({"chat_info.stream_id": stream_id}).sort("time", 1))
|
||||
|
||||
# 找到目标消息在列表中的位置
|
||||
target_index = None
|
||||
for i, msg in enumerate(stream_messages):
|
||||
if msg["message_id"] == message_id:
|
||||
target_index = i
|
||||
break
|
||||
|
||||
if target_index is None:
|
||||
return None
|
||||
|
||||
# 获取目标消息前后的消息
|
||||
start_index = max(0, target_index - context_length)
|
||||
end_index = min(len(stream_messages), target_index + context_length + 1)
|
||||
|
||||
return stream_messages[start_index:end_index]
|
||||
|
||||
def format_messages(self, messages: List[Dict], target_message_id: Optional[int] = None) -> str:
|
||||
"""
|
||||
格式化消息列表为可读字符串
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): 消息列表
|
||||
target_message_id (Optional[int]): 目标消息ID,用于标记
|
||||
|
||||
Returns:
|
||||
str: 格式化的消息字符串
|
||||
"""
|
||||
if not messages:
|
||||
return "没有消息记录"
|
||||
|
||||
reply = ""
|
||||
for msg in messages:
|
||||
# 消息时间
|
||||
msg_time = datetime.datetime.fromtimestamp(int(msg["time"])).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 获取消息内容
|
||||
message_text = msg.get("processed_plain_text", msg.get("detailed_plain_text", "无消息内容"))
|
||||
nickname = msg.get("user_info", {}).get("user_nickname", "未知用户")
|
||||
|
||||
# 标记当前消息
|
||||
is_target = "→ " if target_message_id and msg["message_id"] == target_message_id else " "
|
||||
|
||||
reply += f"{is_target}[{msg_time}] {nickname}: {message_text}\n"
|
||||
|
||||
if target_message_id and msg["message_id"] == target_message_id:
|
||||
reply += " " + "-" * 50 + "\n"
|
||||
|
||||
return reply
|
||||
|
||||
def get_user_random_contexts(
|
||||
self, qq_id: str, num_messages: int = 10, context_length: int = 5
|
||||
) -> tuple[List[str], str]: # noqa: E501
|
||||
"""
|
||||
获取用户的随机消息及其上下文
|
||||
|
||||
Args:
|
||||
qq_id (str): QQ号
|
||||
num_messages (int): 要获取的随机消息数量
|
||||
context_length (int): 每条消息的上下文长度(单侧)
|
||||
|
||||
Returns:
|
||||
tuple[List[str], str]: (每个消息上下文的格式化字符串列表, 用户昵称)
|
||||
"""
|
||||
if not qq_id:
|
||||
return [], ""
|
||||
|
||||
# 获取用户所有消息
|
||||
all_messages = list(self.messages_collection.find({"user_info.user_id": int(qq_id)}))
|
||||
if not all_messages:
|
||||
return [], ""
|
||||
|
||||
# 获取用户昵称
|
||||
user_nickname = all_messages[0].get("chat_info", {}).get("user_info", {}).get("user_nickname", "未知用户")
|
||||
|
||||
# 随机选择指定数量的消息
|
||||
selected_messages = random.sample(all_messages, min(num_messages, len(all_messages)))
|
||||
# 按时间排序
|
||||
selected_messages.sort(key=lambda x: int(x["time"]))
|
||||
|
||||
# 存储所有上下文消息
|
||||
context_list = []
|
||||
|
||||
# 获取每条消息的上下文
|
||||
for msg in selected_messages:
|
||||
message_id = msg["message_id"]
|
||||
|
||||
# 获取消息上下文
|
||||
context_messages = self.get_message_context(message_id, context_length)
|
||||
if context_messages:
|
||||
formatted_context = self.format_messages(context_messages, message_id)
|
||||
context_list.append(formatted_context)
|
||||
|
||||
return context_list, user_nickname
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
analyzer = MessageAnalyzer()
|
||||
test_qq = "1026294844" # 替换为要测试的QQ号
|
||||
print(f"测试QQ号: {test_qq}")
|
||||
print("-" * 50)
|
||||
# 获取5条消息,每条消息前后各3条上下文
|
||||
contexts, nickname = analyzer.get_user_random_contexts(test_qq, num_messages=5, context_length=3)
|
||||
|
||||
print(f"用户昵称: {nickname}\n")
|
||||
# 打印每个上下文
|
||||
for i, context in enumerate(contexts, 1):
|
||||
print(f"\n随机消息 {i}/{len(contexts)}:")
|
||||
print("-" * 30)
|
||||
print(context)
|
||||
print("=" * 50)
|
||||
@@ -1 +0,0 @@
|
||||
那是以后会用到的妙妙小工具.jpg
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
import json
|
||||
import threading
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.config.config import global_config
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_module_logger("remote")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.plugins.config.config import global_config
|
||||
from src.config.config import global_config
|
||||
from src.plugins.chat.message import MessageRecv, MessageSending, Message
|
||||
from src.common.database import db
|
||||
import time
|
||||
|
||||
@@ -12,7 +12,7 @@ sys.path.append(root_path)
|
||||
from src.common.database import db # noqa: E402
|
||||
from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfig # noqa: E402
|
||||
from src.plugins.models.utils_model import LLMRequest # noqa: E402
|
||||
from src.plugins.config.config import global_config # noqa: E402
|
||||
from src.config.config import global_config # noqa: E402
|
||||
|
||||
TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Optional
|
||||
|
||||
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ..config.config import global_config
|
||||
from ...config.config import global_config
|
||||
from src.common.logger import get_module_logger, LogConfig, TOPIC_STYLE_CONFIG
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
@@ -74,9 +74,3 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
|
||||
async def not_reply_handle(self, message_id):
|
||||
return await super().not_reply_handle(message_id)
|
||||
|
||||
async def get_variable_parameters(self):
|
||||
return await super().get_variable_parameters()
|
||||
|
||||
async def set_variable_parameters(self, parameters):
|
||||
return await super().set_variable_parameters(parameters)
|
||||
|
||||
@@ -234,9 +234,3 @@ class DynamicWillingManager(BaseWillingManager):
|
||||
|
||||
async def after_generate_reply_handle(self, message_id):
|
||||
return await super().after_generate_reply_handle(message_id)
|
||||
|
||||
async def get_variable_parameters(self):
|
||||
return await super().get_variable_parameters()
|
||||
|
||||
async def set_variable_parameters(self, parameters):
|
||||
return await super().set_variable_parameters(parameters)
|
||||
|
||||
157
src/plugins/willing/mode_llmcheck.py
Normal file
157
src/plugins/willing/mode_llmcheck.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
llmcheck 模式:
|
||||
此模式的一些参数不会在配置文件中显示,要修改请在可变参数下修改
|
||||
此模式的特点:
|
||||
1.在群聊内的连续对话场景下,使用大语言模型来判断回复概率
|
||||
2.非连续对话场景,使用mxp模式的意愿管理器(可另外配置)
|
||||
3.默认配置的是model_v3,当前参数适用于deepseek-v3-0324
|
||||
|
||||
继承自其他模式,实质上仅重写get_reply_probability方法,未来可能重构成一个插件,可方便地组装到其他意愿模式上。
|
||||
目前的使用方式是拓展到其他意愿管理模式
|
||||
|
||||
"""
|
||||
|
||||
import time
|
||||
from loguru import logger
|
||||
from ..models.utils_model import LLM_request
|
||||
from ...config.config import global_config
|
||||
|
||||
# from ..chat.chat_stream import ChatStream
|
||||
from ..chat.utils import get_recent_group_detailed_plain_text
|
||||
|
||||
# from .willing_manager import BaseWillingManager
|
||||
from .mode_mxp import MxpWillingManager
|
||||
import re
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def is_continuous_chat(self, message_id: str):
|
||||
# 判断是否是连续对话,出于成本考虑,默认限制5条
|
||||
willing_info = self.ongoing_messages[message_id]
|
||||
chat_id = willing_info.chat_id
|
||||
group_info = willing_info.group_info
|
||||
config = self.global_config
|
||||
length = 5
|
||||
if chat_id:
|
||||
chat_talking_text = get_recent_group_detailed_plain_text(chat_id, limit=length, combine=True)
|
||||
if group_info:
|
||||
if str(config.BOT_QQ) in chat_talking_text:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def llmcheck_decorator(trigger_condition_func):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, message_id: str):
|
||||
if trigger_condition_func(self, message_id):
|
||||
# 满足条件,走llm流程
|
||||
return self.get_llmreply_probability(message_id)
|
||||
else:
|
||||
# 不满足条件,走默认流程
|
||||
return func(self, message_id)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class LlmcheckWillingManager(MxpWillingManager):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.3)
|
||||
|
||||
async def get_llmreply_probability(self, message_id: str):
|
||||
message_info = self.ongoing_messages[message_id]
|
||||
chat_id = message_info.chat_id
|
||||
config = self.global_config
|
||||
# 获取信息的长度
|
||||
length = 5
|
||||
if message_info.group_info and config:
|
||||
if message_info.group_info.group_id not in config.talk_allowed_groups:
|
||||
reply_probability = 0
|
||||
return reply_probability
|
||||
|
||||
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||
current_time = time.strftime("%H:%M:%S", time.localtime())
|
||||
chat_talking_prompt = ""
|
||||
if chat_id:
|
||||
chat_talking_prompt = get_recent_group_detailed_plain_text(chat_id, limit=length, combine=True)
|
||||
else:
|
||||
return 0
|
||||
|
||||
# if is_mentioned_bot:
|
||||
# return 1.0
|
||||
prompt = f"""
|
||||
假设你正在查看一个群聊,你在这个群聊里的网名叫{global_config.BOT_NICKNAME},你还有很多别名: {"/".join(global_config.BOT_ALIAS_NAMES)},
|
||||
现在群里聊天的内容是{chat_talking_prompt},
|
||||
今天是{current_date},现在是{current_time}。
|
||||
综合群内的氛围和你自己之前的发言,给出你认为**最新的消息**需要你回复的概率,数值在0到1之间。请注意,群聊内容杂乱,很多时候对话连续,但很可能不是在和你说话。
|
||||
如果最新的消息和你之前的发言在内容上连续,或者提到了你的名字或者称谓,将其视作明确指向你的互动,给出高于0.8的概率。如果现在是睡眠时间,直接概率为0。如果话题内容与你之前不是紧密相关,请不要给出高于0.1的概率。
|
||||
请注意是判断概率,而不是编写回复内容,
|
||||
仅输出在0到1区间内的概率值,不要给出你的判断依据。
|
||||
"""
|
||||
|
||||
content_check, reasoning_check, _ = await self.model_v3.generate_response(prompt)
|
||||
# logger.info(f"{prompt}")
|
||||
logger.info(f"{content_check} {reasoning_check}")
|
||||
probability = self.extract_marked_probability(content_check)
|
||||
# 兴趣系数修正 无关激活效率太高,暂时停用,待新记忆系统上线后调整
|
||||
probability += message_info.interested_rate * 0.25
|
||||
probability = min(1.0, probability)
|
||||
if probability <= 0.1:
|
||||
probability = min(0.03, probability)
|
||||
if probability >= 0.8:
|
||||
probability = max(probability, 0.90)
|
||||
|
||||
# 当前表情包理解能力较差,少说就少错
|
||||
if message_info.is_emoji:
|
||||
probability *= global_config.emoji_response_penalty
|
||||
|
||||
return probability
|
||||
|
||||
@staticmethod
|
||||
def extract_marked_probability(text):
|
||||
"""提取带标记的概率值 该方法主要用于测试微调prompt阶段"""
|
||||
text = text.strip()
|
||||
pattern = r"##PROBABILITY_START##(.*?)##PROBABILITY_END##"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
prob_str = match.group(1).strip()
|
||||
# 处理百分比(65% → 0.65)
|
||||
if "%" in prob_str:
|
||||
return float(prob_str.replace("%", "")) / 100
|
||||
# 处理分数(2/3 → 0.666...)
|
||||
elif "/" in prob_str:
|
||||
numerator, denominator = map(float, prob_str.split("/"))
|
||||
return numerator / denominator
|
||||
# 直接处理小数
|
||||
else:
|
||||
return float(prob_str)
|
||||
|
||||
percent_match = re.search(r"(\d{1,3})%", text) # 65%
|
||||
decimal_match = re.search(r"(0\.\d+|1\.0+)", text) # 0.65
|
||||
fraction_match = re.search(r"(\d+)/(\d+)", text) # 2/3
|
||||
try:
|
||||
if percent_match:
|
||||
prob = float(percent_match.group(1)) / 100
|
||||
elif decimal_match:
|
||||
prob = float(decimal_match.group(0))
|
||||
elif fraction_match:
|
||||
numerator, denominator = map(float, fraction_match.groups())
|
||||
prob = numerator / denominator
|
||||
else:
|
||||
return 0 # 无匹配格式
|
||||
|
||||
# 验证范围是否合法
|
||||
if 0 <= prob <= 1:
|
||||
return prob
|
||||
return 0
|
||||
except (ValueError, ZeroDivisionError):
|
||||
return 0
|
||||
|
||||
@llmcheck_decorator(is_continuous_chat)
|
||||
def get_reply_probability(self, message_id):
|
||||
return super().get_reply_probability(message_id)
|
||||
@@ -10,6 +10,7 @@ Mxp 模式:梦溪畔独家赞助
|
||||
4.限制同时思考的消息数量,防止喷射
|
||||
5.拥有单聊增益,无论在群里还是私聊,只要bot一直和你聊,就会增加意愿值
|
||||
6.意愿分为衰减意愿+临时意愿
|
||||
7.疲劳机制
|
||||
|
||||
如果你发现本模式出现了bug
|
||||
上上策是询问智慧的小草神()
|
||||
@@ -34,26 +35,50 @@ class MxpWillingManager(BaseWillingManager):
|
||||
self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间
|
||||
self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息
|
||||
self.temporary_willing: float = 0 # 临时意愿值
|
||||
self.chat_bot_message_time: Dict[str, list[float]] = {} # 聊天流ID: bot已回复消息时间
|
||||
self.chat_fatigue_punishment_list: Dict[
|
||||
str, list[tuple[float, float]]
|
||||
] = {} # 聊天流疲劳惩罚列, 聊天流ID: 惩罚时间列(开始时间,持续时间)
|
||||
self.chat_fatigue_willing_attenuation: Dict[str, float] = {} # 聊天流疲劳意愿衰减值
|
||||
|
||||
# 可变参数
|
||||
self.intention_decay_rate = 0.93 # 意愿衰减率
|
||||
self.message_expiration_time = 120 # 消息过期时间(秒)
|
||||
self.number_of_message_storage = 10 # 消息存储数量
|
||||
|
||||
self.number_of_message_storage = 12 # 消息存储数量
|
||||
self.expected_replies_per_min = 3 # 每分钟预期回复数
|
||||
self.basic_maximum_willing = 0.5 # 基础最大意愿值
|
||||
|
||||
self.mention_willing_gain = 0.6 # 提及意愿增益
|
||||
self.interest_willing_gain = 0.3 # 兴趣意愿增益
|
||||
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
|
||||
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
|
||||
self.single_chat_gain = 0.12 # 单聊增益
|
||||
|
||||
self.fatigue_messages_triggered_num = self.expected_replies_per_min # 疲劳消息触发数量(int)
|
||||
self.fatigue_coefficient = 1.0 # 疲劳系数
|
||||
|
||||
self.is_debug = False # 是否开启调试模式
|
||||
|
||||
async def async_task_starter(self) -> None:
|
||||
"""异步任务启动器"""
|
||||
asyncio.create_task(self._return_to_basic_willing())
|
||||
asyncio.create_task(self._chat_new_message_to_change_basic_willing())
|
||||
asyncio.create_task(self._fatigue_attenuation())
|
||||
|
||||
async def before_generate_reply_handle(self, message_id: str):
|
||||
"""回复前处理"""
|
||||
pass
|
||||
current_time = time.time()
|
||||
async with self.lock:
|
||||
w_info = self.ongoing_messages[message_id]
|
||||
if w_info.chat_id not in self.chat_bot_message_time:
|
||||
self.chat_bot_message_time[w_info.chat_id] = []
|
||||
self.chat_bot_message_time[w_info.chat_id] = [
|
||||
t for t in self.chat_bot_message_time[w_info.chat_id] if current_time - t < 60
|
||||
]
|
||||
self.chat_bot_message_time[w_info.chat_id].append(current_time)
|
||||
if len(self.chat_bot_message_time[w_info.chat_id]) == int(self.fatigue_messages_triggered_num):
|
||||
time_interval = 60 - (current_time - self.chat_bot_message_time[w_info.chat_id].pop(0))
|
||||
self.chat_fatigue_punishment_list[w_info.chat_id].append([current_time, time_interval * 2])
|
||||
|
||||
async def after_generate_reply_handle(self, message_id: str):
|
||||
"""回复后处理"""
|
||||
@@ -63,9 +88,9 @@ class MxpWillingManager(BaseWillingManager):
|
||||
rel_level = self._get_relationship_level_num(rel_value)
|
||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += rel_level * 0.05
|
||||
|
||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0])
|
||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, [w_info.person_id, 0])
|
||||
if now_chat_new_person[0] == w_info.person_id:
|
||||
if now_chat_new_person[1] < 2:
|
||||
if now_chat_new_person[1] < 3:
|
||||
now_chat_new_person[1] += 1
|
||||
else:
|
||||
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
|
||||
@@ -75,13 +100,14 @@ class MxpWillingManager(BaseWillingManager):
|
||||
async with self.lock:
|
||||
w_info = self.ongoing_messages[message_id]
|
||||
if w_info.is_mentioned_bot:
|
||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += 0.2
|
||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.mention_willing_gain / 2.5
|
||||
if (
|
||||
w_info.chat_id in self.last_response_person
|
||||
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
|
||||
and self.last_response_person[w_info.chat_id][1]
|
||||
):
|
||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
|
||||
2 * self.last_response_person[w_info.chat_id][1] + 1
|
||||
2 * self.last_response_person[w_info.chat_id][1] - 1
|
||||
)
|
||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0])
|
||||
if now_chat_new_person[0] != w_info.person_id:
|
||||
@@ -92,35 +118,63 @@ class MxpWillingManager(BaseWillingManager):
|
||||
async with self.lock:
|
||||
w_info = self.ongoing_messages[message_id]
|
||||
current_willing = self.chat_person_reply_willing[w_info.chat_id][w_info.person_id]
|
||||
if self.is_debug:
|
||||
self.logger.debug(f"基础意愿值:{current_willing}")
|
||||
|
||||
if w_info.is_mentioned_bot:
|
||||
current_willing += self.mention_willing_gain / (int(current_willing) + 1)
|
||||
current_willing_ = self.mention_willing_gain / (int(current_willing) + 1)
|
||||
current_willing += current_willing_
|
||||
if self.is_debug:
|
||||
self.logger.debug(f"提及增益:{current_willing_}")
|
||||
|
||||
if w_info.interested_rate > 0:
|
||||
current_willing += math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain
|
||||
if self.is_debug:
|
||||
self.logger.debug(
|
||||
f"兴趣增益:{math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain}"
|
||||
)
|
||||
|
||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] = current_willing
|
||||
|
||||
rel_value = await w_info.person_info_manager.get_value(w_info.person_id, "relationship_value")
|
||||
rel_level = self._get_relationship_level_num(rel_value)
|
||||
current_willing += rel_level * 0.1
|
||||
if self.is_debug and rel_level != 0:
|
||||
self.logger.debug(f"关系增益:{rel_level * 0.1}")
|
||||
|
||||
if (
|
||||
w_info.chat_id in self.last_response_person
|
||||
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
|
||||
and self.last_response_person[w_info.chat_id][1]
|
||||
):
|
||||
current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)
|
||||
if self.is_debug:
|
||||
self.logger.debug(
|
||||
f"单聊增益:{self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)}"
|
||||
)
|
||||
|
||||
current_willing += self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)
|
||||
if self.is_debug:
|
||||
self.logger.debug(f"疲劳衰减:{self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)}")
|
||||
|
||||
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
|
||||
chat_person_ogoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id]
|
||||
if len(chat_person_ogoing_messages) >= 2:
|
||||
current_willing = 0
|
||||
if self.is_debug:
|
||||
self.logger.debug("进行中消息惩罚:归0")
|
||||
elif len(chat_ongoing_messages) == 2:
|
||||
current_willing -= 0.5
|
||||
if self.is_debug:
|
||||
self.logger.debug("进行中消息惩罚:-0.5")
|
||||
elif len(chat_ongoing_messages) == 3:
|
||||
current_willing -= 1.5
|
||||
if self.is_debug:
|
||||
self.logger.debug("进行中消息惩罚:-1.5")
|
||||
elif len(chat_ongoing_messages) >= 4:
|
||||
current_willing = 0
|
||||
if self.is_debug:
|
||||
self.logger.debug("进行中消息惩罚:归0")
|
||||
|
||||
probability = self._willing_to_probability(current_willing)
|
||||
|
||||
@@ -168,32 +222,52 @@ class MxpWillingManager(BaseWillingManager):
|
||||
self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id]
|
||||
)
|
||||
|
||||
current_time = time.time()
|
||||
if chat.stream_id not in self.chat_new_message_time:
|
||||
self.chat_new_message_time[chat.stream_id] = []
|
||||
self.chat_new_message_time[chat.stream_id].append(time.time())
|
||||
self.chat_new_message_time[chat.stream_id].append(current_time)
|
||||
if len(self.chat_new_message_time[chat.stream_id]) > self.number_of_message_storage:
|
||||
self.chat_new_message_time[chat.stream_id].pop(0)
|
||||
|
||||
if chat.stream_id not in self.chat_fatigue_punishment_list:
|
||||
self.chat_fatigue_punishment_list[chat.stream_id] = [
|
||||
(
|
||||
current_time,
|
||||
self.number_of_message_storage * self.basic_maximum_willing / self.expected_replies_per_min * 60,
|
||||
)
|
||||
]
|
||||
self.chat_fatigue_willing_attenuation[chat.stream_id] = (
|
||||
-2 * self.basic_maximum_willing * self.fatigue_coefficient
|
||||
)
|
||||
|
||||
def _willing_to_probability(self, willing: float) -> float:
|
||||
"""意愿值转化为概率"""
|
||||
willing = max(0, willing)
|
||||
if willing < 2:
|
||||
probability = math.atan(willing * 2) / math.pi * 2
|
||||
else:
|
||||
elif willing < 2.5:
|
||||
probability = math.atan(willing * 4) / math.pi * 2
|
||||
else:
|
||||
probability = 1
|
||||
return probability
|
||||
|
||||
async def _chat_new_message_to_change_basic_willing(self):
|
||||
"""聊天流新消息改变基础意愿"""
|
||||
update_time = 20
|
||||
while True:
|
||||
update_time = 20
|
||||
await asyncio.sleep(update_time)
|
||||
async with self.lock:
|
||||
for chat_id, message_times in self.chat_new_message_time.items():
|
||||
# 清理过期消息
|
||||
current_time = time.time()
|
||||
message_times = [
|
||||
msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time
|
||||
msg_time
|
||||
for msg_time in message_times
|
||||
if current_time - msg_time
|
||||
< self.number_of_message_storage
|
||||
* self.basic_maximum_willing
|
||||
/ self.expected_replies_per_min
|
||||
* 60
|
||||
]
|
||||
self.chat_new_message_time[chat_id] = message_times
|
||||
|
||||
@@ -202,38 +276,14 @@ class MxpWillingManager(BaseWillingManager):
|
||||
update_time = 20
|
||||
elif len(message_times) == self.number_of_message_storage:
|
||||
time_interval = current_time - message_times[0]
|
||||
basic_willing = self.basic_maximum_willing * math.sqrt(
|
||||
time_interval / self.message_expiration_time
|
||||
)
|
||||
basic_willing = self._basic_willing_culculate(time_interval)
|
||||
self.chat_reply_willing[chat_id] = basic_willing
|
||||
update_time = 17 * math.sqrt(time_interval / self.message_expiration_time) + 3
|
||||
update_time = 17 * basic_willing / self.basic_maximum_willing + 3
|
||||
else:
|
||||
self.logger.debug(f"聊天流{chat_id}消息时间数量异常,数量:{len(message_times)}")
|
||||
self.chat_reply_willing[chat_id] = 0
|
||||
|
||||
async def get_variable_parameters(self) -> Dict[str, str]:
|
||||
"""获取可变参数"""
|
||||
return {
|
||||
"intention_decay_rate": "意愿衰减率",
|
||||
"message_expiration_time": "消息过期时间(秒)",
|
||||
"number_of_message_storage": "消息存储数量",
|
||||
"basic_maximum_willing": "基础最大意愿值",
|
||||
"mention_willing_gain": "提及意愿增益",
|
||||
"interest_willing_gain": "兴趣意愿增益",
|
||||
"emoji_response_penalty": "表情包回复惩罚",
|
||||
"down_frequency_rate": "降低回复频率的群组惩罚系数",
|
||||
"single_chat_gain": "单聊增益(不仅是私聊)",
|
||||
}
|
||||
|
||||
async def set_variable_parameters(self, parameters: Dict[str, any]):
|
||||
"""设置可变参数"""
|
||||
async with self.lock:
|
||||
for key, value in parameters.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
self.logger.debug(f"参数 {key} 已更新为 {value}")
|
||||
else:
|
||||
self.logger.debug(f"尝试设置未知参数 {key}")
|
||||
if self.is_debug:
|
||||
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
|
||||
|
||||
def _get_relationship_level_num(self, relationship_value) -> int:
|
||||
"""关系等级计算"""
|
||||
@@ -253,5 +303,27 @@ class MxpWillingManager(BaseWillingManager):
|
||||
level_num = 5 if relationship_value > 1000 else 0
|
||||
return level_num - 2
|
||||
|
||||
def _basic_willing_culculate(self, t: float) -> float:
|
||||
"""基础意愿值计算"""
|
||||
return math.tan(t * self.expected_replies_per_min * math.pi / 120 / self.number_of_message_storage) / 2
|
||||
|
||||
async def _fatigue_attenuation(self):
|
||||
"""疲劳衰减"""
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
current_time = time.time()
|
||||
async with self.lock:
|
||||
for chat_id, fatigue_list in self.chat_fatigue_punishment_list.items():
|
||||
fatigue_list = [z for z in fatigue_list if current_time - z[0] < z[1]]
|
||||
self.chat_fatigue_willing_attenuation[chat_id] = 0
|
||||
for start_time, duration in fatigue_list:
|
||||
self.chat_fatigue_willing_attenuation[chat_id] += (
|
||||
self.chat_reply_willing[chat_id]
|
||||
* 2
|
||||
/ math.pi
|
||||
* math.asin(2 * (current_time - start_time) / duration - 1)
|
||||
- self.chat_reply_willing[chat_id]
|
||||
) * self.fatigue_coefficient
|
||||
|
||||
async def get_willing(self, chat_id):
|
||||
return self.temporary_willing
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
|
||||
from dataclasses import dataclass
|
||||
from ..config.config import global_config, BotConfig
|
||||
from ...config.config import global_config, BotConfig
|
||||
from ..chat.chat_stream import ChatStream, GroupInfo
|
||||
from ..chat.message import MessageRecv
|
||||
from ..person_info.person_info import person_info_manager, PersonInfoManager
|
||||
@@ -18,8 +18,8 @@ after_generate_reply_handle 确定要回复后,在生成回复后的处理
|
||||
not_reply_handle 确定不回复后的处理
|
||||
get_reply_probability 获取回复概率
|
||||
bombing_buffer_message_handle 缓冲器炸飞消息后的处理
|
||||
get_variable_parameters 获取可变参数组,返回一个字典,key为参数名称,value为参数描述(此方法是为拆分全局设置准备)
|
||||
set_variable_parameters 设置可变参数组,你需要传入一个字典,key为参数名称,value为参数值(此方法是为拆分全局设置准备)
|
||||
get_variable_parameters 暂不确定
|
||||
set_variable_parameters 暂不确定
|
||||
以下2个方法根据你的实现可以做调整:
|
||||
get_willing 获取某聊天流意愿
|
||||
set_willing 设置某聊天流意愿
|
||||
@@ -152,15 +152,15 @@ class BaseWillingManager(ABC):
|
||||
async with self.lock:
|
||||
self.chat_reply_willing[chat_id] = willing
|
||||
|
||||
@abstractmethod
|
||||
async def get_variable_parameters(self) -> Dict[str, str]:
|
||||
"""抽象方法:获取可变参数"""
|
||||
pass
|
||||
# @abstractmethod
|
||||
# async def get_variable_parameters(self) -> Dict[str, str]:
|
||||
# """抽象方法:获取可变参数"""
|
||||
# pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_variable_parameters(self, parameters: Dict[str, any]):
|
||||
"""抽象方法:设置可变参数"""
|
||||
pass
|
||||
# @abstractmethod
|
||||
# async def set_variable_parameters(self, parameters: Dict[str, any]):
|
||||
# """抽象方法:设置可变参数"""
|
||||
# pass
|
||||
|
||||
|
||||
def init_willing_manager() -> BaseWillingManager:
|
||||
|
||||
Reference in New Issue
Block a user