调整对应的调用

This commit is contained in:
UnCLAS-Prommer
2025-07-30 17:07:55 +08:00
parent 3c40ceda4c
commit 6c0edd0ad7
40 changed files with 580 additions and 1236 deletions

View File

@@ -1,16 +1,17 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Tuple, Optional, Any
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .expression_learner import get_expression_learner
from src.common.database.database_model import Expression
logger = get_logger("expression_selector")
@@ -75,10 +76,8 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis
class ExpressionSelector:
def __init__(self):
self.expression_learner = get_expression_learner()
# TODO: API-Adapter修改标记
self.llm_model = LLMRequest(
model=global_config.model.utils_small,
request_type="expression.selector",
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
@staticmethod
@@ -92,7 +91,6 @@ class ExpressionSelector:
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
import hashlib
if is_group:
components = [platform, str(id_str)]
else:
@@ -108,8 +106,7 @@ class ExpressionSelector:
for group in groups:
group_chat_ids = []
for stream_config_str in group:
chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str)
if chat_id_candidate:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
@@ -118,9 +115,10 @@ class ExpressionSelector:
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
@@ -128,7 +126,7 @@ class ExpressionSelector:
grammar_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
)
style_exprs = [
{
"situation": expr.situation,
@@ -138,9 +136,10 @@ class ExpressionSelector:
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in style_query
}
for expr in style_query
]
grammar_exprs = [
{
"situation": expr.situation,
@@ -150,9 +149,10 @@ class ExpressionSelector:
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in grammar_query
}
for expr in grammar_query
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
@@ -174,22 +174,22 @@ class ExpressionSelector:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id = expr.get("source_id")
expr_type = expr.get("type", "style")
situation = expr.get("situation")
style = expr.get("style")
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for (chat_id, expr_type, situation, style), _expr in updates_by_key.items():
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == expr_type) &
(Expression.situation == situation) &
(Expression.style == style)
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
@@ -264,7 +264,7 @@ class ExpressionSelector:
# 4. 调用LLM
try:
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix} LLM返回结果: {content}")