fix: Update type hints to use newer Python syntax

- Replace Dict, List, Optional with dict, list,  < /dev/null |  None syntax
- Fix abstract method implementation in message.py
- Improve type annotations and function return types
- Remove unreachable code in get_current_task_tool.py
- Refactor HTML elements to use style attributes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
晴猫
2025-05-01 06:55:05 +09:00
parent 3d001da30e
commit 263e8d196a
29 changed files with 125 additions and 143 deletions

View File

@@ -1,6 +1,6 @@
# 麦麦MaiCore-MaiMBot (编辑中) # 麦麦MaiCore-MaiMBot (编辑中)
<br /> <br />
<div align="center"> <div style="text-align: center">
![Python Version](https://img.shields.io/badge/Python-3.10+-blue) ![Python Version](https://img.shields.io/badge/Python-3.10+-blue)
![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议) ![License](https://img.shields.io/github/license/SengokuCola/MaiMBot?label=协议)
@@ -12,7 +12,7 @@
</div> </div>
<p align="center"> <p style="text-align: center">
<a href="https://github.com/MaiM-with-u/MaiBot/"> <a href="https://github.com/MaiM-with-u/MaiBot/">
<img src="depends-data/maimai.png" alt="Logo" style="width: 200px"> <img src="depends-data/maimai.png" alt="Logo" style="width: 200px">
</a> </a>
@@ -21,8 +21,8 @@
画师略nd 画师略nd
</a> </a>
<h3 align="center">MaiBot(麦麦)</h3> <h3 style="text-align: center">MaiBot(麦麦)</h3>
<p align="center"> <p style="text-align: center">
一款专注于<strong> 群组聊天 </strong>的赛博网友 一款专注于<strong> 群组聊天 </strong>的赛博网友
<br /> <br />
<a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a> <a href="https://docs.mai-mai.org"><strong>探索本项目的文档 »</strong></a>
@@ -50,7 +50,7 @@
- 🧠 **持久记忆系统**基于MongoDB的长期记忆存储 - 🧠 **持久记忆系统**基于MongoDB的长期记忆存储
- 🔄 **动态人格系统**:自适应的性格特征 - 🔄 **动态人格系统**:自适应的性格特征
<div align="center"> <div style="text-align: center">
<a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank"> <a href="https://www.bilibili.com/video/BV1amAneGE3P" target="_blank">
<img src="depends-data/video.png" style="max-width: 200px" alt="麦麦演示视频"> <img src="depends-data/video.png" style="max-width: 200px" alt="麦麦演示视频">
<br> <br>
@@ -97,9 +97,9 @@
- [四群](https://qm.qq.com/q/wlH5eT8OmQ) 729957033【已满】 - [四群](https://qm.qq.com/q/wlH5eT8OmQ) 729957033【已满】
<div align="left">
<h2>📚 文档 </h2> ## 📚 文档
</div>
### (部分内容可能过时,请注意版本对应) ### (部分内容可能过时,请注意版本对应)

View File

@@ -8,7 +8,6 @@ import sys
import os import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from typing import Dict, List
from src.plugins.knowledge.src.lpmmconfig import PG_NAMESPACE, global_config from src.plugins.knowledge.src.lpmmconfig import PG_NAMESPACE, global_config
from src.plugins.knowledge.src.embedding_store import EmbeddingManager from src.plugins.knowledge.src.embedding_store import EmbeddingManager
@@ -26,8 +25,8 @@ logger = get_module_logger("LPMM知识库-OpenIE导入")
def hash_deduplicate( def hash_deduplicate(
raw_paragraphs: Dict[str, str], raw_paragraphs: dict[str, str],
triple_list_data: Dict[str, List[List[str]]], triple_list_data: dict[str, list[list[str]]],
stored_pg_hashes: set, stored_pg_hashes: set,
stored_paragraph_hashes: set, stored_paragraph_hashes: set,
): ):
@@ -126,7 +125,7 @@ def main():
) )
# 初始化Embedding库 # 初始化Embedding库
embed_manager = embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
logger.info("正在从文件加载Embedding库") logger.info("正在从文件加载Embedding库")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional from typing import List, Optional
import strawberry import strawberry
# from packaging.version import Version, InvalidVersion # from packaging.version import Version, InvalidVersion
@@ -128,22 +128,22 @@ class BotConfig:
enable_pfc_chatting: bool # 是否启用PFC聊天 enable_pfc_chatting: bool # 是否启用PFC聊天
# 模型配置 # 模型配置
llm_reasoning: Dict[str, str] # LLM推理 llm_reasoning: dict[str, str] # LLM推理
# llm_reasoning_minor: Dict[str, str] # llm_reasoning_minor: dict[str, str]
llm_normal: Dict[str, str] # LLM普通 llm_normal: dict[str, str] # LLM普通
llm_topic_judge: Dict[str, str] # LLM话题判断 llm_topic_judge: dict[str, str] # LLM话题判断
llm_summary: Dict[str, str] # LLM话题总结 llm_summary: dict[str, str] # LLM话题总结
llm_emotion_judge: Dict[str, str] # LLM情感判断 llm_emotion_judge: dict[str, str] # LLM情感判断
embedding: Dict[str, str] # 嵌入 embedding: dict[str, str] # 嵌入
vlm: Dict[str, str] # VLM vlm: dict[str, str] # VLM
moderation: Dict[str, str] # 审核 moderation: dict[str, str] # 审核
# 实验性 # 实验性
llm_observation: Dict[str, str] # LLM观察 llm_observation: dict[str, str] # LLM观察
llm_sub_heartflow: Dict[str, str] # LLM子心流 llm_sub_heartflow: dict[str, str] # LLM子心流
llm_heartflow: Dict[str, str] # LLM心流 llm_heartflow: dict[str, str] # LLM心流
api_urls: Dict[str, str] # API URLs api_urls: dict[str, str] # API URLs
@strawberry.type @strawberry.type

View File

@@ -1,5 +1,5 @@
from loguru import logger from loguru import logger
from typing import Dict, Optional, Union, List, Tuple from typing import Optional, Union, List, Tuple
import sys import sys
import os import os
from types import ModuleType from types import ModuleType
@@ -75,8 +75,8 @@ if default_handler_id is not None:
LoguruLogger = logger.__class__ LoguruLogger = logger.__class__
# 全局注册表记录模块与处理器ID的映射 # 全局注册表记录模块与处理器ID的映射
_handler_registry: Dict[str, List[int]] = {} _handler_registry: dict[str, List[int]] = {}
_custom_style_handlers: Dict[Tuple[str, str], List[int]] = {} # 记录自定义样式处理器ID _custom_style_handlers: dict[Tuple[str, str], List[int]] = {} # 记录自定义样式处理器ID
# 获取日志存储根地址 # 获取日志存储根地址
current_file_path = Path(__file__).resolve() current_file_path = Path(__file__).resolve()

View File

@@ -7,11 +7,11 @@ logger = get_module_logger(__name__)
def find_messages( def find_messages(
message_filter: Dict[str, Any], message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None, sort: Optional[List[tuple[str, int]]] = None,
limit: int = 0, limit: int = 0,
limit_mode: str = "latest", limit_mode: str = "latest",
) -> List[Dict[str, Any]]: ) -> List[dict[str, Any]]:
""" """
根据提供的过滤器、排序和限制条件查找消息。 根据提供的过滤器、排序和限制条件查找消息。
@@ -26,7 +26,7 @@ def find_messages(
""" """
try: try:
query = db.messages.find(message_filter) query = db.messages.find(message_filter)
results: List[Dict[str, Any]] = [] results: List[dict[str, Any]] = []
if limit > 0: if limit > 0:
if limit_mode == "earliest": if limit_mode == "earliest":
@@ -56,7 +56,7 @@ def find_messages(
return [] return []
def count_messages(message_filter: Dict[str, Any]) -> int: def count_messages(message_filter: dict[str, Any]) -> int:
""" """
根据提供的过滤器计算消息数量。 根据提供的过滤器计算消息数量。

View File

@@ -271,8 +271,8 @@ class BotConfig:
enable_pfc_chatting: bool = False # 是否启用PFC聊天 enable_pfc_chatting: bool = False # 是否启用PFC聊天
# 模型配置 # 模型配置
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {}) llm_reasoning: dict[str, str] = field(default_factory=lambda: {})
# llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {}) # llm_reasoning_minor: dict[str, str] = field(default_factory=lambda: {})
llm_normal: Dict[str, str] = field(default_factory=lambda: {}) llm_normal: Dict[str, str] = field(default_factory=lambda: {})
llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {}) llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {})
llm_summary: Dict[str, str] = field(default_factory=lambda: {}) llm_summary: Dict[str, str] = field(default_factory=lambda: {})

View File

@@ -3,7 +3,7 @@ from src.config.config import global_config
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.plugins.moods.moods import MoodManager from src.plugins.moods.moods import MoodManager
from typing import Dict, Any from typing import Any
logger = get_logger("change_mood_tool") logger = get_logger("change_mood_tool")
@@ -22,7 +22,7 @@ class ChangeMoodTool(BaseTool):
"required": ["text", "response_set"], "required": ["text", "response_set"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict[str, Any]:
"""执行心情改变 """执行心情改变
Args: Args:
@@ -30,7 +30,7 @@ class ChangeMoodTool(BaseTool):
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 dict: 工具执行结果
""" """
try: try:
response_set = function_args.get("response_set") response_set = function_args.get("response_set")

View File

@@ -19,7 +19,7 @@ class RelationshipTool(BaseTool):
"required": ["text", "changed_value", "reason"], "required": ["text", "changed_value", "reason"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> dict: async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict:
"""执行工具功能 """执行工具功能
Args: Args:

View File

@@ -1,7 +1,7 @@
from src.do_tool.tool_can_use.base_tool import BaseTool from src.do_tool.tool_can_use.base_tool import BaseTool
from src.plugins.schedule.schedule_generator import bot_schedule from src.plugins.schedule.schedule_generator import bot_schedule
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from typing import Dict, Any from typing import Any
from datetime import datetime from datetime import datetime
logger = get_module_logger("get_current_task_tool") logger = get_module_logger("get_current_task_tool")
@@ -21,7 +21,7 @@ class GetCurrentTaskTool(BaseTool):
"required": ["start_time", "end_time"], "required": ["start_time", "end_time"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict[str, Any]:
"""执行获取当前任务或指定时间段的日程信息 """执行获取当前任务或指定时间段的日程信息
Args: Args:
@@ -29,7 +29,7 @@ class GetCurrentTaskTool(BaseTool):
message_txt: 原始消息文本,此工具不使用 message_txt: 原始消息文本,此工具不使用
Returns: Returns:
Dict: 工具执行结果 dict: 工具执行结果
""" """
start_time = function_args.get("start_time") start_time = function_args.get("start_time")
end_time = function_args.get("end_time") end_time = function_args.get("end_time")
@@ -55,5 +55,6 @@ class GetCurrentTaskTool(BaseTool):
task_info = "\n".join(task_list) task_info = "\n".join(task_list)
else: else:
task_info = f"{start_time}{end_time} 之间没有找到日程信息" task_info = f"{start_time}{end_time} 之间没有找到日程信息"
else:
task_info = "请提供有效的开始时间和结束时间"
return {"name": "get_current_task", "content": f"日程信息: {task_info}"} return {"name": "get_current_task", "content": f"日程信息: {task_info}"}

View File

@@ -1,6 +1,6 @@
from src.do_tool.tool_can_use.base_tool import BaseTool from src.do_tool.tool_can_use.base_tool import BaseTool
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from typing import Dict, Any from typing import Any
logger = get_module_logger("get_mid_memory_tool") logger = get_module_logger("get_mid_memory_tool")
@@ -18,7 +18,7 @@ class GetMidMemoryTool(BaseTool):
"required": ["id"], "required": ["id"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict[str, Any]:
"""执行记忆获取 """执行记忆获取
Args: Args:
@@ -26,7 +26,7 @@ class GetMidMemoryTool(BaseTool):
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 dict: 工具执行结果
""" """
try: try:
id = function_args.get("id") id = function_args.get("id")

View File

@@ -17,7 +17,7 @@ class SendEmojiTool(BaseTool):
"required": ["text"], "required": ["text"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: dict[str, Any], message_txt: str = "") -> dict[str, Any]:
text = function_args.get("text", message_txt) text = function_args.get("text", message_txt)
return { return {
"name": "send_emoji", "name": "send_emoji",

View File

@@ -42,7 +42,7 @@ class MyNewTool(BaseTool):
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 包含执行结果的字典必须包含name和content字段 dict: 包含执行结果的字典必须包含name和content字段
""" """
# 实现工具逻辑 # 实现工具逻辑
result = f"工具执行结果: {function_args.get('param1')}" result = f"工具执行结果: {function_args.get('param1')}"

View File

@@ -22,11 +22,11 @@ class BaseTool:
parameters = None parameters = None
@classmethod @classmethod
def get_tool_definition(cls) -> Dict[str, Any]: def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用 """获取工具定义用于LLM工具调用
Returns: Returns:
Dict: 工具定义字典 dict: 工具定义字典
""" """
if not cls.name or not cls.description or not cls.parameters: if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
@@ -36,14 +36,14 @@ class BaseTool:
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters}, "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
} }
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数 """执行工具函数
Args: Args:
function_args: 工具调用参数 function_args: 工具调用参数
Returns: Returns:
Dict: 工具执行结果 dict: 工具执行结果
""" """
raise NotImplementedError("子类必须实现execute方法") raise NotImplementedError("子类必须实现execute方法")
@@ -88,11 +88,11 @@ def discover_tools():
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具") logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
def get_all_tool_definitions() -> List[Dict[str, Any]]: def get_all_tool_definitions() -> List[dict[str, Any]]:
"""获取所有已注册工具的定义 """获取所有已注册工具的定义
Returns: Returns:
List[Dict]: 工具定义列表 List[dict]: 工具定义列表
""" """
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()] return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]

View File

@@ -1,6 +1,6 @@
from src.do_tool.tool_can_use.base_tool import BaseTool from src.do_tool.tool_can_use.base_tool import BaseTool
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from typing import Dict, Any from typing import Any
logger = get_module_logger("compare_numbers_tool") logger = get_module_logger("compare_numbers_tool")
@@ -19,15 +19,14 @@ class CompareNumbersTool(BaseTool):
"required": ["num1", "num2"], "required": ["num1", "num2"],
} }
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小 """执行比较两个数的大小
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 dict: 工具执行结果
""" """
try: try:
num1 = function_args.get("num1") num1 = function_args.get("num1")

View File

@@ -2,7 +2,7 @@ from src.do_tool.tool_can_use.base_tool import BaseTool
from src.plugins.chat.utils import get_embedding from src.plugins.chat.utils import get_embedding
from src.common.database import db from src.common.database import db
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from typing import Dict, Any, Union from typing import Any, Union
logger = get_logger("get_knowledge_tool") logger = get_logger("get_knowledge_tool")
@@ -21,15 +21,14 @@ class SearchKnowledgeTool(BaseTool):
"required": ["query"], "required": ["query"],
} }
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行知识库搜索 """执行知识库搜索
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 dict: 工具执行结果
""" """
try: try:
query = function_args.get("query") query = function_args.get("query")

View File

@@ -25,7 +25,6 @@ class GetMemoryTool(BaseTool):
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果

View File

@@ -22,7 +22,6 @@ class GetCurrentDateTimeTool(BaseTool):
Args: Args:
function_args: 工具参数(此工具不使用) function_args: 工具参数(此工具不使用)
message_txt: 原始消息文本(此工具不使用)
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果

View File

@@ -29,7 +29,6 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果

View File

@@ -106,7 +106,6 @@ class ToolUser:
Args: Args:
message_txt: 用户消息文本 message_txt: 用户消息文本
sender_name: 发送者名称
chat_stream: 聊天流对象 chat_stream: 聊天流对象
observation: 观察对象(可选) observation: 观察对象(可选)

View File

@@ -1,9 +1,9 @@
import json import json
from typing import Dict
import os import os
from typing import Any
def load_scenes() -> Dict: def load_scenes() -> dict[str, Any]:
""" """
从JSON文件加载场景数据 从JSON文件加载场景数据
@@ -20,7 +20,7 @@ def load_scenes() -> Dict:
PERSONALITY_SCENES = load_scenes() PERSONALITY_SCENES = load_scenes()
def get_scene_by_factor(factor: str) -> Dict: def get_scene_by_factor(factor: str) -> dict | None:
""" """
根据人格因子获取对应的情景测试 根据人格因子获取对应的情景测试
@@ -28,12 +28,12 @@ def get_scene_by_factor(factor: str) -> Dict:
factor (str): 人格因子名称 factor (str): 人格因子名称
Returns: Returns:
Dict: 包含情景描述的字典 dict: 包含情景描述的字典
""" """
return PERSONALITY_SCENES.get(factor, None) return PERSONALITY_SCENES.get(factor,None)
def get_all_scenes() -> Dict: def get_all_scenes() -> dict:
""" """
获取所有情景测试 获取所有情景测试

View File

@@ -23,6 +23,7 @@ class ChatObserver:
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
private_name: 私聊名称
Returns: Returns:
ChatObserver: 观察器实例 ChatObserver: 观察器实例

View File

@@ -33,6 +33,7 @@ class PFCManager:
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
private_name: 私聊名称
Returns: Returns:
Optional[Conversation]: 对话实例创建失败则返回None Optional[Conversation]: 对话实例创建失败则返回None

View File

@@ -18,6 +18,7 @@ def get_items_from_json(
Args: Args:
content: 包含JSON的文本 content: 包含JSON的文本
private_name: 私聊名称
*items: 要提取的字段名 *items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值} default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型} required_types: 字段的必需类型,格式为 {字段名: 类型}

View File

@@ -29,6 +29,8 @@ class ReplyChecker:
Args: Args:
reply: 生成的回复 reply: 生成的回复
goal: 对话目标 goal: 对话目标
chat_history: 对话历史记录
chat_history_text: 对话历史记录文本
retry_count: 当前重试次数 retry_count: 当前重试次数
Returns: Returns:

View File

@@ -1,6 +1,7 @@
import time import time
from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Optional, Any
import urllib3 import urllib3
@@ -58,12 +59,37 @@ class Message(MessageBase):
# 回复消息 # 回复消息
self.reply = reply self.reply = reply
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
@abstractmethod
async def _process_single_segment(self, segment):
pass
@dataclass @dataclass
class MessageRecv(Message): class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息""" """接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: Dict): def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化 """从MessageCQ的字典初始化
Args: Args:
@@ -90,26 +116,7 @@ class MessageRecv(Message):
self.processed_plain_text = await self._process_message_segments(self.message_segment) self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text() self.detailed_plain_text = self._generate_detailed_text()
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str: async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段 """处理单个消息段
@@ -179,28 +186,7 @@ class MessageProcessBase(Message):
self.thinking_time = round(time.time() - self.thinking_start_time, 2) self.thinking_time = round(time.time() - self.thinking_start_time, 2)
return self.thinking_time return self.thinking_time
async def _process_message_segments(self, segment: Seg) -> str: async def _process_single_segment(self, seg: Seg) -> str | None:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> Union[str, None]:
"""处理单个消息段 """处理单个消息段
Args: Args:
@@ -278,7 +264,7 @@ class MessageSending(MessageProcessBase):
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: ChatStream,
bot_user_info: UserInfo, bot_user_info: UserInfo,
sender_info: UserInfo, # 用来记录发送者信息,用于私聊回复 sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
message_segment: Seg, message_segment: Seg,
reply: Optional["MessageRecv"] = None, reply: Optional["MessageRecv"] = None,
is_head: bool = False, is_head: bool = False,
@@ -303,7 +289,7 @@ class MessageSending(MessageProcessBase):
self.is_emoji = is_emoji self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic self.apply_set_reply_logic = apply_set_reply_logic
def set_reply(self, reply: Optional["MessageRecv"] = None) -> None: def set_reply(self, reply: Optional["MessageRecv"] = None):
"""设置回复消息""" """设置回复消息"""
if self.message_info.format_info is not None and "reply" in self.message_info.format_info.accept_format: if self.message_info.format_info is not None and "reply" in self.message_info.format_info.accept_format:
if reply: if reply:
@@ -317,7 +303,6 @@ class MessageSending(MessageProcessBase):
self.message_segment, self.message_segment,
], ],
) )
return self
async def process(self) -> None: async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本""" """处理消息内容,生成纯文本和详细文本"""
@@ -342,6 +327,7 @@ class MessageSending(MessageProcessBase):
reply=thinking.reply, reply=thinking.reply,
is_head=is_head, is_head=is_head,
is_emoji=is_emoji, is_emoji=is_emoji,
sender_info=None,
) )
def to_dict(self): def to_dict(self):
@@ -361,7 +347,7 @@ class MessageSet:
def __init__(self, chat_stream: ChatStream, message_id: str): def __init__(self, chat_stream: ChatStream, message_id: str):
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.message_id = message_id self.message_id = message_id
self.messages: List[MessageSending] = [] self.messages: list[MessageSending] = []
self.time = round(time.time(), 3) # 保留3位小数 self.time = round(time.time(), 3) # 保留3位小数
def add_message(self, message: MessageSending) -> None: def add_message(self, message: MessageSending) -> None:

View File

@@ -1,7 +1,7 @@
# src/plugins/chat/message_sender.py # src/plugins/chat/message_sender.py
import asyncio import asyncio
import time import time
from typing import Dict, List, Optional, Union from typing import Union
# from ...common.database import db # 数据库依赖似乎不需要了,注释掉 # from ...common.database import db # 数据库依赖似乎不需要了,注释掉
from ..message.api import global_api from ..message.api import global_api
@@ -70,7 +70,7 @@ class MessageContainer:
def __init__(self, chat_id: str, max_size: int = 100): def __init__(self, chat_id: str, max_size: int = 100):
self.chat_id = chat_id self.chat_id = chat_id
self.max_size = max_size self.max_size = max_size
self.messages: List[Union[MessageThinking, MessageSending]] = [] # 明确类型 self.messages: list[MessageThinking | MessageSending] = [] # 明确类型
self.last_send_time = 0 self.last_send_time = 0
self.thinking_wait_timeout = 20 # 思考等待超时时间(秒) - 从旧 sender 合并 self.thinking_wait_timeout = 20 # 思考等待超时时间(秒) - 从旧 sender 合并
@@ -78,7 +78,7 @@ class MessageContainer:
"""计算当前容器中思考消息的数量""" """计算当前容器中思考消息的数量"""
return sum(1 for msg in self.messages if isinstance(msg, MessageThinking)) return sum(1 for msg in self.messages if isinstance(msg, MessageThinking))
def get_timeout_sending_messages(self) -> List[MessageSending]: def get_timeout_sending_messages(self) -> list[MessageSending]:
"""获取所有超时的MessageSending对象思考时间超过20秒按thinking_start_time排序 - 从旧 sender 合并""" """获取所有超时的MessageSending对象思考时间超过20秒按thinking_start_time排序 - 从旧 sender 合并"""
current_time = time.time() current_time = time.time()
timeout_messages = [] timeout_messages = []
@@ -94,7 +94,7 @@ class MessageContainer:
timeout_messages.sort(key=lambda x: x.thinking_start_time) timeout_messages.sort(key=lambda x: x.thinking_start_time)
return timeout_messages return timeout_messages
def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]: def get_earliest_message(self):
"""获取thinking_start_time最早的消息对象""" """获取thinking_start_time最早的消息对象"""
if not self.messages: if not self.messages:
return None return None
@@ -108,7 +108,7 @@ class MessageContainer:
earliest_message = msg earliest_message = msg
return earliest_message return earliest_message
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None: def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]):
"""添加消息到队列""" """添加消息到队列"""
if isinstance(message, MessageSet): if isinstance(message, MessageSet):
for single_message in message.messages: for single_message in message.messages:
@@ -116,7 +116,7 @@ class MessageContainer:
else: else:
self.messages.append(message) self.messages.append(message)
def remove_message(self, message_to_remove: Union[MessageThinking, MessageSending]) -> bool: def remove_message(self, message_to_remove: Union[MessageThinking, MessageSending]):
"""移除指定的消息对象如果消息存在则返回True否则返回False""" """移除指定的消息对象如果消息存在则返回True否则返回False"""
try: try:
_initial_len = len(self.messages) _initial_len = len(self.messages)
@@ -138,7 +138,7 @@ class MessageContainer:
"""检查是否有待发送的消息""" """检查是否有待发送的消息"""
return bool(self.messages) return bool(self.messages)
def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]: def get_all_messages(self) -> list[MessageThinking | MessageSending]:
"""获取所有消息""" """获取所有消息"""
return list(self.messages) # 返回副本 return list(self.messages) # 返回副本
@@ -148,7 +148,7 @@ class MessageManager:
def __init__(self): def __init__(self):
self._processor_task = None self._processor_task = None
self.containers: Dict[str, MessageContainer] = {} self.containers: dict[str, MessageContainer] = {}
self.storage = MessageStorage() # 添加 storage 实例 self.storage = MessageStorage() # 添加 storage 实例
self._running = True # 处理器运行状态 self._running = True # 处理器运行状态
self._container_lock = asyncio.Lock() # 保护 containers 字典的锁 self._container_lock = asyncio.Lock() # 保护 containers 字典的锁

View File

@@ -2,7 +2,6 @@ import random
import time import time
import re import re
from collections import Counter from collections import Counter
from typing import Dict, List, Optional
import jieba import jieba
import numpy as np import numpy as np
@@ -26,7 +25,7 @@ def is_english_letter(char: str) -> bool:
return "a" <= char.lower() <= "z" return "a" <= char.lower() <= "z"
def db_message_to_str(message_dict: Dict) -> str: def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}") logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try: try:
@@ -35,7 +34,7 @@ def db_message_to_str(message_dict: Dict) -> str:
message_dict.get("user_nickname", ""), message_dict.get("user_nickname", ""),
message_dict.get("user_cardname", ""), message_dict.get("user_cardname", ""),
) )
except Exception: except Exception as e:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "") content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n" result = f"[{time_str}] {name}: {content}\n"
@@ -77,13 +76,13 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
if not is_mentioned: if not is_mentioned:
# 判断是否被回复 # 判断是否被回复
if re.match( if re.match(
f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\)[\s\S]*?\],说:", message.processed_plain_text f"\[回复 [\s\S]*?\({str(global_config.BOT_QQ)}\)[\s\S]*?],说:" , message.processed_plain_text
): ):
is_mentioned = True is_mentioned = True
else: else:
# 判断内容中是否被提及 # 判断内容中是否被提及
message_content = re.sub(r"@[\s\S]*?(\d+)", "", message.processed_plain_text) message_content = re.sub(r"@[\s\S]*?(\d+)", "", message.processed_plain_text)
message_content = re.sub(r"\[回复 [\s\S]*?\(((\d+)|未知id)\)[\s\S]*?\],说:", "", message_content) message_content = re.sub(r"\[回复 [\s\S]*?\(((\d+)|未知id)\)[\s\S]*?],说:", "", message_content)
for keyword in keywords: for keyword in keywords:
if keyword in message_content: if keyword in message_content:
is_mentioned = True is_mentioned = True
@@ -223,7 +222,7 @@ def get_recent_group_speaker(chat_stream_id: int, sender, limit: int = 12) -> li
return who_chat_in_group return who_chat_in_group
def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
"""将文本分割成句子,并根据概率合并 """将文本分割成句子,并根据概率合并
1. 识别分割点(, 。 ; 空格),但如果分割点左右都是英文字母则不分割。 1. 识别分割点(, 。 ; 空格),但如果分割点左右都是英文字母则不分割。
2. 将文本分割成 (内容, 分隔符) 的元组。 2. 将文本分割成 (内容, 分隔符) 的元组。
@@ -370,7 +369,7 @@ def random_remove_punctuation(text: str) -> str:
return result return result
def process_llm_response(text: str) -> List[str]: def process_llm_response(text: str) -> list[str]:
# 先保护颜文字 # 先保护颜文字
if global_config.enable_kaomoji_protection: if global_config.enable_kaomoji_protection:
protected_text, kaomoji_mapping = protect_kaomoji(text) protected_text, kaomoji_mapping = protect_kaomoji(text)
@@ -379,7 +378,7 @@ def process_llm_response(text: str) -> List[str]:
protected_text = text protected_text = text
kaomoji_mapping = {} kaomoji_mapping = {}
# 提取被 () 或 [] 包裹且包含中文的内容 # 提取被 () 或 [] 包裹且包含中文的内容
pattern = re.compile(r"[\(\[\](?=.*[\u4e00-\u9fff]).*?[\)\]\]") pattern = re.compile(r"[(\[](?=.*[一-鿿]).*?[)\]]")
# _extracted_contents = pattern.findall(text) # _extracted_contents = pattern.findall(text)
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找 _extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
# 去除 () 和 [] 及其包裹的内容 # 去除 () 和 [] 及其包裹的内容
@@ -554,7 +553,7 @@ def protect_kaomoji(sentence):
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配) r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
r"[^一-龥a-zA-Z0-9\s]" # 非中文、非英文、非数字、非空格字符(必须包含至少一个) r"[^一-龥a-zA-Z0-9\s]" # 非中文、非英文、非数字、非空格字符(必须包含至少一个)
r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配) r"[^()\[\]()【】]*?" # 非括号字符(惰性匹配)
r"[\)\])】" # 右括号 r"[)\])】" # 右括号
r"]" r"]"
r")" r")"
r"|" r"|"
@@ -704,7 +703,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
return 0, 0 return 0, 0
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> Optional[str]: def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
"""将时间戳转换为人类可读的时间格式 """将时间戳转换为人类可读的时间格式
Args: Args:
@@ -732,10 +731,9 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
return f"{int(diff / 86400)}天前:\n" return f"{int(diff / 86400)}天前:\n"
else: else:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":\n" return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":\n"
elif mode == "lite": else: # mode = "lite" or unknown
# 只返回时分秒格式,喵~ # 只返回时分秒格式,喵~
return time.strftime("%H:%M:%S", time.localtime(timestamp)) return time.strftime("%H:%M:%S", time.localtime(timestamp))
return None
def parse_text_timestamps(text: str, mode: str = "normal") -> str: def parse_text_timestamps(text: str, mode: str = "normal") -> str:

View File

@@ -511,7 +511,7 @@ class Hippocampus:
"""从文本中提取关键词并获取相关记忆。 """从文本中提取关键词并获取相关记忆。
Args: Args:
topic (str): 记忆主题 keywords (list): 输入文本
max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3表示最多返回3条与输入文本相关度最高的记忆。 max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3表示最多返回3条与输入文本相关度最高的记忆。
max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2表示每个主题最多返回2条相似度最高的记忆。 max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2表示每个主题最多返回2条相似度最高的记忆。
max_depth (int, optional): 记忆检索深度。默认为3。值越大检索范围越广可以获取更多间接相关的记忆但速度会变慢。 max_depth (int, optional): 记忆检索深度。默认为3。值越大检索范围越广可以获取更多间接相关的记忆但速度会变慢。
@@ -829,7 +829,7 @@ class EntorhinalCortex:
return chat_samples return chat_samples
@staticmethod @staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list: def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)""" """从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
try_count = 0 try_count = 0
time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟 time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟

View File

@@ -1,7 +1,6 @@
import datetime import datetime
import os import os
import sys import sys
from typing import Dict
import asyncio import asyncio
from dateutil import tz from dateutil import tz
@@ -162,7 +161,7 @@ class ScheduleGenerator:
async def generate_daily_schedule( async def generate_daily_schedule(
self, self,
target_date: datetime.datetime = None, target_date: datetime.datetime = None,
) -> Dict[str, str]: ) -> dict[str, str]:
daytime_prompt = self.construct_daytime_prompt(target_date) daytime_prompt = self.construct_daytime_prompt(target_date)
daytime_response, _ = await self.llm_scheduler_all.generate_response_async(daytime_prompt) daytime_response, _ = await self.llm_scheduler_all.generate_response_async(daytime_prompt)
return daytime_response return daytime_response