This commit is contained in:
tcmofashi
2025-04-17 15:53:26 +08:00
52 changed files with 299 additions and 336 deletions

View File

@@ -119,7 +119,6 @@ class ChatObserver:
self.last_cold_chat_check = current_time
# 判断是否冷场
is_cold = False
if self.last_message_time is None:
is_cold = True
else:

View File

@@ -113,7 +113,8 @@ class Conversation:
return True
return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
@staticmethod
def _convert_to_message(msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象"""
try:
chat_info = msg_dict.get("chat_info", {})
@@ -123,7 +124,7 @@ class Conversation:
return Message(
message_id=msg_dict["message_id"],
chat_stream=chat_stream,
time=msg_dict["time"],
timestamp=msg_dict["time"],
user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),

View File

@@ -15,8 +15,8 @@ class DirectMessageSender:
def __init__(self):
pass
@staticmethod
async def send_message(
self,
chat_stream: ChatStream,
content: str,
reply_to_message: Optional[Message] = None,

View File

@@ -51,11 +51,9 @@ class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id}
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
# print(f"storage_check_message: {message_time}")
query["time"] = {"$gt": message_time}
return list(db.messages.find(query).sort("time", 1))
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:

View File

@@ -160,16 +160,16 @@ class GoalAnalyzer:
# 返回第一个目标作为当前主要目标(如果有)
if result:
first_goal = result[0]
return (first_goal.get("goal", ""), "", first_goal.get("reasoning", ""))
return first_goal.get("goal", ""), "", first_goal.get("reasoning", "")
else:
# 单个目标的情况
goal = result.get("goal", "")
reasoning = result.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning))
return (goal, "", reasoning)
return goal, "", reasoning
# 如果解析失败,返回默认值
return ("", "", "")
return "", "", ""
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表
@@ -195,7 +195,8 @@ class GoalAnalyzer:
if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
@staticmethod
def _calculate_similarity(goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
@@ -299,7 +300,8 @@ class DirectMessageSender:
self.logger = get_module_logger("direct_sender")
self.storage = MessageStorage()
async def send_via_ws(self, message: MessageSending) -> None:
@staticmethod
async def send_via_ws(message: MessageSending) -> None:
try:
await global_api.send_message(message)
except Exception as e:

View File

@@ -19,7 +19,8 @@ class KnowledgeFetcher:
request_type="knowledge_fetch",
)
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
@staticmethod
async def fetch(query: str, chat_history: List[Message]) -> Tuple[str, str]:
"""获取相关知识
Args:

View File

@@ -30,11 +30,8 @@ class ReplyGenerator:
"""生成回复
Args:
goal: 对话目标
chat_history: 聊天历史
knowledge_cache: 知识缓存
previous_reply: 上一次生成的回复(如果有)
retry_count: 当前重试次数
observation_info: 观察信息
conversation_info: 对话信息
Returns:
str: 生成的回复

View File

@@ -17,6 +17,5 @@ __all__ = [
"relationship_manager",
"MoodManager",
"willing_manager",
"hippocampus",
"bot_schedule",
]

View File

@@ -103,7 +103,8 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
@staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
@@ -111,7 +112,8 @@ class ChatManager:
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
if group_info:
# 组合关键信息
@@ -188,7 +190,8 @@ class ChatManager:
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)

View File

@@ -82,7 +82,8 @@ class EmojiManager:
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
def _ensure_emoji_collection(self):
@staticmethod
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
@@ -193,7 +194,8 @@ class EmojiManager:
logger.error(f"[错误] 获取表情包失败: {str(e)}")
return None
async def _get_emoji_description(self, image_base64: str) -> str:
@staticmethod
async def _get_emoji_description(image_base64: str) -> str:
"""获取表情包的标签使用image_manager的描述生成功能"""
try:
@@ -554,7 +556,8 @@ class EmojiManager:
self.check_emoji_file_full()
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
async def delete_all_images(self):
@staticmethod
async def delete_all_images():
"""删除 data/image 目录下的所有文件"""
try:
image_dir = os.path.join("data", "image")

View File

@@ -31,7 +31,7 @@ class Message(MessageBase):
def __init__(
self,
message_id: str,
time: float,
timestamp: float,
chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None,
@@ -43,7 +43,7 @@ class Message(MessageBase):
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=time,
time=timestamp,
group_info=chat_stream.group_info,
user_info=user_info,
)
@@ -143,7 +143,7 @@ class MessageRecv(Message):
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
time = self.message_info.time
timestamp = self.message_info.time
user_info = self.message_info.user_info
# name = (
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
@@ -151,7 +151,7 @@ class MessageRecv(Message):
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
# )
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{time}] {name}: {self.processed_plain_text}\n"
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@dataclass
@@ -170,7 +170,7 @@ class MessageProcessBase(Message):
# 调用父类初始化
super().__init__(
message_id=message_id,
time=round(time.time(), 3), # 保留3位小数
timestamp=round(time.time(), 3), # 保留3位小数
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
@@ -242,7 +242,7 @@ class MessageProcessBase(Message):
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
time = self.message_info.time
timestamp = self.message_info.time
user_info = self.message_info.user_info
# name = (
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
@@ -250,7 +250,7 @@ class MessageProcessBase(Message):
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
# )
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{time}] {name}: {self.processed_plain_text}\n"
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@dataclass

View File

@@ -26,7 +26,8 @@ class MessageBuffer:
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
self.lock = asyncio.Lock()
def get_person_id_(self, platform: str, user_id: str, group_info: GroupInfo):
@staticmethod
def get_person_id_(platform: str, user_id: str, group_info: GroupInfo):
"""获取唯一id"""
if group_info:
group_id = group_info.group_id
@@ -150,20 +151,20 @@ class MessageBuffer:
keep_msgs[msg_id] = msg
elif msg.result == "F":
# 收集F消息的文本内容
F_type = "seglist"
f_type = "seglist"
if msg.message.message_segment.type != "seglist":
F_type = msg.message.message_segment.type
f_type = msg.message.message_segment.type
else:
if (
isinstance(msg.message.message_segment.data, list)
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
and len(msg.message.message_segment.data) == 1
):
F_type = msg.message.message_segment.data[0].type
f_type = msg.message.message_segment.data[0].type
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
if F_type == "text":
if f_type == "text":
combined_text.append(msg.message.processed_plain_text)
elif F_type != "text":
elif f_type != "text":
is_update = False
elif msg.result == "U":
logger.debug(f"异常未处理信息id {msg.message.message_info.message_id}")
@@ -185,7 +186,8 @@ class MessageBuffer:
logger.debug(f"查询超时消息id {message.message_info.message_id}")
return False
async def save_message_interval(self, person_id: str, message: BaseMessageInfo):
@staticmethod
async def save_message_interval(person_id: str, message: BaseMessageInfo):
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
now_time_ms = int(round(time.time() * 1000))
if len(message_interval_list) < 1000:

View File

@@ -35,7 +35,8 @@ class MessageSender:
"""设置当前bot实例"""
pass
def get_recalled_messages(self, stream_id: str) -> list:
@staticmethod
def get_recalled_messages(stream_id: str) -> list:
"""获取所有撤回的消息"""
recalled_messages = []
@@ -43,7 +44,8 @@ class MessageSender:
# 按thinking_start_time排序时间早的在前面
return recalled_messages
async def send_via_ws(self, message: MessageSending) -> None:
@staticmethod
async def send_via_ws(message: MessageSending) -> None:
try:
await global_api.send_message(message)
except Exception as e:

View File

@@ -135,7 +135,7 @@ async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
msg = Message(
message_id=msg_data["message_id"],
chat_stream=chat_stream,
time=msg_data["time"],
timestamp=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""),
detailed_plain_text=msg_data.get("detailed_plain_text", ""),

View File

@@ -38,7 +38,8 @@ class ImageManager:
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
def _ensure_image_collection(self):
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
@@ -50,7 +51,8 @@ class ImageManager:
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
def _ensure_description_collection(self):
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
@@ -60,7 +62,8 @@ class ImageManager:
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
Args:
@@ -73,7 +76,8 @@ class ImageManager:
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
@@ -226,7 +230,8 @@ class ImageManager:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
def transform_gif(self, gif_base64: str) -> str:
@staticmethod
def transform_gif(gif_base64: str) -> str:
"""将GIF转换为水平拼接的静态图像
Args:

View File

@@ -13,7 +13,8 @@ class MessageProcessor:
def __init__(self):
self.storage = MessageStorage()
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
@staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词"""
for word in global_config.ban_words:
if word in text:
@@ -24,7 +25,8 @@ class MessageProcessor:
return True
return False
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
@staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
for pattern in global_config.ban_msgs_regex:
if pattern.search(text):

View File

@@ -37,7 +37,8 @@ class ReasoningChat:
self.mood_manager = MoodManager.get_instance()
self.mood_manager.start_mood_update()
async def _create_thinking_message(self, message, chat, userinfo, messageinfo):
@staticmethod
async def _create_thinking_message(message, chat, userinfo, messageinfo):
"""创建思考消息"""
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
@@ -59,7 +60,8 @@ class ReasoningChat:
return thinking_id
async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending:
@staticmethod
async def _send_response_messages(message, chat, response_set: List[str], thinking_id) -> MessageSending:
"""发送回复消息"""
container = message_manager.get_container(chat.stream_id)
thinking_message = None
@@ -104,7 +106,8 @@ class ReasoningChat:
return first_bot_msg
async def _handle_emoji(self, message, chat, response):
@staticmethod
async def _handle_emoji(message, chat, response):
"""处理表情包"""
if random() < global_config.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response)
@@ -192,21 +195,21 @@ class ReasoningChat:
if not buffer_result:
await willing_manager.bombing_buffer_message_handle(message.message_info.message_id)
willing_manager.delete(message.message_info.message_id)
F_type = "seglist"
f_type = "seglist"
if message.message_segment.type != "seglist":
F_type = message.message_segment.type
f_type = message.message_segment.type
else:
if (
isinstance(message.message_segment.data, list)
and all(isinstance(x, Seg) for x in message.message_segment.data)
and len(message.message_segment.data) == 1
):
F_type = message.message_segment.data[0].type
if F_type == "text":
f_type = message.message_segment.data[0].type
if f_type == "text":
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
elif F_type == "image":
elif f_type == "image":
logger.info("触发缓冲,已炸飞表情包/图片")
elif F_type == "seglist":
elif f_type == "seglist":
logger.info("触发缓冲,已炸飞消息列")
return
@@ -291,7 +294,8 @@ class ReasoningChat:
# 意愿管理器注销当前message信息
willing_manager.delete(message.message_info.message_id)
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
@staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词"""
for word in global_config.ban_words:
if word in text:
@@ -302,7 +306,8 @@ class ReasoningChat:
return True
return False
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
@staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
for pattern in global_config.ban_msgs_regex:
if pattern.search(text):

View File

@@ -69,8 +69,6 @@ class ResponseGenerator:
return None
async def _generate_response_with_model(self, message: MessageThinking, model: LLMRequest, thinking_id: str):
sender_name = ""
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
@@ -188,7 +186,8 @@ class ResponseGenerator:
logger.debug(f"获取情感标签时出错: {e}")
return "中立", "平静" # 出错时返回默认值
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
@staticmethod
async def _process_response(content: str) -> Tuple[List[str], List[str]]:
"""处理响应内容,返回处理后的内容和情感标签"""
if not content:
return None, []

View File

@@ -101,16 +101,14 @@ class PromptBuilder:
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)
related_memory_info = ""
if related_memory:
related_memory_info = ""
for memory in related_memory:
related_memory_info += memory[1]
# memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n"
memory_prompt = await global_prompt_manager.format_prompt(
"memory_prompt", related_memory_info=related_memory_info
)
else:
related_memory_info = ""
# print(f"相关记忆:{related_memory_info}")
@@ -162,7 +160,6 @@ class PromptBuilder:
# 知识构建
start_time = time.time()
prompt_info = ""
prompt_info = await self.get_prompt_info(message_txt, threshold=0.38)
if prompt_info:
# prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
@@ -373,8 +370,9 @@ class PromptBuilder:
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info
@staticmethod
def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []

View File

@@ -40,7 +40,8 @@ class ThinkFlowChat:
self.mood_manager.start_mood_update()
self.tool_user = ToolUser()
async def _create_thinking_message(self, message, chat, userinfo, messageinfo):
@staticmethod
async def _create_thinking_message(message, chat, userinfo, messageinfo):
"""创建思考消息"""
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
@@ -62,7 +63,8 @@ class ThinkFlowChat:
return thinking_id
async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending:
@staticmethod
async def _send_response_messages(message, chat, response_set: List[str], thinking_id) -> MessageSending:
"""发送回复消息"""
container = message_manager.get_container(chat.stream_id)
thinking_message = None
@@ -108,7 +110,8 @@ class ThinkFlowChat:
message_manager.add_message(message_set)
return first_bot_msg
async def _handle_emoji(self, message, chat, response, send_emoji=""):
@staticmethod
async def _handle_emoji(message, chat, response, send_emoji=""):
"""处理表情包"""
if send_emoji:
emoji_raw = await emoji_manager.get_emoji_for_text(send_emoji)
@@ -204,21 +207,21 @@ class ThinkFlowChat:
if not buffer_result:
await willing_manager.bombing_buffer_message_handle(message.message_info.message_id)
willing_manager.delete(message.message_info.message_id)
F_type = "seglist"
f_type = "seglist"
if message.message_segment.type != "seglist":
F_type = message.message_segment.type
f_type = message.message_segment.type
else:
if (
isinstance(message.message_segment.data, list)
and all(isinstance(x, Seg) for x in message.message_segment.data)
and len(message.message_segment.data) == 1
):
F_type = message.message_segment.data[0].type
if F_type == "text":
f_type = message.message_segment.data[0].type
if f_type == "text":
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
elif F_type == "image":
elif f_type == "image":
logger.info("触发缓冲,已炸飞表情包/图片")
elif F_type == "seglist":
elif f_type == "seglist":
logger.info("触发缓冲,已炸飞消息列")
return
@@ -461,7 +464,8 @@ class ThinkFlowChat:
# 意愿管理器注销当前message信息
willing_manager.delete(message.message_info.message_id)
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
@staticmethod
def _check_ban_words(text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词"""
for word in global_config.ban_words:
if word in text:
@@ -472,7 +476,8 @@ class ThinkFlowChat:
return True
return False
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
@staticmethod
def _check_ban_regex(text: str, chat, userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
for pattern in global_config.ban_msgs_regex:
if pattern.search(text):

View File

@@ -236,7 +236,8 @@ class ResponseGenerator:
logger.debug(f"获取情感标签时出错: {e}")
return "中立", "平静" # 出错时返回默认值
async def _process_response(self, content: str) -> List[str]:
@staticmethod
async def _process_response(content: str) -> List[str]:
"""处理响应内容,返回处理后的内容和情感标签"""
if not content:
return None

View File

@@ -64,8 +64,9 @@ class PromptBuilder:
self.prompt_built = ""
self.activate_messages = ""
@staticmethod
async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
@@ -168,8 +169,9 @@ class PromptBuilder:
return prompt
@staticmethod
async def _build_prompt_simple(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
@@ -237,8 +239,8 @@ class PromptBuilder:
logger.info(f"生成回复的prompt: {prompt}")
return prompt
@staticmethod
async def _build_prompt_check_response(
self,
chat_stream,
message_txt: str,
sender_name: str = "某人",

View File

@@ -4,6 +4,8 @@ import math
import random
import time
import re
from itertools import combinations
import jieba
import networkx as nx
import numpy as np
@@ -250,7 +252,8 @@ class Hippocampus:
"""获取记忆图中所有节点的名字列表"""
return list(self.memory_graph.G.nodes())
def calculate_node_hash(self, concept, memory_items) -> int:
@staticmethod
def calculate_node_hash(concept, memory_items) -> int:
"""计算节点的特征值"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
@@ -258,12 +261,14 @@ class Hippocampus:
content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content)
def calculate_edge_hash(self, source, target) -> int:
@staticmethod
def calculate_edge_hash(source, target) -> int:
"""计算边的特征值"""
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def find_topic_llm(self, text, topic_num):
@staticmethod
def find_topic_llm(text, topic_num):
prompt = (
f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
@@ -271,14 +276,16 @@ class Hippocampus:
)
return prompt
def topic_what(self, text, topic, time_info):
@staticmethod
def topic_what(text, topic, time_info):
prompt = (
f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
)
return prompt
def calculate_topic_num(self, text, compress_rate):
@staticmethod
def calculate_topic_num(text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count("\n") * compress_rate
@@ -693,7 +700,8 @@ class EntorhinalCortex:
return chat_samples
def random_get_msg_snippet(self, target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
@staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
"""从数据库中随机获取指定时间戳附近的消息片段"""
try_count = 0
while try_count < 3:
@@ -958,7 +966,8 @@ class Hippocampus:
"""获取记忆图中所有节点的名字列表"""
return list(self.memory_graph.G.nodes())
def calculate_node_hash(self, concept, memory_items) -> int:
@staticmethod
def calculate_node_hash(concept, memory_items) -> int:
"""计算节点的特征值"""
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
@@ -966,12 +975,14 @@ class Hippocampus:
content = f"{concept}:{'|'.join(sorted_items)}"
return hash(content)
def calculate_edge_hash(self, source, target) -> int:
@staticmethod
def calculate_edge_hash(source, target) -> int:
"""计算边的特征值"""
nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}")
def find_topic_llm(self, text, topic_num):
@staticmethod
def find_topic_llm(text, topic_num):
prompt = (
f"这是一段文字:{text}。请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
@@ -979,14 +990,16 @@ class Hippocampus:
)
return prompt
def topic_what(self, text, topic, time_info):
@staticmethod
def topic_what(text, topic, time_info):
prompt = (
f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
f"可以包含时间和人物,以及具体的观点。只输出这句话就好"
)
return prompt
def calculate_topic_num(self, text, compress_rate):
@staticmethod
def calculate_topic_num(text, compress_rate):
"""计算文本的话题数量"""
information_content = calculate_information_content(text)
topic_by_length = text.count("\n") * compress_rate
@@ -1542,11 +1555,10 @@ class ParahippocampalGyrus:
last_modified=current_time,
)
for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)):
logger.debug(f"连接同批次节点: {all_topics[i]}{all_topics[j]}")
all_added_edges.append(f"{all_topics[i]}-{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
for topic1, topic2 in combinations(all_topics, 2):
logger.debug(f"连接同批次节点: {topic1}{topic2}")
all_added_edges.append(f"{topic1}-{topic2}")
self.memory_graph.connect_dot(topic1, topic2)
logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
logger.debug(f"强化连接: {', '.join(all_added_edges)}")

View File

@@ -1,95 +0,0 @@
import unittest
import asyncio
import aiohttp
from api import BaseMessageAPI
from message_base import (
BaseMessageInfo,
UserInfo,
GroupInfo,
FormatInfo,
MessageBase,
Seg,
)
send_url = "http://localhost"
receive_port = 18002 # 接收消息的端口
send_port = 18000 # 发送消息的端口
test_endpoint = "/api/message"
# 创建并启动API实例
api = BaseMessageAPI(host="0.0.0.0", port=receive_port)
class TestLiveAPI(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
"""测试前的设置"""
self.received_messages = []
async def message_handler(message):
self.received_messages.append(message)
self.api = api
self.api.register_message_handler(message_handler)
self.server_task = asyncio.create_task(self.api.run())
try:
await asyncio.wait_for(asyncio.sleep(1), timeout=5)
except asyncio.TimeoutError:
self.skipTest("服务器启动超时")
async def asyncTearDown(self):
"""测试后的清理"""
if hasattr(self, "server_task"):
await self.api.stop() # 先调用正常的停止流程
if not self.server_task.done():
self.server_task.cancel()
try:
await asyncio.wait_for(self.server_task, timeout=100)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
async def test_send_and_receive_message(self):
"""测试向运行中的API发送消息并接收响应"""
# 准备测试消息
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
format_info = FormatInfo(content_format=["text"], accept_format=["text", "emoji", "reply"])
template_info = None
message_info = BaseMessageInfo(
platform="qq",
message_id=12345678,
time=12345678,
group_info=group_info,
user_info=user_info,
format_info=format_info,
template_info=template_info,
)
message = MessageBase(
message_info=message_info,
raw_message="测试消息",
message_segment=Seg(type="text", data="测试消息"),
)
test_message = message.to_dict()
# 发送测试消息到发送端口
async with aiohttp.ClientSession() as session:
async with session.post(
f"{send_url}:{send_port}{test_endpoint}",
json=test_message,
) as response:
response_data = await response.json()
self.assertEqual(response.status, 200)
self.assertEqual(response_data["status"], "success")
try:
async with asyncio.timeout(5): # 设置5秒超时
while len(self.received_messages) == 0:
await asyncio.sleep(0.1)
received_message = self.received_messages[0]
print(received_message)
self.received_messages.clear()
except asyncio.TimeoutError:
self.fail("等待接收消息超时")
if __name__ == "__main__":
unittest.main()

View File

@@ -72,7 +72,8 @@ class PersonInfoManager:
self.person_name_list[doc["person_id"]] = doc["person_name"]
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
def get_person_id(self, platform: str, user_id: int):
@staticmethod
def get_person_id(platform: str, user_id: int):
"""获取唯一id"""
# 如果platform中存在-,就截取-后面的部分
if "-" in platform:
@@ -91,7 +92,8 @@ class PersonInfoManager:
else:
return False
async def create_person_info(self, person_id: str, data: dict = None):
@staticmethod
async def create_person_info(person_id: str, data: dict = None):
"""创建一个项"""
if not person_id:
logger.debug("创建失败personid不存在")
@@ -131,7 +133,8 @@ class PersonInfoManager:
else:
return False
def _extract_json_from_text(self, text: str) -> dict:
@staticmethod
def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法"""
try:
# 尝试直接解析
@@ -225,7 +228,8 @@ class PersonInfoManager:
logger.error(f"{max_retries}次尝试后仍未能生成唯一昵称")
return None
async def del_one_document(self, person_id: str):
@staticmethod
async def del_one_document(person_id: str):
"""删除指定 person_id 的文档"""
if not person_id:
logger.debug("删除失败person_id 不能为空")
@@ -237,7 +241,8 @@ class PersonInfoManager:
else:
logger.debug(f"删除失败:未找到 person_id={person_id}")
async def get_value(self, person_id: str, field_name: str):
@staticmethod
async def get_value(person_id: str, field_name: str):
"""获取指定person_id文档的字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id:
logger.debug("get_value获取失败person_id不能为空")
@@ -256,7 +261,8 @@ class PersonInfoManager:
logger.trace(f"获取{person_id}{field_name}失败,已返回默认值{default_value}")
return default_value
async def get_values(self, person_id: str, field_names: list) -> dict:
@staticmethod
async def get_values(person_id: str, field_names: list) -> dict:
"""获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id:
logger.debug("get_values获取失败person_id不能为空")
@@ -281,7 +287,8 @@ class PersonInfoManager:
return result
async def del_all_undefined_field(self):
@staticmethod
async def del_all_undefined_field():
"""删除所有项里的未定义字段"""
# 获取所有已定义的字段名
defined_fields = set(person_info_default.keys())
@@ -307,8 +314,8 @@ class PersonInfoManager:
logger.error(f"清理未定义字段时出错: {e}")
return
@staticmethod
async def get_specific_value_list(
self,
field_name: str,
way: Callable[[Any], bool], # 接受任意类型值
) -> Dict[str, Any]:

View File

@@ -62,7 +62,7 @@ class RelationshipManager:
def mood_feedback(self, value):
"""情绪反馈"""
mood_manager = self.mood_manager
mood_gain = (mood_manager.get_current_mood().valence) ** 2 * math.copysign(
mood_gain = mood_manager.get_current_mood().valence ** 2 * math.copysign(
1, value * mood_manager.get_current_mood().valence
)
value += value * mood_gain
@@ -77,24 +77,27 @@ class RelationshipManager:
else:
return mood_value / coefficient
async def is_known_some_one(self, platform, user_id):
@staticmethod
async def is_known_some_one(platform, user_id):
"""判断是否认识某人"""
is_known = person_info_manager.is_person_known(platform, user_id)
return is_known
async def is_qved_name(self, platform, user_id):
@staticmethod
async def is_qved_name(platform, user_id):
"""判断是否认识某人"""
person_id = person_info_manager.get_person_id(platform, user_id)
is_qved = await person_info_manager.has_one_field(person_id, "person_name")
old_name = await person_info_manager.get_value(person_id, "person_name")
print(f"old_name: {old_name}")
print(f"is_qved: {is_qved}")
if is_qved and old_name != None:
if is_qved and old_name is not None:
return True
else:
return False
async def first_knowing_some_one(self, platform, user_id, user_nickname, user_cardname, user_avatar):
@staticmethod
async def first_knowing_some_one(platform, user_id, user_nickname, user_cardname, user_avatar):
"""判断是否认识某人"""
person_id = person_info_manager.get_person_id(platform, user_id)
await person_info_manager.update_one_field(person_id, "nickname", user_nickname)
@@ -102,7 +105,8 @@ class RelationshipManager:
# await person_info_manager.update_one_field(person_id, "user_avatar", user_avatar)
await person_info_manager.qv_person_name(person_id, user_nickname, user_cardname, user_avatar)
async def convert_all_person_sign_to_person_name(self, input_text: str):
@staticmethod
async def convert_all_person_sign_to_person_name(input_text: str):
"""将所有人的<platform:user_id:nickname:cardname>格式转换为person_name"""
try:
# 使用正则表达式匹配<platform:user_id:nickname:cardname>格式
@@ -119,7 +123,7 @@ class RelationshipManager:
person_name = nickname.strip() if nickname.strip() else cardname.strip()
if person_id in all_person:
if all_person[person_id] != None:
if all_person[person_id] is not None:
person_name = all_person[person_id]
print(f"将<{platform}:{user_id}:{nickname}:{cardname}>替换为{person_name}")
@@ -326,7 +330,8 @@ class RelationshipManager:
f"回复态度为{relation_prompt2_list[level_num]},关系等级为{level_num}"
)
def calculate_level_num(self, relationship_value) -> int:
@staticmethod
def calculate_level_num(relationship_value) -> int:
"""关系等级计算"""
if -1000 <= relationship_value < -227:
level_num = 0
@@ -344,7 +349,8 @@ class RelationshipManager:
level_num = 5 if relationship_value > 1000 else 0
return level_num
def ensure_float(self, value, person_id):
@staticmethod
def ensure_float(value, person_id):
"""确保返回浮点数转换失败返回0.0"""
if isinstance(value, float):
return value

View File

@@ -100,7 +100,8 @@ class InfoCatcher:
self.trigger_response_message, first_bot_msg
)
def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message):
@staticmethod
def get_message_from_db_between_msgs(message_start: Message, message_end: Message):
try:
# 从数据库中获取消息的时间戳
time_start = message_start.message_info.time
@@ -155,7 +156,8 @@ class InfoCatcher:
return result
def message_to_dict(self, message):
@staticmethod
def message_to_dict(message):
if not message:
return None
if isinstance(message, dict):

View File

@@ -235,6 +235,7 @@ class ScheduleGenerator:
Args:
num (int): 需要获取的日程数量默认为1
time_info (bool): 是否包含时间信息默认为False
Returns:
list: 最新加入的日程列表
@@ -267,7 +268,8 @@ class ScheduleGenerator:
db.schedule.update_one({"date": date_str}, {"$set": schedule_data}, upsert=True)
logger.debug(f"已保存{date_str}的日程到数据库")
def load_schedule_from_db(self, date: datetime.datetime):
@staticmethod
def load_schedule_from_db(date: datetime.datetime):
"""从数据库加载日程,同时加载 today_done_list"""
date_str = date.strftime("%Y-%m-%d")
existing_schedule = db.schedule.find_one({"date": date_str})

View File

@@ -10,7 +10,8 @@ logger = get_module_logger("message_storage")
class MessageStorage:
async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
@staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 莫越权 救世啊
@@ -43,7 +44,8 @@ class MessageStorage:
except Exception:
logger.exception("存储消息失败")
async def store_recalled_message(self, message_id: str, time: str, chat_stream: ChatStream) -> None:
@staticmethod
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
"""存储撤回消息到数据库"""
if "recalled_messages" not in db.list_collection_names():
db.create_collection("recalled_messages")
@@ -58,7 +60,8 @@ class MessageStorage:
except Exception:
logger.exception("存储撤回消息失败")
async def remove_recalled_message(self, time: str) -> None:
@staticmethod
async def remove_recalled_message(time: str) -> None:
"""删除撤回消息"""
try:
db.recalled_messages.delete_many({"time": {"$lt": time - 300}})

View File

@@ -28,7 +28,7 @@ class TopicIdentifier:
消息内容:{text}"""
# 使用 LLM_request 类进行请求
# 使用 LLMRequest 类进行请求
try:
topic, _, _ = await self.llm_topic_judge.generate_response(prompt)
except Exception as e:

View File

@@ -24,7 +24,8 @@ class LLMStatistics:
self._init_database()
self.name_dict: Dict[List] = {}
def _init_database(self):
@staticmethod
def _init_database():
"""初始化数据库集合"""
if "online_time" not in db.list_collection_names():
db.create_collection("online_time")
@@ -51,7 +52,8 @@ class LLMStatistics:
if self.console_thread:
self.console_thread.join()
def _record_online_time(self):
@staticmethod
def _record_online_time():
"""记录在线时间"""
current_time = datetime.now()
# 检查5分钟内是否已有记录
@@ -187,7 +189,7 @@ class LLMStatistics:
# 按模型统计
output.append("按模型统计:")
output.append(("模型名称 调用次数 Token总量 累计花费"))
output.append("模型名称 调用次数 Token总量 累计花费")
for model_name, count in sorted(stats["requests_by_model"].items()):
tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name]
@@ -198,7 +200,7 @@ class LLMStatistics:
# 按请求类型统计
output.append("按请求类型统计:")
output.append(("模型名称 调用次数 Token总量 累计花费"))
output.append("模型名称 调用次数 Token总量 累计花费")
for req_type, count in sorted(stats["requests_by_type"].items()):
tokens = stats["tokens_by_type"][req_type]
cost = stats["costs_by_type"][req_type]
@@ -209,7 +211,7 @@ class LLMStatistics:
# 修正用户统计列宽
output.append("按用户统计:")
output.append(("用户ID 调用次数 Token总量 累计花费"))
output.append("用户ID 调用次数 Token总量 累计花费")
for user_id, count in sorted(stats["requests_by_user"].items()):
tokens = stats["tokens_by_user"][user_id]
cost = stats["costs_by_user"][user_id]
@@ -225,7 +227,7 @@ class LLMStatistics:
# 添加聊天统计
output.append("群组统计:")
output.append(("群组名称 消息数量"))
output.append("群组名称 消息数量")
for group_id, count in sorted(stats["messages_by_chat"].items()):
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")
@@ -246,7 +248,7 @@ class LLMStatistics:
# 按模型统计
output.append("按模型统计:")
output.append(("模型名称 调用次数 Token总量 累计花费"))
output.append("模型名称 调用次数 Token总量 累计花费")
for model_name, count in sorted(stats["requests_by_model"].items()):
tokens = stats["tokens_by_model"][model_name]
cost = stats["costs_by_model"][model_name]
@@ -284,7 +286,7 @@ class LLMStatistics:
# 添加聊天统计
output.append("群组统计:")
output.append(("群组名称 消息数量"))
output.append("群组名称 消息数量")
for group_id, count in sorted(stats["messages_by_chat"].items()):
output.append(f"{self.name_dict[group_id][0][:32]:<32} {count:>10}")

View File

@@ -90,7 +90,8 @@ class Timer:
self.auto_unit = auto_unit
self.start = None
def _validate_types(self, name, storage):
@staticmethod
def _validate_types(name, storage):
"""类型检查"""
if name is not None and not isinstance(name, str):
raise TimerTypeError("name", "Optional[str]", type(name))

View File

@@ -77,7 +77,8 @@ class ChineseTypoGenerator:
return normalized_freq
def _create_pinyin_dict(self):
@staticmethod
def _create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
@@ -95,7 +96,8 @@ class ChineseTypoGenerator:
return pinyin_dict
def _is_chinese_char(self, char):
@staticmethod
def _is_chinese_char(char):
"""
判断是否为汉字
"""
@@ -124,7 +126,8 @@ class ChineseTypoGenerator:
return result
def _get_similar_tone_pinyin(self, py):
@staticmethod
def _get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
"""
@@ -211,13 +214,15 @@ class ChineseTypoGenerator:
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def _get_word_pinyin(self, word):
@staticmethod
def _get_word_pinyin(word):
"""
获取词语的拼音列表
"""
return [py[0] for py in pinyin(word, style=Style.TONE3)]
def _segment_sentence(self, sentence):
@staticmethod
def _segment_sentence(sentence):
"""
使用jieba分词返回词语列表
"""
@@ -392,7 +397,8 @@ class ChineseTypoGenerator:
return "".join(result), correction_suggestion
def format_typo_info(self, typo_info):
@staticmethod
def format_typo_info(typo_info):
"""
格式化错别字信息

View File

@@ -13,7 +13,7 @@ llmcheck 模式:
import time
from loguru import logger
from ..models.utils_model import LLM_request
from ..models.utils_model import LLMRequest
from ...config.config import global_config
# from ..chat.chat_stream import ChatStream
@@ -61,7 +61,7 @@ def llmcheck_decorator(trigger_condition_func):
class LlmcheckWillingManager(MxpWillingManager):
def __init__(self):
super().__init__()
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.3)
self.model_v3 = LLMRequest(model=global_config.llm_normal, temperature=0.3)
async def get_llmreply_probability(self, message_id: str):
message_info = self.ongoing_messages[message_id]

View File

@@ -240,7 +240,8 @@ class MxpWillingManager(BaseWillingManager):
-2 * self.basic_maximum_willing * self.fatigue_coefficient
)
def _willing_to_probability(self, willing: float) -> float:
@staticmethod
def _willing_to_probability(willing: float) -> float:
"""意愿值转化为概率"""
willing = max(0, willing)
if willing < 2:
@@ -285,7 +286,8 @@ class MxpWillingManager(BaseWillingManager):
if self.is_debug:
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
def _get_relationship_level_num(self, relationship_value) -> int:
@staticmethod
def _get_relationship_level_num(relationship_value) -> int:
"""关系等级计算"""
if -1000 <= relationship_value < -227:
level_num = 0

View File

@@ -35,12 +35,14 @@ class KnowledgeLibrary:
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def read_file(self, file_path: str) -> str:
@staticmethod
def read_file(file_path: str) -> str:
"""读取文件内容"""
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def split_content(self, content: str, max_length: int = 512) -> list:
@staticmethod
def split_content(content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,按空行分割
Args:
@@ -146,7 +148,8 @@ class KnowledgeLibrary:
return result
def _update_stats(self, total_stats, result, filename):
@staticmethod
def _update_stats(total_stats, result, filename):
"""更新总体统计信息"""
if result["status"] == "success":
total_stats["processed_files"] += 1
@@ -181,7 +184,8 @@ class KnowledgeLibrary:
for filename in stats["skipped_files"]:
self.console.print(f"[yellow]- {filename}[/yellow]")
def calculate_file_hash(self, file_path):
@staticmethod
def calculate_file_hash(file_path):
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f: