🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-05-14 15:11:33 +00:00
parent 17d19e7cac
commit fb6094d269
17 changed files with 278 additions and 254 deletions

View File

@@ -14,18 +14,21 @@ import datetime
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=3306)
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
# 这允许您在一个地方为所有模型指定数据库。
class BaseModel(Model):
class Meta:
# 将下面的 'db' 替换为您实际的数据库实例变量名。
database = db # 例如: database = my_actual_db_instance
pass # 在用户定义数据库实例之前,此处为占位符
pass # 在用户定义数据库实例之前,此处为占位符
class ChatStreams(BaseModel):
"""
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
"""
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
stream_id = TextField(unique=True, index=True)
@@ -63,28 +66,31 @@ class ChatStreams(BaseModel):
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
# database = db
table_name = 'chat_streams' # 可选:明确指定数据库中的表名
# database = db
table_name = "chat_streams" # 可选:明确指定数据库中的表名
class LLMUsage(BaseModel):
"""
用于存储 API 使用日志数据的模型。
"""
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
endpoint = TextField()
prompt_tokens = IntegerField()
completion_tokens = IntegerField()
total_tokens = IntegerField()
cost = DoubleField()
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = 'llm_usage'
# database = db
table_name = "llm_usage"
class Emoji(BaseModel):
"""表情包"""
@@ -105,16 +111,18 @@ class Emoji(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'emoji'
table_name = "emoji"
class Messages(BaseModel):
"""
用于存储消息数据的模型。
"""
message_id = IntegerField(index=True) # 消息 ID
time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
message_id = IntegerField(index=True) # 消息 ID
time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = TextField()
@@ -123,7 +131,7 @@ class Messages(BaseModel):
chat_info_user_id = TextField()
chat_info_user_nickname = TextField()
chat_info_user_cardname = TextField(null=True)
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_id = TextField(null=True)
chat_info_group_name = TextField(null=True)
chat_info_create_time = DoubleField()
@@ -135,18 +143,20 @@ class Messages(BaseModel):
user_nickname = TextField()
user_cardname = TextField(null=True)
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数
class Meta:
# database = db # 继承自 BaseModel
table_name = 'messages'
table_name = "messages"
class Images(BaseModel):
"""
用于存储图像信息的模型。
"""
hash = TextField(index=True) # 图像的哈希值
description = TextField(null=True) # 图像的描述
path = TextField(unique=True) # 图像文件的路径
@@ -155,12 +165,14 @@ class Images(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'images'
table_name = "images"
class ImageDescriptions(BaseModel):
"""
用于存储图像描述信息的模型。
"""
type = TextField() # 类型,例如 "emoji"
hash = TextField(index=True) # 图像的哈希值
description = TextField() # 图像的描述
@@ -168,12 +180,14 @@ class ImageDescriptions(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'image_descriptions'
table_name = "image_descriptions"
class OnlineTime(BaseModel):
"""
用于存储在线时长记录的模型。
"""
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = TextField()
duration = IntegerField() # 时长,单位分钟
@@ -182,12 +196,14 @@ class OnlineTime(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'online_time'
table_name = "online_time"
class PersonInfo(BaseModel):
"""
用于存储个人信息数据的模型。
"""
person_id = TextField(unique=True, index=True) # 个人唯一ID
person_name = TextField() # 个人名称
name_reason = TextField(null=True) # 名称设定的原因
@@ -202,26 +218,28 @@ class PersonInfo(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = 'person_info'
table_name = "person_info"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = 'knowledges'
table_name = "knowledges"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
@@ -235,28 +253,32 @@ class ThinkingLog(BaseModel):
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = 'thinking_logs'
table_name = "thinking_logs"
def create_tables():
"""
创建所有在模型中定义的数据库表。
"""
with db:
db.create_tables([
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog
])
db.create_tables(
[
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
]
)
def initialize_database():
"""
@@ -272,9 +294,9 @@ def initialize_database():
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog
ThinkingLog,
]
needs_creation = False
try:
with db: # 管理 table_exists 检查的连接
@@ -298,5 +320,6 @@ def initialize_database():
else:
print("所有数据库表均已存在。")
# 模块加载时调用初始化函数
initialize_database()

View File

@@ -1,4 +1,4 @@
from src.common.database.database_model import Messages # 更改导入
from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_module_logger
import traceback
from typing import List, Any, Optional
@@ -42,9 +42,7 @@ def find_messages(
if hasattr(Messages, key):
conditions.append(getattr(Messages, key) == value)
else:
logger.warning(
f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
)
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
# 使用 *conditions 将所有条件以 AND 连接
query = query.where(*conditions)
@@ -59,9 +57,7 @@ def find_messages(
query = query.order_by(Messages.time.desc()).limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
peewee_results = sorted(
latest_results_peewee, key=lambda msg: msg.time
)
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
@@ -74,13 +70,9 @@ def find_messages(
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
else:
logger.warning(
f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。"
)
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(
f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。"
)
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
@@ -116,9 +108,7 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if hasattr(Messages, key):
conditions.append(getattr(Messages, key) == value)
else:
logger.warning(
f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。"
)
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)