🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-05-14 15:11:33 +00:00
parent 17d19e7cac
commit fb6094d269
17 changed files with 278 additions and 254 deletions

View File

@@ -148,20 +148,21 @@ class MaiEmoji:
# 准备数据库记录 for emoji collection
emotion_str = ",".join(self.emotion) if self.emotion else ""
Emoji.create(hash=self.hash,
full_path=self.full_path,
format=self.format,
description=self.description,
emotion=emotion_str, # Store as comma-separated string
query_count=0, # Default value
is_registered=True,
is_banned=False, # Default value
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
register_time=self.register_time,
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
Emoji.create(
hash=self.hash,
full_path=self.full_path,
format=self.format,
description=self.description,
emotion=emotion_str, # Store as comma-separated string
query_count=0, # Default value
is_registered=True,
is_banned=False, # Default value
record_time=self.register_time, # Use MaiEmoji's register_time for DB record_time
register_time=self.register_time,
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
logger.success(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
return True
@@ -197,10 +198,10 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
will_delete_emoji = Emoji.get(Emoji.hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist:
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
result = 0 # Indicate no DB record was deleted
if result > 0:
logger.info(f"[删除] 表情包数据库记录 {self.filename} (Hash: {self.hash})")
@@ -245,12 +246,14 @@ def _to_emoji_objects(data):
emoji_objects = []
load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data)
emoji_data_list = list(data)
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
full_path = emoji_data.full_path
if not full_path:
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}")
logger.warning(
f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}"
)
load_errors += 1
continue
@@ -265,9 +268,9 @@ def _to_emoji_objects(data):
emoji.description = emoji_data.description
# Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.split(',') if emoji_data.emotion else []
emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count
db_last_used_time = emoji_data.last_used_time
db_register_time = emoji_data.register_time
@@ -275,7 +278,7 @@ def _to_emoji_objects(data):
emoji.last_used_time = db_last_used_time if db_last_used_time is not None else emoji.register_time
# If register_time from DB is None, use MaiEmoji's initialized register_time (which is time.time())
emoji.register_time = db_register_time if db_register_time is not None else emoji.register_time
emoji.format = emoji_data.format
emoji_objects.append(emoji)
@@ -385,7 +388,7 @@ class EmojiManager:
# Ensure Peewee database connection is up and tables are created
if not peewee_db.is_closed():
peewee_db.connect(reuse_if_open=True)
Emoji.create_table(safe=True) # Ensures table exists
Emoji.create_table(safe=True) # Ensures table exists
_ensure_emoji_dir()
self._initialized = True
@@ -404,8 +407,8 @@ class EmojiManager:
try:
emoji_update = Emoji.get(Emoji.hash == emoji_hash)
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
except Emoji.DoesNotExist:
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e:
@@ -674,7 +677,7 @@ class EmojiManager:
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = Emoji.select()
emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)

View File

@@ -91,7 +91,6 @@ class HeartFChatting:
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager)
# --- 处理器列表 ---
self.processors: List[BaseProcessor] = []
self._register_default_processors()
@@ -526,5 +525,3 @@ class HeartFChatting:
if last_n is not None:
history = history[-last_n:]
return [cycle.to_dict() for cycle in history]

View File

@@ -7,12 +7,14 @@ from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding
import time
from typing import Union, Optional
# from common.database.database import db
from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
# import traceback
import random
import json
@@ -614,7 +616,7 @@ class PromptBuilder:
return "" if not return_raw else []
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
if query_embedding_magnitude == 0: # Avoid division by zero
if query_embedding_magnitude == 0: # Avoid division by zero
return "" if not return_raw else []
for knowledge_item in all_knowledges:
@@ -623,35 +625,35 @@ class PromptBuilder:
db_embedding = json.loads(db_embedding_str)
if len(db_embedding) != len(query_embedding):
logger.warning(f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping.")
logger.warning(
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
)
continue
# Calculate Cosine Similarity
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
if db_embedding_magnitude == 0: # Avoid division by zero
if db_embedding_magnitude == 0: # Avoid division by zero
similarity = 0.0
else:
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
if similarity >= threshold:
results_with_similarity.append({
"content": knowledge_item.content,
"similarity": similarity
})
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
except json.JSONDecodeError:
logger.error(f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}")
logger.error(
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
)
except Exception as e:
logger.error(f"Error processing knowledge item: {e}")
# Sort by similarity in descending order
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
# Limit results
limited_results = results_with_similarity[:limit]
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
if not limited_results:

View File

@@ -27,20 +27,19 @@ class ActionManager:
self._using_actions: Dict[str, ActionInfo] = {}
# 临时备份原始使用中的动作
self._original_actions_backup: Optional[Dict[str, ActionInfo]] = None
# 默认动作集,仅作为快照,用于恢复默认
self._default_actions: Dict[str, ActionInfo] = {}
# 加载所有已注册动作
self._load_registered_actions()
# 初始化时将默认动作加载到使用中的动作
self._using_actions = self._default_actions.copy()
# logger.info(f"当前可用动作: {list(self._using_actions.keys())}")
# for action_name, action_info in self._using_actions.items():
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
def _load_registered_actions(self) -> None:
"""
@@ -50,35 +49,35 @@ class ActionManager:
# 从_ACTION_REGISTRY获取所有已注册动作
for action_name, action_class in _ACTION_REGISTRY.items():
# 获取动作相关信息
action_description:str = getattr(action_class, "action_description", "")
action_parameters:dict[str:str] = getattr(action_class, "action_parameters", {})
action_require:list[str] = getattr(action_class, "action_require", [])
is_default:bool = getattr(action_class, "default", False)
action_description: str = getattr(action_class, "action_description", "")
action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {})
action_require: list[str] = getattr(action_class, "action_require", [])
is_default: bool = getattr(action_class, "default", False)
if action_name and action_description:
# 创建动作信息字典
action_info = {
"description": action_description,
"parameters": action_parameters,
"require": action_require
"require": action_require,
}
# 注册2
print("注册2")
print(action_info)
# 添加到所有已注册的动作
self._registered_actions[action_name] = action_info
# 添加到默认动作(如果是默认动作)
if is_default:
self._default_actions[action_name] = action_info
logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
logger.info(f"默认动作: {list(self._default_actions.keys())}")
# for action_name, action_info in self._default_actions.items():
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
except Exception as e:
logger.error(f"加载已注册动作失败: {e}")
@@ -125,7 +124,7 @@ class ActionManager:
if action_name not in self._using_actions:
logger.warning(f"当前不可用的动作类型: {action_name}")
return None
handler_class = _ACTION_REGISTRY.get(action_name)
if not handler_class:
logger.warning(f"未注册的动作类型: {action_name}")
@@ -149,7 +148,7 @@ class ActionManager:
expressor=expressor,
chat_stream=chat_stream,
)
return instance
except Exception as e:
@@ -163,7 +162,7 @@ class ActionManager:
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
def get_using_actions(self) -> Dict[str, ActionInfo]:
"""获取当前正在使用的动作集"""
return self._using_actions.copy()
@@ -171,21 +170,21 @@ class ActionManager:
def add_action_to_using(self, action_name: str) -> bool:
"""
添加已注册的动作到当前使用的动作集
Args:
action_name: 动作名称
Returns:
bool: 添加是否成功
"""
if action_name not in self._registered_actions:
logger.warning(f"添加失败: 动作 {action_name} 未注册")
return False
if action_name in self._using_actions:
logger.info(f"动作 {action_name} 已经在使用中")
return True
self._using_actions[action_name] = self._registered_actions[action_name]
logger.info(f"添加动作 {action_name} 到使用集")
return True
@@ -193,17 +192,17 @@ class ActionManager:
def remove_action_from_using(self, action_name: str) -> bool:
"""
从当前使用的动作集中移除指定动作
Args:
action_name: 动作名称
Returns:
bool: 移除是否成功
"""
if action_name not in self._using_actions:
logger.warning(f"移除失败: 动作 {action_name} 不在当前使用的动作集中")
return False
del self._using_actions[action_name]
logger.info(f"已从使用集中移除动作 {action_name}")
return True
@@ -211,30 +210,26 @@ class ActionManager:
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
"""
添加新的动作到注册集
Args:
action_name: 动作名称
description: 动作描述
parameters: 动作参数定义,默认为空字典
require: 动作依赖项,默认为空列表
Returns:
bool: 添加是否成功
"""
if action_name in self._registered_actions:
return False
if parameters is None:
parameters = {}
if require is None:
require = []
action_info = {
"description": description,
"parameters": parameters,
"require": require
}
action_info = {"description": description, "parameters": parameters, "require": require}
self._registered_actions[action_name] = action_info
return True
@@ -260,7 +255,7 @@ class ActionManager:
if self._original_actions_backup is not None:
self._using_actions = self._original_actions_backup.copy()
self._original_actions_backup = None
def restore_default_actions(self) -> None:
"""恢复默认动作集到使用集"""
self._using_actions = self._default_actions.copy()
@@ -269,10 +264,10 @@ class ActionManager:
def get_action(self, action_name: str) -> Optional[Type[BaseAction]]:
"""
获取指定动作的处理器类
Args:
action_name: 动作名称
Returns:
Optional[Type[BaseAction]]: 动作处理器类如果不存在则返回None
"""

View File

@@ -12,7 +12,7 @@ _DEFAULT_ACTIONS: Dict[str, str] = {}
def register_action(cls):
"""
动作注册装饰器
用法:
@register_action
class MyAction(BaseAction):
@@ -24,22 +24,22 @@ def register_action(cls):
if not hasattr(cls, "action_name") or not hasattr(cls, "action_description"):
logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description")
return cls
action_name = getattr(cls, "action_name") #noqa
action_description = getattr(cls, "action_description") #noqa
action_name = getattr(cls, "action_name") # noqa
action_description = getattr(cls, "action_description") # noqa
is_default = getattr(cls, "default", False)
if not action_name or not action_description:
logger.error(f"动作类 {cls.__name__} 的 action_name 或 action_description 为空")
return cls
# 将动作类注册到全局注册表
_ACTION_REGISTRY[action_name] = cls
# 如果是默认动作,添加到默认动作集
if is_default:
_DEFAULT_ACTIONS[action_name] = action_description
logger.info(f"已注册动作: {action_name} -> {cls.__name__},默认: {is_default}")
return cls
@@ -60,15 +60,14 @@ class BaseAction(ABC):
cycle_timers: 计时器字典
thinking_id: 思考ID
"""
#每个动作必须实现
self.action_name:str = "base_action"
self.action_description:str = "基础动作"
self.action_parameters:dict = {}
self.action_require:list[str] = []
self.default:bool = False
# 每个动作必须实现
self.action_name: str = "base_action"
self.action_description: str = "基础动作"
self.action_parameters: dict = {}
self.action_require: list[str] = []
self.default: bool = False
self.action_data = action_data
self.reasoning = reasoning
self.cycle_timers = cycle_timers

View File

@@ -29,7 +29,7 @@ class NoReplyAction(BaseAction):
action_require = [
"话题无关/无聊/不感兴趣/不懂",
"最后一条消息是你自己发的且无人回应你",
"你发送了太多消息,且无人回复"
"你发送了太多消息,且无人回复",
]
default = True
@@ -46,7 +46,7 @@ class NoReplyAction(BaseAction):
total_no_reply_count: int = 0,
total_waiting_time: float = 0.0,
shutting_down: bool = False,
**kwargs
**kwargs,
):
"""初始化不回复动作处理器

View File

@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from src.common.logger_manager import get_logger
# from src.chat.utils.timer_calculator import Timer
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
from typing import Tuple, List
@@ -22,14 +23,14 @@ class ReplyAction(BaseAction):
处理构建和发送消息回复的动作。
"""
action_name:str = "reply"
action_description:str = "表达想法,可以只包含文本、表情或两者都有"
action_parameters:dict[str:str] = {
action_name: str = "reply"
action_description: str = "表达想法,可以只包含文本、表情或两者都有"
action_parameters: dict[str:str] = {
"text": "你想要表达的内容(可选)",
"emojis": "描述当前使用表情包的场景(可选)",
"target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)",
}
action_require:list[str] = [
action_require: list[str] = [
"有实质性内容需要表达",
"有人提到你,但你还没有回应他",
"在合适的时候添加表情(不要总是添加)",
@@ -38,7 +39,7 @@ class ReplyAction(BaseAction):
"一次只回复一个人,一次只回复一个话题,突出重点",
"如果是自己发的消息想继续,需自然衔接",
"避免重复或评价自己的发言,不要和自己聊天",
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。"
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。",
]
default = True
@@ -54,7 +55,7 @@ class ReplyAction(BaseAction):
chat_stream: ChatStream,
current_cycle: CycleDetail,
log_prefix: str,
**kwargs
**kwargs,
):
"""初始化回复动作处理器
@@ -89,9 +90,9 @@ class ReplyAction(BaseAction):
reasoning=self.reasoning,
reply_data=self.action_data,
cycle_timers=self.cycle_timers,
thinking_id=self.thinking_id
thinking_id=self.thinking_id,
)
async def _handle_reply(
self, reasoning: str, reply_data: dict, cycle_timers: dict, thinking_id: str
) -> tuple[bool, str]:

View File

@@ -4,6 +4,7 @@ from typing import List, Dict, Any, Optional
from rich.traceback import install
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
# from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info.obs_info import ObsInfo
@@ -15,10 +16,12 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.individuality.individuality import Individuality
from src.chat.focus_chat.planners.action_factory import ActionManager
from src.chat.focus_chat.planners.action_factory import ActionInfo
logger = get_logger("planner")
install(extra_lines=3)
def init_prompt():
Prompt(
"""你的名字是{bot_name},{prompt_personality}{chat_context_description}。需要基于以下信息决定如何参与对话:
@@ -44,8 +47,9 @@ def init_prompt():
}}
请输出你的决策 JSON""",
"planner_prompt",)
"planner_prompt",
)
Prompt(
"""
action_name: {action_name}
@@ -57,7 +61,7 @@ action_name: {action_name}
""",
"action_prompt",
)
class ActionPlanner:
def __init__(self, log_prefix: str, action_manager: ActionManager):
@@ -68,7 +72,7 @@ class ActionPlanner:
max_tokens=1000,
request_type="action_planning", # 用于动作规划
)
self.action_manager = action_manager
async def plan(self, all_plan_info: List[InfoBase], cycle_timers: dict) -> Dict[str, Any]:
@@ -106,7 +110,7 @@ class ActionPlanner:
_structured_info = info.get_data()
current_available_actions = self.action_manager.get_using_actions()
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
prompt = await self.build_planner_prompt(
is_group_chat=is_group_chat, # <-- Pass HFC state
@@ -197,7 +201,6 @@ class ActionPlanner:
# 返回结果字典
return plan_result
async def build_planner_prompt(
self,
is_group_chat: bool, # Now passed as argument
@@ -218,7 +221,6 @@ class ActionPlanner:
)
chat_context_description = f"你正在和 {chat_target_name} 私聊"
chat_content_block = ""
if observed_messages_str:
chat_content_block = f"聊天记录:\n{observed_messages_str}"
@@ -234,7 +236,6 @@ class ActionPlanner:
individuality = Individuality.get_instance()
personality_block = individuality.get_prompt(x_person=2, level=2)
action_options_block = ""
for using_actions_name, using_actions_info in current_available_actions.items():
# print(using_actions_name)
@@ -242,29 +243,26 @@ class ActionPlanner:
# print(using_actions_info["parameters"])
# print(using_actions_info["require"])
# print(using_actions_info["description"])
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
param_text = ""
for param_name, param_description in using_actions_info["parameters"].items():
param_text += f"{param_name}: {param_description}\n"
require_text = ""
for require_item in using_actions_info["require"]:
require_text += f"- {require_item}\n"
using_action_prompt = using_action_prompt.format(
action_name=using_actions_name,
action_description=using_actions_info["description"],
action_parameters=param_text,
action_require=require_text,
)
action_options_block += using_action_prompt
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
bot_name=global_config.BOT_NICKNAME,

View File

@@ -261,7 +261,9 @@ class PersonInfoManager:
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason}"
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
qv_name_prompt += "\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改"
qv_name_prompt += (
"\n请根据以上用户信息想想你叫他什么比较好不要太浮夸请最好使用用户的qq昵称可以稍作修改"
)
if existing_names_str:
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}\n"
@@ -289,6 +291,7 @@ class PersonInfoManager:
if generated_nickname in current_name_set:
is_duplicate = True
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
@@ -415,7 +418,9 @@ class PersonInfoManager:
@staticmethod
async def del_all_undefined_field():
"""删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
logger.info("del_all_undefined_field: 对于使用Peewee的SQL数据库此操作通常不适用或不需要因为表结构是预定义的。")
logger.info(
"del_all_undefined_field: 对于使用Peewee的SQL数据库此操作通常不适用或不需要因为表结构是预定义的。"
)
return
@staticmethod
@@ -512,7 +517,9 @@ class PersonInfoManager:
if trimmed_interval:
msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
await self.update_one_field(person_id, "msg_interval", msg_interval_val)
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}")
logger.trace(
f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
)
else:
logger.trace(f"用户{person_id}截断后数据为空无法计算msg_interval")
else:
@@ -577,13 +584,17 @@ class PersonInfoManager:
break
if not found_person_id:
def _db_find_by_name_sync(p_name_to_find: str):
return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
if record:
found_person_id = record.person_id
if found_person_id not in self.person_name_list or self.person_name_list[found_person_id] != person_name:
if (
found_person_id not in self.person_name_list
or self.person_name_list[found_person_id] != person_name
):
self.person_name_list[found_person_id] = person_name
else:
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
@@ -600,7 +611,9 @@ class PersonInfoManager:
"person_name",
"name_reason",
]
valid_fields_to_get = [f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default]
valid_fields_to_get = [
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
]
person_data = await self.get_values(found_person_id, valid_fields_to_get)

View File

@@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
def reply_replacer(match):
# aaa = match.group(1)
bbb = match.group(2)
anon_reply = get_anon_name(platform, bbb) #noqa
anon_reply = get_anon_name(platform, bbb) # noqa
return f"回复 {anon_reply}"
content = re.sub(reply_pattern, reply_replacer, content, count=1)
@@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
def at_replacer(match):
# aaa = match.group(1)
bbb = match.group(2)
anon_at = get_anon_name(platform, bbb) #noqa
anon_at = get_anon_name(platform, bbb) # noqa
return f"@{anon_at}"
content = re.sub(at_pattern, at_replacer, content)

View File

@@ -103,11 +103,11 @@ class InfoCatcher:
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
messages_between_query = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.time > time_start) &
(Messages.time < time_end)
).order_by(Messages.time.desc())
messages_between_query = (
Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
.order_by(Messages.time.desc())
)
result = list(messages_between_query)
print(f"查询结果数量: {len(result)}")
@@ -124,10 +124,12 @@ class InfoCatcher:
message_id_val = message.message_info.message_id
chat_id_val = message.chat_stream.stream_id
messages_before_query = Messages.select().where(
(Messages.chat_id == chat_id_val) &
(Messages.message_id < message_id_val)
).order_by(Messages.time.desc()).limit(self.context_length * 3)
messages_before_query = (
Messages.select()
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
.order_by(Messages.time.desc())
.limit(self.context_length * 3)
)
return list(messages_before_query)
@@ -137,7 +139,7 @@ class InfoCatcher:
processed_msg_item = msg_item
if not isinstance(msg_item, dict):
processed_msg_item = self.message_to_dict(msg_item)
if not processed_msg_item:
continue
@@ -163,15 +165,15 @@ class InfoCatcher:
"user_nickname": msg_obj.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
if hasattr(msg_obj, 'message_info') and hasattr(msg_obj.message_info, 'user_info'):
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
return {
"time": msg_obj.message_info.time,
"user_id": msg_obj.message_info.user_info.user_id,
"user_nickname": msg_obj.message_info.user_info.user_nickname,
"processed_plain_text": msg_obj.processed_plain_text,
}
print(f"Warning: message_to_dict received an unhandled type: {type(msg_obj)}")
return {}
@@ -198,7 +200,7 @@ class InfoCatcher:
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
heartflow_data_json=json.dumps(self.heartflow_data),
reasoning_data_json=json.dumps(self.reasoning_data)
reasoning_data_json=json.dumps(self.reasoning_data),
)
log_entry.save()

View File

@@ -5,8 +5,8 @@ from typing import Any, Dict, Tuple, List
from src.common.logger import get_module_logger
from src.manager.async_task_manager import AsyncTask
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_module_logger("maibot_statistic")
@@ -48,8 +48,8 @@ class OnlineTimeRecordTask(AsyncTask):
@staticmethod
def _init_database():
"""初始化数据库"""
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self):
try:
@@ -62,14 +62,17 @@ class OnlineTimeRecordTask(AsyncTask):
updated_rows = query.execute()
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
self.record_id = None # Reset record_id to trigger find/create logic below
if not self.record_id: # Check again if record_id was reset or initially None
self.record_id = None # Reset record_id to trigger find/create logic below
if not self.record_id: # Check again if record_id was reset or initially None
# 如果没有记录,检查一分钟以内是否已有记录
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = OnlineTime.select().where(
OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))
).order_by(OnlineTime.end_timestamp.desc()).first()
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
.order_by(OnlineTime.end_timestamp.desc())
.first()
)
if recent_record:
# 如果有记录,则更新结束时间
@@ -87,7 +90,6 @@ class OnlineTimeRecordTask(AsyncTask):
logger.error(f"在线时间记录失败,错误信息:{e}")
def _format_online_time(online_seconds: int) -> str:
"""
格式化在线时间
@@ -197,7 +199,7 @@ class StatisticOutputTask(AsyncTask):
"""
if not collect_period:
return {}
# 排序-按照时间段开始时间降序排列(最晚的时间段在前)
collect_period.sort(key=lambda x: x[1], reverse=True)
@@ -228,14 +230,14 @@ class StatisticOutputTask(AsyncTask):
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
record_timestamp = record.timestamp # This is already a datetime object
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REQ_CNT] += 1
request_type = record.request_type or "unknown"
user_id = record.user_id or "unknown" # user_id is TextField, already string
user_id = record.user_id or "unknown" # user_id is TextField, already string
model_name = record.model_name or "unknown"
stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1
@@ -275,7 +277,7 @@ class StatisticOutputTask(AsyncTask):
"""
if not collect_period:
return {}
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
@@ -300,7 +302,7 @@ class StatisticOutputTask(AsyncTask):
for period_key, current_period_start_time in collect_period[idx:]:
# Determine the portion of the record that falls within this specific statistical period
overlap_start = max(record_start_timestamp, current_period_start_time)
overlap_end = effective_end_time # Already capped by 'now' and record's own end
overlap_end = effective_end_time # Already capped by 'now' and record's own end
if overlap_end > overlap_start:
stats[period_key][ONLINE_TIME] += (overlap_end - overlap_start).total_seconds()
@@ -315,7 +317,7 @@ class StatisticOutputTask(AsyncTask):
"""
if not collect_period:
return {}
collect_period.sort(key=lambda x: x[1], reverse=True)
stats = {
@@ -326,9 +328,9 @@ class StatisticOutputTask(AsyncTask):
for period_key, _ in collect_period
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp):
message_time_ts = message.time # This is a float timestamp
message_time_ts = message.time # This is a float timestamp
chat_id = None
chat_name = None
@@ -337,16 +339,18 @@ class StatisticOutputTask(AsyncTask):
if message.chat_info_group_id:
chat_id = f"g{message.chat_info_group_id}"
chat_name = message.chat_info_group_name or f"{message.chat_info_group_id}"
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
# This uses the message SENDER's ID as per original logic's fallback
chat_id = f"u{message.user_id}" # SENDER's user_id
chat_name = message.user_nickname # SENDER's nickname
else:
# If neither group_id nor sender_id is available for chat identification
logger.warning(f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats.")
logger.warning(
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
)
continue
if not chat_id: # Should not happen if above logic is correct
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping

View File

@@ -35,13 +35,13 @@ class ImageManager:
if not self._initialized:
self._ensure_image_dir()
self._llm = LLMRequest(model=global_config.vlm, temperature=0.4, max_tokens=300, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
self._initialized = True
def _ensure_image_dir(self):
@@ -61,8 +61,7 @@ class ImageManager:
"""
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.hash == image_hash) &
(ImageDescriptions.type == description_type)
(ImageDescriptions.hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
except Exception as e:
@@ -80,14 +79,9 @@ class ImageManager:
"""
try:
current_timestamp = time.time()
defaults = {
'description': description,
'timestamp': current_timestamp
}
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
hash=image_hash,
type=description_type,
defaults=defaults
hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
@@ -120,7 +114,7 @@ class ImageManager:
else:
prompt = "这是一个表情包,请用使用几个词描述一下表情包所表达的情感和内容,简短一些"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成表情包描述")
return "[表情包(描述生成失败)]"
@@ -191,7 +185,7 @@ class ImageManager:
"请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多100个字。"
)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片(描述生成失败)]"

View File

@@ -14,18 +14,21 @@ import datetime
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=3306)
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
# 这允许您在一个地方为所有模型指定数据库。
class BaseModel(Model):
class Meta:
# 将下面的 'db' 替换为您实际的数据库实例变量名。
database = db # 例如: database = my_actual_db_instance
pass # 在用户定义数据库实例之前,此处为占位符
pass # 在用户定义数据库实例之前,此处为占位符
class ChatStreams(BaseModel):
"""
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
"""
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
stream_id = TextField(unique=True, index=True)
@@ -63,28 +66,31 @@ class ChatStreams(BaseModel):
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
# database = db
table_name = 'chat_streams' # 可选:明确指定数据库中的表名
# database = db
table_name = "chat_streams" # 可选:明确指定数据库中的表名
class LLMUsage(BaseModel):
"""
用于存储 API 使用日志数据的模型。
"""
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
endpoint = TextField()
prompt_tokens = IntegerField()
completion_tokens = IntegerField()
total_tokens = IntegerField()
cost = DoubleField()
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = 'llm_usage'
# database = db
table_name = "llm_usage"
class Emoji(BaseModel):
"""表情包"""
@@ -105,16 +111,18 @@ class Emoji(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'emoji'
table_name = "emoji"
class Messages(BaseModel):
"""
用于存储消息数据的模型。
"""
message_id = IntegerField(index=True) # 消息 ID
time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
message_id = IntegerField(index=True) # 消息 ID
time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = TextField()
@@ -123,7 +131,7 @@ class Messages(BaseModel):
chat_info_user_id = TextField()
chat_info_user_nickname = TextField()
chat_info_user_cardname = TextField(null=True)
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_id = TextField(null=True)
chat_info_group_name = TextField(null=True)
chat_info_create_time = DoubleField()
@@ -135,18 +143,20 @@ class Messages(BaseModel):
user_nickname = TextField()
user_cardname = TextField(null=True)
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数
class Meta:
# database = db # 继承自 BaseModel
table_name = 'messages'
table_name = "messages"
class Images(BaseModel):
"""
用于存储图像信息的模型。
"""
hash = TextField(index=True) # 图像的哈希值
description = TextField(null=True) # 图像的描述
path = TextField(unique=True) # 图像文件的路径
@@ -155,12 +165,14 @@ class Images(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'images'
table_name = "images"
class ImageDescriptions(BaseModel):
"""
用于存储图像描述信息的模型。
"""
type = TextField() # 类型,例如 "emoji"
hash = TextField(index=True) # 图像的哈希值
description = TextField() # 图像的描述
@@ -168,12 +180,14 @@ class ImageDescriptions(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'image_descriptions'
table_name = "image_descriptions"
class OnlineTime(BaseModel):
"""
用于存储在线时长记录的模型。
"""
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = TextField()
duration = IntegerField() # 时长,单位分钟
@@ -182,12 +196,14 @@ class OnlineTime(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'online_time'
table_name = "online_time"
class PersonInfo(BaseModel):
"""
用于存储个人信息数据的模型。
"""
person_id = TextField(unique=True, index=True) # 个人唯一ID
person_name = TextField() # 个人名称
name_reason = TextField(null=True) # 名称设定的原因
@@ -202,26 +218,28 @@ class PersonInfo(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'person_info'
table_name = "person_info"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = 'knowledges'
table_name = "knowledges"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
@@ -235,28 +253,32 @@ class ThinkingLog(BaseModel):
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = 'thinking_logs'
table_name = "thinking_logs"
def create_tables():
"""
创建所有在模型中定义的数据库表。
"""
with db:
db.create_tables([
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog
])
db.create_tables(
[
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
]
)
def initialize_database():
"""
@@ -272,9 +294,9 @@ def initialize_database():
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog
ThinkingLog,
]
needs_creation = False
try:
with db: # 管理 table_exists 检查的连接
@@ -298,5 +320,6 @@ def initialize_database():
else:
print("所有数据库表均已存在。")
# 模块加载时调用初始化函数
initialize_database()

View File

@@ -1,4 +1,4 @@
from src.common.database.database_model import Messages # 更改导入
from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_module_logger
import traceback
from typing import List, Any, Optional
@@ -42,9 +42,7 @@ def find_messages(
if hasattr(Messages, key):
conditions.append(getattr(Messages, key) == value)
else:
logger.warning(
f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
)
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
# 使用 *conditions 将所有条件以 AND 连接
query = query.where(*conditions)
@@ -59,9 +57,7 @@ def find_messages(
query = query.order_by(Messages.time.desc()).limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
peewee_results = sorted(
latest_results_peewee, key=lambda msg: msg.time
)
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
@@ -74,13 +70,9 @@ def find_messages(
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
else:
logger.warning(
f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。"
)
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(
f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。"
)
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
@@ -116,9 +108,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if hasattr(Messages, key):
conditions.append(getattr(Messages, key) == value)
else:
logger.warning(
f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
)
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
# from src.common.database.database import db # Peewee db 导入
from src.common.database.database_model import Messages # Peewee Messages 模型导入
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
@@ -53,20 +54,23 @@ class PeeweeMessageStorage(MessageStorage):
"""Peewee消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.time > message_time)
).order_by(Messages.time.asc())
query = (
Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time > message_time))
.order_by(Messages.time.asc())
)
# print(f"storage_check_message: {message_time}")
messages_models = list(query)
return [model_to_dict(msg) for msg in messages_models]
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.time < time_point)
).order_by(Messages.time.desc()).limit(limit)
query = (
Messages.select()
.where((Messages.chat_id == chat_id) & (Messages.time < time_point))
.order_by(Messages.time.desc())
.limit(limit)
)
messages_models = list(query)
# 将消息按时间正序排列
@@ -74,10 +78,7 @@ class PeeweeMessageStorage(MessageStorage):
return [model_to_dict(msg) for msg in messages_models]
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
return Messages.select().where(
(Messages.chat_id == chat_id) &
(Messages.time > after_time)
).exists()
return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists()
# # 创建一个内存消息存储实现,用于测试

View File

@@ -89,7 +89,9 @@ class SearchKnowledgeTool(BaseTool):
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
continue
item_embedding = json.loads(item_embedding_str)
if not isinstance(item_embedding, list) or not all(isinstance(x, (int, float)) for x in item_embedding):
if not isinstance(item_embedding, list) or not all(
isinstance(x, (int, float)) for x in item_embedding
):
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
continue
except json.JSONDecodeError: