Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
.git
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.DS_Store
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -239,6 +239,5 @@ logs
|
||||
.vscode
|
||||
|
||||
/config/*
|
||||
run_none.bat
|
||||
config/old/bot_config_20250405_212257.toml
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
|
||||
● [我有问题](#我有问题)
|
||||
● [我想做贡献](#我想做贡献)
|
||||
● [我想报告BUG](#报告BUG)
|
||||
● [我想提出建议](#提出建议)
|
||||
|
||||
## 我有问题
|
||||
|
||||
@@ -114,7 +114,7 @@
|
||||
## 🎯 功能介绍
|
||||
|
||||
| 模块 | 主要功能 | 特点 |
|
||||
|------|---------|------|
|
||||
|----------|------------------------------------------------------------------|-------|
|
||||
| 💬 聊天系统 | • 心流/推理聊天<br>• 关键词主动发言<br>• 多模型支持<br>• 动态prompt构建<br>• 私聊功能(PFC) | 拟人化交互 |
|
||||
| 🧠 心流系统 | • 实时思考生成<br>• 自动启停机制<br>• 日程系统联动<br>• 工具调用能力 | 智能化决策 |
|
||||
| 🧠 记忆系统 | • 优化记忆抽取<br>• 海马体记忆机制<br>• 聊天记录概括 | 持久化记忆 |
|
||||
|
||||
@@ -47,7 +47,7 @@ if not SIMPLE_OUTPUT:
|
||||
"<cyan>{extra[module]: <12}</cyan> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | {message}",
|
||||
"log_dir": LOG_ROOT,
|
||||
"rotation": "00:00",
|
||||
"retention": "3 days",
|
||||
@@ -59,8 +59,8 @@ else:
|
||||
"console_level": "INFO",
|
||||
"file_level": "DEBUG",
|
||||
# 格式配置
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <cyan>{extra[module]}</cyan> | {message}"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <cyan>{extra[module]}</cyan> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | {message}",
|
||||
"log_dir": LOG_ROOT,
|
||||
"rotation": "00:00",
|
||||
"retention": "3 days",
|
||||
@@ -78,13 +78,13 @@ MEMORY_STYLE_CONFIG = {
|
||||
"<light-yellow>海马体</light-yellow> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": (
|
||||
"<green>{time:MM-DD HH:mm}</green> | <light-yellow>海马体</light-yellow> | <light-yellow>{message}</light-yellow>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -99,11 +99,11 @@ MOOD_STYLE_CONFIG = {
|
||||
"<light-green>心情</light-green> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 心情 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 心情 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <magenta>心情</magenta> | {message}"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 心情 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <magenta>心情</magenta> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 心情 | {message}",
|
||||
},
|
||||
}
|
||||
# tool use
|
||||
@@ -116,11 +116,11 @@ TOOL_USE_STYLE_CONFIG = {
|
||||
"<magenta>工具使用</magenta> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 工具使用 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 工具使用 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <magenta>工具使用</magenta> | {message}"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 工具使用 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <magenta>工具使用</magenta> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 工具使用 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -135,11 +135,11 @@ RELATION_STYLE_CONFIG = {
|
||||
"<light-magenta>关系</light-magenta> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-magenta>关系</light-magenta> | {message}"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <light-magenta>关系</light-magenta> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -153,11 +153,11 @@ CONFIG_STYLE_CONFIG = {
|
||||
"<light-cyan>配置</light-cyan> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 配置 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 配置 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-cyan>配置</light-cyan> | {message}"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 配置 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <light-cyan>配置</light-cyan> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 配置 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -170,11 +170,11 @@ SENDER_STYLE_CONFIG = {
|
||||
"<light-yellow>消息发送</light-yellow> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息发送 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息发送 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <green>消息发送</green> | {message}"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息发送 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <green>消息发送</green> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息发送 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -187,13 +187,13 @@ HEARTFLOW_STYLE_CONFIG = {
|
||||
"<light-yellow>麦麦大脑袋</light-yellow> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": (
|
||||
"<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦大脑袋</light-green> | <light-green>{message}</light-green>"
|
||||
), # noqa: E501
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -206,11 +206,11 @@ SCHEDULE_STYLE_CONFIG = {
|
||||
"<light-yellow>在干嘛</light-yellow> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <cyan>在干嘛</cyan> | <cyan>{message}</cyan>"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <cyan>在干嘛</cyan> | <cyan>{message}</cyan>",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -223,11 +223,11 @@ LLM_STYLE_CONFIG = {
|
||||
"<light-yellow>麦麦组织语言</light-yellow> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦组织语言 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦组织语言 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦组织语言</light-green> | {message}"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦组织语言 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦组织语言</light-green> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦组织语言 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -242,11 +242,11 @@ TOPIC_STYLE_CONFIG = {
|
||||
"<light-blue>话题</light-blue> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 话题 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 话题 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>主题</light-blue> | {message}"),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 话题 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <light-blue>主题</light-blue> | {message}",
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 话题 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -260,13 +260,13 @@ CHAT_STYLE_CONFIG = {
|
||||
"<light-blue>见闻</light-blue> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": (
|
||||
"<green>{time:MM-DD HH:mm}</green> | <light-blue>见闻</light-blue> | <green>{message}</green>"
|
||||
), # noqa: E501
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -279,13 +279,13 @@ SUB_HEARTFLOW_STYLE_CONFIG = {
|
||||
"<light-blue>麦麦小脑袋</light-blue> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": (
|
||||
"<green>{time:MM-DD HH:mm}</green> | <light-blue>麦麦小脑袋</light-blue> | <light-blue>{message}</light-blue>"
|
||||
), # noqa: E501
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -298,17 +298,17 @@ WILLING_STYLE_CONFIG = {
|
||||
"<light-blue>意愿</light-blue> | "
|
||||
"<level>{message}</level>"
|
||||
),
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}",
|
||||
},
|
||||
"simple": {
|
||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}"), # noqa: E501
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
||||
"console_format": "<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}", # noqa: E501
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}",
|
||||
},
|
||||
}
|
||||
|
||||
CONFIRM_STYLE_CONFIG = {
|
||||
"console_format": ("<RED>{message}</RED>"), # noqa: E501
|
||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"),
|
||||
"console_format": "<RED>{message}</RED>", # noqa: E501
|
||||
"file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}",
|
||||
}
|
||||
|
||||
# 根据SIMPLE_OUTPUT选择配置
|
||||
@@ -459,7 +459,7 @@ other_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
DEFAULT_FILE_HANDLER = logger.add(
|
||||
sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"),
|
||||
level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
|
||||
format=("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}"),
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}",
|
||||
rotation=DEFAULT_CONFIG["rotation"],
|
||||
retention=DEFAULT_CONFIG["retention"],
|
||||
compression=DEFAULT_CONFIG["compression"],
|
||||
|
||||
@@ -130,9 +130,6 @@ def update_config():
|
||||
logger.info("配置文件更新完成")
|
||||
|
||||
|
||||
logger = get_module_logger("config")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotConfig:
|
||||
"""机器人配置类"""
|
||||
@@ -284,21 +281,6 @@ class BotConfig:
|
||||
llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {})
|
||||
llm_heartflow: Dict[str, str] = field(default_factory=lambda: {})
|
||||
|
||||
build_memory_interval: int = 600 # 记忆构建间隔(秒)
|
||||
|
||||
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
|
||||
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
|
||||
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
|
||||
memory_compress_rate: float = 0.1 # 记忆压缩率
|
||||
build_memory_sample_num: int = 10 # 记忆构建采样数量
|
||||
build_memory_sample_length: int = 20 # 记忆构建采样长度
|
||||
memory_build_distribution: list = field(
|
||||
default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
|
||||
) # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
||||
memory_ban_words: list = field(
|
||||
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
|
||||
) # 添加新的配置项默认值
|
||||
|
||||
api_urls: Dict[str, str] = field(default_factory=lambda: {})
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -23,13 +23,12 @@ class ChangeMoodTool(BaseTool):
|
||||
"required": ["text", "response_set"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行心情改变
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
message_processed_plain_text: 原始消息文本
|
||||
response_set: 原始消息文本
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# from src.plugins.person_info.relationship_manager import relationship_manager
|
||||
from typing import Dict, Any
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.do_tool.tool_can_use.base_tool import BaseTool
|
||||
# from src.plugins.chat_module.think_flow_chat.think_flow_generator import ResponseGenerator
|
||||
@@ -20,22 +22,20 @@ class RelationshipTool(BaseTool):
|
||||
"required": ["text", "changed_value", "reason"],
|
||||
}
|
||||
|
||||
async def execute(self, args: dict, message_txt: str) -> dict:
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> dict:
|
||||
"""执行工具功能
|
||||
|
||||
Args:
|
||||
args: 包含工具参数的字典
|
||||
text: 原始消息文本
|
||||
changed_value: 变更值
|
||||
reason: 变更原因
|
||||
function_args: 包含工具参数的字典
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 包含执行结果的字典
|
||||
"""
|
||||
try:
|
||||
text = args.get("text")
|
||||
changed_value = args.get("changed_value")
|
||||
reason = args.get("reason")
|
||||
text = function_args.get("text")
|
||||
changed_value = function_args.get("changed_value")
|
||||
reason = function_args.get("reason")
|
||||
|
||||
return {"content": f"因为你刚刚因为{reason},所以你和发[{text}]这条消息的人的关系值变化为{changed_value}"}
|
||||
|
||||
|
||||
@@ -49,8 +49,9 @@ class SearchKnowledgeTool(BaseTool):
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def get_info_from_db(
|
||||
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
) -> Union[str, list]:
|
||||
"""从数据库中获取相关信息
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class SendEmojiTool(BaseTool):
|
||||
"required": ["text"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
text = function_args.get("text", message_txt)
|
||||
return {
|
||||
"name": "send_emoji",
|
||||
|
||||
@@ -24,8 +24,9 @@ class ToolUser:
|
||||
model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _build_tool_prompt(
|
||||
self, message_txt: str, sender_name: str, chat_stream: ChatStream, subheartflow: SubHeartflow = None
|
||||
message_txt: str, sender_name: str, chat_stream: ChatStream, subheartflow: SubHeartflow = None
|
||||
):
|
||||
"""构建工具使用的提示词
|
||||
|
||||
@@ -69,7 +70,8 @@ class ToolUser:
|
||||
prompt += "你现在需要对群里的聊天内容进行回复,现在选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么。"
|
||||
return prompt
|
||||
|
||||
def _define_tools(self):
|
||||
@staticmethod
|
||||
def _define_tools():
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
@@ -77,7 +79,8 @@ class ToolUser:
|
||||
"""
|
||||
return get_all_tool_definitions()
|
||||
|
||||
async def _execute_tool_call(self, tool_call, message_txt: str):
|
||||
@staticmethod
|
||||
async def _execute_tool_call(tool_call, message_txt: str):
|
||||
"""执行特定的工具调用
|
||||
|
||||
Args:
|
||||
|
||||
@@ -105,7 +105,8 @@ class Heartflow:
|
||||
# 启动子心流更新任务
|
||||
asyncio.create_task(self._sub_heartflow_update())
|
||||
|
||||
async def _update_current_state(self):
|
||||
@staticmethod
|
||||
async def _update_current_state():
|
||||
print("TODO")
|
||||
|
||||
async def do_a_thinking(self):
|
||||
|
||||
@@ -161,7 +161,8 @@ class ChattingObservation(Observation):
|
||||
# print(f"prompt:{prompt}")
|
||||
# print(f"self.observe_info:{self.observe_info}")
|
||||
|
||||
def translate_message_list_to_str(self, talking_message):
|
||||
@staticmethod
|
||||
def translate_message_list_to_str(talking_message):
|
||||
talking_message_str = ""
|
||||
for message in talking_message:
|
||||
talking_message_str += message["detailed_plain_text"]
|
||||
|
||||
@@ -102,7 +102,7 @@ class Identity:
|
||||
random.shuffle(identity_detail)
|
||||
prompt_identity += identity_detail[0]
|
||||
elif level == 2:
|
||||
for detail in identity_detail:
|
||||
for detail in self.identity_detail:
|
||||
prompt_identity += f",{detail}"
|
||||
prompt_identity += "。"
|
||||
return prompt_identity
|
||||
|
||||
@@ -131,14 +131,16 @@ class MainSystem:
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def build_memory_task(self):
|
||||
@staticmethod
|
||||
async def build_memory_task():
|
||||
"""记忆构建任务"""
|
||||
while True:
|
||||
await asyncio.sleep(global_config.build_memory_interval)
|
||||
logger.info("正在进行记忆构建")
|
||||
await HippocampusManager.get_instance().build_memory()
|
||||
|
||||
async def forget_memory_task(self):
|
||||
@staticmethod
|
||||
async def forget_memory_task():
|
||||
"""记忆遗忘任务"""
|
||||
while True:
|
||||
await asyncio.sleep(global_config.forget_memory_interval)
|
||||
@@ -152,7 +154,8 @@ class MainSystem:
|
||||
self.mood_manager.print_mood_status()
|
||||
await asyncio.sleep(30)
|
||||
|
||||
async def remove_recalled_message_task(self):
|
||||
@staticmethod
|
||||
async def remove_recalled_message_task():
|
||||
"""删除撤回消息任务"""
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -119,7 +119,6 @@ class ChatObserver:
|
||||
self.last_cold_chat_check = current_time
|
||||
|
||||
# 判断是否冷场
|
||||
is_cold = False
|
||||
if self.last_message_time is None:
|
||||
is_cold = True
|
||||
else:
|
||||
|
||||
@@ -113,7 +113,8 @@ class Conversation:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
||||
@staticmethod
|
||||
def _convert_to_message(msg_dict: Dict[str, Any]) -> Message:
|
||||
"""将消息字典转换为Message对象"""
|
||||
try:
|
||||
chat_info = msg_dict.get("chat_info", {})
|
||||
@@ -123,7 +124,7 @@ class Conversation:
|
||||
return Message(
|
||||
message_id=msg_dict["message_id"],
|
||||
chat_stream=chat_stream,
|
||||
time=msg_dict["time"],
|
||||
timestamp=msg_dict["time"],
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
|
||||
|
||||
@@ -15,8 +15,8 @@ class DirectMessageSender:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def send_message(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
content: str,
|
||||
reply_to_message: Optional[Message] = None,
|
||||
|
||||
@@ -51,11 +51,9 @@ class MongoDBMessageStorage(MessageStorage):
|
||||
"""MongoDB消息存储实现"""
|
||||
|
||||
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id}
|
||||
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
|
||||
# print(f"storage_check_message: {message_time}")
|
||||
|
||||
query["time"] = {"$gt": message_time}
|
||||
|
||||
return list(db.messages.find(query).sort("time", 1))
|
||||
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
|
||||
@@ -160,16 +160,16 @@ class GoalAnalyzer:
|
||||
# 返回第一个目标作为当前主要目标(如果有)
|
||||
if result:
|
||||
first_goal = result[0]
|
||||
return (first_goal.get("goal", ""), "", first_goal.get("reasoning", ""))
|
||||
return first_goal.get("goal", ""), "", first_goal.get("reasoning", "")
|
||||
else:
|
||||
# 单个目标的情况
|
||||
goal = result.get("goal", "")
|
||||
reasoning = result.get("reasoning", "")
|
||||
conversation_info.goal_list.append((goal, reasoning))
|
||||
return (goal, "", reasoning)
|
||||
return goal, "", reasoning
|
||||
|
||||
# 如果解析失败,返回默认值
|
||||
return ("", "", "")
|
||||
return "", "", ""
|
||||
|
||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||
"""更新目标列表
|
||||
@@ -195,7 +195,8 @@ class GoalAnalyzer:
|
||||
if len(self.goals) > self.max_goals:
|
||||
self.goals.pop() # 移除最老的目标
|
||||
|
||||
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
|
||||
@staticmethod
|
||||
def _calculate_similarity(goal1: str, goal2: str) -> float:
|
||||
"""简单计算两个目标之间的相似度
|
||||
|
||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||
@@ -299,7 +300,8 @@ class DirectMessageSender:
|
||||
self.logger = get_module_logger("direct_sender")
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def send_via_ws(self, message: MessageSending) -> None:
|
||||
@staticmethod
|
||||
async def send_via_ws(message: MessageSending) -> None:
|
||||
try:
|
||||
await global_api.send_message(message)
|
||||
except Exception as e:
|
||||
|
||||
@@ -19,7 +19,8 @@ class KnowledgeFetcher:
|
||||
request_type="knowledge_fetch",
|
||||
)
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
||||
@staticmethod
|
||||
async def fetch(query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
|
||||
@@ -30,11 +30,8 @@ class ReplyGenerator:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
goal: 对话目标
|
||||
chat_history: 聊天历史
|
||||
knowledge_cache: 知识缓存
|
||||
previous_reply: 上一次生成的回复(如果有)
|
||||
retry_count: 当前重试次数
|
||||
observation_info: 观察信息
|
||||
conversation_info: 对话信息
|
||||
|
||||
Returns:
|
||||
str: 生成的回复
|
||||
|
||||
@@ -17,6 +17,5 @@ __all__ = [
|
||||
"relationship_manager",
|
||||
"MoodManager",
|
||||
"willing_manager",
|
||||
"hippocampus",
|
||||
"bot_schedule",
|
||||
]
|
||||
|
||||
@@ -103,7 +103,8 @@ class ChatManager:
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
|
||||
def _ensure_collection(self):
|
||||
@staticmethod
|
||||
def _ensure_collection():
|
||||
"""确保数据库集合存在并创建索引"""
|
||||
if "chat_streams" not in db.list_collection_names():
|
||||
db.create_collection("chat_streams")
|
||||
@@ -111,7 +112,8 @@ class ChatManager:
|
||||
db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
||||
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
|
||||
|
||||
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||
@staticmethod
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
@@ -188,7 +190,8 @@ class ChatManager:
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
async def _save_stream(self, stream: ChatStream):
|
||||
@staticmethod
|
||||
async def _save_stream(stream: ChatStream):
|
||||
"""保存聊天流到数据库"""
|
||||
if not stream.saved:
|
||||
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)
|
||||
|
||||
@@ -82,7 +82,8 @@ class EmojiManager:
|
||||
if not self._initialized:
|
||||
raise RuntimeError("EmojiManager not initialized")
|
||||
|
||||
def _ensure_emoji_collection(self):
|
||||
@staticmethod
|
||||
def _ensure_emoji_collection():
|
||||
"""确保emoji集合存在并创建索引
|
||||
|
||||
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
|
||||
@@ -193,7 +194,8 @@ class EmojiManager:
|
||||
logger.error(f"[错误] 获取表情包失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _get_emoji_description(self, image_base64: str) -> str:
|
||||
@staticmethod
|
||||
async def _get_emoji_description(image_base64: str) -> str:
|
||||
"""获取表情包的标签,使用image_manager的描述生成功能"""
|
||||
|
||||
try:
|
||||
@@ -554,7 +556,8 @@ class EmojiManager:
|
||||
self.check_emoji_file_full()
|
||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
||||
|
||||
async def delete_all_images(self):
|
||||
@staticmethod
|
||||
async def delete_all_images():
|
||||
"""删除 data/image 目录下的所有文件"""
|
||||
try:
|
||||
image_dir = os.path.join("data", "image")
|
||||
|
||||
@@ -31,7 +31,7 @@ class Message(MessageBase):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
time: float,
|
||||
timestamp: float,
|
||||
chat_stream: ChatStream,
|
||||
user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
@@ -43,7 +43,7 @@ class Message(MessageBase):
|
||||
message_info = BaseMessageInfo(
|
||||
platform=chat_stream.platform,
|
||||
message_id=message_id,
|
||||
time=time,
|
||||
time=timestamp,
|
||||
group_info=chat_stream.group_info,
|
||||
user_info=user_info,
|
||||
)
|
||||
@@ -143,7 +143,7 @@ class MessageRecv(Message):
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
|
||||
time = self.message_info.time
|
||||
timestamp = self.message_info.time
|
||||
user_info = self.message_info.user_info
|
||||
# name = (
|
||||
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
|
||||
@@ -151,7 +151,7 @@ class MessageRecv(Message):
|
||||
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
||||
# )
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
||||
return f"[{time}] {name}: {self.processed_plain_text}\n"
|
||||
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -170,7 +170,7 @@ class MessageProcessBase(Message):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
time=round(time.time(), 3), # 保留3位小数
|
||||
timestamp=round(time.time(), 3), # 保留3位小数
|
||||
chat_stream=chat_stream,
|
||||
user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
@@ -242,7 +242,7 @@ class MessageProcessBase(Message):
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
|
||||
time = self.message_info.time
|
||||
timestamp = self.message_info.time
|
||||
user_info = self.message_info.user_info
|
||||
# name = (
|
||||
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
|
||||
@@ -250,7 +250,7 @@ class MessageProcessBase(Message):
|
||||
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
||||
# )
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
||||
return f"[{time}] {name}: {self.processed_plain_text}\n"
|
||||
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -26,7 +26,8 @@ class MessageBuffer:
|
||||
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
def get_person_id_(self, platform: str, user_id: str, group_info: GroupInfo):
|
||||
@staticmethod
|
||||
def get_person_id_(platform: str, user_id: str, group_info: GroupInfo):
|
||||
"""获取唯一id"""
|
||||
if group_info:
|
||||
group_id = group_info.group_id
|
||||
@@ -150,20 +151,20 @@ class MessageBuffer:
|
||||
keep_msgs[msg_id] = msg
|
||||
elif msg.result == "F":
|
||||
# 收集F消息的文本内容
|
||||
F_type = "seglist"
|
||||
f_type = "seglist"
|
||||
if msg.message.message_segment.type != "seglist":
|
||||
F_type = msg.message.message_segment.type
|
||||
f_type = msg.message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(msg.message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
|
||||
and len(msg.message.message_segment.data) == 1
|
||||
):
|
||||
F_type = msg.message.message_segment.data[0].type
|
||||
f_type = msg.message.message_segment.data[0].type
|
||||
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
|
||||
if F_type == "text":
|
||||
if f_type == "text":
|
||||
combined_text.append(msg.message.processed_plain_text)
|
||||
elif F_type != "text":
|
||||
elif f_type != "text":
|
||||
is_update = False
|
||||
elif msg.result == "U":
|
||||
logger.debug(f"异常未处理信息id: {msg.message.message_info.message_id}")
|
||||
@@ -185,7 +186,8 @@ class MessageBuffer:
|
||||
logger.debug(f"查询超时消息id: {message.message_info.message_id}")
|
||||
return False
|
||||
|
||||
async def save_message_interval(self, person_id: str, message: BaseMessageInfo):
|
||||
@staticmethod
|
||||
async def save_message_interval(person_id: str, message: BaseMessageInfo):
|
||||
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
|
||||
now_time_ms = int(round(time.time() * 1000))
|
||||
if len(message_interval_list) < 1000:
|
||||
|
||||
@@ -35,7 +35,8 @@ class MessageSender:
|
||||
"""设置当前bot实例"""
|
||||
pass
|
||||
|
||||
def get_recalled_messages(self, stream_id: str) -> list:
|
||||
@staticmethod
|
||||
def get_recalled_messages(stream_id: str) -> list:
|
||||
"""获取所有撤回的消息"""
|
||||
recalled_messages = []
|
||||
|
||||
@@ -43,7 +44,8 @@ class MessageSender:
|
||||
# 按thinking_start_time排序,时间早的在前面
|
||||
return recalled_messages
|
||||
|
||||
async def send_via_ws(self, message: MessageSending) -> None:
|
||||
@staticmethod
|
||||
async def send_via_ws(message: MessageSending) -> None:
|
||||
try:
|
||||
await global_api.send_message(message)
|
||||
except Exception as e:
|
||||
|
||||
@@ -135,7 +135,7 @@ async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
|
||||
msg = Message(
|
||||
message_id=msg_data["message_id"],
|
||||
chat_stream=chat_stream,
|
||||
time=msg_data["time"],
|
||||
timestamp=msg_data["time"],
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_data.get("processed_text", ""),
|
||||
detailed_plain_text=msg_data.get("detailed_plain_text", ""),
|
||||
|
||||
@@ -38,7 +38,8 @@ class ImageManager:
|
||||
"""确保图像存储目录存在"""
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
def _ensure_image_collection(self):
|
||||
@staticmethod
|
||||
def _ensure_image_collection():
|
||||
"""确保images集合存在并创建索引"""
|
||||
if "images" not in db.list_collection_names():
|
||||
db.create_collection("images")
|
||||
@@ -50,7 +51,8 @@ class ImageManager:
|
||||
db.images.create_index([("url", 1)])
|
||||
db.images.create_index([("path", 1)])
|
||||
|
||||
def _ensure_description_collection(self):
|
||||
@staticmethod
|
||||
def _ensure_description_collection():
|
||||
"""确保image_descriptions集合存在并创建索引"""
|
||||
if "image_descriptions" not in db.list_collection_names():
|
||||
db.create_collection("image_descriptions")
|
||||
@@ -60,7 +62,8 @@ class ImageManager:
|
||||
# 创建新的复合索引
|
||||
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||
|
||||
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
||||
@staticmethod
|
||||
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
|
||||
Args:
|
||||
@@ -73,7 +76,8 @@ class ImageManager:
|
||||
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
|
||||
return result["description"] if result else None
|
||||
|
||||
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
|
||||
@staticmethod
|
||||
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
|
||||
"""保存图片描述到数据库
|
||||
|
||||
Args:
|
||||
@@ -226,7 +230,8 @@ class ImageManager:
|
||||
logger.error(f"获取图片描述失败: {str(e)}")
|
||||
return "[图片]"
|
||||
|
||||
def transform_gif(self, gif_base64: str) -> str:
|
||||
@staticmethod
|
||||
def transform_gif(gif_base64: str) -> str:
|
||||
"""将GIF转换为水平拼接的静态图像
|
||||
|
||||
Args:
|
||||
|
||||
@@ -13,7 +13,8 @@ class MessageProcessor:
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
|
||||
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
||||
@staticmethod
|
||||
def _check_ban_words(text: str, chat, userinfo) -> bool:
|
||||
"""检查消息中是否包含过滤词"""
|
||||
for word in global_config.ban_words:
|
||||
if word in text:
|
||||
@@ -24,7 +25,8 @@ class MessageProcessor:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
|
||||
@staticmethod
|
||||
def _check_ban_regex(text: str, chat, userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
for pattern in global_config.ban_msgs_regex:
|
||||
if pattern.search(text):
|
||||
|
||||
@@ -37,7 +37,8 @@ class ReasoningChat:
|
||||
self.mood_manager = MoodManager.get_instance()
|
||||
self.mood_manager.start_mood_update()
|
||||
|
||||
async def _create_thinking_message(self, message, chat, userinfo, messageinfo):
|
||||
@staticmethod
|
||||
async def _create_thinking_message(message, chat, userinfo, messageinfo):
|
||||
"""创建思考消息"""
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
@@ -59,7 +60,8 @@ class ReasoningChat:
|
||||
|
||||
return thinking_id
|
||||
|
||||
async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending:
|
||||
@staticmethod
|
||||
async def _send_response_messages(message, chat, response_set: List[str], thinking_id) -> MessageSending:
|
||||
"""发送回复消息"""
|
||||
container = message_manager.get_container(chat.stream_id)
|
||||
thinking_message = None
|
||||
@@ -104,7 +106,8 @@ class ReasoningChat:
|
||||
|
||||
return first_bot_msg
|
||||
|
||||
async def _handle_emoji(self, message, chat, response):
|
||||
@staticmethod
|
||||
async def _handle_emoji(message, chat, response):
|
||||
"""处理表情包"""
|
||||
if random() < global_config.emoji_chance:
|
||||
emoji_raw = await emoji_manager.get_emoji_for_text(response)
|
||||
@@ -192,21 +195,21 @@ class ReasoningChat:
|
||||
if not buffer_result:
|
||||
await willing_manager.bombing_buffer_message_handle(message.message_info.message_id)
|
||||
willing_manager.delete(message.message_info.message_id)
|
||||
F_type = "seglist"
|
||||
f_type = "seglist"
|
||||
if message.message_segment.type != "seglist":
|
||||
F_type = message.message_segment.type
|
||||
f_type = message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
||||
and len(message.message_segment.data) == 1
|
||||
):
|
||||
F_type = message.message_segment.data[0].type
|
||||
if F_type == "text":
|
||||
f_type = message.message_segment.data[0].type
|
||||
if f_type == "text":
|
||||
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
|
||||
elif F_type == "image":
|
||||
elif f_type == "image":
|
||||
logger.info("触发缓冲,已炸飞表情包/图片")
|
||||
elif F_type == "seglist":
|
||||
elif f_type == "seglist":
|
||||
logger.info("触发缓冲,已炸飞消息列")
|
||||
return
|
||||
|
||||
@@ -291,7 +294,8 @@ class ReasoningChat:
|
||||
# 意愿管理器:注销当前message信息
|
||||
willing_manager.delete(message.message_info.message_id)
|
||||
|
||||
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
||||
@staticmethod
|
||||
def _check_ban_words(text: str, chat, userinfo) -> bool:
|
||||
"""检查消息中是否包含过滤词"""
|
||||
for word in global_config.ban_words:
|
||||
if word in text:
|
||||
@@ -302,7 +306,8 @@ class ReasoningChat:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
|
||||
@staticmethod
|
||||
def _check_ban_regex(text: str, chat, userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
for pattern in global_config.ban_msgs_regex:
|
||||
if pattern.search(text):
|
||||
|
||||
@@ -69,8 +69,6 @@ class ResponseGenerator:
|
||||
return None
|
||||
|
||||
async def _generate_response_with_model(self, message: MessageThinking, model: LLMRequest, thinking_id: str):
|
||||
sender_name = ""
|
||||
|
||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||
|
||||
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
||||
@@ -188,7 +186,8 @@ class ResponseGenerator:
|
||||
logger.debug(f"获取情感标签时出错: {e}")
|
||||
return "中立", "平静" # 出错时返回默认值
|
||||
|
||||
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
||||
@staticmethod
|
||||
async def _process_response(content: str) -> Tuple[List[str], List[str]]:
|
||||
"""处理响应内容,返回处理后的内容和情感标签"""
|
||||
if not content:
|
||||
return None, []
|
||||
|
||||
@@ -101,16 +101,14 @@ class PromptBuilder:
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
if related_memory:
|
||||
related_memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
related_memory_info += memory[1]
|
||||
# memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n"
|
||||
memory_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_prompt", related_memory_info=related_memory_info
|
||||
)
|
||||
else:
|
||||
related_memory_info = ""
|
||||
|
||||
# print(f"相关记忆:{related_memory_info}")
|
||||
|
||||
@@ -162,7 +160,6 @@ class PromptBuilder:
|
||||
|
||||
# 知识构建
|
||||
start_time = time.time()
|
||||
prompt_info = ""
|
||||
prompt_info = await self.get_prompt_info(message_txt, threshold=0.38)
|
||||
if prompt_info:
|
||||
# prompt_info = f"""\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
|
||||
@@ -373,8 +370,9 @@ class PromptBuilder:
|
||||
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
||||
return related_info
|
||||
|
||||
@staticmethod
|
||||
def get_info_from_db(
|
||||
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
) -> Union[str, list]:
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
|
||||
@@ -40,7 +40,8 @@ class ThinkFlowChat:
|
||||
self.mood_manager.start_mood_update()
|
||||
self.tool_user = ToolUser()
|
||||
|
||||
async def _create_thinking_message(self, message, chat, userinfo, messageinfo):
|
||||
@staticmethod
|
||||
async def _create_thinking_message(message, chat, userinfo, messageinfo):
|
||||
"""创建思考消息"""
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
@@ -62,7 +63,8 @@ class ThinkFlowChat:
|
||||
|
||||
return thinking_id
|
||||
|
||||
async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending:
|
||||
@staticmethod
|
||||
async def _send_response_messages(message, chat, response_set: List[str], thinking_id) -> MessageSending:
|
||||
"""发送回复消息"""
|
||||
container = message_manager.get_container(chat.stream_id)
|
||||
thinking_message = None
|
||||
@@ -108,7 +110,8 @@ class ThinkFlowChat:
|
||||
message_manager.add_message(message_set)
|
||||
return first_bot_msg
|
||||
|
||||
async def _handle_emoji(self, message, chat, response, send_emoji=""):
|
||||
@staticmethod
|
||||
async def _handle_emoji(message, chat, response, send_emoji=""):
|
||||
"""处理表情包"""
|
||||
if send_emoji:
|
||||
emoji_raw = await emoji_manager.get_emoji_for_text(send_emoji)
|
||||
@@ -204,21 +207,21 @@ class ThinkFlowChat:
|
||||
if not buffer_result:
|
||||
await willing_manager.bombing_buffer_message_handle(message.message_info.message_id)
|
||||
willing_manager.delete(message.message_info.message_id)
|
||||
F_type = "seglist"
|
||||
f_type = "seglist"
|
||||
if message.message_segment.type != "seglist":
|
||||
F_type = message.message_segment.type
|
||||
f_type = message.message_segment.type
|
||||
else:
|
||||
if (
|
||||
isinstance(message.message_segment.data, list)
|
||||
and all(isinstance(x, Seg) for x in message.message_segment.data)
|
||||
and len(message.message_segment.data) == 1
|
||||
):
|
||||
F_type = message.message_segment.data[0].type
|
||||
if F_type == "text":
|
||||
f_type = message.message_segment.data[0].type
|
||||
if f_type == "text":
|
||||
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
|
||||
elif F_type == "image":
|
||||
elif f_type == "image":
|
||||
logger.info("触发缓冲,已炸飞表情包/图片")
|
||||
elif F_type == "seglist":
|
||||
elif f_type == "seglist":
|
||||
logger.info("触发缓冲,已炸飞消息列")
|
||||
return
|
||||
|
||||
@@ -461,7 +464,8 @@ class ThinkFlowChat:
|
||||
# 意愿管理器:注销当前message信息
|
||||
willing_manager.delete(message.message_info.message_id)
|
||||
|
||||
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
||||
@staticmethod
|
||||
def _check_ban_words(text: str, chat, userinfo) -> bool:
|
||||
"""检查消息中是否包含过滤词"""
|
||||
for word in global_config.ban_words:
|
||||
if word in text:
|
||||
@@ -472,7 +476,8 @@ class ThinkFlowChat:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
|
||||
@staticmethod
|
||||
def _check_ban_regex(text: str, chat, userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
for pattern in global_config.ban_msgs_regex:
|
||||
if pattern.search(text):
|
||||
|
||||
@@ -236,7 +236,8 @@ class ResponseGenerator:
|
||||
logger.debug(f"获取情感标签时出错: {e}")
|
||||
return "中立", "平静" # 出错时返回默认值
|
||||
|
||||
async def _process_response(self, content: str) -> List[str]:
|
||||
@staticmethod
|
||||
async def _process_response(content: str) -> List[str]:
|
||||
"""处理响应内容,返回处理后的内容和情感标签"""
|
||||
if not content:
|
||||
return None
|
||||
|
||||
@@ -64,8 +64,9 @@ class PromptBuilder:
|
||||
self.prompt_built = ""
|
||||
self.activate_messages = ""
|
||||
|
||||
@staticmethod
|
||||
async def _build_prompt(
|
||||
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||
chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||
) -> tuple[str, str]:
|
||||
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
|
||||
|
||||
@@ -168,8 +169,9 @@ class PromptBuilder:
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
async def _build_prompt_simple(
|
||||
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||
chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||
) -> tuple[str, str]:
|
||||
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
|
||||
|
||||
@@ -237,8 +239,8 @@ class PromptBuilder:
|
||||
logger.info(f"生成回复的prompt: {prompt}")
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
async def _build_prompt_check_response(
|
||||
self,
|
||||
chat_stream,
|
||||
message_txt: str,
|
||||
sender_name: str = "某人",
|
||||
|
||||
@@ -4,6 +4,8 @@ import math
|
||||
import random
|
||||
import time
|
||||
import re
|
||||
from itertools import combinations
|
||||
|
||||
import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -250,7 +252,8 @@ class Hippocampus:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
return list(self.memory_graph.G.nodes())
|
||||
|
||||
def calculate_node_hash(self, concept, memory_items) -> int:
|
||||
@staticmethod
|
||||
def calculate_node_hash(concept, memory_items) -> int:
|
||||
"""计算节点的特征值"""
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
@@ -258,12 +261,14 @@ class Hippocampus:
|
||||
content = f"{concept}:{'|'.join(sorted_items)}"
|
||||
return hash(content)
|
||||
|
||||
def calculate_edge_hash(self, source, target) -> int:
|
||||
@staticmethod
|
||||
def calculate_edge_hash(source, target) -> int:
|
||||
"""计算边的特征值"""
|
||||
nodes = sorted([source, target])
|
||||
return hash(f"{nodes[0]}:{nodes[1]}")
|
||||
|
||||
def find_topic_llm(self, text, topic_num):
|
||||
@staticmethod
|
||||
def find_topic_llm(text, topic_num):
|
||||
prompt = (
|
||||
f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
||||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
@@ -271,14 +276,16 @@ class Hippocampus:
|
||||
)
|
||||
return prompt
|
||||
|
||||
def topic_what(self, text, topic, time_info):
|
||||
@staticmethod
|
||||
def topic_what(text, topic, time_info):
|
||||
prompt = (
|
||||
f'这是一段文字,{time_info}:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
|
||||
f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
|
||||
)
|
||||
return prompt
|
||||
|
||||
def calculate_topic_num(self, text, compress_rate):
|
||||
@staticmethod
|
||||
def calculate_topic_num(text, compress_rate):
|
||||
"""计算文本的话题数量"""
|
||||
information_content = calculate_information_content(text)
|
||||
topic_by_length = text.count("\n") * compress_rate
|
||||
@@ -693,7 +700,8 @@ class EntorhinalCortex:
|
||||
|
||||
return chat_samples
|
||||
|
||||
def random_get_msg_snippet(self, target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
|
||||
@staticmethod
|
||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
|
||||
"""从数据库中随机获取指定时间戳附近的消息片段"""
|
||||
try_count = 0
|
||||
while try_count < 3:
|
||||
@@ -958,7 +966,8 @@ class Hippocampus:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
return list(self.memory_graph.G.nodes())
|
||||
|
||||
def calculate_node_hash(self, concept, memory_items) -> int:
|
||||
@staticmethod
|
||||
def calculate_node_hash(concept, memory_items) -> int:
|
||||
"""计算节点的特征值"""
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
@@ -966,12 +975,14 @@ class Hippocampus:
|
||||
content = f"{concept}:{'|'.join(sorted_items)}"
|
||||
return hash(content)
|
||||
|
||||
def calculate_edge_hash(self, source, target) -> int:
|
||||
@staticmethod
|
||||
def calculate_edge_hash(source, target) -> int:
|
||||
"""计算边的特征值"""
|
||||
nodes = sorted([source, target])
|
||||
return hash(f"{nodes[0]}:{nodes[1]}")
|
||||
|
||||
def find_topic_llm(self, text, topic_num):
|
||||
@staticmethod
|
||||
def find_topic_llm(text, topic_num):
|
||||
prompt = (
|
||||
f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
||||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
@@ -979,14 +990,16 @@ class Hippocampus:
|
||||
)
|
||||
return prompt
|
||||
|
||||
def topic_what(self, text, topic, time_info):
|
||||
@staticmethod
|
||||
def topic_what(text, topic, time_info):
|
||||
prompt = (
|
||||
f'这是一段文字,{time_info}:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
|
||||
f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
|
||||
)
|
||||
return prompt
|
||||
|
||||
def calculate_topic_num(self, text, compress_rate):
|
||||
@staticmethod
|
||||
def calculate_topic_num(text, compress_rate):
|
||||
"""计算文本的话题数量"""
|
||||
information_content = calculate_information_content(text)
|
||||
topic_by_length = text.count("\n") * compress_rate
|
||||
@@ -1542,11 +1555,10 @@ class ParahippocampalGyrus:
|
||||
last_modified=current_time,
|
||||
)
|
||||
|
||||
for i in range(len(all_topics)):
|
||||
for j in range(i + 1, len(all_topics)):
|
||||
logger.debug(f"连接同批次节点: {all_topics[i]} 和 {all_topics[j]}")
|
||||
all_added_edges.append(f"{all_topics[i]}-{all_topics[j]}")
|
||||
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
|
||||
for topic1, topic2 in combinations(all_topics, 2):
|
||||
logger.debug(f"连接同批次节点: {topic1} 和 {topic2}")
|
||||
all_added_edges.append(f"{topic1}-{topic2}")
|
||||
self.memory_graph.connect_dot(topic1, topic2)
|
||||
|
||||
logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
|
||||
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
import unittest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from api import BaseMessageAPI
|
||||
from message_base import (
|
||||
BaseMessageInfo,
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
FormatInfo,
|
||||
MessageBase,
|
||||
Seg,
|
||||
)
|
||||
|
||||
|
||||
send_url = "http://localhost"
|
||||
receive_port = 18002 # 接收消息的端口
|
||||
send_port = 18000 # 发送消息的端口
|
||||
test_endpoint = "/api/message"
|
||||
|
||||
# 创建并启动API实例
|
||||
api = BaseMessageAPI(host="0.0.0.0", port=receive_port)
|
||||
|
||||
|
||||
class TestLiveAPI(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
"""测试前的设置"""
|
||||
self.received_messages = []
|
||||
|
||||
async def message_handler(message):
|
||||
self.received_messages.append(message)
|
||||
|
||||
self.api = api
|
||||
self.api.register_message_handler(message_handler)
|
||||
self.server_task = asyncio.create_task(self.api.run())
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.sleep(1), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
self.skipTest("服务器启动超时")
|
||||
|
||||
async def asyncTearDown(self):
|
||||
"""测试后的清理"""
|
||||
if hasattr(self, "server_task"):
|
||||
await self.api.stop() # 先调用正常的停止流程
|
||||
if not self.server_task.done():
|
||||
self.server_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self.server_task, timeout=100)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
async def test_send_and_receive_message(self):
|
||||
"""测试向运行中的API发送消息并接收响应"""
|
||||
# 准备测试消息
|
||||
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
|
||||
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
|
||||
format_info = FormatInfo(content_format=["text"], accept_format=["text", "emoji", "reply"])
|
||||
template_info = None
|
||||
message_info = BaseMessageInfo(
|
||||
platform="qq",
|
||||
message_id=12345678,
|
||||
time=12345678,
|
||||
group_info=group_info,
|
||||
user_info=user_info,
|
||||
format_info=format_info,
|
||||
template_info=template_info,
|
||||
)
|
||||
message = MessageBase(
|
||||
message_info=message_info,
|
||||
raw_message="测试消息",
|
||||
message_segment=Seg(type="text", data="测试消息"),
|
||||
)
|
||||
test_message = message.to_dict()
|
||||
|
||||
# 发送测试消息到发送端口
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{send_url}:{send_port}{test_endpoint}",
|
||||
json=test_message,
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
self.assertEqual(response.status, 200)
|
||||
self.assertEqual(response_data["status"], "success")
|
||||
try:
|
||||
async with asyncio.timeout(5): # 设置5秒超时
|
||||
while len(self.received_messages) == 0:
|
||||
await asyncio.sleep(0.1)
|
||||
received_message = self.received_messages[0]
|
||||
print(received_message)
|
||||
self.received_messages.clear()
|
||||
except asyncio.TimeoutError:
|
||||
self.fail("等待接收消息超时")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -72,7 +72,8 @@ class PersonInfoManager:
|
||||
self.person_name_list[doc["person_id"]] = doc["person_name"]
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
|
||||
|
||||
def get_person_id(self, platform: str, user_id: int):
|
||||
@staticmethod
|
||||
def get_person_id(platform: str, user_id: int):
|
||||
"""获取唯一id"""
|
||||
# 如果platform中存在-,就截取-后面的部分
|
||||
if "-" in platform:
|
||||
@@ -91,7 +92,8 @@ class PersonInfoManager:
|
||||
else:
|
||||
return False
|
||||
|
||||
async def create_person_info(self, person_id: str, data: dict = None):
|
||||
@staticmethod
|
||||
async def create_person_info(person_id: str, data: dict = None):
|
||||
"""创建一个项"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,personid不存在")
|
||||
@@ -131,7 +133,8 @@ class PersonInfoManager:
|
||||
else:
|
||||
return False
|
||||
|
||||
def _extract_json_from_text(self, text: str) -> dict:
|
||||
@staticmethod
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
try:
|
||||
# 尝试直接解析
|
||||
@@ -225,7 +228,8 @@ class PersonInfoManager:
|
||||
logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称")
|
||||
return None
|
||||
|
||||
async def del_one_document(self, person_id: str):
|
||||
@staticmethod
|
||||
async def del_one_document(person_id: str):
|
||||
"""删除指定 person_id 的文档"""
|
||||
if not person_id:
|
||||
logger.debug("删除失败:person_id 不能为空")
|
||||
@@ -237,7 +241,8 @@ class PersonInfoManager:
|
||||
else:
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id}")
|
||||
|
||||
async def get_value(self, person_id: str, field_name: str):
|
||||
@staticmethod
|
||||
async def get_value(person_id: str, field_name: str):
|
||||
"""获取指定person_id文档的字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
if not person_id:
|
||||
logger.debug("get_value获取失败:person_id不能为空")
|
||||
@@ -256,7 +261,8 @@ class PersonInfoManager:
|
||||
logger.trace(f"获取{person_id}的{field_name}失败,已返回默认值{default_value}")
|
||||
return default_value
|
||||
|
||||
async def get_values(self, person_id: str, field_names: list) -> dict:
|
||||
@staticmethod
|
||||
async def get_values(person_id: str, field_names: list) -> dict:
|
||||
"""获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
if not person_id:
|
||||
logger.debug("get_values获取失败:person_id不能为空")
|
||||
@@ -281,7 +287,8 @@ class PersonInfoManager:
|
||||
|
||||
return result
|
||||
|
||||
async def del_all_undefined_field(self):
|
||||
@staticmethod
|
||||
async def del_all_undefined_field():
|
||||
"""删除所有项里的未定义字段"""
|
||||
# 获取所有已定义的字段名
|
||||
defined_fields = set(person_info_default.keys())
|
||||
@@ -307,8 +314,8 @@ class PersonInfoManager:
|
||||
logger.error(f"清理未定义字段时出错: {e}")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def get_specific_value_list(
|
||||
self,
|
||||
field_name: str,
|
||||
way: Callable[[Any], bool], # 接受任意类型值
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
@@ -62,7 +62,7 @@ class RelationshipManager:
|
||||
def mood_feedback(self, value):
|
||||
"""情绪反馈"""
|
||||
mood_manager = self.mood_manager
|
||||
mood_gain = (mood_manager.get_current_mood().valence) ** 2 * math.copysign(
|
||||
mood_gain = mood_manager.get_current_mood().valence ** 2 * math.copysign(
|
||||
1, value * mood_manager.get_current_mood().valence
|
||||
)
|
||||
value += value * mood_gain
|
||||
@@ -77,24 +77,27 @@ class RelationshipManager:
|
||||
else:
|
||||
return mood_value / coefficient
|
||||
|
||||
async def is_known_some_one(self, platform, user_id):
|
||||
@staticmethod
|
||||
async def is_known_some_one(platform, user_id):
|
||||
"""判断是否认识某人"""
|
||||
is_known = person_info_manager.is_person_known(platform, user_id)
|
||||
return is_known
|
||||
|
||||
async def is_qved_name(self, platform, user_id):
|
||||
@staticmethod
|
||||
async def is_qved_name(platform, user_id):
|
||||
"""判断是否认识某人"""
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
is_qved = await person_info_manager.has_one_field(person_id, "person_name")
|
||||
old_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
print(f"old_name: {old_name}")
|
||||
print(f"is_qved: {is_qved}")
|
||||
if is_qved and old_name != None:
|
||||
if is_qved and old_name is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def first_knowing_some_one(self, platform, user_id, user_nickname, user_cardname, user_avatar):
|
||||
@staticmethod
|
||||
async def first_knowing_some_one(platform, user_id, user_nickname, user_cardname, user_avatar):
|
||||
"""判断是否认识某人"""
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
await person_info_manager.update_one_field(person_id, "nickname", user_nickname)
|
||||
@@ -102,7 +105,8 @@ class RelationshipManager:
|
||||
# await person_info_manager.update_one_field(person_id, "user_avatar", user_avatar)
|
||||
await person_info_manager.qv_person_name(person_id, user_nickname, user_cardname, user_avatar)
|
||||
|
||||
async def convert_all_person_sign_to_person_name(self, input_text: str):
|
||||
@staticmethod
|
||||
async def convert_all_person_sign_to_person_name(input_text: str):
|
||||
"""将所有人的<platform:user_id:nickname:cardname>格式转换为person_name"""
|
||||
try:
|
||||
# 使用正则表达式匹配<platform:user_id:nickname:cardname>格式
|
||||
@@ -119,7 +123,7 @@ class RelationshipManager:
|
||||
person_name = nickname.strip() if nickname.strip() else cardname.strip()
|
||||
|
||||
if person_id in all_person:
|
||||
if all_person[person_id] != None:
|
||||
if all_person[person_id] is not None:
|
||||
person_name = all_person[person_id]
|
||||
|
||||
print(f"将<{platform}:{user_id}:{nickname}:{cardname}>替换为{person_name}")
|
||||
@@ -326,7 +330,8 @@ class RelationshipManager:
|
||||
f"回复态度为{relation_prompt2_list[level_num]},关系等级为{level_num}。"
|
||||
)
|
||||
|
||||
def calculate_level_num(self, relationship_value) -> int:
|
||||
@staticmethod
|
||||
def calculate_level_num(relationship_value) -> int:
|
||||
"""关系等级计算"""
|
||||
if -1000 <= relationship_value < -227:
|
||||
level_num = 0
|
||||
@@ -344,7 +349,8 @@ class RelationshipManager:
|
||||
level_num = 5 if relationship_value > 1000 else 0
|
||||
return level_num
|
||||
|
||||
def ensure_float(self, value, person_id):
|
||||
@staticmethod
|
||||
def ensure_float(value, person_id):
|
||||
"""确保返回浮点数,转换失败返回0.0"""
|
||||
if isinstance(value, float):
|
||||
return value
|
||||
|
||||
@@ -100,7 +100,8 @@ class InfoCatcher:
|
||||
self.trigger_response_message, first_bot_msg
|
||||
)
|
||||
|
||||
def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message):
|
||||
@staticmethod
|
||||
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
|
||||
try:
|
||||
# 从数据库中获取消息的时间戳
|
||||
time_start = message_start.message_info.time
|
||||
@@ -155,7 +156,8 @@ class InfoCatcher:
|
||||
|
||||
return result
|
||||
|
||||
def message_to_dict(self, message):
|
||||
@staticmethod
|
||||
def message_to_dict(message):
|
||||
if not message:
|
||||
return None
|
||||
if isinstance(message, dict):
|
||||
|
||||
@@ -235,6 +235,7 @@ class ScheduleGenerator:
|
||||
|
||||
Args:
|
||||
num (int): 需要获取的日程数量,默认为1
|
||||
time_info (bool): 是否包含时间信息,默认为False
|
||||
|
||||
Returns:
|
||||
list: 最新加入的日程列表
|
||||
@@ -267,7 +268,8 @@ class ScheduleGenerator:
|
||||
db.schedule.update_one({"date": date_str}, {"$set": schedule_data}, upsert=True)
|
||||
logger.debug(f"已保存{date_str}的日程到数据库")
|
||||
|
||||
def load_schedule_from_db(self, date: datetime.datetime):
|
||||
@staticmethod
|
||||
def load_schedule_from_db(date: datetime.datetime):
|
||||
"""从数据库加载日程,同时加载 today_done_list"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
existing_schedule = db.schedule.find_one({"date": date_str})
|
||||
|
||||
@@ -10,7 +10,8 @@ logger = get_module_logger("message_storage")
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
@staticmethod
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 莫越权 救世啊
|
||||
@@ -43,7 +44,8 @@ class MessageStorage:
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
|
||||
async def store_recalled_message(self, message_id: str, time: str, chat_stream: ChatStream) -> None:
|
||||
@staticmethod
|
||||
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
|
||||
"""存储撤回消息到数据库"""
|
||||
if "recalled_messages" not in db.list_collection_names():
|
||||
db.create_collection("recalled_messages")
|
||||
@@ -58,7 +60,8 @@ class MessageStorage:
|
||||
except Exception:
|
||||
logger.exception("存储撤回消息失败")
|
||||
|
||||
async def remove_recalled_message(self, time: str) -> None:
|
||||
@staticmethod
|
||||
async def remove_recalled_message(time: str) -> None:
|
||||
"""删除撤回消息"""
|
||||
try:
|
||||
db.recalled_messages.delete_many({"time": {"$lt": time - 300}})
|
||||
|
||||
@@ -28,7 +28,7 @@ class TopicIdentifier:
|
||||
|
||||
消息内容:{text}"""
|
||||
|
||||
# 使用 LLM_request 类进行请求
|
||||
# 使用 LLMRequest 类进行请求
|
||||
try:
|
||||
topic, _, _ = await self.llm_topic_judge.generate_response(prompt)
|
||||
except Exception as e:
|
||||
|
||||
@@ -24,7 +24,8 @@ class LLMStatistics:
|
||||
self._init_database()
|
||||
self.name_dict: Dict[List] = {}
|
||||
|
||||
def _init_database(self):
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
if "online_time" not in db.list_collection_names():
|
||||
db.create_collection("online_time")
|
||||
@@ -51,7 +52,8 @@ class LLMStatistics:
|
||||
if self.console_thread:
|
||||
self.console_thread.join()
|
||||
|
||||
def _record_online_time(self):
|
||||
@staticmethod
|
||||
def _record_online_time():
|
||||
"""记录在线时间"""
|
||||
current_time = datetime.now()
|
||||
# 检查5分钟内是否已有记录
|
||||
@@ -187,7 +189,7 @@ class LLMStatistics:
|
||||
|
||||
# 按模型统计
|
||||
output.append("按模型统计:")
|
||||
output.append(("模型名称 调用次数 Token总量 累计花费"))
|
||||
output.append("模型名称 调用次数 Token总量 累计花费")
|
||||
for model_name, count in sorted(stats["requests_by_model"].items()):
|
||||
tokens = stats["tokens_by_model"][model_name]
|
||||
cost = stats["costs_by_model"][model_name]
|
||||
@@ -198,7 +200,7 @@ class LLMStatistics:
|
||||
|
||||
# 按请求类型统计
|
||||
output.append("按请求类型统计:")
|
||||
output.append(("模型名称 调用次数 Token总量 累计花费"))
|
||||
output.append("模型名称 调用次数 Token总量 累计花费")
|
||||
for req_type, count in sorted(stats["requests_by_type"].items()):
|
||||
tokens = stats["tokens_by_type"][req_type]
|
||||
cost = stats["costs_by_type"][req_type]
|
||||
@@ -209,7 +211,7 @@ class LLMStatistics:
|
||||
|
||||
# 修正用户统计列宽
|
||||
output.append("按用户统计:")
|
||||
output.append(("用户ID 调用次数 Token总量 累计花费"))
|
||||
output.append("用户ID 调用次数 Token总量 累计花费")
|
||||
for user_id, count in sorted(stats["requests_by_user"].items()):
|
||||
tokens = stats["tokens_by_user"][user_id]
|
||||
cost = stats["costs_by_user"][user_id]
|
||||
@@ -225,7 +227,7 @@ class LLMStatistics:
|
||||
|
||||
# 添加聊天统计
|
||||
output.append("群组统计:")
|
||||
output.append(("群组名称 消息数量"))
|
||||
output.append("群组名称 消息数量")
|
||||
for group_id, count in sorted(stats["messages_by_chat"].items()):
|
||||
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
|
||||
|
||||
@@ -246,7 +248,7 @@ class LLMStatistics:
|
||||
|
||||
# 按模型统计
|
||||
output.append("按模型统计:")
|
||||
output.append(("模型名称 调用次数 Token总量 累计花费"))
|
||||
output.append("模型名称 调用次数 Token总量 累计花费")
|
||||
for model_name, count in sorted(stats["requests_by_model"].items()):
|
||||
tokens = stats["tokens_by_model"][model_name]
|
||||
cost = stats["costs_by_model"][model_name]
|
||||
@@ -284,7 +286,7 @@ class LLMStatistics:
|
||||
|
||||
# 添加聊天统计
|
||||
output.append("群组统计:")
|
||||
output.append(("群组名称 消息数量"))
|
||||
output.append("群组名称 消息数量")
|
||||
for group_id, count in sorted(stats["messages_by_chat"].items()):
|
||||
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
|
||||
|
||||
|
||||
@@ -90,7 +90,8 @@ class Timer:
|
||||
self.auto_unit = auto_unit
|
||||
self.start = None
|
||||
|
||||
def _validate_types(self, name, storage):
|
||||
@staticmethod
|
||||
def _validate_types(name, storage):
|
||||
"""类型检查"""
|
||||
if name is not None and not isinstance(name, str):
|
||||
raise TimerTypeError("name", "Optional[str]", type(name))
|
||||
|
||||
@@ -77,7 +77,8 @@ class ChineseTypoGenerator:
|
||||
|
||||
return normalized_freq
|
||||
|
||||
def _create_pinyin_dict(self):
|
||||
@staticmethod
|
||||
def _create_pinyin_dict():
|
||||
"""
|
||||
创建拼音到汉字的映射字典
|
||||
"""
|
||||
@@ -95,7 +96,8 @@ class ChineseTypoGenerator:
|
||||
|
||||
return pinyin_dict
|
||||
|
||||
def _is_chinese_char(self, char):
|
||||
@staticmethod
|
||||
def _is_chinese_char(char):
|
||||
"""
|
||||
判断是否为汉字
|
||||
"""
|
||||
@@ -124,7 +126,8 @@ class ChineseTypoGenerator:
|
||||
|
||||
return result
|
||||
|
||||
def _get_similar_tone_pinyin(self, py):
|
||||
@staticmethod
|
||||
def _get_similar_tone_pinyin(py):
|
||||
"""
|
||||
获取相似声调的拼音
|
||||
"""
|
||||
@@ -211,13 +214,15 @@ class ChineseTypoGenerator:
|
||||
# 返回概率最高的几个字
|
||||
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||
|
||||
def _get_word_pinyin(self, word):
|
||||
@staticmethod
|
||||
def _get_word_pinyin(word):
|
||||
"""
|
||||
获取词语的拼音列表
|
||||
"""
|
||||
return [py[0] for py in pinyin(word, style=Style.TONE3)]
|
||||
|
||||
def _segment_sentence(self, sentence):
|
||||
@staticmethod
|
||||
def _segment_sentence(sentence):
|
||||
"""
|
||||
使用jieba分词,返回词语列表
|
||||
"""
|
||||
@@ -392,7 +397,8 @@ class ChineseTypoGenerator:
|
||||
|
||||
return "".join(result), correction_suggestion
|
||||
|
||||
def format_typo_info(self, typo_info):
|
||||
@staticmethod
|
||||
def format_typo_info(typo_info):
|
||||
"""
|
||||
格式化错别字信息
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ llmcheck 模式:
|
||||
|
||||
import time
|
||||
from loguru import logger
|
||||
from ..models.utils_model import LLM_request
|
||||
from ..models.utils_model import LLMRequest
|
||||
from ...config.config import global_config
|
||||
|
||||
# from ..chat.chat_stream import ChatStream
|
||||
@@ -61,7 +61,7 @@ def llmcheck_decorator(trigger_condition_func):
|
||||
class LlmcheckWillingManager(MxpWillingManager):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.3)
|
||||
self.model_v3 = LLMRequest(model=global_config.llm_normal, temperature=0.3)
|
||||
|
||||
async def get_llmreply_probability(self, message_id: str):
|
||||
message_info = self.ongoing_messages[message_id]
|
||||
|
||||
@@ -240,7 +240,8 @@ class MxpWillingManager(BaseWillingManager):
|
||||
-2 * self.basic_maximum_willing * self.fatigue_coefficient
|
||||
)
|
||||
|
||||
def _willing_to_probability(self, willing: float) -> float:
|
||||
@staticmethod
|
||||
def _willing_to_probability(willing: float) -> float:
|
||||
"""意愿值转化为概率"""
|
||||
willing = max(0, willing)
|
||||
if willing < 2:
|
||||
@@ -285,7 +286,8 @@ class MxpWillingManager(BaseWillingManager):
|
||||
if self.is_debug:
|
||||
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
|
||||
|
||||
def _get_relationship_level_num(self, relationship_value) -> int:
|
||||
@staticmethod
|
||||
def _get_relationship_level_num(relationship_value) -> int:
|
||||
"""关系等级计算"""
|
||||
if -1000 <= relationship_value < -227:
|
||||
level_num = 0
|
||||
|
||||
@@ -35,12 +35,14 @@ class KnowledgeLibrary:
|
||||
"""确保必要的目录存在"""
|
||||
os.makedirs(self.raw_info_dir, exist_ok=True)
|
||||
|
||||
def read_file(self, file_path: str) -> str:
|
||||
@staticmethod
|
||||
def read_file(file_path: str) -> str:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def split_content(self, content: str, max_length: int = 512) -> list:
|
||||
@staticmethod
|
||||
def split_content(content: str, max_length: int = 512) -> list:
|
||||
"""将内容分割成适当大小的块,按空行分割
|
||||
|
||||
Args:
|
||||
@@ -146,7 +148,8 @@ class KnowledgeLibrary:
|
||||
|
||||
return result
|
||||
|
||||
def _update_stats(self, total_stats, result, filename):
|
||||
@staticmethod
|
||||
def _update_stats(total_stats, result, filename):
|
||||
"""更新总体统计信息"""
|
||||
if result["status"] == "success":
|
||||
total_stats["processed_files"] += 1
|
||||
@@ -181,7 +184,8 @@ class KnowledgeLibrary:
|
||||
for filename in stats["skipped_files"]:
|
||||
self.console.print(f"[yellow]- {filename}[/yellow]")
|
||||
|
||||
def calculate_file_hash(self, file_path):
|
||||
@staticmethod
|
||||
def calculate_file_hash(file_path):
|
||||
"""计算文件的MD5哈希值"""
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
|
||||
Reference in New Issue
Block a user