re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -1,10 +1,9 @@
import time
import sys
import os
from typing import Dict, List
import sys
import time
# Add project root to Python path
from src.common.database.database_model import Expression, ChatStreams
from src.common.database.database_model import ChatStreams, Expression
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
@@ -30,7 +29,7 @@ def get_chat_name(chat_id: str) -> str:
return f"查询失败 ({chat_id})"
def calculate_time_distribution(expressions) -> Dict[str, int]:
def calculate_time_distribution(expressions) -> dict[str, int]:
"""Calculate distribution of last active time in days"""
now = time.time()
distribution = {
@@ -64,7 +63,7 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
return distribution
def calculate_count_distribution(expressions) -> Dict[str, int]:
def calculate_count_distribution(expressions) -> dict[str, int]:
"""Calculate distribution of count values"""
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
for expr in expressions:
@@ -86,7 +85,7 @@ def calculate_count_distribution(expressions) -> Dict[str, int]:
return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]:
"""Get top N most used expressions for a specific chat_id"""
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)

View File

@@ -1,7 +1,6 @@
import time
import sys
import os
from typing import Dict, List, Tuple, Optional
import sys
import time
from datetime import datetime
# Add project root to Python path
@@ -35,7 +34,7 @@ def format_timestamp(timestamp: float) -> str:
return "未知时间"
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
def calculate_interest_value_distribution(messages) -> dict[str, int]:
"""Calculate distribution of interest_value"""
distribution = {
"0.000-0.010": 0,
@@ -76,7 +75,7 @@ def calculate_interest_value_distribution(messages) -> Dict[str, int]:
return distribution
def get_interest_value_stats(messages) -> Dict[str, float]:
def get_interest_value_stats(messages) -> dict[str, float]:
"""Calculate basic statistics for interest_value"""
values = [
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
@@ -97,7 +96,7 @@ def get_interest_value_stats(messages) -> Dict[str, float]:
}
def get_available_chats() -> List[Tuple[str, str, int]]:
def get_available_chats() -> list[tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id
@@ -130,7 +129,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
def get_time_range_input() -> tuple[float | None, float | None]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
@@ -170,7 +169,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
def analyze_interest_values(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
) -> None:
"""Analyze interest values with optional filters"""

View File

@@ -1,13 +1,14 @@
import tkinter as tk
from tkinter import ttk, messagebox, filedialog, colorchooser
import orjson
from pathlib import Path
import threading
import toml
from datetime import datetime
from collections import defaultdict
import os
import threading
import time
import tkinter as tk
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from tkinter import colorchooser, filedialog, messagebox, ttk
import orjson
import toml
class LogIndex:
@@ -409,7 +410,7 @@ class AsyncLogLoader:
file_size = os.path.getsize(file_path)
processed_size = 0
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, encoding="utf-8") as f:
line_count = 0
batch_size = 1000 # 批量处理
@@ -561,7 +562,7 @@ class LogViewer:
try:
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
bot_config = toml.load(f)
if "log" in bot_config:
self.log_config.update(bot_config["log"])
@@ -575,7 +576,7 @@ class LogViewer:
try:
if viewer_config_path.exists():
with open(viewer_config_path, "r", encoding="utf-8") as f:
with open(viewer_config_path, encoding="utf-8") as f:
viewer_config = toml.load(f)
if "viewer" in viewer_config:
self.viewer_config.update(viewer_config["viewer"])
@@ -843,7 +844,7 @@ class LogViewer:
mapping_file = Path("config/module_mapping.json")
if mapping_file.exists():
try:
with open(mapping_file, "r", encoding="utf-8") as f:
with open(mapping_file, encoding="utf-8") as f:
custom_mapping = orjson.loads(f.read())
self.module_name_mapping.update(custom_mapping)
except Exception as e:
@@ -1172,7 +1173,7 @@ class LogViewer:
"""读取新的日志条目并返回它们"""
new_entries = []
new_modules = set() # 收集新发现的模块
with open(self.current_log_file, "r", encoding="utf-8") as f:
with open(self.current_log_file, encoding="utf-8") as f:
f.seek(from_position)
line_count = self.log_index.total_entries
for line in f:

View File

@@ -1,36 +1,37 @@
import asyncio
import datetime
import os
import shutil
import sys
import orjson
import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
from typing import Optional
import orjson
from json_repair import repair_json
# 将项目根目录添加到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
from src.chat.knowledge.utils.hash import get_sha256
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.kg_manager import KGManager
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.utils.hash import get_sha256
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("LPMM_LearningTool")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
@@ -59,7 +60,7 @@ def clear_cache():
def process_text_file(file_path):
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, encoding="utf-8") as f:
raw = f.read()
return [p.strip() for p in raw.split("\n\n") if p.strip()]
@@ -86,7 +87,7 @@ def preprocess_raw_data():
# --- 模块二:信息提取 ---
def _parse_and_repair_json(json_string: str) -> Optional[dict]:
def _parse_and_repair_json(json_string: str) -> dict | None:
"""
尝试解析JSON字符串如果失败则尝试修复并重新解析。
@@ -249,7 +250,7 @@ def extract_information(paragraphs_dict, model_set):
# --- 模块三:数据导入 ---
async def import_data(openie_obj: Optional[OpenIE] = None):
async def import_data(openie_obj: OpenIE | None = None):
"""
将OpenIE数据导入知识库Embedding Store 和 KG

View File

@@ -4,11 +4,13 @@
提供插件manifest文件的创建、验证和管理功能
"""
import argparse
import os
import sys
import argparse
import orjson
from pathlib import Path
import orjson
from src.common.logger import get_logger
from src.plugin_system.utils.manifest_utils import (
ManifestValidator,
@@ -124,7 +126,7 @@ def validate_manifest_file(plugin_dir: str) -> bool:
return False
try:
with open(manifest_path, "r", encoding="utf-8") as f:
with open(manifest_path, encoding="utf-8") as f:
manifest_data = orjson.loads(f.read())
validator = ManifestValidator()

View File

@@ -1,46 +1,48 @@
import os
import orjson
import sys # 新增系统模块导入
# import time
import pickle
import sys # 新增系统模块导入
from pathlib import Path
import orjson
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from typing import Dict, Any, List, Optional, Type
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from peewee import Field, IntegrityError, Model
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from peewee import Model, Field, IntegrityError
# Rich 进度条和显示组件
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
TimeElapsedColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich.panel import Panel
# from rich.text import Text
# from rich.text import Text
from src.common.database.database import db
from src.common.database.sqlalchemy_models import (
ChatStreams,
Emoji,
Messages,
Images,
ImageDescriptions,
PersonInfo,
Knowledges,
ThinkingLog,
GraphNodes,
GraphEdges,
GraphNodes,
ImageDescriptions,
Images,
Knowledges,
Messages,
PersonInfo,
ThinkingLog,
)
from src.common.logger import get_logger
@@ -54,12 +56,12 @@ class MigrationConfig:
"""迁移配置类"""
mongo_collection: str
target_model: Type[Model]
field_mapping: Dict[str, str]
target_model: type[Model]
field_mapping: dict[str, str]
batch_size: int = 500
enable_validation: bool = True
skip_duplicates: bool = True
unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
# 数据验证相关类已移除 - 用户要求不要数据验证
@@ -73,7 +75,7 @@ class MigrationCheckpoint:
processed_count: int
last_processed_id: Any
timestamp: datetime
batch_errors: List[Dict[str, Any]] = field(default_factory=list)
batch_errors: list[dict[str, Any]] = field(default_factory=list)
@dataclass
@@ -88,11 +90,11 @@ class MigrationStats:
duplicate_count: int = 0
validation_errors: int = 0
batch_insert_count: int = 0
errors: List[Dict[str, Any]] = field(default_factory=list)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
errors: list[dict[str, Any]] = field(default_factory=list)
start_time: datetime | None = None
end_time: datetime | None = None
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
"""添加错误记录"""
self.errors.append(
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
@@ -108,10 +110,10 @@ class MigrationStats:
class MongoToSQLiteMigrator:
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
self.mongo_uri = mongo_uri or self._build_mongo_uri()
self.mongo_client: Optional[MongoClient] = None
self.mongo_client: MongoClient | None = None
self.mongo_db = None
# 迁移配置
@@ -142,7 +144,7 @@ class MongoToSQLiteMigrator:
else:
return f"mongodb://{host}:{port}/{self.database_name}"
def _initialize_migration_configs(self) -> List[MigrationConfig]:
def _initialize_migration_configs(self) -> list[MigrationConfig]:
"""初始化迁移配置"""
return [ # 表情包迁移配置
MigrationConfig(
@@ -306,7 +308,7 @@ class MongoToSQLiteMigrator:
),
]
def _initialize_validation_rules(self) -> Dict[str, Any]:
def _initialize_validation_rules(self) -> dict[str, Any]:
"""数据验证已禁用 - 返回空字典"""
return {}
@@ -337,7 +339,7 @@ class MongoToSQLiteMigrator:
self.mongo_client.close()
logger.info("MongoDB连接已关闭")
def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
"""获取嵌套字段的值"""
if "." not in field_path:
return document.get(field_path)
@@ -434,7 +436,7 @@ class MongoToSQLiteMigrator:
return None
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
"""数据验证已禁用 - 始终返回True"""
return True
@@ -454,7 +456,7 @@ class MongoToSQLiteMigrator:
except Exception as e:
logger.warning(f"保存断点失败: {e}")
def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
"""加载迁移断点"""
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
if not checkpoint_file.exists():
@@ -467,7 +469,7 @@ class MongoToSQLiteMigrator:
logger.warning(f"加载断点失败: {e}")
return None
def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
"""批量插入数据"""
if not data_list:
return 0
@@ -494,7 +496,7 @@ class MongoToSQLiteMigrator:
return success_count
def _check_duplicate_by_unique_fields(
self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
) -> bool:
"""根据唯一字段检查重复"""
if not unique_fields:
@@ -512,7 +514,7 @@ class MongoToSQLiteMigrator:
logger.debug(f"重复检查失败: {e}")
return False
def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
"""使用ORM创建模型实例"""
try:
# 过滤掉不存在的字段
@@ -669,7 +671,7 @@ class MongoToSQLiteMigrator:
return stats
def migrate_all(self) -> Dict[str, MigrationStats]:
def migrate_all(self) -> dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
@@ -730,7 +732,7 @@ class MongoToSQLiteMigrator:
self._print_migration_summary(all_stats)
return all_stats
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
"""使用Rich打印美观的迁移汇总信息"""
# 计算总体统计
total_processed = sum(stats.processed_count for stats in all_stats.values())
@@ -857,7 +859,7 @@ class MongoToSQLiteMigrator:
"""添加新的迁移配置"""
self.migration_configs.append(config)
def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
"""迁移单个指定的集合"""
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
if not config:
@@ -875,7 +877,7 @@ class MongoToSQLiteMigrator:
finally:
self.disconnect_mongodb()
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
"""导出错误报告"""
error_report = {
"timestamp": datetime.now().isoformat(),

View File

@@ -1,17 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
从现有ChromaDB数据重建JSON元数据索引
"""
import asyncio
import sys
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
from src.chat.memory_system.memory_system import MemorySystem
from src.common.logger import get_logger
logger = get_logger(__name__)

View File

@@ -1,12 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错
"""
import asyncio
import sys
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

View File

@@ -1,8 +1,7 @@
import time
import sys
import os
import re
from typing import Dict, List, Tuple, Optional
import sys
import time
from datetime import datetime
# Add project root to Python path
@@ -63,7 +62,7 @@ def format_timestamp(timestamp: float) -> str:
return "未知时间"
def calculate_text_length_distribution(messages) -> Dict[str, int]:
def calculate_text_length_distribution(messages) -> dict[str, int]:
"""Calculate distribution of processed_plain_text length"""
distribution = {
"0": 0, # 空文本
@@ -126,7 +125,7 @@ def calculate_text_length_distribution(messages) -> Dict[str, int]:
return distribution
def get_text_length_stats(messages) -> Dict[str, float]:
def get_text_length_stats(messages) -> dict[str, float]:
"""Calculate basic statistics for processed_plain_text length"""
lengths = []
null_count = 0
@@ -168,7 +167,7 @@ def get_text_length_stats(messages) -> Dict[str, float]:
}
def get_available_chats() -> List[Tuple[str, str, int]]:
def get_available_chats() -> list[tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id排除特殊类型消息
@@ -202,7 +201,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
def get_time_range_input() -> tuple[float | None, float | None]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
@@ -241,7 +240,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
return None, None
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
def get_top_longest_messages(messages, top_n: int = 10) -> list[tuple[str, int, str, str]]:
"""Get top N longest messages"""
message_lengths = []
@@ -266,7 +265,7 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int,
def analyze_text_lengths(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
) -> None:
"""Analyze processed_plain_text lengths with optional filters"""

View File

@@ -30,7 +30,7 @@ def update_prompt_imports(file_path):
print(f"文件不存在: {file_path}")
return False
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, encoding="utf-8") as f:
content = f.read()
# 替换导入语句