re-style: 格式化代码
This commit is contained in:
@@ -1,18 +1,18 @@
|
||||
import orjson
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -45,7 +45,7 @@ def init_prompt():
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
||||
def weighted_sample(population: list[dict], weights: list[float], k: int) -> list[dict]:
|
||||
"""按权重随机抽样"""
|
||||
if not population or not weights or k <= 0:
|
||||
return []
|
||||
@@ -95,7 +95,7 @@ class ExpressionSelector:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
@@ -114,7 +114,7 @@ class ExpressionSelector:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
def get_related_chat_ids(self, chat_id: str) -> list[str]:
|
||||
"""根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
rules = global_config.expression.rules
|
||||
current_group = None
|
||||
@@ -139,7 +139,7 @@ class ExpressionSelector:
|
||||
|
||||
async def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
@@ -195,7 +195,7 @@ class ExpressionSelector:
|
||||
return selected_style, selected_grammar
|
||||
|
||||
@staticmethod
|
||||
async def update_expressions_count_batch(expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
@@ -240,8 +240,8 @@ class ExpressionSelector:
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
target_message: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
target_message: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user