🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -148,7 +148,8 @@ class MaiEmoji:
|
||||
# 准备数据库记录 for emoji collection
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||
|
||||
Emoji.create(hash=self.hash,
|
||||
Emoji.create(
|
||||
hash=self.hash,
|
||||
full_path=self.full_path,
|
||||
format=self.format,
|
||||
description=self.description,
|
||||
@@ -250,7 +251,9 @@ def _to_emoji_objects(data):
|
||||
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
|
||||
full_path = emoji_data.full_path
|
||||
if not full_path:
|
||||
logger.warning(f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}")
|
||||
logger.warning(
|
||||
f"[加载错误] 数据库记录缺少 'full_path' 字段: ID {emoji_data.id if hasattr(emoji_data, 'id') else 'Unknown'}"
|
||||
)
|
||||
load_errors += 1
|
||||
continue
|
||||
|
||||
@@ -265,7 +268,7 @@ def _to_emoji_objects(data):
|
||||
|
||||
emoji.description = emoji_data.description
|
||||
# Deserialize emotion string from DB to list
|
||||
emoji.emotion = emoji_data.emotion.split(',') if emoji_data.emotion else []
|
||||
emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
|
||||
emoji.usage_count = emoji_data.usage_count
|
||||
|
||||
db_last_used_time = emoji_data.last_used_time
|
||||
|
||||
@@ -91,7 +91,6 @@ class HeartFChatting:
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager)
|
||||
|
||||
|
||||
# --- 处理器列表 ---
|
||||
self.processors: List[BaseProcessor] = []
|
||||
self._register_default_processors()
|
||||
@@ -526,5 +525,3 @@ class HeartFChatting:
|
||||
if last_n is not None:
|
||||
history = history[-last_n:]
|
||||
return [cycle.to_dict() for cycle in history]
|
||||
|
||||
|
||||
|
||||
@@ -7,12 +7,14 @@ from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.chat.utils.utils import get_embedding
|
||||
import time
|
||||
from typing import Union, Optional
|
||||
|
||||
# from common.database.database import db
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||
|
||||
# import traceback
|
||||
import random
|
||||
import json
|
||||
@@ -623,7 +625,9 @@ class PromptBuilder:
|
||||
db_embedding = json.loads(db_embedding_str)
|
||||
|
||||
if len(db_embedding) != len(query_embedding):
|
||||
logger.warning(f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping.")
|
||||
logger.warning(
|
||||
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate Cosine Similarity
|
||||
@@ -636,16 +640,14 @@ class PromptBuilder:
|
||||
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
|
||||
|
||||
if similarity >= threshold:
|
||||
results_with_similarity.append({
|
||||
"content": knowledge_item.content,
|
||||
"similarity": similarity
|
||||
})
|
||||
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}")
|
||||
logger.error(
|
||||
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing knowledge item: {e}")
|
||||
|
||||
|
||||
# Sort by similarity in descending order
|
||||
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ class ActionManager:
|
||||
# for action_name, action_info in self._using_actions.items():
|
||||
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
|
||||
|
||||
|
||||
def _load_registered_actions(self) -> None:
|
||||
"""
|
||||
加载所有通过装饰器注册的动作
|
||||
@@ -50,17 +49,17 @@ class ActionManager:
|
||||
# 从_ACTION_REGISTRY获取所有已注册动作
|
||||
for action_name, action_class in _ACTION_REGISTRY.items():
|
||||
# 获取动作相关信息
|
||||
action_description:str = getattr(action_class, "action_description", "")
|
||||
action_parameters:dict[str:str] = getattr(action_class, "action_parameters", {})
|
||||
action_require:list[str] = getattr(action_class, "action_require", [])
|
||||
is_default:bool = getattr(action_class, "default", False)
|
||||
action_description: str = getattr(action_class, "action_description", "")
|
||||
action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {})
|
||||
action_require: list[str] = getattr(action_class, "action_require", [])
|
||||
is_default: bool = getattr(action_class, "default", False)
|
||||
|
||||
if action_name and action_description:
|
||||
# 创建动作信息字典
|
||||
action_info = {
|
||||
"description": action_description,
|
||||
"parameters": action_parameters,
|
||||
"require": action_require
|
||||
"require": action_require,
|
||||
}
|
||||
|
||||
# 注册2
|
||||
@@ -229,11 +228,7 @@ class ActionManager:
|
||||
if require is None:
|
||||
require = []
|
||||
|
||||
action_info = {
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
"require": require
|
||||
}
|
||||
action_info = {"description": description, "parameters": parameters, "require": require}
|
||||
|
||||
self._registered_actions[action_name] = action_info
|
||||
return True
|
||||
|
||||
@@ -25,8 +25,8 @@ def register_action(cls):
|
||||
logger.error(f"动作类 {cls.__name__} 缺少必要的属性: action_name 或 action_description")
|
||||
return cls
|
||||
|
||||
action_name = getattr(cls, "action_name") #noqa
|
||||
action_description = getattr(cls, "action_description") #noqa
|
||||
action_name = getattr(cls, "action_name") # noqa
|
||||
action_description = getattr(cls, "action_description") # noqa
|
||||
is_default = getattr(cls, "default", False)
|
||||
|
||||
if not action_name or not action_description:
|
||||
@@ -60,14 +60,13 @@ class BaseAction(ABC):
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
"""
|
||||
#每个动作必须实现
|
||||
self.action_name:str = "base_action"
|
||||
self.action_description:str = "基础动作"
|
||||
self.action_parameters:dict = {}
|
||||
self.action_require:list[str] = []
|
||||
|
||||
self.default:bool = False
|
||||
# 每个动作必须实现
|
||||
self.action_name: str = "base_action"
|
||||
self.action_description: str = "基础动作"
|
||||
self.action_parameters: dict = {}
|
||||
self.action_require: list[str] = []
|
||||
|
||||
self.default: bool = False
|
||||
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
|
||||
@@ -29,7 +29,7 @@ class NoReplyAction(BaseAction):
|
||||
action_require = [
|
||||
"话题无关/无聊/不感兴趣/不懂",
|
||||
"最后一条消息是你自己发的且无人回应你",
|
||||
"你发送了太多消息,且无人回复"
|
||||
"你发送了太多消息,且无人回复",
|
||||
]
|
||||
default = True
|
||||
|
||||
@@ -46,7 +46,7 @@ class NoReplyAction(BaseAction):
|
||||
total_no_reply_count: int = 0,
|
||||
total_waiting_time: float = 0.0,
|
||||
shutting_down: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化不回复动作处理器
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
# from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
||||
from typing import Tuple, List
|
||||
@@ -22,14 +23,14 @@ class ReplyAction(BaseAction):
|
||||
处理构建和发送消息回复的动作。
|
||||
"""
|
||||
|
||||
action_name:str = "reply"
|
||||
action_description:str = "表达想法,可以只包含文本、表情或两者都有"
|
||||
action_parameters:dict[str:str] = {
|
||||
action_name: str = "reply"
|
||||
action_description: str = "表达想法,可以只包含文本、表情或两者都有"
|
||||
action_parameters: dict[str:str] = {
|
||||
"text": "你想要表达的内容(可选)",
|
||||
"emojis": "描述当前使用表情包的场景(可选)",
|
||||
"target": "你想要回复的原始文本内容(非必须,仅文本,不包含发送者)(可选)",
|
||||
}
|
||||
action_require:list[str] = [
|
||||
action_require: list[str] = [
|
||||
"有实质性内容需要表达",
|
||||
"有人提到你,但你还没有回应他",
|
||||
"在合适的时候添加表情(不要总是添加)",
|
||||
@@ -38,7 +39,7 @@ class ReplyAction(BaseAction):
|
||||
"一次只回复一个人,一次只回复一个话题,突出重点",
|
||||
"如果是自己发的消息想继续,需自然衔接",
|
||||
"避免重复或评价自己的发言,不要和自己聊天",
|
||||
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。"
|
||||
"注意:回复尽量简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。",
|
||||
]
|
||||
default = True
|
||||
|
||||
@@ -54,7 +55,7 @@ class ReplyAction(BaseAction):
|
||||
chat_stream: ChatStream,
|
||||
current_cycle: CycleDetail,
|
||||
log_prefix: str,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化回复动作处理器
|
||||
|
||||
@@ -89,7 +90,7 @@ class ReplyAction(BaseAction):
|
||||
reasoning=self.reasoning,
|
||||
reply_data=self.action_data,
|
||||
cycle_timers=self.cycle_timers,
|
||||
thinking_id=self.thinking_id
|
||||
thinking_id=self.thinking_id,
|
||||
)
|
||||
|
||||
async def _handle_reply(
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Dict, Any, Optional
|
||||
from rich.traceback import install
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
|
||||
# from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
@@ -15,10 +16,12 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.individuality.individuality import Individuality
|
||||
from src.chat.focus_chat.planners.action_factory import ActionManager
|
||||
from src.chat.focus_chat.planners.action_factory import ActionInfo
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""你的名字是{bot_name},{prompt_personality},{chat_context_description}。需要基于以下信息决定如何参与对话:
|
||||
@@ -44,7 +47,8 @@ def init_prompt():
|
||||
}}
|
||||
|
||||
请输出你的决策 JSON:""",
|
||||
"planner_prompt",)
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
@@ -197,7 +201,6 @@ class ActionPlanner:
|
||||
# 返回结果字典
|
||||
return plan_result
|
||||
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
is_group_chat: bool, # Now passed as argument
|
||||
@@ -218,7 +221,6 @@ class ActionPlanner:
|
||||
)
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
|
||||
chat_content_block = ""
|
||||
if observed_messages_str:
|
||||
chat_content_block = f"聊天记录:\n{observed_messages_str}"
|
||||
@@ -234,7 +236,6 @@ class ActionPlanner:
|
||||
individuality = Individuality.get_instance()
|
||||
personality_block = individuality.get_prompt(x_person=2, level=2)
|
||||
|
||||
|
||||
action_options_block = ""
|
||||
for using_actions_name, using_actions_info in current_available_actions.items():
|
||||
# print(using_actions_name)
|
||||
@@ -262,9 +263,6 @@ class ActionPlanner:
|
||||
|
||||
action_options_block += using_action_prompt
|
||||
|
||||
|
||||
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
bot_name=global_config.BOT_NICKNAME,
|
||||
|
||||
@@ -261,7 +261,9 @@ class PersonInfoManager:
|
||||
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason},"
|
||||
|
||||
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸"
|
||||
qv_name_prompt += "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改"
|
||||
qv_name_prompt += (
|
||||
"\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称,可以稍作修改"
|
||||
)
|
||||
|
||||
if existing_names_str:
|
||||
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n"
|
||||
@@ -289,6 +291,7 @@ class PersonInfoManager:
|
||||
if generated_nickname in current_name_set:
|
||||
is_duplicate = True
|
||||
else:
|
||||
|
||||
def _db_check_name_exists_sync(name_to_check):
|
||||
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
|
||||
|
||||
@@ -415,7 +418,9 @@ class PersonInfoManager:
|
||||
@staticmethod
|
||||
async def del_all_undefined_field():
|
||||
"""删除所有项里的未定义字段 - 对于Peewee (SQL),此操作通常不适用,因为模式是固定的。"""
|
||||
logger.info("del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。")
|
||||
logger.info(
|
||||
"del_all_undefined_field: 对于使用Peewee的SQL数据库,此操作通常不适用或不需要,因为表结构是预定义的。"
|
||||
)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@@ -512,7 +517,9 @@ class PersonInfoManager:
|
||||
if trimmed_interval:
|
||||
msg_interval_val = int(round(np.percentile(trimmed_interval, 37)))
|
||||
await self.update_one_field(person_id, "msg_interval", msg_interval_val)
|
||||
logger.trace(f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}")
|
||||
logger.trace(
|
||||
f"用户{person_id}的msg_interval通过头尾截断和37分位数更新为{msg_interval_val}"
|
||||
)
|
||||
else:
|
||||
logger.trace(f"用户{person_id}截断后数据为空,无法计算msg_interval")
|
||||
else:
|
||||
@@ -577,13 +584,17 @@ class PersonInfoManager:
|
||||
break
|
||||
|
||||
if not found_person_id:
|
||||
|
||||
def _db_find_by_name_sync(p_name_to_find: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_name == p_name_to_find)
|
||||
|
||||
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
|
||||
if record:
|
||||
found_person_id = record.person_id
|
||||
if found_person_id not in self.person_name_list or self.person_name_list[found_person_id] != person_name:
|
||||
if (
|
||||
found_person_id not in self.person_name_list
|
||||
or self.person_name_list[found_person_id] != person_name
|
||||
):
|
||||
self.person_name_list[found_person_id] = person_name
|
||||
else:
|
||||
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
|
||||
@@ -600,7 +611,9 @@ class PersonInfoManager:
|
||||
"person_name",
|
||||
"name_reason",
|
||||
]
|
||||
valid_fields_to_get = [f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default]
|
||||
valid_fields_to_get = [
|
||||
f for f in required_fields if f in PersonInfo._meta.fields or f in person_info_default
|
||||
]
|
||||
|
||||
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
||||
|
||||
|
||||
@@ -454,7 +454,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
def reply_replacer(match):
|
||||
# aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
anon_reply = get_anon_name(platform, bbb) #noqa
|
||||
anon_reply = get_anon_name(platform, bbb) # noqa
|
||||
return f"回复 {anon_reply}"
|
||||
|
||||
content = re.sub(reply_pattern, reply_replacer, content, count=1)
|
||||
@@ -465,7 +465,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
def at_replacer(match):
|
||||
# aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
anon_at = get_anon_name(platform, bbb) #noqa
|
||||
anon_at = get_anon_name(platform, bbb) # noqa
|
||||
return f"@{anon_at}"
|
||||
|
||||
content = re.sub(at_pattern, at_replacer, content)
|
||||
|
||||
@@ -103,11 +103,11 @@ class InfoCatcher:
|
||||
|
||||
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
||||
|
||||
messages_between_query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time > time_start) &
|
||||
(Messages.time < time_end)
|
||||
).order_by(Messages.time.desc())
|
||||
messages_between_query = (
|
||||
Messages.select()
|
||||
.where((Messages.chat_id == chat_id) & (Messages.time > time_start) & (Messages.time < time_end))
|
||||
.order_by(Messages.time.desc())
|
||||
)
|
||||
|
||||
result = list(messages_between_query)
|
||||
print(f"查询结果数量: {len(result)}")
|
||||
@@ -124,10 +124,12 @@ class InfoCatcher:
|
||||
message_id_val = message.message_info.message_id
|
||||
chat_id_val = message.chat_stream.stream_id
|
||||
|
||||
messages_before_query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id_val) &
|
||||
(Messages.message_id < message_id_val)
|
||||
).order_by(Messages.time.desc()).limit(self.context_length * 3)
|
||||
messages_before_query = (
|
||||
Messages.select()
|
||||
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(self.context_length * 3)
|
||||
)
|
||||
|
||||
return list(messages_before_query)
|
||||
|
||||
@@ -164,7 +166,7 @@ class InfoCatcher:
|
||||
"processed_plain_text": msg_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
if hasattr(msg_obj, 'message_info') and hasattr(msg_obj.message_info, 'user_info'):
|
||||
if hasattr(msg_obj, "message_info") and hasattr(msg_obj.message_info, "user_info"):
|
||||
return {
|
||||
"time": msg_obj.message_info.time,
|
||||
"user_id": msg_obj.message_info.user_info.user_id,
|
||||
@@ -198,7 +200,7 @@ class InfoCatcher:
|
||||
chat_history_in_thinking_json=json.dumps(chat_history_in_thinking_list),
|
||||
chat_history_after_response_json=json.dumps(chat_history_after_response_list),
|
||||
heartflow_data_json=json.dumps(self.heartflow_data),
|
||||
reasoning_data_json=json.dumps(self.reasoning_data)
|
||||
reasoning_data_json=json.dumps(self.reasoning_data),
|
||||
)
|
||||
log_entry.save()
|
||||
|
||||
|
||||
@@ -67,9 +67,12 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
if not self.record_id: # Check again if record_id was reset or initially None
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = OnlineTime.select().where(
|
||||
OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))
|
||||
).order_by(OnlineTime.end_timestamp.desc()).first()
|
||||
recent_record = (
|
||||
OnlineTime.select()
|
||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
|
||||
.order_by(OnlineTime.end_timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if recent_record:
|
||||
# 如果有记录,则更新结束时间
|
||||
@@ -87,7 +90,6 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
|
||||
|
||||
|
||||
def _format_online_time(online_seconds: int) -> str:
|
||||
"""
|
||||
格式化在线时间
|
||||
@@ -343,7 +345,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
chat_name = message.user_nickname # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats.")
|
||||
logger.warning(
|
||||
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
|
||||
)
|
||||
continue
|
||||
|
||||
if not chat_id: # Should not happen if above logic is correct
|
||||
|
||||
@@ -61,8 +61,7 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.hash == image_hash) &
|
||||
(ImageDescriptions.type == description_type)
|
||||
(ImageDescriptions.hash == image_hash) & (ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
@@ -80,14 +79,9 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
defaults = {
|
||||
'description': description,
|
||||
'timestamp': current_timestamp
|
||||
}
|
||||
defaults = {"description": description, "timestamp": current_timestamp}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
hash=image_hash,
|
||||
type=description_type,
|
||||
defaults=defaults
|
||||
hash=image_hash, type=description_type, defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
|
||||
@@ -14,6 +14,7 @@ import datetime
|
||||
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
|
||||
# host='localhost', port=3306)
|
||||
|
||||
|
||||
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
|
||||
# 这允许您在一个地方为所有模型指定数据库。
|
||||
class BaseModel(Model):
|
||||
@@ -22,10 +23,12 @@ class BaseModel(Model):
|
||||
database = db # 例如: database = my_actual_db_instance
|
||||
pass # 在用户定义数据库实例之前,此处为占位符
|
||||
|
||||
|
||||
class ChatStreams(BaseModel):
|
||||
"""
|
||||
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
|
||||
"""
|
||||
|
||||
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
|
||||
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
|
||||
stream_id = TextField(unique=True, index=True)
|
||||
@@ -64,12 +67,14 @@ class ChatStreams(BaseModel):
|
||||
# 如果不使用带有数据库实例的 BaseModel,或者想覆盖它,
|
||||
# 请取消注释并在下面设置数据库实例:
|
||||
# database = db
|
||||
table_name = 'chat_streams' # 可选:明确指定数据库中的表名
|
||||
table_name = "chat_streams" # 可选:明确指定数据库中的表名
|
||||
|
||||
|
||||
class LLMUsage(BaseModel):
|
||||
"""
|
||||
用于存储 API 使用日志数据的模型。
|
||||
"""
|
||||
|
||||
model_name = TextField(index=True) # 添加索引
|
||||
user_id = TextField(index=True) # 添加索引
|
||||
request_type = TextField(index=True) # 添加索引
|
||||
@@ -84,7 +89,8 @@ class LLMUsage(BaseModel):
|
||||
class Meta:
|
||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||
# database = db
|
||||
table_name = 'llm_usage'
|
||||
table_name = "llm_usage"
|
||||
|
||||
|
||||
class Emoji(BaseModel):
|
||||
"""表情包"""
|
||||
@@ -105,12 +111,14 @@ class Emoji(BaseModel):
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'emoji'
|
||||
table_name = "emoji"
|
||||
|
||||
|
||||
class Messages(BaseModel):
|
||||
"""
|
||||
用于存储消息数据的模型。
|
||||
"""
|
||||
|
||||
message_id = IntegerField(index=True) # 消息 ID
|
||||
time = DoubleField() # 消息时间戳
|
||||
|
||||
@@ -141,12 +149,14 @@ class Messages(BaseModel):
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'messages'
|
||||
table_name = "messages"
|
||||
|
||||
|
||||
class Images(BaseModel):
|
||||
"""
|
||||
用于存储图像信息的模型。
|
||||
"""
|
||||
|
||||
hash = TextField(index=True) # 图像的哈希值
|
||||
description = TextField(null=True) # 图像的描述
|
||||
path = TextField(unique=True) # 图像文件的路径
|
||||
@@ -155,12 +165,14 @@ class Images(BaseModel):
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'images'
|
||||
table_name = "images"
|
||||
|
||||
|
||||
class ImageDescriptions(BaseModel):
|
||||
"""
|
||||
用于存储图像描述信息的模型。
|
||||
"""
|
||||
|
||||
type = TextField() # 类型,例如 "emoji"
|
||||
hash = TextField(index=True) # 图像的哈希值
|
||||
description = TextField() # 图像的描述
|
||||
@@ -168,12 +180,14 @@ class ImageDescriptions(BaseModel):
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'image_descriptions'
|
||||
table_name = "image_descriptions"
|
||||
|
||||
|
||||
class OnlineTime(BaseModel):
|
||||
"""
|
||||
用于存储在线时长记录的模型。
|
||||
"""
|
||||
|
||||
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
|
||||
timestamp = TextField()
|
||||
duration = IntegerField() # 时长,单位分钟
|
||||
@@ -182,12 +196,14 @@ class OnlineTime(BaseModel):
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'online_time'
|
||||
table_name = "online_time"
|
||||
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
"""
|
||||
用于存储个人信息数据的模型。
|
||||
"""
|
||||
|
||||
person_id = TextField(unique=True, index=True) # 个人唯一ID
|
||||
person_name = TextField() # 个人名称
|
||||
name_reason = TextField(null=True) # 名称设定的原因
|
||||
@@ -202,19 +218,21 @@ class PersonInfo(BaseModel):
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'person_info'
|
||||
table_name = "person_info"
|
||||
|
||||
|
||||
class Knowledges(BaseModel):
|
||||
"""
|
||||
用于存储知识库条目的模型。
|
||||
"""
|
||||
|
||||
content = TextField() # 知识内容的文本
|
||||
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
||||
# 可以添加其他元数据字段,如 source, create_time 等
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = 'knowledges'
|
||||
table_name = "knowledges"
|
||||
|
||||
|
||||
class ThinkingLog(BaseModel):
|
||||
@@ -238,14 +256,16 @@ class ThinkingLog(BaseModel):
|
||||
created_at = DateTimeField(default=datetime.datetime.now)
|
||||
|
||||
class Meta:
|
||||
table_name = 'thinking_logs'
|
||||
table_name = "thinking_logs"
|
||||
|
||||
|
||||
def create_tables():
|
||||
"""
|
||||
创建所有在模型中定义的数据库表。
|
||||
"""
|
||||
with db:
|
||||
db.create_tables([
|
||||
db.create_tables(
|
||||
[
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
@@ -255,8 +275,10 @@ def create_tables():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
ThinkingLog
|
||||
])
|
||||
ThinkingLog,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""
|
||||
@@ -272,7 +294,7 @@ def initialize_database():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
ThinkingLog
|
||||
ThinkingLog,
|
||||
]
|
||||
|
||||
needs_creation = False
|
||||
@@ -298,5 +320,6 @@ def initialize_database():
|
||||
else:
|
||||
print("所有数据库表均已存在。")
|
||||
|
||||
|
||||
# 模块加载时调用初始化函数
|
||||
initialize_database()
|
||||
|
||||
@@ -42,9 +42,7 @@ def find_messages(
|
||||
if hasattr(Messages, key):
|
||||
conditions.append(getattr(Messages, key) == value)
|
||||
else:
|
||||
logger.warning(
|
||||
f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
|
||||
)
|
||||
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||
if conditions:
|
||||
# 使用 *conditions 将所有条件以 AND 连接
|
||||
query = query.where(*conditions)
|
||||
@@ -59,9 +57,7 @@ def find_messages(
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
latest_results_peewee = list(query)
|
||||
# 将结果按时间正序排列
|
||||
peewee_results = sorted(
|
||||
latest_results_peewee, key=lambda msg: msg.time
|
||||
)
|
||||
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
@@ -74,13 +70,9 @@ def find_messages(
|
||||
elif direction == -1: # DESC
|
||||
peewee_sort_terms.append(field.desc())
|
||||
else:
|
||||
logger.warning(
|
||||
f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。"
|
||||
)
|
||||
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
||||
else:
|
||||
logger.warning(
|
||||
f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。"
|
||||
)
|
||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||
if peewee_sort_terms:
|
||||
query = query.order_by(*peewee_sort_terms)
|
||||
peewee_results = list(query)
|
||||
@@ -116,9 +108,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if hasattr(Messages, key):
|
||||
conditions.append(getattr(Messages, key) == value)
|
||||
else:
|
||||
logger.warning(
|
||||
f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
|
||||
)
|
||||
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# from src.common.database.database import db # Peewee db 导入
|
||||
from src.common.database.database_model import Messages # Peewee Messages 模型导入
|
||||
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
|
||||
@@ -53,20 +54,23 @@ class PeeweeMessageStorage(MessageStorage):
|
||||
"""Peewee消息存储实现"""
|
||||
|
||||
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
|
||||
query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time > message_time)
|
||||
).order_by(Messages.time.asc())
|
||||
query = (
|
||||
Messages.select()
|
||||
.where((Messages.chat_id == chat_id) & (Messages.time > message_time))
|
||||
.order_by(Messages.time.asc())
|
||||
)
|
||||
|
||||
# print(f"storage_check_message: {message_time}")
|
||||
messages_models = list(query)
|
||||
return [model_to_dict(msg) for msg in messages_models]
|
||||
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
query = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time < time_point)
|
||||
).order_by(Messages.time.desc()).limit(limit)
|
||||
query = (
|
||||
Messages.select()
|
||||
.where((Messages.chat_id == chat_id) & (Messages.time < time_point))
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
messages_models = list(query)
|
||||
# 将消息按时间正序排列
|
||||
@@ -74,10 +78,7 @@ class PeeweeMessageStorage(MessageStorage):
|
||||
return [model_to_dict(msg) for msg in messages_models]
|
||||
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
return Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.time > after_time)
|
||||
).exists()
|
||||
return Messages.select().where((Messages.chat_id == chat_id) & (Messages.time > after_time)).exists()
|
||||
|
||||
|
||||
# # 创建一个内存消息存储实现,用于测试
|
||||
|
||||
@@ -89,7 +89,9 @@ class SearchKnowledgeTool(BaseTool):
|
||||
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
|
||||
continue
|
||||
item_embedding = json.loads(item_embedding_str)
|
||||
if not isinstance(item_embedding, list) or not all(isinstance(x, (int, float)) for x in item_embedding):
|
||||
if not isinstance(item_embedding, list) or not all(
|
||||
isinstance(x, (int, float)) for x in item_embedding
|
||||
):
|
||||
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
|
||||
Reference in New Issue
Block a user