调整对应的调用
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from .expression_learner import get_expression_learner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -75,10 +76,8 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.expression_learner = get_expression_learner()
|
||||
# TODO: API-Adapter修改标记
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="expression.selector",
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -92,7 +91,6 @@ class ExpressionSelector:
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
import hashlib
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
@@ -108,8 +106,7 @@ class ExpressionSelector:
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str)
|
||||
if chat_id_candidate:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
@@ -118,9 +115,10 @@ class ExpressionSelector:
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
@@ -128,7 +126,7 @@ class ExpressionSelector:
|
||||
grammar_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
|
||||
)
|
||||
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
@@ -138,9 +136,10 @@ class ExpressionSelector:
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in style_query
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
|
||||
grammar_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
@@ -150,9 +149,10 @@ class ExpressionSelector:
|
||||
"source_id": expr.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in grammar_query
|
||||
}
|
||||
for expr in grammar_query
|
||||
]
|
||||
|
||||
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
# 按权重抽样(使用count作为权重)
|
||||
@@ -174,22 +174,22 @@ class ExpressionSelector:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id = expr.get("source_id")
|
||||
expr_type = expr.get("type", "style")
|
||||
situation = expr.get("situation")
|
||||
style = expr.get("style")
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for (chat_id, expr_type, situation, style), _expr in updates_by_key.items():
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == expr_type) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style)
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
@@ -264,7 +264,7 @@ class ExpressionSelector:
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix} LLM返回结果: {content}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user