re-style: 格式化代码
This commit is contained in:
@@ -1,21 +1,20 @@
|
||||
import time
|
||||
import random
|
||||
import orjson
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
DECAY_DAYS = 30 # 30天衰减到0.01
|
||||
@@ -193,7 +192,7 @@ class ExpressionLearner:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
@@ -341,7 +340,7 @@ class ExpressionLearner:
|
||||
return []
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
chat_dict: dict[str, list[dict[str, Any]]] = {}
|
||||
for chat_id, situation, style in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
@@ -398,7 +397,7 @@ class ExpressionLearner:
|
||||
return learnt_expressions
|
||||
return None
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
async def learn_expression(self, type: str, num: int = 10) -> tuple[list[tuple[str, str, str]], str] | None:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
@@ -416,7 +415,7 @@ class ExpressionLearner:
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习时间
|
||||
random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
random_msg: list[dict[str, Any]] | None = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=current_time,
|
||||
@@ -447,16 +446,16 @@ class ExpressionLearner:
|
||||
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
|
||||
return expressions, chat_id
|
||||
|
||||
@staticmethod
|
||||
def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]:
|
||||
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
expressions: list[tuple[str, str, str]] = []
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
@@ -562,7 +561,7 @@ class ExpressionLearnerManager:
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
with open(expr_file, encoding="utf-8") as f:
|
||||
expressions = orjson.loads(f.read())
|
||||
|
||||
if not isinstance(expressions, list):
|
||||
|
||||
Reference in New Issue
Block a user