re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
committed by Windpicker-owo
parent 00ba07e0e1
commit a79253c714
263 changed files with 3781 additions and 3189 deletions

View File

@@ -1,17 +1,19 @@
import time
import orjson
import hashlib
import time
from pathlib import Path
import numpy as np
from typing import Any
import faiss
from typing import Any, Dict, Optional, Union
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
import numpy as np
import orjson
from src.common.config_helpers import resolve_embedding_dimension
from src.common.database.sqlalchemy_models import CacheEntries
from src.common.database.sqlalchemy_database_api import db_query, db_save
from src.common.database.sqlalchemy_models import CacheEntries
from src.common.logger import get_logger
from src.common.vector_db import vector_db_service
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("cache_manager")
@@ -40,14 +42,14 @@ class CacheManager:
self.semantic_cache_collection_name = "semantic_cache"
# L1 缓存 (内存)
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
self.l1_kv_cache: dict[str, dict[str, Any]] = {}
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
if not embedding_dim:
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
self.embedding_dimension = embedding_dim
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
self.l1_vector_id_to_key: Dict[int, str] = {}
self.l1_vector_id_to_key: dict[int, str] = {}
# L2 向量缓存 (使用新的服务)
vector_db_service.get_or_create_collection(self.semantic_cache_collection_name)
@@ -59,7 +61,7 @@ class CacheManager:
logger.info("缓存管理器已初始化: L1 (内存+FAISS), L2 (数据库+ChromaDB)")
@staticmethod
def _validate_embedding(embedding_result: Any) -> Optional[np.ndarray]:
def _validate_embedding(embedding_result: Any) -> np.ndarray | None:
"""
验证和标准化嵌入向量格式
"""
@@ -100,7 +102,7 @@ class CacheManager:
return None
@staticmethod
def _generate_key(tool_name: str, function_args: Dict[str, Any], tool_file_path: Union[str, Path]) -> str:
def _generate_key(tool_name: str, function_args: dict[str, Any], tool_file_path: str | Path) -> str:
"""生成确定性的缓存键,包含文件修改时间以实现自动失效。"""
try:
tool_file_path = Path(tool_file_path)
@@ -124,10 +126,10 @@ class CacheManager:
async def get(
self,
tool_name: str,
function_args: Dict[str, Any],
tool_file_path: Union[str, Path],
semantic_query: Optional[str] = None,
) -> Optional[Any]:
function_args: dict[str, Any],
tool_file_path: str | Path,
semantic_query: str | None = None,
) -> Any | None:
"""
从缓存获取结果,查询顺序: L1-KV -> L1-Vector -> L2-KV -> L2-Vector。
"""
@@ -251,11 +253,11 @@ class CacheManager:
async def set(
self,
tool_name: str,
function_args: Dict[str, Any],
tool_file_path: Union[str, Path],
function_args: dict[str, Any],
tool_file_path: str | Path,
data: Any,
ttl: Optional[int] = None,
semantic_query: Optional[str] = None,
ttl: int | None = None,
semantic_query: str | None = None,
):
"""将结果存入所有缓存层。"""
if ttl is None:

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
from typing import Optional
from src.config.config import global_config, model_config
def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]:
def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: bool = True) -> int | None:
"""获取当前配置的嵌入向量维度。
优先顺序:
@@ -14,7 +12,7 @@ def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global:
3. 调用方提供的 fallback
"""
candidates: list[Optional[int]] = []
candidates: list[int | None] = []
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
@@ -30,7 +28,7 @@ def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global:
candidates.append(fallback)
resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
if resolved and sync_global:
try:

View File

@@ -4,8 +4,8 @@
"""
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from datetime import datetime
from typing import Any
from . import BaseDataModel
@@ -16,12 +16,12 @@ class BotInterestTag(BaseDataModel):
tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
embedding: Optional[List[float]] = None # 标签的embedding向量
embedding: list[float] | None = None # 标签的embedding向量
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"tag_name": self.tag_name,
@@ -33,7 +33,7 @@ class BotInterestTag(BaseDataModel):
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
def from_dict(cls, data: dict[str, Any]) -> "BotInterestTag":
"""从字典创建对象"""
return cls(
tag_name=data["tag_name"],
@@ -51,16 +51,16 @@ class BotPersonalityInterests(BaseDataModel):
personality_id: str
personality_description: str # 人设描述文本
interest_tags: List[BotInterestTag] = field(default_factory=list)
interest_tags: list[BotInterestTag] = field(default_factory=list)
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
last_updated: datetime = field(default_factory=datetime.now)
version: int = 1 # 版本号,用于追踪更新
def get_active_tags(self) -> List[BotInterestTag]:
def get_active_tags(self) -> list[BotInterestTag]:
"""获取活跃的兴趣标签"""
return [tag for tag in self.interest_tags if tag.is_active]
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"personality_id": self.personality_id,
@@ -72,7 +72,7 @@ class BotPersonalityInterests(BaseDataModel):
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
def from_dict(cls, data: dict[str, Any]) -> "BotPersonalityInterests":
"""从字典创建对象"""
return cls(
personality_id=data["personality_id"],
@@ -89,14 +89,14 @@ class InterestMatchResult(BaseDataModel):
"""兴趣匹配结果"""
message_id: str
matched_tags: List[str] = field(default_factory=list)
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
matched_tags: list[str] = field(default_factory=list)
match_scores: dict[str, float] = field(default_factory=dict) # tag_name -> score
overall_score: float = 0.0
top_tag: Optional[str] = None
top_tag: str | None = None
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
matched_keywords: List[str] = field(default_factory=list)
matched_keywords: list[str] = field(default_factory=list)
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
def add_match(self, tag_name: str, score: float, keywords: list[str] = None):
"""添加匹配结果"""
self.matched_tags.append(tag_name)
self.match_scores[tag_name] = score
@@ -131,7 +131,7 @@ class InterestMatchResult(BaseDataModel):
else:
self.confidence = 0.0
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
def get_top_matches(self, top_n: int = 3) -> list[tuple]:
"""获取前N个最佳匹配"""
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_matches[:top_n]

View File

@@ -1,6 +1,6 @@
import json
from typing import Optional, Any, Dict
from dataclasses import dataclass, field
from typing import Any
from . import BaseDataModel
@@ -10,7 +10,7 @@ class DatabaseUserInfo(BaseDataModel):
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None
user_cardname: str | None = None
# def __post_init__(self):
# assert isinstance(self.platform, str), "platform must be a string"
@@ -25,7 +25,7 @@ class DatabaseUserInfo(BaseDataModel):
class DatabaseGroupInfo(BaseDataModel):
group_id: str = field(default_factory=str)
group_name: str = field(default_factory=str)
group_platform: Optional[str] = None
group_platform: str | None = None
# def __post_init__(self):
# assert isinstance(self.group_id, str), "group_id must be a string"
@@ -42,7 +42,7 @@ class DatabaseChatInfo(BaseDataModel):
create_time: float = field(default_factory=float)
last_active_time: float = field(default_factory=float)
user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo)
group_info: Optional[DatabaseGroupInfo] = None
group_info: DatabaseGroupInfo | None = None
# def __post_init__(self):
# assert isinstance(self.stream_id, str), "stream_id must be a string"
@@ -62,41 +62,41 @@ class DatabaseMessages(BaseDataModel):
message_id: str = "",
time: float = 0.0,
chat_id: str = "",
reply_to: Optional[str] = None,
interest_value: Optional[float] = None,
key_words: Optional[str] = None,
key_words_lite: Optional[str] = None,
is_mentioned: Optional[bool] = None,
is_at: Optional[bool] = None,
reply_probability_boost: Optional[float] = None,
processed_plain_text: Optional[str] = None,
display_message: Optional[str] = None,
priority_mode: Optional[str] = None,
priority_info: Optional[str] = None,
additional_config: Optional[str] = None,
reply_to: str | None = None,
interest_value: float | None = None,
key_words: str | None = None,
key_words_lite: str | None = None,
is_mentioned: bool | None = None,
is_at: bool | None = None,
reply_probability_boost: float | None = None,
processed_plain_text: str | None = None,
display_message: str | None = None,
priority_mode: str | None = None,
priority_info: str | None = None,
additional_config: str | None = None,
is_emoji: bool = False,
is_picid: bool = False,
is_command: bool = False,
is_notify: bool = False,
selected_expressions: Optional[str] = None,
selected_expressions: str | None = None,
is_read: bool = False,
user_id: str = "",
user_nickname: str = "",
user_cardname: Optional[str] = None,
user_cardname: str | None = None,
user_platform: str = "",
chat_info_group_id: Optional[str] = None,
chat_info_group_name: Optional[str] = None,
chat_info_group_platform: Optional[str] = None,
chat_info_group_id: str | None = None,
chat_info_group_name: str | None = None,
chat_info_group_platform: str | None = None,
chat_info_user_id: str = "",
chat_info_user_nickname: str = "",
chat_info_user_cardname: Optional[str] = None,
chat_info_user_cardname: str | None = None,
chat_info_user_platform: str = "",
chat_info_stream_id: str = "",
chat_info_platform: str = "",
chat_info_create_time: float = 0.0,
chat_info_last_active_time: float = 0.0,
# 新增字段
actions: Optional[list] = None,
actions: list | None = None,
should_reply: bool = False,
**kwargs: Any,
):
@@ -132,7 +132,7 @@ class DatabaseMessages(BaseDataModel):
self.selected_expressions = selected_expressions
self.is_read = is_read
self.group_info: Optional[DatabaseGroupInfo] = None
self.group_info: DatabaseGroupInfo | None = None
self.user_info = DatabaseUserInfo(
user_id=user_id,
user_nickname=user_nickname,
@@ -172,7 +172,7 @@ class DatabaseMessages(BaseDataModel):
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
# "interest_value must be a float or None"
# )
def flatten(self) -> Dict[str, Any]:
def flatten(self) -> dict[str, Any]:
"""
将消息数据模型转换为字典格式,便于存储或传输
"""
@@ -255,7 +255,7 @@ class DatabaseMessages(BaseDataModel):
"""
return self.actions or []
def get_message_summary(self) -> Dict[str, Any]:
def get_message_summary(self) -> dict[str, Any]:
"""
获取消息摘要信息

View File

@@ -1,11 +1,14 @@
from dataclasses import dataclass, field
from typing import Optional, Dict, List, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from src.plugin_system.base.component_types import ChatType
from . import BaseDataModel
if TYPE_CHECKING:
pass
from src.plugin_system.base.component_types import ActionInfo, ChatMode
from .database_data_model import DatabaseMessages
@dataclass
@@ -13,17 +16,17 @@ class TargetPersonInfo(BaseDataModel):
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
person_id: Optional[str] = None
person_name: Optional[str] = None
person_id: str | None = None
person_name: str | None = None
@dataclass
class ActionPlannerInfo(BaseDataModel):
action_type: str = field(default_factory=str)
reasoning: Optional[str] = None
action_data: Optional[Dict] = None
action_message: Optional[Dict] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None
reasoning: str | None = None
action_data: dict | None = None
action_message: Optional["DatabaseMessages"] = None
available_actions: dict[str, "ActionInfo"] | None = None
@dataclass
@@ -35,7 +38,7 @@ class InterestScore(BaseDataModel):
interest_match_score: float
relationship_score: float
mentioned_score: float
details: Dict[str, str]
details: dict[str, str]
@dataclass
@@ -49,10 +52,10 @@ class Plan(BaseDataModel):
chat_type: "ChatType"
# Generator 填充
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
chat_history: List["DatabaseMessages"] = field(default_factory=list)
target_info: Optional[TargetPersonInfo] = None
available_actions: dict[str, "ActionInfo"] = field(default_factory=dict)
chat_history: list["DatabaseMessages"] = field(default_factory=list)
target_info: TargetPersonInfo | None = None
# Filter 填充
llm_prompt: Optional[str] = None
decided_actions: Optional[List[ActionPlannerInfo]] = None
llm_prompt: str | None = None
decided_actions: list[ActionPlannerInfo] | None = None

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any
from . import BaseDataModel
@@ -9,10 +9,10 @@ if TYPE_CHECKING:
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None
reasoning: Optional[str] = None
model: Optional[str] = None
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None
content: str | None = None
reasoning: str | None = None
model: str | None = None
tool_calls: list["ToolCall"] | None = None
prompt: str | None = None
selected_expressions: list[int] | None = None
reply_set: list[tuple[str, Any]] | None = None

View File

@@ -7,11 +7,12 @@ import asyncio
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ChatMode, ChatType
from . import BaseDataModel
from src.plugin_system.base.component_types import ChatMode, ChatType
from src.common.logger import get_logger
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
@@ -34,11 +35,11 @@ class StreamContext(BaseDataModel):
stream_id: str
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
history_messages: List["DatabaseMessages"] = field(default_factory=list)
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
history_messages: list["DatabaseMessages"] = field(default_factory=list)
last_check_time: float = field(default_factory=time.time)
is_active: bool = True
processing_task: Optional[asyncio.Task] = None
processing_task: asyncio.Task | None = None
interruption_count: int = 0 # 打断计数器
last_interruption_time: float = 0.0 # 上次打断时间
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
@@ -49,8 +50,8 @@ class StreamContext(BaseDataModel):
# 新增字段以替代ChatMessageContext功能
current_message: Optional["DatabaseMessages"] = None
priority_mode: Optional[str] = None
priority_info: Optional[dict] = None
priority_mode: str | None = None
priority_info: dict | None = None
def add_message(self, message: "DatabaseMessages"):
"""添加消息到上下文"""
@@ -150,11 +151,11 @@ class StreamContext(BaseDataModel):
self.unread_messages.remove(msg)
break
def get_unread_messages(self) -> List["DatabaseMessages"]:
def get_unread_messages(self) -> list["DatabaseMessages"]:
"""获取未读消息"""
return [msg for msg in self.unread_messages if not msg.is_read]
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
def get_history_messages(self, limit: int = 20) -> list["DatabaseMessages"]:
"""获取历史消息"""
# 优先返回最近的历史消息和所有未读消息
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
@@ -230,7 +231,7 @@ class StreamContext(BaseDataModel):
"""设置当前消息"""
self.current_message = message
def get_template_name(self) -> Optional[str]:
def get_template_name(self) -> str | None:
"""获取模板名称"""
if (
self.current_message
@@ -336,11 +337,11 @@ class StreamContext(BaseDataModel):
return False
return True
def get_priority_mode(self) -> Optional[str]:
def get_priority_mode(self) -> str | None:
"""获取优先级模式"""
return self.priority_mode
def get_priority_info(self) -> Optional[dict]:
def get_priority_info(self) -> dict | None:
"""获取优先级信息"""
return self.priority_info

View File

@@ -1,10 +1,11 @@
import os
from rich.traceback import install
from src.common.logger import get_logger
# SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_engine, get_db_session
from src.common.database.sqlalchemy_models import get_db_session, get_engine
from src.common.logger import get_logger
install(extra_lines=3)

View File

@@ -6,31 +6,31 @@
import time
import traceback
from typing import Dict, List, Any, Union, Optional
from typing import Any
from sqlalchemy import desc, asc, func, and_, select
from sqlalchemy import and_, asc, desc, func, select
from sqlalchemy.exc import SQLAlchemyError
from src.common.database.sqlalchemy_models import (
get_db_session,
Messages,
ActionRecords,
PersonInfo,
ChatStreams,
LLMUsage,
Emoji,
Images,
ImageDescriptions,
OnlineTime,
Memory,
Expression,
ThinkingLog,
GraphNodes,
GraphEdges,
Schedule,
MaiZoneScheduleStatus,
CacheEntries,
ChatStreams,
Emoji,
Expression,
GraphEdges,
GraphNodes,
ImageDescriptions,
Images,
LLMUsage,
MaiZoneScheduleStatus,
Memory,
Messages,
OnlineTime,
PersonInfo,
Schedule,
ThinkingLog,
UserRelationships,
get_db_session,
)
from src.common.logger import get_logger
@@ -59,7 +59,7 @@ MODEL_MAPPING = {
}
async def build_filters(model_class, filters: Dict[str, Any]):
async def build_filters(model_class, filters: dict[str, Any]):
"""构建查询过滤条件"""
conditions = []
@@ -98,13 +98,13 @@ async def build_filters(model_class, filters: Dict[str, Any]):
async def db_query(
model_class,
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
data: dict[str, Any] | None = None,
query_type: str | None = "get",
filters: dict[str, Any] | None = None,
limit: int | None = None,
order_by: list[str] | None = None,
single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None:
"""执行异步数据库查询操作
Args:
@@ -263,8 +263,8 @@ async def db_query(
async def db_save(
model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
model_class, data: dict[str, Any], key_field: str | None = None, key_value: Any | None = None
) -> dict[str, Any] | None:
"""异步保存数据到数据库(创建或更新)
Args:
@@ -325,11 +325,11 @@ async def db_save(
async def db_get(
model_class,
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
filters: dict[str, Any] | None = None,
limit: int | None = None,
order_by: str | None = None,
single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None:
"""异步从数据库获取记录
Args:
@@ -359,9 +359,9 @@ async def store_action_info(
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_data: dict | None = None,
action_name: str = "",
) -> Optional[Dict[str, Any]]:
) -> dict[str, Any] | None:
"""异步存储动作信息到数据库
Args:

View File

@@ -4,10 +4,10 @@
提供统一的异步数据库初始化接口
"""
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database
from src.common.logger import get_logger
logger = get_logger("sqlalchemy_init")
@@ -71,7 +71,7 @@ async def create_all_tables() -> bool:
return False
async def get_database_info() -> Optional[dict]:
async def get_database_info() -> dict | None:
"""
异步获取数据库信息

View File

@@ -6,11 +6,12 @@
import datetime
import os
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Optional, Any, Dict, AsyncGenerator
from typing import Any
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
@@ -423,7 +424,7 @@ class Expression(Base):
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
type: Mapped[str] = mapped_column(Text, nullable=False)
create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
create_date: Mapped[float | None] = mapped_column(Float, nullable=True)
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
@@ -710,7 +711,7 @@ async def initialize_database():
config = global_config.database
# 配置引擎参数
engine_kwargs: Dict[str, Any] = {
engine_kwargs: dict[str, Any] = {
"echo": False, # 生产环境关闭SQL日志
"future": True,
}
@@ -759,12 +760,12 @@ async def initialize_database():
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]:
async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]:
"""
异步数据库会话上下文管理器。
在初始化失败时会yield None调用方需要检查会话是否为None。
"""
session: Optional[AsyncSession] = None
session: AsyncSession | None = None
SessionLocal = None
try:
_, SessionLocal = await initialize_database()

View File

@@ -1,16 +1,16 @@
# 使用基于时间戳的文件处理器,简单的轮转份数限制
import logging
import orjson
import threading
import time
from collections.abc import Callable
from datetime import datetime, timedelta
from pathlib import Path
import orjson
import structlog
import tomlkit
from pathlib import Path
from typing import Callable, Optional
from datetime import datetime, timedelta
# 创建logs目录
LOG_DIR = Path("logs")
LOG_DIR.mkdir(exist_ok=True)
@@ -212,7 +212,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
try:
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
config = tomlkit.load(f)
return config.get("log", default_config)
except Exception as e:
@@ -982,7 +982,7 @@ raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger()
binds: dict[str, Callable] = {}
def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
def get_logger(name: str | None) -> structlog.stdlib.BoundLogger:
"""获取logger实例支持按名称绑定"""
if name is None:
return raw_logger

View File

@@ -4,7 +4,6 @@ __version__ = "0.1.0"
from .api import get_global_api
__all__ = [
"get_global_api",
]

View File

@@ -1,10 +1,12 @@
from src.common.server import get_global_server
import importlib.metadata
from maim_message import MessageServer
from src.common.logger import get_logger
from src.config.config import global_config
import os
from maim_message import MessageServer
from src.common.logger import get_logger
from src.common.server import get_global_server
from src.config.config import global_config
global_api = None

View File

@@ -1,15 +1,15 @@
import traceback
from typing import Any
from typing import List, Optional, Any, Dict
from sqlalchemy import not_, select, func
from sqlalchemy import func, not_, select
from sqlalchemy.orm import DeclarativeBase
from src.config.config import global_config
from src.common.database.sqlalchemy_database_api import get_db_session
# from src.common.database.database_model import Messages
from src.common.database.sqlalchemy_models import Messages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger(__name__)
@@ -18,7 +18,7 @@ class Base(DeclarativeBase):
pass
def _model_to_dict(instance: Base) -> Dict[str, Any]:
def _model_to_dict(instance: Base) -> dict[str, Any]:
"""
将 SQLAlchemy 模型实例转换为字典。
"""
@@ -32,12 +32,12 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
async def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
sort: list[tuple[str, int]] | None = None,
limit: int = 0,
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
) -> List[dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
根据提供的过滤器、排序和限制条件查找消息。

View File

@@ -1,13 +1,13 @@
import asyncio
import base64
import json
import platform
from datetime import datetime, timezone
import aiohttp
import platform
from datetime import datetime, timezone
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from src.common.logger import get_logger
from src.common.tcp_connector import get_tcp_connector
from src.config.config import global_config

View File

@@ -1,20 +1,20 @@
import os
from typing import Optional
from fastapi import FastAPI, APIRouter
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware # 新增导入
from rich.traceback import install
from uvicorn import Config, Server as UvicornServer
from uvicorn import Config
from uvicorn import Server as UvicornServer
install(extra_lines=3)
class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MaiMCore"):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
self._server: Optional[UvicornServer] = None
self._server: UvicornServer | None = None
self.set_address(host, port)
# 配置 CORS
@@ -57,7 +57,7 @@ class Server:
"""
self.app.include_router(router, prefix=prefix)
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
def set_address(self, host: str | None = None, port: int | None = None):
"""设置服务器地址和端口"""
if host:
self._host = host
@@ -76,7 +76,7 @@ class Server:
raise
except Exception as e:
await self.shutdown()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
raise RuntimeError(f"服务器运行错误: {e!s}") from e
finally:
await self.shutdown()

View File

@@ -1,6 +1,7 @@
import ssl
import certifi
import aiohttp
import certifi
ssl_context = ssl.create_default_context(cafile=certifi.where())

View File

@@ -18,4 +18,4 @@ def get_vector_db_service() -> VectorDBBase:
# 全局向量数据库服务实例
vector_db_service: VectorDBBase = get_vector_db_service()
__all__ = ["vector_db_service", "VectorDBBase"]
__all__ = ["VectorDBBase", "vector_db_service"]

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
class VectorDBBase(ABC):
@@ -36,10 +36,10 @@ class VectorDBBase(ABC):
def add(
self,
collection_name: str,
embeddings: List[List[float]],
documents: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
embeddings: list[list[float]],
documents: list[str] | None = None,
metadatas: list[dict[str, Any]] | None = None,
ids: list[str] | None = None,
) -> None:
"""
向指定集合中添加数据。
@@ -57,11 +57,11 @@ class VectorDBBase(ABC):
def query(
self,
collection_name: str,
query_embeddings: List[List[float]],
query_embeddings: list[list[float]],
n_results: int = 1,
where: Optional[Dict[str, Any]] = None,
where: dict[str, Any] | None = None,
**kwargs: Any,
) -> Dict[str, List[Any]]:
) -> dict[str, list[Any]]:
"""
在指定集合中查询相似向量。
@@ -81,8 +81,8 @@ class VectorDBBase(ABC):
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
ids: list[str] | None = None,
where: dict[str, Any] | None = None,
) -> None:
"""
从指定集合中删除数据。
@@ -98,13 +98,13 @@ class VectorDBBase(ABC):
def get(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[Dict[str, Any]] = None,
include: Optional[List[str]] = None,
) -> Dict[str, Any]:
ids: list[str] | None = None,
where: dict[str, Any] | None = None,
limit: int | None = None,
offset: int | None = None,
where_document: dict[str, Any] | None = None,
include: list[str] | None = None,
) -> dict[str, Any]:
"""
根据条件从集合中获取数据。

View File

@@ -1,12 +1,13 @@
import threading
from typing import Any, Dict, List, Optional
from typing import Any
import chromadb
from chromadb.config import Settings
from .base import VectorDBBase
from src.common.logger import get_logger
from .base import VectorDBBase
logger = get_logger("chromadb_impl")
@@ -38,7 +39,7 @@ class ChromaDBImpl(VectorDBBase):
self.client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
self._collections: Dict[str, Any] = {}
self._collections: dict[str, Any] = {}
self._initialized = True
logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}")
except Exception as e:
@@ -65,10 +66,10 @@ class ChromaDBImpl(VectorDBBase):
def add(
self,
collection_name: str,
embeddings: List[List[float]],
documents: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
embeddings: list[list[float]],
documents: list[str] | None = None,
metadatas: list[dict[str, Any]] | None = None,
ids: list[str] | None = None,
) -> None:
collection = self.get_or_create_collection(collection_name)
if collection:
@@ -85,11 +86,11 @@ class ChromaDBImpl(VectorDBBase):
def query(
self,
collection_name: str,
query_embeddings: List[List[float]],
query_embeddings: list[list[float]],
n_results: int = 1,
where: Optional[Dict[str, Any]] = None,
where: dict[str, Any] | None = None,
**kwargs: Any,
) -> Dict[str, List[Any]]:
) -> dict[str, list[Any]]:
collection = self.get_or_create_collection(collection_name)
if collection:
try:
@@ -120,7 +121,7 @@ class ChromaDBImpl(VectorDBBase):
logger.error(f"回退查询也失败: {fallback_e}")
return {}
def _process_where_condition(self, where: Dict[str, Any]) -> Optional[Dict[str, Any]]:
def _process_where_condition(self, where: dict[str, Any]) -> dict[str, Any] | None:
"""
处理where条件转换为ChromaDB支持的格式
ChromaDB支持的格式
@@ -174,13 +175,13 @@ class ChromaDBImpl(VectorDBBase):
def get(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[Dict[str, Any]] = None,
include: Optional[List[str]] = None,
) -> Dict[str, Any]:
ids: list[str] | None = None,
where: dict[str, Any] | None = None,
limit: int | None = None,
offset: int | None = None,
where_document: dict[str, Any] | None = None,
include: list[str] | None = None,
) -> dict[str, Any]:
"""根据条件从集合中获取数据"""
collection = self.get_or_create_collection(collection_name)
if collection:
@@ -217,8 +218,8 @@ class ChromaDBImpl(VectorDBBase):
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict[str, Any]] = None,
ids: list[str] | None = None,
where: dict[str, Any] | None = None,
) -> None:
collection = self.get_or_create_collection(collection_name)
if collection: