This commit is contained in:
tcmofashi
2025-04-17 15:53:26 +08:00
52 changed files with 299 additions and 336 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -1,6 +1,5 @@
.git .git
__pycache__ __pycache__
*.pyc
*.pyo *.pyo
*.pyd *.pyd
.DS_Store .DS_Store

1
.gitignore vendored
View File

@@ -239,6 +239,5 @@ logs
.vscode .vscode
/config/* /config/*
run_none.bat
config/old/bot_config_20250405_212257.toml config/old/bot_config_20250405_212257.toml

View File

@@ -19,7 +19,6 @@
● [我有问题](#我有问题) ● [我有问题](#我有问题)
● [我想做贡献](#我想做贡献) ● [我想做贡献](#我想做贡献)
● [我想报告BUG](#报告BUG)
● [我想提出建议](#提出建议) ● [我想提出建议](#提出建议)
## 我有问题 ## 我有问题

View File

@@ -114,7 +114,7 @@
## 🎯 功能介绍 ## 🎯 功能介绍
| 模块 | 主要功能 | 特点 | | 模块 | 主要功能 | 特点 |
|------|---------|------| |----------|------------------------------------------------------------------|-------|
| 💬 聊天系统 | • 心流/推理聊天<br>• 关键词主动发言<br>• 多模型支持<br>• 动态prompt构建<br>• 私聊功能(PFC) | 拟人化交互 | | 💬 聊天系统 | • 心流/推理聊天<br>• 关键词主动发言<br>• 多模型支持<br>• 动态prompt构建<br>• 私聊功能(PFC) | 拟人化交互 |
| 🧠 心流系统 | • 实时思考生成<br>• 自动启停机制<br>• 日程系统联动<br>• 工具调用能力 | 智能化决策 | | 🧠 心流系统 | • 实时思考生成<br>• 自动启停机制<br>• 日程系统联动<br>• 工具调用能力 | 智能化决策 |
| 🧠 记忆系统 | • 优化记忆抽取<br>• 海马体记忆机制<br>• 聊天记录概括 | 持久化记忆 | | 🧠 记忆系统 | • 优化记忆抽取<br>• 海马体记忆机制<br>• 聊天记录概括 | 持久化记忆 |

View File

@@ -47,7 +47,7 @@ if not SIMPLE_OUTPUT:
"<cyan>{extra[module]: <12}</cyan> | " "<cyan>{extra[module]: <12}</cyan> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | {message}",
"log_dir": LOG_ROOT, "log_dir": LOG_ROOT,
"rotation": "00:00", "rotation": "00:00",
"retention": "3 days", "retention": "3 days",
@@ -59,8 +59,8 @@ else:
"console_level": "INFO", "console_level": "INFO",
"file_level": "DEBUG", "file_level": "DEBUG",
# 格式配置 # 格式配置
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <cyan>{extra[module]}</cyan> | {message}"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <cyan>{extra[module]}</cyan> | {message}",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | {message}",
"log_dir": LOG_ROOT, "log_dir": LOG_ROOT,
"rotation": "00:00", "rotation": "00:00",
"retention": "3 days", "retention": "3 days",
@@ -78,13 +78,13 @@ MEMORY_STYLE_CONFIG = {
"<light-yellow>海马体</light-yellow> | " "<light-yellow>海马体</light-yellow> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}",
}, },
"simple": { "simple": {
"console_format": ( "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-yellow>海马体</light-yellow> | <light-yellow>{message}</light-yellow>" "<green>{time:MM-DD HH:mm}</green> | <light-yellow>海马体</light-yellow> | <light-yellow>{message}</light-yellow>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}",
}, },
} }
@@ -99,11 +99,11 @@ MOOD_STYLE_CONFIG = {
"<light-green>心情</light-green> | " "<light-green>心情</light-green> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 心情 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 心情 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <magenta>心情</magenta> | {message}"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <magenta>心情</magenta> | {message}",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 心情 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 心情 | {message}",
}, },
} }
# tool use # tool use
@@ -116,11 +116,11 @@ TOOL_USE_STYLE_CONFIG = {
"<magenta>工具使用</magenta> | " "<magenta>工具使用</magenta> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 工具使用 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 工具使用 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <magenta>工具使用</magenta> | {message}"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <magenta>工具使用</magenta> | {message}",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 工具使用 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 工具使用 | {message}",
}, },
} }
@@ -135,11 +135,11 @@ RELATION_STYLE_CONFIG = {
"<light-magenta>关系</light-magenta> | " "<light-magenta>关系</light-magenta> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-magenta>关系</light-magenta> | {message}"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <light-magenta>关系</light-magenta> | {message}",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 关系 | {message}",
}, },
} }
@@ -153,11 +153,11 @@ CONFIG_STYLE_CONFIG = {
"<light-cyan>配置</light-cyan> | " "<light-cyan>配置</light-cyan> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 配置 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 配置 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-cyan>配置</light-cyan> | {message}"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <light-cyan>配置</light-cyan> | {message}",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 配置 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 配置 | {message}",
}, },
} }
@@ -170,11 +170,11 @@ SENDER_STYLE_CONFIG = {
"<light-yellow>消息发送</light-yellow> | " "<light-yellow>消息发送</light-yellow> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息发送 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息发送 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <green>消息发送</green> | {message}"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <green>消息发送</green> | {message}",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息发送 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 消息发送 | {message}",
}, },
} }
@@ -187,13 +187,13 @@ HEARTFLOW_STYLE_CONFIG = {
"<light-yellow>麦麦大脑袋</light-yellow> | " "<light-yellow>麦麦大脑袋</light-yellow> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}",
}, },
"simple": { "simple": {
"console_format": ( "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦大脑袋</light-green> | <light-green>{message}</light-green>" "<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦大脑袋</light-green> | <light-green>{message}</light-green>"
), # noqa: E501 ), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}",
}, },
} }
@@ -206,11 +206,11 @@ SCHEDULE_STYLE_CONFIG = {
"<light-yellow>在干嘛</light-yellow> | " "<light-yellow>在干嘛</light-yellow> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <cyan>在干嘛</cyan> | <cyan>{message}</cyan>"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <cyan>在干嘛</cyan> | <cyan>{message}</cyan>",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 在干嘛 | {message}",
}, },
} }
@@ -223,11 +223,11 @@ LLM_STYLE_CONFIG = {
"<light-yellow>麦麦组织语言</light-yellow> | " "<light-yellow>麦麦组织语言</light-yellow> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦组织语言 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦组织语言 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦组织语言</light-green> | {message}"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦组织语言</light-green> | {message}",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦组织语言 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦组织语言 | {message}",
}, },
} }
@@ -242,11 +242,11 @@ TOPIC_STYLE_CONFIG = {
"<light-blue>话题</light-blue> | " "<light-blue>话题</light-blue> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 话题 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 话题 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>主题</light-blue> | {message}"), "console_format": "<green>{time:MM-DD HH:mm}</green> | <light-blue>主题</light-blue> | {message}",
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 话题 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 话题 | {message}",
}, },
} }
@@ -260,13 +260,13 @@ CHAT_STYLE_CONFIG = {
"<light-blue>见闻</light-blue> | " "<light-blue>见闻</light-blue> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
}, },
"simple": { "simple": {
"console_format": ( "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-blue>见闻</light-blue> | <green>{message}</green>" "<green>{time:MM-DD HH:mm}</green> | <light-blue>见闻</light-blue> | <green>{message}</green>"
), # noqa: E501 ), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}",
}, },
} }
@@ -279,13 +279,13 @@ SUB_HEARTFLOW_STYLE_CONFIG = {
"<light-blue>麦麦小脑袋</light-blue> | " "<light-blue>麦麦小脑袋</light-blue> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}",
}, },
"simple": { "simple": {
"console_format": ( "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-blue>麦麦小脑袋</light-blue> | <light-blue>{message}</light-blue>" "<green>{time:MM-DD HH:mm}</green> | <light-blue>麦麦小脑袋</light-blue> | <light-blue>{message}</light-blue>"
), # noqa: E501 ), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}",
}, },
} }
@@ -298,17 +298,17 @@ WILLING_STYLE_CONFIG = {
"<light-blue>意愿</light-blue> | " "<light-blue>意愿</light-blue> | "
"<level>{message}</level>" "<level>{message}</level>"
), ),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}",
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}"), # noqa: E501 "console_format": "<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}", # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}",
}, },
} }
CONFIRM_STYLE_CONFIG = { CONFIRM_STYLE_CONFIG = {
"console_format": ("<RED>{message}</RED>"), # noqa: E501 "console_format": "<RED>{message}</RED>", # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"), "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}",
} }
# 根据SIMPLE_OUTPUT选择配置 # 根据SIMPLE_OUTPUT选择配置
@@ -459,7 +459,7 @@ other_log_dir.mkdir(parents=True, exist_ok=True)
DEFAULT_FILE_HANDLER = logger.add( DEFAULT_FILE_HANDLER = logger.add(
sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"), sink=str(other_log_dir / "{time:YYYY-MM-DD}.log"),
level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"), level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
format=("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}"), format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name: <15} | {message}",
rotation=DEFAULT_CONFIG["rotation"], rotation=DEFAULT_CONFIG["rotation"],
retention=DEFAULT_CONFIG["retention"], retention=DEFAULT_CONFIG["retention"],
compression=DEFAULT_CONFIG["compression"], compression=DEFAULT_CONFIG["compression"],

View File

@@ -130,9 +130,6 @@ def update_config():
logger.info("配置文件更新完成") logger.info("配置文件更新完成")
logger = get_module_logger("config")
@dataclass @dataclass
class BotConfig: class BotConfig:
"""机器人配置类""" """机器人配置类"""
@@ -284,21 +281,6 @@ class BotConfig:
llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {}) llm_sub_heartflow: Dict[str, str] = field(default_factory=lambda: {})
llm_heartflow: Dict[str, str] = field(default_factory=lambda: {}) llm_heartflow: Dict[str, str] = field(default_factory=lambda: {})
build_memory_interval: int = 600 # 记忆构建间隔(秒)
forget_memory_interval: int = 600 # 记忆遗忘间隔(秒)
memory_forget_time: int = 24 # 记忆遗忘时间(小时)
memory_forget_percentage: float = 0.01 # 记忆遗忘比例
memory_compress_rate: float = 0.1 # 记忆压缩率
build_memory_sample_num: int = 10 # 记忆构建采样数量
build_memory_sample_length: int = 20 # 记忆构建采样长度
memory_build_distribution: list = field(
default_factory=lambda: [4, 2, 0.6, 24, 8, 0.4]
) # 记忆构建分布参数分布1均值标准差权重分布2均值标准差权重
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
api_urls: Dict[str, str] = field(default_factory=lambda: {}) api_urls: Dict[str, str] = field(default_factory=lambda: {})
@staticmethod @staticmethod

View File

@@ -23,13 +23,12 @@ class ChangeMoodTool(BaseTool):
"required": ["text", "response_set"], "required": ["text", "response_set"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str) -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行心情改变 """执行心情改变
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_processed_plain_text: 原始消息文本 message_txt: 原始消息文本
response_set: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果

View File

@@ -1,4 +1,6 @@
# from src.plugins.person_info.relationship_manager import relationship_manager # from src.plugins.person_info.relationship_manager import relationship_manager
from typing import Dict, Any
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.do_tool.tool_can_use.base_tool import BaseTool from src.do_tool.tool_can_use.base_tool import BaseTool
# from src.plugins.chat_module.think_flow_chat.think_flow_generator import ResponseGenerator # from src.plugins.chat_module.think_flow_chat.think_flow_generator import ResponseGenerator
@@ -20,22 +22,20 @@ class RelationshipTool(BaseTool):
"required": ["text", "changed_value", "reason"], "required": ["text", "changed_value", "reason"],
} }
async def execute(self, args: dict, message_txt: str) -> dict: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> dict:
"""执行工具功能 """执行工具功能
Args: Args:
args: 包含工具参数的字典 function_args: 包含工具参数的字典
text: 原始消息文本 message_txt: 原始消息文本
changed_value: 变更值
reason: 变更原因
Returns: Returns:
dict: 包含执行结果的字典 dict: 包含执行结果的字典
""" """
try: try:
text = args.get("text") text = function_args.get("text")
changed_value = args.get("changed_value") changed_value = function_args.get("changed_value")
reason = args.get("reason") reason = function_args.get("reason")
return {"content": f"因为你刚刚因为{reason},所以你和发[{text}]这条消息的人的关系值变化为{changed_value}"} return {"content": f"因为你刚刚因为{reason},所以你和发[{text}]这条消息的人的关系值变化为{changed_value}"}

View File

@@ -49,8 +49,9 @@ class SearchKnowledgeTool(BaseTool):
logger.error(f"知识库搜索工具执行失败: {str(e)}") logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"} return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"}
@staticmethod
def get_info_from_db( def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]: ) -> Union[str, list]:
"""从数据库中获取相关信息 """从数据库中获取相关信息

View File

@@ -17,7 +17,7 @@ class SendEmojiTool(BaseTool):
"required": ["text"], "required": ["text"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str) -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
text = function_args.get("text", message_txt) text = function_args.get("text", message_txt)
return { return {
"name": "send_emoji", "name": "send_emoji",

View File

@@ -24,8 +24,9 @@ class ToolUser:
model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use" model=global_config.llm_tool_use, temperature=0.2, max_tokens=1000, request_type="tool_use"
) )
@staticmethod
async def _build_tool_prompt( async def _build_tool_prompt(
self, message_txt: str, sender_name: str, chat_stream: ChatStream, subheartflow: SubHeartflow = None message_txt: str, sender_name: str, chat_stream: ChatStream, subheartflow: SubHeartflow = None
): ):
"""构建工具使用的提示词 """构建工具使用的提示词
@@ -69,7 +70,8 @@ class ToolUser:
prompt += "你现在需要对群里的聊天内容进行回复,现在选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么。" prompt += "你现在需要对群里的聊天内容进行回复,现在选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么。"
return prompt return prompt
def _define_tools(self): @staticmethod
def _define_tools():
"""获取所有已注册工具的定义 """获取所有已注册工具的定义
Returns: Returns:
@@ -77,7 +79,8 @@ class ToolUser:
""" """
return get_all_tool_definitions() return get_all_tool_definitions()
async def _execute_tool_call(self, tool_call, message_txt: str): @staticmethod
async def _execute_tool_call(tool_call, message_txt: str):
"""执行特定的工具调用 """执行特定的工具调用
Args: Args:

View File

@@ -105,7 +105,8 @@ class Heartflow:
# 启动子心流更新任务 # 启动子心流更新任务
asyncio.create_task(self._sub_heartflow_update()) asyncio.create_task(self._sub_heartflow_update())
async def _update_current_state(self): @staticmethod
async def _update_current_state():
print("TODO") print("TODO")
async def do_a_thinking(self): async def do_a_thinking(self):

View File

@@ -161,7 +161,8 @@ class ChattingObservation(Observation):
# print(f"prompt{prompt}") # print(f"prompt{prompt}")
# print(f"self.observe_info{self.observe_info}") # print(f"self.observe_info{self.observe_info}")
def translate_message_list_to_str(self, talking_message): @staticmethod
def translate_message_list_to_str(talking_message):
talking_message_str = "" talking_message_str = ""
for message in talking_message: for message in talking_message:
talking_message_str += message["detailed_plain_text"] talking_message_str += message["detailed_plain_text"]

View File

@@ -102,7 +102,7 @@ class Identity:
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_identity += identity_detail[0] prompt_identity += identity_detail[0]
elif level == 2: elif level == 2:
for detail in identity_detail: for detail in self.identity_detail:
prompt_identity += f",{detail}" prompt_identity += f",{detail}"
prompt_identity += "" prompt_identity += ""
return prompt_identity return prompt_identity

View File

@@ -131,14 +131,16 @@ class MainSystem:
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
async def build_memory_task(self): @staticmethod
async def build_memory_task():
"""记忆构建任务""" """记忆构建任务"""
while True: while True:
await asyncio.sleep(global_config.build_memory_interval) await asyncio.sleep(global_config.build_memory_interval)
logger.info("正在进行记忆构建") logger.info("正在进行记忆构建")
await HippocampusManager.get_instance().build_memory() await HippocampusManager.get_instance().build_memory()
async def forget_memory_task(self): @staticmethod
async def forget_memory_task():
"""记忆遗忘任务""" """记忆遗忘任务"""
while True: while True:
await asyncio.sleep(global_config.forget_memory_interval) await asyncio.sleep(global_config.forget_memory_interval)
@@ -152,7 +154,8 @@ class MainSystem:
self.mood_manager.print_mood_status() self.mood_manager.print_mood_status()
await asyncio.sleep(30) await asyncio.sleep(30)
async def remove_recalled_message_task(self): @staticmethod
async def remove_recalled_message_task():
"""删除撤回消息任务""" """删除撤回消息任务"""
while True: while True:
try: try:

View File

@@ -119,7 +119,6 @@ class ChatObserver:
self.last_cold_chat_check = current_time self.last_cold_chat_check = current_time
# 判断是否冷场 # 判断是否冷场
is_cold = False
if self.last_message_time is None: if self.last_message_time is None:
is_cold = True is_cold = True
else: else:

View File

@@ -113,7 +113,8 @@ class Conversation:
return True return True
return False return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message: @staticmethod
def _convert_to_message(msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象""" """将消息字典转换为Message对象"""
try: try:
chat_info = msg_dict.get("chat_info", {}) chat_info = msg_dict.get("chat_info", {})
@@ -123,7 +124,7 @@ class Conversation:
return Message( return Message(
message_id=msg_dict["message_id"], message_id=msg_dict["message_id"],
chat_stream=chat_stream, chat_stream=chat_stream,
time=msg_dict["time"], timestamp=msg_dict["time"],
user_info=user_info, user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""), processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", ""), detailed_plain_text=msg_dict.get("detailed_plain_text", ""),

View File

@@ -15,8 +15,8 @@ class DirectMessageSender:
def __init__(self): def __init__(self):
pass pass
@staticmethod
async def send_message( async def send_message(
self,
chat_stream: ChatStream, chat_stream: ChatStream,
content: str, content: str,
reply_to_message: Optional[Message] = None, reply_to_message: Optional[Message] = None,

View File

@@ -51,11 +51,9 @@ class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现""" """MongoDB消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]: async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id} query = {"chat_id": chat_id, "time": {"$gt": message_time}}
# print(f"storage_check_message: {message_time}") # print(f"storage_check_message: {message_time}")
query["time"] = {"$gt": message_time}
return list(db.messages.find(query).sort("time", 1)) return list(db.messages.find(query).sort("time", 1))
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:

View File

@@ -160,16 +160,16 @@ class GoalAnalyzer:
# 返回第一个目标作为当前主要目标(如果有) # 返回第一个目标作为当前主要目标(如果有)
if result: if result:
first_goal = result[0] first_goal = result[0]
return (first_goal.get("goal", ""), "", first_goal.get("reasoning", "")) return first_goal.get("goal", ""), "", first_goal.get("reasoning", "")
else: else:
# 单个目标的情况 # 单个目标的情况
goal = result.get("goal", "") goal = result.get("goal", "")
reasoning = result.get("reasoning", "") reasoning = result.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning)) conversation_info.goal_list.append((goal, reasoning))
return (goal, "", reasoning) return goal, "", reasoning
# 如果解析失败,返回默认值 # 如果解析失败,返回默认值
return ("", "", "") return "", "", ""
async def _update_goals(self, new_goal: str, method: str, reasoning: str): async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表 """更新目标列表
@@ -195,7 +195,8 @@ class GoalAnalyzer:
if len(self.goals) > self.max_goals: if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标 self.goals.pop() # 移除最老的目标
def _calculate_similarity(self, goal1: str, goal2: str) -> float: @staticmethod
def _calculate_similarity(goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度 """简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法 这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
@@ -299,7 +300,8 @@ class DirectMessageSender:
self.logger = get_module_logger("direct_sender") self.logger = get_module_logger("direct_sender")
self.storage = MessageStorage() self.storage = MessageStorage()
async def send_via_ws(self, message: MessageSending) -> None: @staticmethod
async def send_via_ws(message: MessageSending) -> None:
try: try:
await global_api.send_message(message) await global_api.send_message(message)
except Exception as e: except Exception as e:

View File

@@ -19,7 +19,8 @@ class KnowledgeFetcher:
request_type="knowledge_fetch", request_type="knowledge_fetch",
) )
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]: @staticmethod
async def fetch(query: str, chat_history: List[Message]) -> Tuple[str, str]:
"""获取相关知识 """获取相关知识
Args: Args:

View File

@@ -30,11 +30,8 @@ class ReplyGenerator:
"""生成回复 """生成回复
Args: Args:
goal: 对话目标 observation_info: 观察信息
chat_history: 聊天历史 conversation_info: 对话信息
knowledge_cache: 知识缓存
previous_reply: 上一次生成的回复(如果有)
retry_count: 当前重试次数
Returns: Returns:
str: 生成的回复 str: 生成的回复

View File

@@ -17,6 +17,5 @@ __all__ = [
"relationship_manager", "relationship_manager",
"MoodManager", "MoodManager",
"willing_manager", "willing_manager",
"hippocampus",
"bot_schedule", "bot_schedule",
] ]

View File

@@ -103,7 +103,8 @@ class ChatManager:
except Exception as e: except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}") logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self): @staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引""" """确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names(): if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams") db.create_collection("chat_streams")
@@ -111,7 +112,8 @@ class ChatManager:
db.chat_streams.create_index([("stream_id", 1)], unique=True) db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]) db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: @staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
if group_info: if group_info:
# 组合关键信息 # 组合关键信息
@@ -188,7 +190,8 @@ class ChatManager:
stream_id = self._generate_stream_id(platform, user_info, group_info) stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id) return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream): @staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库"""
if not stream.saved: if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True) db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)

View File

@@ -82,7 +82,8 @@ class EmojiManager:
if not self._initialized: if not self._initialized:
raise RuntimeError("EmojiManager not initialized") raise RuntimeError("EmojiManager not initialized")
def _ensure_emoji_collection(self): @staticmethod
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引 """确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
@@ -193,7 +194,8 @@ class EmojiManager:
logger.error(f"[错误] 获取表情包失败: {str(e)}") logger.error(f"[错误] 获取表情包失败: {str(e)}")
return None return None
async def _get_emoji_description(self, image_base64: str) -> str: @staticmethod
async def _get_emoji_description(image_base64: str) -> str:
"""获取表情包的标签使用image_manager的描述生成功能""" """获取表情包的标签使用image_manager的描述生成功能"""
try: try:
@@ -554,7 +556,8 @@ class EmojiManager:
self.check_emoji_file_full() self.check_emoji_file_full()
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60) await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
async def delete_all_images(self): @staticmethod
async def delete_all_images():
"""删除 data/image 目录下的所有文件""" """删除 data/image 目录下的所有文件"""
try: try:
image_dir = os.path.join("data", "image") image_dir = os.path.join("data", "image")

View File

@@ -31,7 +31,7 @@ class Message(MessageBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
time: float, timestamp: float,
chat_stream: ChatStream, chat_stream: ChatStream,
user_info: UserInfo, user_info: UserInfo,
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
@@ -43,7 +43,7 @@ class Message(MessageBase):
message_info = BaseMessageInfo( message_info = BaseMessageInfo(
platform=chat_stream.platform, platform=chat_stream.platform,
message_id=message_id, message_id=message_id,
time=time, time=timestamp,
group_info=chat_stream.group_info, group_info=chat_stream.group_info,
user_info=user_info, user_info=user_info,
) )
@@ -143,7 +143,7 @@ class MessageRecv(Message):
def _generate_detailed_text(self) -> str: def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息""" """生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time)) # time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
time = self.message_info.time timestamp = self.message_info.time
user_info = self.message_info.user_info user_info = self.message_info.user_info
# name = ( # name = (
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})" # f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
@@ -151,7 +151,7 @@ class MessageRecv(Message):
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})" # else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
# ) # )
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{time}] {name}: {self.processed_plain_text}\n" return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@dataclass @dataclass
@@ -170,7 +170,7 @@ class MessageProcessBase(Message):
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(
message_id=message_id, message_id=message_id,
time=round(time.time(), 3), # 保留3位小数 timestamp=round(time.time(), 3), # 保留3位小数
chat_stream=chat_stream, chat_stream=chat_stream,
user_info=bot_user_info, user_info=bot_user_info,
message_segment=message_segment, message_segment=message_segment,
@@ -242,7 +242,7 @@ class MessageProcessBase(Message):
def _generate_detailed_text(self) -> str: def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息""" """生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time)) # time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
time = self.message_info.time timestamp = self.message_info.time
user_info = self.message_info.user_info user_info = self.message_info.user_info
# name = ( # name = (
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})" # f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
@@ -250,7 +250,7 @@ class MessageProcessBase(Message):
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})" # else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
# ) # )
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{time}] {name}: {self.processed_plain_text}\n" return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@dataclass @dataclass

View File

@@ -26,7 +26,8 @@ class MessageBuffer:
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {} self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
def get_person_id_(self, platform: str, user_id: str, group_info: GroupInfo): @staticmethod
def get_person_id_(platform: str, user_id: str, group_info: GroupInfo):
"""获取唯一id""" """获取唯一id"""
if group_info: if group_info:
group_id = group_info.group_id group_id = group_info.group_id
@@ -150,20 +151,20 @@ class MessageBuffer:
keep_msgs[msg_id] = msg keep_msgs[msg_id] = msg
elif msg.result == "F": elif msg.result == "F":
# 收集F消息的文本内容 # 收集F消息的文本内容
F_type = "seglist" f_type = "seglist"
if msg.message.message_segment.type != "seglist": if msg.message.message_segment.type != "seglist":
F_type = msg.message.message_segment.type f_type = msg.message.message_segment.type
else: else:
if ( if (
isinstance(msg.message.message_segment.data, list) isinstance(msg.message.message_segment.data, list)
and all(isinstance(x, Seg) for x in msg.message.message_segment.data) and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
and len(msg.message.message_segment.data) == 1 and len(msg.message.message_segment.data) == 1
): ):
F_type = msg.message.message_segment.data[0].type f_type = msg.message.message_segment.data[0].type
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text: if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
if F_type == "text": if f_type == "text":
combined_text.append(msg.message.processed_plain_text) combined_text.append(msg.message.processed_plain_text)
elif F_type != "text": elif f_type != "text":
is_update = False is_update = False
elif msg.result == "U": elif msg.result == "U":
logger.debug(f"异常未处理信息id {msg.message.message_info.message_id}") logger.debug(f"异常未处理信息id {msg.message.message_info.message_id}")
@@ -185,7 +186,8 @@ class MessageBuffer:
logger.debug(f"查询超时消息id {message.message_info.message_id}") logger.debug(f"查询超时消息id {message.message_info.message_id}")
return False return False
async def save_message_interval(self, person_id: str, message: BaseMessageInfo): @staticmethod
async def save_message_interval(person_id: str, message: BaseMessageInfo):
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list") message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
now_time_ms = int(round(time.time() * 1000)) now_time_ms = int(round(time.time() * 1000))
if len(message_interval_list) < 1000: if len(message_interval_list) < 1000:

View File

@@ -35,7 +35,8 @@ class MessageSender:
"""设置当前bot实例""" """设置当前bot实例"""
pass pass
def get_recalled_messages(self, stream_id: str) -> list: @staticmethod
def get_recalled_messages(stream_id: str) -> list:
"""获取所有撤回的消息""" """获取所有撤回的消息"""
recalled_messages = [] recalled_messages = []
@@ -43,7 +44,8 @@ class MessageSender:
# 按thinking_start_time排序时间早的在前面 # 按thinking_start_time排序时间早的在前面
return recalled_messages return recalled_messages
async def send_via_ws(self, message: MessageSending) -> None: @staticmethod
async def send_via_ws(message: MessageSending) -> None:
try: try:
await global_api.send_message(message) await global_api.send_message(message)
except Exception as e: except Exception as e:

View File

@@ -135,7 +135,7 @@ async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
msg = Message( msg = Message(
message_id=msg_data["message_id"], message_id=msg_data["message_id"],
chat_stream=chat_stream, chat_stream=chat_stream,
time=msg_data["time"], timestamp=msg_data["time"],
user_info=user_info, user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""), processed_plain_text=msg_data.get("processed_text", ""),
detailed_plain_text=msg_data.get("detailed_plain_text", ""), detailed_plain_text=msg_data.get("detailed_plain_text", ""),

View File

@@ -38,7 +38,8 @@ class ImageManager:
"""确保图像存储目录存在""" """确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True) os.makedirs(self.IMAGE_DIR, exist_ok=True)
def _ensure_image_collection(self): @staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引""" """确保images集合存在并创建索引"""
if "images" not in db.list_collection_names(): if "images" not in db.list_collection_names():
db.create_collection("images") db.create_collection("images")
@@ -50,7 +51,8 @@ class ImageManager:
db.images.create_index([("url", 1)]) db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)]) db.images.create_index([("path", 1)])
def _ensure_description_collection(self): @staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引""" """确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names(): if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions") db.create_collection("image_descriptions")
@@ -60,7 +62,8 @@ class ImageManager:
# 创建新的复合索引 # 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True) db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: @staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
Args: Args:
@@ -73,7 +76,8 @@ class ImageManager:
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None return result["description"] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None: @staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库 """保存图片描述到数据库
Args: Args:
@@ -226,7 +230,8 @@ class ImageManager:
logger.error(f"获取图片描述失败: {str(e)}") logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]" return "[图片]"
def transform_gif(self, gif_base64: str) -> str: @staticmethod
def transform_gif(gif_base64: str) -> str:
"""将GIF转换为水平拼接的静态图像 """将GIF转换为水平拼接的静态图像
Args: Args:

View File

@@ -13,7 +13,8 @@ class MessageProcessor:
def __init__(self): def __init__(self):
self.storage = MessageStorage() self.storage = MessageStorage()
def _check_ban_words(self, text: str, chat, userinfo) -> bool: @staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词""" """检查消息中是否包含过滤词"""
for word in global_config.ban_words: for word in global_config.ban_words:
if word in text: if word in text:
@@ -24,7 +25,8 @@ class MessageProcessor:
return True return True
return False return False
def _check_ban_regex(self, text: str, chat, userinfo) -> bool: @staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式""" """检查消息是否匹配过滤正则表达式"""
for pattern in global_config.ban_msgs_regex: for pattern in global_config.ban_msgs_regex:
if pattern.search(text): if pattern.search(text):

View File

@@ -37,7 +37,8 @@ class ReasoningChat:
self.mood_manager = MoodManager.get_instance() self.mood_manager = MoodManager.get_instance()
self.mood_manager.start_mood_update() self.mood_manager.start_mood_update()
async def _create_thinking_message(self, message, chat, userinfo, messageinfo): @staticmethod
async def _create_thinking_message(message, chat, userinfo, messageinfo):
"""创建思考消息""" """创建思考消息"""
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.BOT_QQ,
@@ -59,7 +60,8 @@ class ReasoningChat:
return thinking_id return thinking_id
async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending: @staticmethod
async def _send_response_messages(message, chat, response_set: List[str], thinking_id) -> MessageSending:
"""发送回复消息""" """发送回复消息"""
container = message_manager.get_container(chat.stream_id) container = message_manager.get_container(chat.stream_id)
thinking_message = None thinking_message = None
@@ -104,7 +106,8 @@ class ReasoningChat:
return first_bot_msg return first_bot_msg
async def _handle_emoji(self, message, chat, response): @staticmethod
async def _handle_emoji(message, chat, response):
"""处理表情包""" """处理表情包"""
if random() < global_config.emoji_chance: if random() < global_config.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response) emoji_raw = await emoji_manager.get_emoji_for_text(response)
@@ -192,21 +195,21 @@ class ReasoningChat:
if not buffer_result: if not buffer_result:
await willing_manager.bombing_buffer_message_handle(message.message_info.message_id) await willing_manager.bombing_buffer_message_handle(message.message_info.message_id)
willing_manager.delete(message.message_info.message_id) willing_manager.delete(message.message_info.message_id)
F_type = "seglist" f_type = "seglist"
if message.message_segment.type != "seglist": if message.message_segment.type != "seglist":
F_type = message.message_segment.type f_type = message.message_segment.type
else: else:
if ( if (
isinstance(message.message_segment.data, list) isinstance(message.message_segment.data, list)
and all(isinstance(x, Seg) for x in message.message_segment.data) and all(isinstance(x, Seg) for x in message.message_segment.data)
and len(message.message_segment.data) == 1 and len(message.message_segment.data) == 1
): ):
F_type = message.message_segment.data[0].type f_type = message.message_segment.data[0].type
if F_type == "text": if f_type == "text":
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}") logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
elif F_type == "image": elif f_type == "image":
logger.info("触发缓冲,已炸飞表情包/图片") logger.info("触发缓冲,已炸飞表情包/图片")
elif F_type == "seglist": elif f_type == "seglist":
logger.info("触发缓冲,已炸飞消息列") logger.info("触发缓冲,已炸飞消息列")
return return
@@ -291,7 +294,8 @@ class ReasoningChat:
# 意愿管理器注销当前message信息 # 意愿管理器注销当前message信息
willing_manager.delete(message.message_info.message_id) willing_manager.delete(message.message_info.message_id)
def _check_ban_words(self, text: str, chat, userinfo) -> bool: @staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词""" """检查消息中是否包含过滤词"""
for word in global_config.ban_words: for word in global_config.ban_words:
if word in text: if word in text:
@@ -302,7 +306,8 @@ class ReasoningChat:
return True return True
return False return False
def _check_ban_regex(self, text: str, chat, userinfo) -> bool: @staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式""" """检查消息是否匹配过滤正则表达式"""
for pattern in global_config.ban_msgs_regex: for pattern in global_config.ban_msgs_regex:
if pattern.search(text): if pattern.search(text):

View File

@@ -69,8 +69,6 @@ class ResponseGenerator:
return None return None
async def _generate_response_with_model(self, message: MessageThinking, model: LLMRequest, thinking_id: str): async def _generate_response_with_model(self, message: MessageThinking, model: LLMRequest, thinking_id: str):
sender_name = ""
info_catcher = info_catcher_manager.get_info_catcher(thinking_id) info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
@@ -188,7 +186,8 @@ class ResponseGenerator:
logger.debug(f"获取情感标签时出错: {e}") logger.debug(f"获取情感标签时出错: {e}")
return "中立", "平静" # 出错时返回默认值 return "中立", "平静" # 出错时返回默认值
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]: @staticmethod
async def _process_response(content: str) -> Tuple[List[str], List[str]]:
"""处理响应内容,返回处理后的内容和情感标签""" """处理响应内容,返回处理后的内容和情感标签"""
if not content: if not content:
return None, [] return None, []

View File

@@ -101,16 +101,14 @@ class PromptBuilder:
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
) )
if related_memory:
related_memory_info = "" related_memory_info = ""
if related_memory:
for memory in related_memory: for memory in related_memory:
related_memory_info += memory[1] related_memory_info += memory[1]
# memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n" # memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n"
memory_prompt = await global_prompt_manager.format_prompt( memory_prompt = await global_prompt_manager.format_prompt(
"memory_prompt", related_memory_info=related_memory_info "memory_prompt", related_memory_info=related_memory_info
) )
else:
related_memory_info = ""
# print(f"相关记忆:{related_memory_info}") # print(f"相关记忆:{related_memory_info}")
@@ -162,7 +160,6 @@ class PromptBuilder:
# 知识构建 # 知识构建
start_time = time.time() start_time = time.time()
prompt_info = ""
prompt_info = await self.get_prompt_info(message_txt, threshold=0.38) prompt_info = await self.get_prompt_info(message_txt, threshold=0.38)
if prompt_info: if prompt_info:
# prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n""" # prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
@@ -373,8 +370,9 @@ class PromptBuilder:
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}") logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info return related_info
@staticmethod
def get_info_from_db( def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]: ) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []

View File

@@ -40,7 +40,8 @@ class ThinkFlowChat:
self.mood_manager.start_mood_update() self.mood_manager.start_mood_update()
self.tool_user = ToolUser() self.tool_user = ToolUser()
async def _create_thinking_message(self, message, chat, userinfo, messageinfo): @staticmethod
async def _create_thinking_message(message, chat, userinfo, messageinfo):
"""创建思考消息""" """创建思考消息"""
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.BOT_QQ,
@@ -62,7 +63,8 @@ class ThinkFlowChat:
return thinking_id return thinking_id
async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending: @staticmethod
async def _send_response_messages(message, chat, response_set: List[str], thinking_id) -> MessageSending:
"""发送回复消息""" """发送回复消息"""
container = message_manager.get_container(chat.stream_id) container = message_manager.get_container(chat.stream_id)
thinking_message = None thinking_message = None
@@ -108,7 +110,8 @@ class ThinkFlowChat:
message_manager.add_message(message_set) message_manager.add_message(message_set)
return first_bot_msg return first_bot_msg
async def _handle_emoji(self, message, chat, response, send_emoji=""): @staticmethod
async def _handle_emoji(message, chat, response, send_emoji=""):
"""处理表情包""" """处理表情包"""
if send_emoji: if send_emoji:
emoji_raw = await emoji_manager.get_emoji_for_text(send_emoji) emoji_raw = await emoji_manager.get_emoji_for_text(send_emoji)
@@ -204,21 +207,21 @@ class ThinkFlowChat:
if not buffer_result: if not buffer_result:
await willing_manager.bombing_buffer_message_handle(message.message_info.message_id) await willing_manager.bombing_buffer_message_handle(message.message_info.message_id)
willing_manager.delete(message.message_info.message_id) willing_manager.delete(message.message_info.message_id)
F_type = "seglist" f_type = "seglist"
if message.message_segment.type != "seglist": if message.message_segment.type != "seglist":
F_type = message.message_segment.type f_type = message.message_segment.type
else: else:
if ( if (
isinstance(message.message_segment.data, list) isinstance(message.message_segment.data, list)
and all(isinstance(x, Seg) for x in message.message_segment.data) and all(isinstance(x, Seg) for x in message.message_segment.data)
and len(message.message_segment.data) == 1 and len(message.message_segment.data) == 1
): ):
F_type = message.message_segment.data[0].type f_type = message.message_segment.data[0].type
if F_type == "text": if f_type == "text":
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}") logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
elif F_type == "image": elif f_type == "image":
logger.info("触发缓冲,已炸飞表情包/图片") logger.info("触发缓冲,已炸飞表情包/图片")
elif F_type == "seglist": elif f_type == "seglist":
logger.info("触发缓冲,已炸飞消息列") logger.info("触发缓冲,已炸飞消息列")
return return
@@ -461,7 +464,8 @@ class ThinkFlowChat:
# 意愿管理器注销当前message信息 # 意愿管理器注销当前message信息
willing_manager.delete(message.message_info.message_id) willing_manager.delete(message.message_info.message_id)
def _check_ban_words(self, text: str, chat, userinfo) -> bool: @staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词""" """检查消息中是否包含过滤词"""
for word in global_config.ban_words: for word in global_config.ban_words:
if word in text: if word in text:
@@ -472,7 +476,8 @@ class ThinkFlowChat:
return True return True
return False return False
def _check_ban_regex(self, text: str, chat, userinfo) -> bool: @staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式""" """检查消息是否匹配过滤正则表达式"""
for pattern in global_config.ban_msgs_regex: for pattern in global_config.ban_msgs_regex:
if pattern.search(text): if pattern.search(text):

View File

@@ -236,7 +236,8 @@ class ResponseGenerator:
logger.debug(f"获取情感标签时出错: {e}") logger.debug(f"获取情感标签时出错: {e}")
return "中立", "平静" # 出错时返回默认值 return "中立", "平静" # 出错时返回默认值
async def _process_response(self, content: str) -> List[str]: @staticmethod
async def _process_response(content: str) -> List[str]:
"""处理响应内容,返回处理后的内容和情感标签""" """处理响应内容,返回处理后的内容和情感标签"""
if not content: if not content:
return None return None

View File

@@ -64,8 +64,9 @@ class PromptBuilder:
self.prompt_built = "" self.prompt_built = ""
self.activate_messages = "" self.activate_messages = ""
@staticmethod
async def _build_prompt( async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]: ) -> tuple[str, str]:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
@@ -168,8 +169,9 @@ class PromptBuilder:
return prompt return prompt
@staticmethod
async def _build_prompt_simple( async def _build_prompt_simple(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]: ) -> tuple[str, str]:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
@@ -237,8 +239,8 @@ class PromptBuilder:
logger.info(f"生成回复的prompt: {prompt}") logger.info(f"生成回复的prompt: {prompt}")
return prompt return prompt
@staticmethod
async def _build_prompt_check_response( async def _build_prompt_check_response(
self,
chat_stream, chat_stream,
message_txt: str, message_txt: str,
sender_name: str = "某人", sender_name: str = "某人",

View File

@@ -4,6 +4,8 @@ import math
import random import random
import time import time
import re import re
from itertools import combinations
import jieba import jieba
import networkx as nx import networkx as nx
import numpy as np import numpy as np
@@ -250,7 +252,8 @@ class Hippocampus:
"""获取记忆图中所有节点的名字列表""" """获取记忆图中所有节点的名字列表"""
return list(self.memory_graph.G.nodes()) return list(self.memory_graph.G.nodes())
def calculate_node_hash(self, concept, memory_items) -> int: @staticmethod
def calculate_node_hash(concept, memory_items) -> int:
"""计算节点的特征值""" """计算节点的特征值"""
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
@@ -258,12 +261,14 @@ class Hippocampus:
content = f"{concept}:{'|'.join(sorted_items)}" content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content) return hash(content)
def calculate_edge_hash(self, source, target) -> int: @staticmethod
def calculate_edge_hash(source, target) -> int:
"""计算边的特征值""" """计算边的特征值"""
nodes = sorted([source, target]) nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}") return hash(f"{nodes[0]}:{nodes[1]}")
def find_topic_llm(self, text, topic_num): @staticmethod
def find_topic_llm(text, topic_num):
prompt = ( prompt = (
f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
@@ -271,14 +276,16 @@ class Hippocampus:
) )
return prompt return prompt
def topic_what(self, text, topic, time_info): @staticmethod
def topic_what(text, topic, time_info):
prompt = ( prompt = (
f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,' f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
f"可以包含时间和人物,以及具体的观点。只输出这句话就好" f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
) )
return prompt return prompt
def calculate_topic_num(self, text, compress_rate): @staticmethod
def calculate_topic_num(text, compress_rate):
"""计算文本的话题数量""" """计算文本的话题数量"""
information_content = calculate_information_content(text) information_content = calculate_information_content(text)
topic_by_length = text.count("\n") * compress_rate topic_by_length = text.count("\n") * compress_rate
@@ -693,7 +700,8 @@ class EntorhinalCortex:
return chat_samples return chat_samples
def random_get_msg_snippet(self, target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list: @staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
"""从数据库中随机获取指定时间戳附近的消息片段""" """从数据库中随机获取指定时间戳附近的消息片段"""
try_count = 0 try_count = 0
while try_count < 3: while try_count < 3:
@@ -958,7 +966,8 @@ class Hippocampus:
"""获取记忆图中所有节点的名字列表""" """获取记忆图中所有节点的名字列表"""
return list(self.memory_graph.G.nodes()) return list(self.memory_graph.G.nodes())
def calculate_node_hash(self, concept, memory_items) -> int: @staticmethod
def calculate_node_hash(concept, memory_items) -> int:
"""计算节点的特征值""" """计算节点的特征值"""
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
@@ -966,12 +975,14 @@ class Hippocampus:
content = f"{concept}:{'|'.join(sorted_items)}" content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content) return hash(content)
def calculate_edge_hash(self, source, target) -> int: @staticmethod
def calculate_edge_hash(source, target) -> int:
"""计算边的特征值""" """计算边的特征值"""
nodes = sorted([source, target]) nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}") return hash(f"{nodes[0]}:{nodes[1]}")
def find_topic_llm(self, text, topic_num): @staticmethod
def find_topic_llm(text, topic_num):
prompt = ( prompt = (
f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来," f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。" f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
@@ -979,14 +990,16 @@ class Hippocampus:
) )
return prompt return prompt
def topic_what(self, text, topic, time_info): @staticmethod
def topic_what(text, topic, time_info):
prompt = ( prompt = (
f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,' f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
f"可以包含时间和人物,以及具体的观点。只输出这句话就好" f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
) )
return prompt return prompt
def calculate_topic_num(self, text, compress_rate): @staticmethod
def calculate_topic_num(text, compress_rate):
"""计算文本的话题数量""" """计算文本的话题数量"""
information_content = calculate_information_content(text) information_content = calculate_information_content(text)
topic_by_length = text.count("\n") * compress_rate topic_by_length = text.count("\n") * compress_rate
@@ -1542,11 +1555,10 @@ class ParahippocampalGyrus:
last_modified=current_time, last_modified=current_time,
) )
for i in range(len(all_topics)): for topic1, topic2 in combinations(all_topics, 2):
for j in range(i + 1, len(all_topics)): logger.debug(f"连接同批次节点: {topic1}{topic2}")
logger.debug(f"连接同批次节点: {all_topics[i]}{all_topics[j]}") all_added_edges.append(f"{topic1}-{topic2}")
all_added_edges.append(f"{all_topics[i]}-{all_topics[j]}") self.memory_graph.connect_dot(topic1, topic2)
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
logger.success(f"更新记忆: {', '.join(all_added_nodes)}") logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
logger.debug(f"强化连接: {', '.join(all_added_edges)}") logger.debug(f"强化连接: {', '.join(all_added_edges)}")

View File

@@ -1,95 +0,0 @@
import unittest
import asyncio
import aiohttp
from api import BaseMessageAPI
from message_base import (
BaseMessageInfo,
UserInfo,
GroupInfo,
FormatInfo,
MessageBase,
Seg,
)
send_url = "http://localhost"
receive_port = 18002 # 接收消息的端口
send_port = 18000 # 发送消息的端口
test_endpoint = "/api/message"
# 创建并启动API实例
api = BaseMessageAPI(host="0.0.0.0", port=receive_port)
class TestLiveAPI(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
"""测试前的设置"""
self.received_messages = []
async def message_handler(message):
self.received_messages.append(message)
self.api = api
self.api.register_message_handler(message_handler)
self.server_task = asyncio.create_task(self.api.run())
try:
await asyncio.wait_for(asyncio.sleep(1), timeout=5)
except asyncio.TimeoutError:
self.skipTest("服务器启动超时")
async def asyncTearDown(self):
"""测试后的清理"""
if hasattr(self, "server_task"):
await self.api.stop() # 先调用正常的停止流程
if not self.server_task.done():
self.server_task.cancel()
try:
await asyncio.wait_for(self.server_task, timeout=100)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
async def test_send_and_receive_message(self):
"""测试向运行中的API发送消息并接收响应"""
# 准备测试消息
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
format_info = FormatInfo(content_format=["text"], accept_format=["text", "emoji", "reply"])
template_info = None
message_info = BaseMessageInfo(
platform="qq",
message_id=12345678,
time=12345678,
group_info=group_info,
user_info=user_info,
format_info=format_info,
template_info=template_info,
)
message = MessageBase(
message_info=message_info,
raw_message="测试消息",
message_segment=Seg(type="text", data="测试消息"),
)
test_message = message.to_dict()
# 发送测试消息到发送端口
async with aiohttp.ClientSession() as session:
async with session.post(
f"{send_url}:{send_port}{test_endpoint}",
json=test_message,
) as response:
response_data = await response.json()
self.assertEqual(response.status, 200)
self.assertEqual(response_data["status"], "success")
try:
async with asyncio.timeout(5): # 设置5秒超时
while len(self.received_messages) == 0:
await asyncio.sleep(0.1)
received_message = self.received_messages[0]
print(received_message)
self.received_messages.clear()
except asyncio.TimeoutError:
self.fail("等待接收消息超时")
if __name__ == "__main__":
unittest.main()

View File

@@ -72,7 +72,8 @@ class PersonInfoManager:
self.person_name_list[doc["person_id"]] = doc["person_name"] self.person_name_list[doc["person_id"]] = doc["person_name"]
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
def get_person_id(self, platform: str, user_id: int): @staticmethod
def get_person_id(platform: str, user_id: int):
"""获取唯一id""" """获取唯一id"""
# 如果platform中存在-,就截取-后面的部分 # 如果platform中存在-,就截取-后面的部分
if "-" in platform: if "-" in platform:
@@ -91,7 +92,8 @@ class PersonInfoManager:
else: else:
return False return False
async def create_person_info(self, person_id: str, data: dict = None): @staticmethod
async def create_person_info(person_id: str, data: dict = None):
"""创建一个项""" """创建一个项"""
if not person_id: if not person_id:
logger.debug("创建失败personid不存在") logger.debug("创建失败personid不存在")
@@ -131,7 +133,8 @@ class PersonInfoManager:
else: else:
return False return False
def _extract_json_from_text(self, text: str) -> dict: @staticmethod
def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法""" """从文本中提取JSON数据的高容错方法"""
try: try:
# 尝试直接解析 # 尝试直接解析
@@ -225,7 +228,8 @@ class PersonInfoManager:
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称") logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称")
return None return None
async def del_one_document(self, person_id: str): @staticmethod
async def del_one_document(person_id: str):
"""删除指定 person_id 的文档""" """删除指定 person_id 的文档"""
if not person_id: if not person_id:
logger.debug("删除失败person_id 不能为空") logger.debug("删除失败person_id 不能为空")
@@ -237,7 +241,8 @@ class PersonInfoManager:
else: else:
logger.debug(f"删除失败:未找到 person_id={person_id}") logger.debug(f"删除失败:未找到 person_id={person_id}")
async def get_value(self, person_id: str, field_name: str): @staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值""" """获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id: if not person_id:
logger.debug("get_value获取失败person_id不能为空") logger.debug("get_value获取失败person_id不能为空")
@@ -256,7 +261,8 @@ class PersonInfoManager:
logger.trace(f"获取{person_id}{field_name}失败,已返回默认值{default_value}") logger.trace(f"获取{person_id}{field_name}失败,已返回默认值{default_value}")
return default_value return default_value
async def get_values(self, person_id: str, field_names: list) -> dict: @staticmethod
async def get_values(person_id: str, field_names: list) -> dict:
"""获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值""" """获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id: if not person_id:
logger.debug("get_values获取失败person_id不能为空") logger.debug("get_values获取失败person_id不能为空")
@@ -281,7 +287,8 @@ class PersonInfoManager:
return result return result
async def del_all_undefined_field(self): @staticmethod
async def del_all_undefined_field():
"""删除所有项里的未定义字段""" """删除所有项里的未定义字段"""
# 获取所有已定义的字段名 # 获取所有已定义的字段名
defined_fields = set(person_info_default.keys()) defined_fields = set(person_info_default.keys())
@@ -307,8 +314,8 @@ class PersonInfoManager:
logger.error(f"清理未定义字段时出错: {e}") logger.error(f"清理未定义字段时出错: {e}")
return return
@staticmethod
async def get_specific_value_list( async def get_specific_value_list(
self,
field_name: str, field_name: str,
way: Callable[[Any], bool], # 接受任意类型值 way: Callable[[Any], bool], # 接受任意类型值
) -> Dict[str, Any]: ) -> Dict[str, Any]:

View File

@@ -62,7 +62,7 @@ class RelationshipManager:
def mood_feedback(self, value): def mood_feedback(self, value):
"""情绪反馈""" """情绪反馈"""
mood_manager = self.mood_manager mood_manager = self.mood_manager
mood_gain = (mood_manager.get_current_mood().valence) ** 2 * math.copysign( mood_gain = mood_manager.get_current_mood().valence ** 2 * math.copysign(
1, value * mood_manager.get_current_mood().valence 1, value * mood_manager.get_current_mood().valence
) )
value += value * mood_gain value += value * mood_gain
@@ -77,24 +77,27 @@ class RelationshipManager:
else: else:
return mood_value / coefficient return mood_value / coefficient
async def is_known_some_one(self, platform, user_id): @staticmethod
async def is_known_some_one(platform, user_id):
"""判断是否认识某人""" """判断是否认识某人"""
is_known = person_info_manager.is_person_known(platform, user_id) is_known = person_info_manager.is_person_known(platform, user_id)
return is_known return is_known
async def is_qved_name(self, platform, user_id): @staticmethod
async def is_qved_name(platform, user_id):
"""判断是否认识某人""" """判断是否认识某人"""
person_id = person_info_manager.get_person_id(platform, user_id) person_id = person_info_manager.get_person_id(platform, user_id)
is_qved = await person_info_manager.has_one_field(person_id, "person_name") is_qved = await person_info_manager.has_one_field(person_id, "person_name")
old_name = await person_info_manager.get_value(person_id, "person_name") old_name = await person_info_manager.get_value(person_id, "person_name")
print(f"old_name: {old_name}") print(f"old_name: {old_name}")
print(f"is_qved: {is_qved}") print(f"is_qved: {is_qved}")
if is_qved and old_name != None: if is_qved and old_name is not None:
return True return True
else: else:
return False return False
async def first_knowing_some_one(self, platform, user_id, user_nickname, user_cardname, user_avatar): @staticmethod
async def first_knowing_some_one(platform, user_id, user_nickname, user_cardname, user_avatar):
"""判断是否认识某人""" """判断是否认识某人"""
person_id = person_info_manager.get_person_id(platform, user_id) person_id = person_info_manager.get_person_id(platform, user_id)
await person_info_manager.update_one_field(person_id, "nickname", user_nickname) await person_info_manager.update_one_field(person_id, "nickname", user_nickname)
@@ -102,7 +105,8 @@ class RelationshipManager:
# await person_info_manager.update_one_field(person_id, "user_avatar", user_avatar) # await person_info_manager.update_one_field(person_id, "user_avatar", user_avatar)
await person_info_manager.qv_person_name(person_id, user_nickname, user_cardname, user_avatar) await person_info_manager.qv_person_name(person_id, user_nickname, user_cardname, user_avatar)
async def convert_all_person_sign_to_person_name(self, input_text: str): @staticmethod
async def convert_all_person_sign_to_person_name(input_text: str):
"""将所有人的<platform:user_id:nickname:cardname>格式转换为person_name""" """将所有人的<platform:user_id:nickname:cardname>格式转换为person_name"""
try: try:
# 使用正则表达式匹配<platform:user_id:nickname:cardname>格式 # 使用正则表达式匹配<platform:user_id:nickname:cardname>格式
@@ -119,7 +123,7 @@ class RelationshipManager:
person_name = nickname.strip() if nickname.strip() else cardname.strip() person_name = nickname.strip() if nickname.strip() else cardname.strip()
if person_id in all_person: if person_id in all_person:
if all_person[person_id] != None: if all_person[person_id] is not None:
person_name = all_person[person_id] person_name = all_person[person_id]
print(f"将<{platform}:{user_id}:{nickname}:{cardname}>替换为{person_name}") print(f"将<{platform}:{user_id}:{nickname}:{cardname}>替换为{person_name}")
@@ -326,7 +330,8 @@ class RelationshipManager:
f"回复态度为{relation_prompt2_list[level_num]},关系等级为{level_num}" f"回复态度为{relation_prompt2_list[level_num]},关系等级为{level_num}"
) )
def calculate_level_num(self, relationship_value) -> int: @staticmethod
def calculate_level_num(relationship_value) -> int:
"""关系等级计算""" """关系等级计算"""
if -1000 <= relationship_value < -227: if -1000 <= relationship_value < -227:
level_num = 0 level_num = 0
@@ -344,7 +349,8 @@ class RelationshipManager:
level_num = 5 if relationship_value > 1000 else 0 level_num = 5 if relationship_value > 1000 else 0
return level_num return level_num
def ensure_float(self, value, person_id): @staticmethod
def ensure_float(value, person_id):
"""确保返回浮点数转换失败返回0.0""" """确保返回浮点数转换失败返回0.0"""
if isinstance(value, float): if isinstance(value, float):
return value return value

View File

@@ -100,7 +100,8 @@ class InfoCatcher:
self.trigger_response_message, first_bot_msg self.trigger_response_message, first_bot_msg
) )
def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message): @staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try: try:
# 从数据库中获取消息的时间戳 # 从数据库中获取消息的时间戳
time_start = message_start.message_info.time time_start = message_start.message_info.time
@@ -155,7 +156,8 @@ class InfoCatcher:
return result return result
def message_to_dict(self, message): @staticmethod
def message_to_dict(message):
if not message: if not message:
return None return None
if isinstance(message, dict): if isinstance(message, dict):

View File

@@ -235,6 +235,7 @@ class ScheduleGenerator:
Args: Args:
num (int): 需要获取的日程数量默认为1 num (int): 需要获取的日程数量默认为1
time_info (bool): 是否包含时间信息默认为False
Returns: Returns:
list: 最新加入的日程列表 list: 最新加入的日程列表
@@ -267,7 +268,8 @@ class ScheduleGenerator:
db.schedule.update_one({"date": date_str}, {"$set": schedule_data}, upsert=True) db.schedule.update_one({"date": date_str}, {"$set": schedule_data}, upsert=True)
logger.debug(f"已保存{date_str}的日程到数据库") logger.debug(f"已保存{date_str}的日程到数据库")
def load_schedule_from_db(self, date: datetime.datetime): @staticmethod
def load_schedule_from_db(date: datetime.datetime):
"""从数据库加载日程,同时加载 today_done_list""" """从数据库加载日程,同时加载 today_done_list"""
date_str = date.strftime("%Y-%m-%d") date_str = date.strftime("%Y-%m-%d")
existing_schedule = db.schedule.find_one({"date": date_str}) existing_schedule = db.schedule.find_one({"date": date_str})

View File

@@ -10,7 +10,8 @@ logger = get_module_logger("message_storage")
class MessageStorage: class MessageStorage:
async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: @staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
# 莫越权 救世啊 # 莫越权 救世啊
@@ -43,7 +44,8 @@ class MessageStorage:
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")
async def store_recalled_message(self, message_id: str, time: str, chat_stream: ChatStream) -> None: @staticmethod
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
"""存储撤回消息到数据库""" """存储撤回消息到数据库"""
if "recalled_messages" not in db.list_collection_names(): if "recalled_messages" not in db.list_collection_names():
db.create_collection("recalled_messages") db.create_collection("recalled_messages")
@@ -58,7 +60,8 @@ class MessageStorage:
except Exception: except Exception:
logger.exception("存储撤回消息失败") logger.exception("存储撤回消息失败")
async def remove_recalled_message(self, time: str) -> None: @staticmethod
async def remove_recalled_message(time: str) -> None:
"""删除撤回消息""" """删除撤回消息"""
try: try:
db.recalled_messages.delete_many({"time": {"$lt": time - 300}}) db.recalled_messages.delete_many({"time": {"$lt": time - 300}})

View File

@@ -28,7 +28,7 @@ class TopicIdentifier:
消息内容:{text}""" 消息内容:{text}"""
# 使用 LLM_request 类进行请求 # 使用 LLMRequest 类进行请求
try: try:
topic, _, _ = await self.llm_topic_judge.generate_response(prompt) topic, _, _ = await self.llm_topic_judge.generate_response(prompt)
except Exception as e: except Exception as e:

View File

@@ -24,7 +24,8 @@ class LLMStatistics:
self._init_database() self._init_database()
self.name_dict: Dict[List] = {} self.name_dict: Dict[List] = {}
def _init_database(self): @staticmethod
def _init_database():
"""初始化数据库集合""" """初始化数据库集合"""
if "online_time" not in db.list_collection_names(): if "online_time" not in db.list_collection_names():
db.create_collection("online_time") db.create_collection("online_time")
@@ -51,7 +52,8 @@ class LLMStatistics:
if self.console_thread: if self.console_thread:
self.console_thread.join() self.console_thread.join()
def _record_online_time(self): @staticmethod
def _record_online_time():
"""记录在线时间""" """记录在线时间"""
current_time = datetime.now() current_time = datetime.now()
# 检查5分钟内是否已有记录 # 检查5分钟内是否已有记录
@@ -187,7 +189,7 @@ class LLMStatistics:
# 按模型统计 # 按模型统计
output.append("按模型统计:") output.append("按模型统计:")
output.append(("模型名称 调用次数 Token总量 累计花费")) output.append("模型名称 调用次数 Token总量 累计花费")
for model_name, count in sorted(stats["requests_by_model"].items()): for model_name, count in sorted(stats["requests_by_model"].items()):
tokens = stats["tokens_by_model"][model_name] tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name] cost = stats["costs_by_model"][model_name]
@@ -198,7 +200,7 @@ class LLMStatistics:
# 按请求类型统计 # 按请求类型统计
output.append("按请求类型统计:") output.append("按请求类型统计:")
output.append(("模型名称 调用次数 Token总量 累计花费")) output.append("模型名称 调用次数 Token总量 累计花费")
for req_type, count in sorted(stats["requests_by_type"].items()): for req_type, count in sorted(stats["requests_by_type"].items()):
tokens = stats["tokens_by_type"][req_type] tokens = stats["tokens_by_type"][req_type]
cost = stats["costs_by_type"][req_type] cost = stats["costs_by_type"][req_type]
@@ -209,7 +211,7 @@ class LLMStatistics:
# 修正用户统计列宽 # 修正用户统计列宽
output.append("按用户统计:") output.append("按用户统计:")
output.append(("用户ID 调用次数 Token总量 累计花费")) output.append("用户ID 调用次数 Token总量 累计花费")
for user_id, count in sorted(stats["requests_by_user"].items()): for user_id, count in sorted(stats["requests_by_user"].items()):
tokens = stats["tokens_by_user"][user_id] tokens = stats["tokens_by_user"][user_id]
cost = stats["costs_by_user"][user_id] cost = stats["costs_by_user"][user_id]
@@ -225,7 +227,7 @@ class LLMStatistics:
# 添加聊天统计 # 添加聊天统计
output.append("群组统计:") output.append("群组统计:")
output.append(("群组名称 消息数量")) output.append("群组名称 消息数量")
for group_id, count in sorted(stats["messages_by_chat"].items()): for group_id, count in sorted(stats["messages_by_chat"].items()):
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}") output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
@@ -246,7 +248,7 @@ class LLMStatistics:
# 按模型统计 # 按模型统计
output.append("按模型统计:") output.append("按模型统计:")
output.append(("模型名称 调用次数 Token总量 累计花费")) output.append("模型名称 调用次数 Token总量 累计花费")
for model_name, count in sorted(stats["requests_by_model"].items()): for model_name, count in sorted(stats["requests_by_model"].items()):
tokens = stats["tokens_by_model"][model_name] tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name] cost = stats["costs_by_model"][model_name]
@@ -284,7 +286,7 @@ class LLMStatistics:
# 添加聊天统计 # 添加聊天统计
output.append("群组统计:") output.append("群组统计:")
output.append(("群组名称 消息数量")) output.append("群组名称 消息数量")
for group_id, count in sorted(stats["messages_by_chat"].items()): for group_id, count in sorted(stats["messages_by_chat"].items()):
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}") output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")

View File

@@ -90,7 +90,8 @@ class Timer:
self.auto_unit = auto_unit self.auto_unit = auto_unit
self.start = None self.start = None
def _validate_types(self, name, storage): @staticmethod
def _validate_types(name, storage):
"""类型检查""" """类型检查"""
if name is not None and not isinstance(name, str): if name is not None and not isinstance(name, str):
raise TimerTypeError("name", "Optional[str]", type(name)) raise TimerTypeError("name", "Optional[str]", type(name))

View File

@@ -77,7 +77,8 @@ class ChineseTypoGenerator:
return normalized_freq return normalized_freq
def _create_pinyin_dict(self): @staticmethod
def _create_pinyin_dict():
""" """
创建拼音到汉字的映射字典 创建拼音到汉字的映射字典
""" """
@@ -95,7 +96,8 @@ class ChineseTypoGenerator:
return pinyin_dict return pinyin_dict
def _is_chinese_char(self, char): @staticmethod
def _is_chinese_char(char):
""" """
判断是否为汉字 判断是否为汉字
""" """
@@ -124,7 +126,8 @@ class ChineseTypoGenerator:
return result return result
def _get_similar_tone_pinyin(self, py): @staticmethod
def _get_similar_tone_pinyin(py):
""" """
获取相似声调的拼音 获取相似声调的拼音
""" """
@@ -211,13 +214,15 @@ class ChineseTypoGenerator:
# 返回概率最高的几个字 # 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]] return [char for char, _ in candidates_with_prob[:num_candidates]]
def _get_word_pinyin(self, word): @staticmethod
def _get_word_pinyin(word):
""" """
获取词语的拼音列表 获取词语的拼音列表
""" """
return [py[0] for py in pinyin(word, style=Style.TONE3)] return [py[0] for py in pinyin(word, style=Style.TONE3)]
def _segment_sentence(self, sentence): @staticmethod
def _segment_sentence(sentence):
""" """
使用jieba分词返回词语列表 使用jieba分词返回词语列表
""" """
@@ -392,7 +397,8 @@ class ChineseTypoGenerator:
return "".join(result), correction_suggestion return "".join(result), correction_suggestion
def format_typo_info(self, typo_info): @staticmethod
def format_typo_info(typo_info):
""" """
格式化错别字信息 格式化错别字信息

View File

@@ -13,7 +13,7 @@ llmcheck 模式:
import time import time
from loguru import logger from loguru import logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLMRequest
from ...config.config import global_config from ...config.config import global_config
# from ..chat.chat_stream import ChatStream # from ..chat.chat_stream import ChatStream
@@ -61,7 +61,7 @@ def llmcheck_decorator(trigger_condition_func):
class LlmcheckWillingManager(MxpWillingManager): class LlmcheckWillingManager(MxpWillingManager):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.3) self.model_v3 = LLMRequest(model=global_config.llm_normal, temperature=0.3)
async def get_llmreply_probability(self, message_id: str): async def get_llmreply_probability(self, message_id: str):
message_info = self.ongoing_messages[message_id] message_info = self.ongoing_messages[message_id]

View File

@@ -240,7 +240,8 @@ class MxpWillingManager(BaseWillingManager):
-2 * self.basic_maximum_willing * self.fatigue_coefficient -2 * self.basic_maximum_willing * self.fatigue_coefficient
) )
def _willing_to_probability(self, willing: float) -> float: @staticmethod
def _willing_to_probability(willing: float) -> float:
"""意愿值转化为概率""" """意愿值转化为概率"""
willing = max(0, willing) willing = max(0, willing)
if willing < 2: if willing < 2:
@@ -285,7 +286,8 @@ class MxpWillingManager(BaseWillingManager):
if self.is_debug: if self.is_debug:
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}") self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
def _get_relationship_level_num(self, relationship_value) -> int: @staticmethod
def _get_relationship_level_num(relationship_value) -> int:
"""关系等级计算""" """关系等级计算"""
if -1000 <= relationship_value < -227: if -1000 <= relationship_value < -227:
level_num = 0 level_num = 0

View File

@@ -35,12 +35,14 @@ class KnowledgeLibrary:
"""确保必要的目录存在""" """确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True) os.makedirs(self.raw_info_dir, exist_ok=True)
def read_file(self, file_path: str) -> str: @staticmethod
def read_file(file_path: str) -> str:
"""读取文件内容""" """读取文件内容"""
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
return f.read() return f.read()
def split_content(self, content: str, max_length: int = 512) -> list: @staticmethod
def split_content(content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,按空行分割 """将内容分割成适当大小的块,按空行分割
Args: Args:
@@ -146,7 +148,8 @@ class KnowledgeLibrary:
return result return result
def _update_stats(self, total_stats, result, filename): @staticmethod
def _update_stats(total_stats, result, filename):
"""更新总体统计信息""" """更新总体统计信息"""
if result["status"] == "success": if result["status"] == "success":
total_stats["processed_files"] += 1 total_stats["processed_files"] += 1
@@ -181,7 +184,8 @@ class KnowledgeLibrary:
for filename in stats["skipped_files"]: for filename in stats["skipped_files"]:
self.console.print(f"[yellow]- {filename}[/yellow]") self.console.print(f"[yellow]- {filename}[/yellow]")
def calculate_file_hash(self, file_path): @staticmethod
def calculate_file_hash(file_path):
"""计算文件的MD5哈希值""" """计算文件的MD5哈希值"""
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
with open(file_path, "rb") as f: with open(file_path, "rb") as f: