re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -1,21 +1,20 @@
import time
import random
import orjson
import os
import random
import time
from datetime import datetime
from typing import Any
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_db_session
import orjson
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 30 # 30天衰减到0.01
@@ -193,7 +192,7 @@ class ExpressionLearner:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
"""
获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
@@ -341,7 +340,7 @@ class ExpressionLearner:
return []
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
chat_dict: dict[str, list[dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
@@ -398,7 +397,7 @@ class ExpressionLearner:
return learnt_expressions
return None
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
async def learn_expression(self, type: str, num: int = 10) -> tuple[list[tuple[str, str, str]], str] | None:
"""从指定聊天流学习表达方式
Args:
@@ -416,7 +415,7 @@ class ExpressionLearner:
current_time = time.time()
# 获取上次学习时间
random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive(
random_msg: list[dict[str, Any]] | None = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
@@ -447,16 +446,16 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
return expressions, chat_id
@staticmethod
def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]:
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容,存储为(situation, style)元组
"""
expressions: List[Tuple[str, str, str]] = []
expressions: list[tuple[str, str, str]] = []
for line in response.splitlines():
line = line.strip()
if not line:
@@ -562,7 +561,7 @@ class ExpressionLearnerManager:
if not os.path.exists(expr_file):
continue
try:
with open(expr_file, "r", encoding="utf-8") as f:
with open(expr_file, encoding="utf-8") as f:
expressions = orjson.loads(f.read())
if not isinstance(expressions, list):

View File

@@ -1,18 +1,18 @@
import orjson
import time
import random
import hashlib
import random
import time
from typing import Any
from typing import List, Dict, Tuple, Optional, Any
import orjson
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("expression_selector")
@@ -45,7 +45,7 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
def weighted_sample(population: list[dict], weights: list[float], k: int) -> list[dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
@@ -95,7 +95,7 @@ class ExpressionSelector:
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
@@ -114,7 +114,7 @@ class ExpressionSelector:
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
def get_related_chat_ids(self, chat_id: str) -> list[str]:
"""根据expression.rules配置获取与当前chat_id相关的所有chat_id包括自身"""
rules = global_config.expression.rules
current_group = None
@@ -139,7 +139,7 @@ class ExpressionSelector:
async def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
@@ -195,7 +195,7 @@ class ExpressionSelector:
return selected_style, selected_grammar
@staticmethod
async def update_expressions_count_batch(expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
@@ -240,8 +240,8 @@ class ExpressionSelector:
chat_info: str,
max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None,
) -> List[Dict[str, Any]]:
target_message: str | None = None,
) -> list[dict[str, Any]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""