re-style: 格式化代码
This commit is contained in:
@@ -1,17 +1,19 @@
|
||||
import time
|
||||
import orjson
|
||||
import hashlib
|
||||
import time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
|
||||
import faiss
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.database.sqlalchemy_models import CacheEntries
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save
|
||||
from src.common.database.sqlalchemy_models import CacheEntries
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("cache_manager")
|
||||
|
||||
@@ -40,14 +42,14 @@ class CacheManager:
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
# L1 缓存 (内存)
|
||||
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.l1_kv_cache: dict[str, dict[str, Any]] = {}
|
||||
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
||||
if not embedding_dim:
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
|
||||
self.embedding_dimension = embedding_dim
|
||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||
self.l1_vector_id_to_key: dict[int, str] = {}
|
||||
|
||||
# L2 向量缓存 (使用新的服务)
|
||||
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
|
||||
@@ -59,7 +61,7 @@ class CacheManager:
|
||||
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
|
||||
|
||||
@staticmethod
|
||||
def _validate_embedding(embedding_result: Any) -> Optional[np.ndarray]:
|
||||
def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
|
||||
"""
|
||||
验证和标准化嵌入向量格式
|
||||
"""
|
||||
@@ -100,7 +102,7 @@ class CacheManager:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _generate_key(tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path]) -> str:
|
||||
def _generate_key(tool_name: str, function_args: dict[str, Any], tool_file_path: str | Path) -> str:
|
||||
"""生成确定性的缓存键,包含文件修改时间以实现自动失效。"""
|
||||
try:
|
||||
tool_file_path = Path(tool_file_path)
|
||||
@@ -124,10 +126,10 @@ class CacheManager:
|
||||
async def get(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
tool_file_path: Union[str, Path],
|
||||
semantic_query: Optional[str] = None,
|
||||
) -> Optional[Any]:
|
||||
function_args: dict[str, Any],
|
||||
tool_file_path: str | Path,
|
||||
semantic_query: str | None = None,
|
||||
) -> Any | None:
|
||||
"""
|
||||
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
|
||||
"""
|
||||
@@ -251,11 +253,11 @@ class CacheManager:
|
||||
async def set(
|
||||
self,
|
||||
tool_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
tool_file_path: Union[str, Path],
|
||||
function_args: dict[str, Any],
|
||||
tool_file_path: str | Path,
|
||||
data: Any,
|
||||
ttl: Optional[int] = None,
|
||||
semantic_query: Optional[str] = None,
|
||||
ttl: int | None = None,
|
||||
semantic_query: str | None = None,
|
||||
):
|
||||
"""将结果存入所有缓存层。"""
|
||||
if ttl is None:
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
|
||||
def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]:
|
||||
def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: bool = True) -> int | None:
|
||||
"""获取当前配置的嵌入向量维度。
|
||||
|
||||
优先顺序:
|
||||
@@ -14,7 +12,7 @@ def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global:
|
||||
3. 调用方提供的 fallback
|
||||
"""
|
||||
|
||||
candidates: list[Optional[int]] = []
|
||||
candidates: list[int | None] = []
|
||||
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
@@ -30,7 +28,7 @@ def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global:
|
||||
|
||||
candidates.append(fallback)
|
||||
|
||||
resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||
|
||||
if resolved and sync_global:
|
||||
try:
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@@ -16,12 +16,12 @@ class BotInterestTag(BaseDataModel):
|
||||
|
||||
tag_name: str
|
||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||
embedding: Optional[List[float]] = None # 标签的embedding向量
|
||||
embedding: list[float] | None = None # 标签的embedding向量
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
is_active: bool = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"tag_name": self.tag_name,
|
||||
@@ -33,7 +33,7 @@ class BotInterestTag(BaseDataModel):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "BotInterestTag":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
tag_name=data["tag_name"],
|
||||
@@ -51,16 +51,16 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
|
||||
personality_id: str
|
||||
personality_description: str # 人设描述文本
|
||||
interest_tags: List[BotInterestTag] = field(default_factory=list)
|
||||
interest_tags: list[BotInterestTag] = field(default_factory=list)
|
||||
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
|
||||
last_updated: datetime = field(default_factory=datetime.now)
|
||||
version: int = 1 # 版本号,用于追踪更新
|
||||
|
||||
def get_active_tags(self) -> List[BotInterestTag]:
|
||||
def get_active_tags(self) -> list[BotInterestTag]:
|
||||
"""获取活跃的兴趣标签"""
|
||||
return [tag for tag in self.interest_tags if tag.is_active]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"personality_id": self.personality_id,
|
||||
@@ -72,7 +72,7 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "BotPersonalityInterests":
|
||||
"""从字典创建对象"""
|
||||
return cls(
|
||||
personality_id=data["personality_id"],
|
||||
@@ -89,14 +89,14 @@ class InterestMatchResult(BaseDataModel):
|
||||
"""兴趣匹配结果"""
|
||||
|
||||
message_id: str
|
||||
matched_tags: List[str] = field(default_factory=list)
|
||||
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
matched_tags: list[str] = field(default_factory=list)
|
||||
match_scores: dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
overall_score: float = 0.0
|
||||
top_tag: Optional[str] = None
|
||||
top_tag: str | None = None
|
||||
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
|
||||
matched_keywords: List[str] = field(default_factory=list)
|
||||
matched_keywords: list[str] = field(default_factory=list)
|
||||
|
||||
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
|
||||
def add_match(self, tag_name: str, score: float, keywords: list[str] = None):
|
||||
"""添加匹配结果"""
|
||||
self.matched_tags.append(tag_name)
|
||||
self.match_scores[tag_name] = score
|
||||
@@ -131,7 +131,7 @@ class InterestMatchResult(BaseDataModel):
|
||||
else:
|
||||
self.confidence = 0.0
|
||||
|
||||
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
|
||||
def get_top_matches(self, top_n: int = 3) -> list[tuple]:
|
||||
"""获取前N个最佳匹配"""
|
||||
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_matches[:top_n]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from typing import Optional, Any, Dict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@@ -10,7 +10,7 @@ class DatabaseUserInfo(BaseDataModel):
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
user_cardname: str | None = None
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.platform, str), "platform must be a string"
|
||||
@@ -25,7 +25,7 @@ class DatabaseUserInfo(BaseDataModel):
|
||||
class DatabaseGroupInfo(BaseDataModel):
|
||||
group_id: str = field(default_factory=str)
|
||||
group_name: str = field(default_factory=str)
|
||||
group_platform: Optional[str] = None
|
||||
group_platform: str | None = None
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.group_id, str), "group_id must be a string"
|
||||
@@ -42,7 +42,7 @@ class DatabaseChatInfo(BaseDataModel):
|
||||
create_time: float = field(default_factory=float)
|
||||
last_active_time: float = field(default_factory=float)
|
||||
user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo)
|
||||
group_info: Optional[DatabaseGroupInfo] = None
|
||||
group_info: DatabaseGroupInfo | None = None
|
||||
|
||||
# def __post_init__(self):
|
||||
# assert isinstance(self.stream_id, str), "stream_id must be a string"
|
||||
@@ -62,41 +62,41 @@ class DatabaseMessages(BaseDataModel):
|
||||
message_id: str = "",
|
||||
time: float = 0.0,
|
||||
chat_id: str = "",
|
||||
reply_to: Optional[str] = None,
|
||||
interest_value: Optional[float] = None,
|
||||
key_words: Optional[str] = None,
|
||||
key_words_lite: Optional[str] = None,
|
||||
is_mentioned: Optional[bool] = None,
|
||||
is_at: Optional[bool] = None,
|
||||
reply_probability_boost: Optional[float] = None,
|
||||
processed_plain_text: Optional[str] = None,
|
||||
display_message: Optional[str] = None,
|
||||
priority_mode: Optional[str] = None,
|
||||
priority_info: Optional[str] = None,
|
||||
additional_config: Optional[str] = None,
|
||||
reply_to: str | None = None,
|
||||
interest_value: float | None = None,
|
||||
key_words: str | None = None,
|
||||
key_words_lite: str | None = None,
|
||||
is_mentioned: bool | None = None,
|
||||
is_at: bool | None = None,
|
||||
reply_probability_boost: float | None = None,
|
||||
processed_plain_text: str | None = None,
|
||||
display_message: str | None = None,
|
||||
priority_mode: str | None = None,
|
||||
priority_info: str | None = None,
|
||||
additional_config: str | None = None,
|
||||
is_emoji: bool = False,
|
||||
is_picid: bool = False,
|
||||
is_command: bool = False,
|
||||
is_notify: bool = False,
|
||||
selected_expressions: Optional[str] = None,
|
||||
selected_expressions: str | None = None,
|
||||
is_read: bool = False,
|
||||
user_id: str = "",
|
||||
user_nickname: str = "",
|
||||
user_cardname: Optional[str] = None,
|
||||
user_cardname: str | None = None,
|
||||
user_platform: str = "",
|
||||
chat_info_group_id: Optional[str] = None,
|
||||
chat_info_group_name: Optional[str] = None,
|
||||
chat_info_group_platform: Optional[str] = None,
|
||||
chat_info_group_id: str | None = None,
|
||||
chat_info_group_name: str | None = None,
|
||||
chat_info_group_platform: str | None = None,
|
||||
chat_info_user_id: str = "",
|
||||
chat_info_user_nickname: str = "",
|
||||
chat_info_user_cardname: Optional[str] = None,
|
||||
chat_info_user_cardname: str | None = None,
|
||||
chat_info_user_platform: str = "",
|
||||
chat_info_stream_id: str = "",
|
||||
chat_info_platform: str = "",
|
||||
chat_info_create_time: float = 0.0,
|
||||
chat_info_last_active_time: float = 0.0,
|
||||
# 新增字段
|
||||
actions: Optional[list] = None,
|
||||
actions: list | None = None,
|
||||
should_reply: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@@ -132,7 +132,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
self.selected_expressions = selected_expressions
|
||||
self.is_read = is_read
|
||||
|
||||
self.group_info: Optional[DatabaseGroupInfo] = None
|
||||
self.group_info: DatabaseGroupInfo | None = None
|
||||
self.user_info = DatabaseUserInfo(
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
@@ -172,7 +172,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
|
||||
# "interest_value must be a float or None"
|
||||
# )
|
||||
def flatten(self) -> Dict[str, Any]:
|
||||
def flatten(self) -> dict[str, Any]:
|
||||
"""
|
||||
将消息数据模型转换为字典格式,便于存储或传输
|
||||
"""
|
||||
@@ -255,7 +255,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"""
|
||||
return self.actions or []
|
||||
|
||||
def get_message_summary(self) -> Dict[str, Any]:
|
||||
def get_message_summary(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取消息摘要信息
|
||||
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .database_data_model import DatabaseMessages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
|
||||
from .database_data_model import DatabaseMessages
|
||||
|
||||
|
||||
@dataclass
|
||||
class TargetPersonInfo(BaseDataModel):
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
person_id: Optional[str] = None
|
||||
person_name: Optional[str] = None
|
||||
person_id: str | None = None
|
||||
person_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionPlannerInfo(BaseDataModel):
|
||||
action_type: str = field(default_factory=str)
|
||||
reasoning: Optional[str] = None
|
||||
action_data: Optional[Dict] = None
|
||||
reasoning: str | None = None
|
||||
action_data: dict | None = None
|
||||
action_message: Optional["DatabaseMessages"] = None
|
||||
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||
available_actions: dict[str, "ActionInfo"] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -36,7 +38,7 @@ class InterestScore(BaseDataModel):
|
||||
interest_match_score: float
|
||||
relationship_score: float
|
||||
mentioned_score: float
|
||||
details: Dict[str, str]
|
||||
details: dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -50,10 +52,10 @@ class Plan(BaseDataModel):
|
||||
|
||||
chat_type: "ChatType"
|
||||
# Generator 填充
|
||||
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
|
||||
chat_history: List["DatabaseMessages"] = field(default_factory=list)
|
||||
target_info: Optional[TargetPersonInfo] = None
|
||||
available_actions: dict[str, "ActionInfo"] = field(default_factory=dict)
|
||||
chat_history: list["DatabaseMessages"] = field(default_factory=list)
|
||||
target_info: TargetPersonInfo | None = None
|
||||
|
||||
# Filter 填充
|
||||
llm_prompt: Optional[str] = None
|
||||
decided_actions: Optional[List[ActionPlannerInfo]] = None
|
||||
llm_prompt: str | None = None
|
||||
decided_actions: list[ActionPlannerInfo] | None = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
@@ -9,10 +9,10 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
content: str | None = None
|
||||
reasoning: str | None = None
|
||||
model: str | None = None
|
||||
tool_calls: list["ToolCall"] | None = None
|
||||
prompt: str | None = None
|
||||
selected_expressions: list[int] | None = None
|
||||
reply_set: list[tuple[str, Any]] | None = None
|
||||
|
||||
@@ -7,11 +7,12 @@ import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
from . import BaseDataModel
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .database_data_model import DatabaseMessages
|
||||
@@ -34,11 +35,11 @@ class StreamContext(BaseDataModel):
|
||||
stream_id: str
|
||||
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
|
||||
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||
last_check_time: float = field(default_factory=time.time)
|
||||
is_active: bool = True
|
||||
processing_task: Optional[asyncio.Task] = None
|
||||
processing_task: asyncio.Task | None = None
|
||||
interruption_count: int = 0 # 打断计数器
|
||||
last_interruption_time: float = 0.0 # 上次打断时间
|
||||
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
|
||||
@@ -49,8 +50,8 @@ class StreamContext(BaseDataModel):
|
||||
|
||||
# 新增字段以替代ChatMessageContext功能
|
||||
current_message: Optional["DatabaseMessages"] = None
|
||||
priority_mode: Optional[str] = None
|
||||
priority_info: Optional[dict] = None
|
||||
priority_mode: str | None = None
|
||||
priority_info: dict | None = None
|
||||
|
||||
def add_message(self, message: "DatabaseMessages"):
|
||||
"""添加消息到上下文"""
|
||||
@@ -150,11 +151,11 @@ class StreamContext(BaseDataModel):
|
||||
self.unread_messages.remove(msg)
|
||||
break
|
||||
|
||||
def get_unread_messages(self) -> List["DatabaseMessages"]:
|
||||
def get_unread_messages(self) -> list["DatabaseMessages"]:
|
||||
"""获取未读消息"""
|
||||
return [msg for msg in self.unread_messages if not msg.is_read]
|
||||
|
||||
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
|
||||
def get_history_messages(self, limit: int = 20) -> list["DatabaseMessages"]:
|
||||
"""获取历史消息"""
|
||||
# 优先返回最近的历史消息和所有未读消息
|
||||
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
|
||||
@@ -230,7 +231,7 @@ class StreamContext(BaseDataModel):
|
||||
"""设置当前消息"""
|
||||
self.current_message = message
|
||||
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
def get_template_name(self) -> str | None:
|
||||
"""获取模板名称"""
|
||||
if (
|
||||
self.current_message
|
||||
@@ -336,11 +337,11 @@ class StreamContext(BaseDataModel):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> Optional[str]:
|
||||
def get_priority_mode(self) -> str | None:
|
||||
"""获取优先级模式"""
|
||||
return self.priority_mode
|
||||
|
||||
def get_priority_info(self) -> Optional[dict]:
|
||||
def get_priority_info(self) -> dict | None:
|
||||
"""获取优先级信息"""
|
||||
return self.priority_info
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# SQLAlchemy相关导入
|
||||
from src.common.database.sqlalchemy_init import initialize_database_compat
|
||||
from src.common.database.sqlalchemy_models import get_engine, get_db_session
|
||||
from src.common.database.sqlalchemy_models import get_db_session, get_engine
|
||||
from src.common.logger import get_logger
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
@@ -6,31 +6,31 @@
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, List, Any, Union, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import desc, asc, func, and_, select
|
||||
from sqlalchemy import and_, asc, desc, func, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
get_db_session,
|
||||
Messages,
|
||||
ActionRecords,
|
||||
PersonInfo,
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
Memory,
|
||||
Expression,
|
||||
ThinkingLog,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
Schedule,
|
||||
MaiZoneScheduleStatus,
|
||||
CacheEntries,
|
||||
ChatStreams,
|
||||
Emoji,
|
||||
Expression,
|
||||
GraphEdges,
|
||||
GraphNodes,
|
||||
ImageDescriptions,
|
||||
Images,
|
||||
LLMUsage,
|
||||
MaiZoneScheduleStatus,
|
||||
Memory,
|
||||
Messages,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Schedule,
|
||||
ThinkingLog,
|
||||
UserRelationships,
|
||||
get_db_session,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -59,7 +59,7 @@ MODEL_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
async def build_filters(model_class, filters: Dict[str, Any]):
|
||||
async def build_filters(model_class, filters: dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
|
||||
@@ -98,13 +98,13 @@ async def build_filters(model_class, filters: Dict[str, Any]):
|
||||
|
||||
async def db_query(
|
||||
model_class,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
data: dict[str, Any] | None = None,
|
||||
query_type: str | None = "get",
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[str] | None = None,
|
||||
single_result: bool | None = False,
|
||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||
"""执行异步数据库查询操作
|
||||
|
||||
Args:
|
||||
@@ -263,8 +263,8 @@ async def db_query(
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
model_class, data: dict[str, Any], key_field: str | None = None, key_value: Any | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""异步保存数据到数据库(创建或更新)
|
||||
|
||||
Args:
|
||||
@@ -325,11 +325,11 @@ async def db_save(
|
||||
|
||||
async def db_get(
|
||||
model_class,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: str | None = None,
|
||||
single_result: bool | None = False,
|
||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||
"""异步从数据库获取记录
|
||||
|
||||
Args:
|
||||
@@ -359,9 +359,9 @@ async def store_action_info(
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_data: dict | None = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""异步存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
提供统一的异步数据库初始化接口
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("sqlalchemy_init")
|
||||
|
||||
@@ -71,7 +71,7 @@ async def create_all_tables() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def get_database_info() -> Optional[dict]:
|
||||
async def get_database_info() -> dict | None:
|
||||
"""
|
||||
异步获取数据库信息
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Any, Dict, AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime, text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -423,7 +424,7 @@ class Expression(Base):
|
||||
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
create_date: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
|
||||
|
||||
@@ -710,7 +711,7 @@ async def initialize_database():
|
||||
config = global_config.database
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs: Dict[str, Any] = {
|
||||
engine_kwargs: dict[str, Any] = {
|
||||
"echo": False, # 生产环境关闭SQL日志
|
||||
"future": True,
|
||||
}
|
||||
@@ -759,12 +760,12 @@ async def initialize_database():
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]:
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]:
|
||||
"""
|
||||
异步数据库会话上下文管理器。
|
||||
在初始化失败时会yield None,调用方需要检查会话是否为None。
|
||||
"""
|
||||
session: Optional[AsyncSession] = None
|
||||
session: AsyncSession | None = None
|
||||
SessionLocal = None
|
||||
try:
|
||||
_, SessionLocal = await initialize_database()
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
||||
|
||||
import logging
|
||||
import orjson
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
import structlog
|
||||
import tomlkit
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 创建logs目录
|
||||
LOG_DIR = Path("logs")
|
||||
LOG_DIR.mkdir(exist_ok=True)
|
||||
@@ -212,7 +212,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
|
||||
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
return config.get("log", default_config)
|
||||
except Exception as e:
|
||||
@@ -942,7 +942,7 @@ raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
binds: dict[str, Callable] = {}
|
||||
|
||||
|
||||
def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
||||
def get_logger(name: str | None) -> structlog.stdlib.BoundLogger:
|
||||
"""获取logger实例,支持按名称绑定"""
|
||||
if name is None:
|
||||
return raw_logger
|
||||
|
||||
@@ -4,7 +4,6 @@ __version__ = "0.1.0"
|
||||
|
||||
from .api import get_global_api
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_global_api",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from src.common.server import get_global_server
|
||||
import importlib.metadata
|
||||
from maim_message import MessageServer
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
|
||||
from maim_message import MessageServer
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.server import get_global_server
|
||||
from src.config.config import global_config
|
||||
|
||||
global_api = None
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Optional, Any, Dict
|
||||
from sqlalchemy import not_, select, func
|
||||
|
||||
from sqlalchemy import func, not_, select
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from src.config.config import global_config
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
|
||||
# from src.common.database.database_model import Messages
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -18,7 +18,7 @@ class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
def _model_to_dict(instance: Base) -> dict[str, Any]:
|
||||
"""
|
||||
将 SQLAlchemy 模型实例转换为字典。
|
||||
"""
|
||||
@@ -32,12 +32,12 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
|
||||
async def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
sort: list[tuple[str, int]] | None = None,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
) -> List[dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
根据提供的过滤器、排序和限制条件查找消息。
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import platform
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import aiohttp
|
||||
import platform
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.tcp_connector import get_tcp_connector
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware # 新增导入
|
||||
from rich.traceback import install
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
from uvicorn import Config
|
||||
from uvicorn import Server as UvicornServer
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
||||
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MaiMCore"):
|
||||
self.app = FastAPI(title=app_name)
|
||||
self._host: str = "127.0.0.1"
|
||||
self._port: int = 8080
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self._server: UvicornServer | None = None
|
||||
self.set_address(host, port)
|
||||
|
||||
# 配置 CORS
|
||||
@@ -57,7 +57,7 @@ class Server:
|
||||
"""
|
||||
self.app.include_router(router, prefix=prefix)
|
||||
|
||||
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
|
||||
def set_address(self, host: str | None = None, port: int | None = None):
|
||||
"""设置服务器地址和端口"""
|
||||
if host:
|
||||
self._host = host
|
||||
@@ -76,7 +76,7 @@ class Server:
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.shutdown()
|
||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||
raise RuntimeError(f"服务器运行错误: {e!s}") from e
|
||||
finally:
|
||||
await self.shutdown()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import ssl
|
||||
import certifi
|
||||
|
||||
import aiohttp
|
||||
import certifi
|
||||
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
|
||||
@@ -18,4 +18,4 @@ def get_vector_db_service() -> VectorDBBase:
|
||||
# 全局向量数据库服务实例
|
||||
vector_db_service: VectorDBBase = get_vector_db_service()
|
||||
|
||||
__all__ = ["vector_db_service", "VectorDBBase"]
|
||||
__all__ = ["VectorDBBase", "vector_db_service"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class VectorDBBase(ABC):
|
||||
@@ -36,10 +36,10 @@ class VectorDBBase(ABC):
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str] | None = None,
|
||||
metadatas: list[dict[str, Any]] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
向指定集合中添加数据。
|
||||
@@ -57,11 +57,11 @@ class VectorDBBase(ABC):
|
||||
def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_embeddings: List[List[float]],
|
||||
query_embeddings: list[list[float]],
|
||||
n_results: int = 1,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, List[Any]]:
|
||||
) -> dict[str, list[Any]]:
|
||||
"""
|
||||
在指定集合中查询相似向量。
|
||||
|
||||
@@ -81,8 +81,8 @@ class VectorDBBase(ABC):
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
ids: list[str] | None = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
从指定集合中删除数据。
|
||||
@@ -98,13 +98,13 @@ class VectorDBBase(ABC):
|
||||
def get(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[Dict[str, Any]] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
ids: list[str] | None = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
where_document: dict[str, Any] | None = None,
|
||||
include: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
根据条件从集合中获取数据。
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
from .base import VectorDBBase
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .base import VectorDBBase
|
||||
|
||||
logger = get_logger("chromadb_impl")
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ class ChromaDBImpl(VectorDBBase):
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=path, settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
self._collections: Dict[str, Any] = {}
|
||||
self._collections: dict[str, Any] = {}
|
||||
self._initialized = True
|
||||
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
|
||||
except Exception as e:
|
||||
@@ -65,10 +66,10 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def add(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
documents: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
embeddings: list[list[float]],
|
||||
documents: list[str] | None = None,
|
||||
metadatas: list[dict[str, Any]] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
) -> None:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
@@ -85,11 +86,11 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def query(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_embeddings: List[List[float]],
|
||||
query_embeddings: list[list[float]],
|
||||
n_results: int = 1,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, List[Any]]:
|
||||
) -> dict[str, list[Any]]:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
try:
|
||||
@@ -120,7 +121,7 @@ class ChromaDBImpl(VectorDBBase):
|
||||
logger.error(f"回退查询也失败: {fallback_e}")
|
||||
return {}
|
||||
|
||||
def _process_where_condition(self, where: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
def _process_where_condition(self, where: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
处理where条件,转换为ChromaDB支持的格式
|
||||
ChromaDB支持的格式:
|
||||
@@ -174,13 +175,13 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def get(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
where_document: Optional[Dict[str, Any]] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
ids: list[str] | None = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
where_document: dict[str, Any] | None = None,
|
||||
include: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""根据条件从集合中获取数据"""
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
@@ -217,8 +218,8 @@ class ChromaDBImpl(VectorDBBase):
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
ids: list[str] | None = None,
|
||||
where: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
collection = self.get_or_create_collection(collection_name)
|
||||
if collection:
|
||||
|
||||
Reference in New Issue
Block a user