feat:更好的配置文件更新,表达方式迁移到数据库
This commit is contained in:
@@ -11,6 +11,7 @@ from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from .expression_learner import get_expression_learner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -84,88 +85,77 @@ class ExpressionSelector:
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
) = self.expression_learner.get_expression_by_chat_id(chat_id)
|
||||
|
||||
# 直接数据库查询
|
||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||
style_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style"
|
||||
} for expr in style_query
|
||||
]
|
||||
grammar_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar"
|
||||
} for expr in grammar_query
|
||||
]
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
|
||||
# 按权重抽样(使用count作为权重)
|
||||
if learnt_style_expressions:
|
||||
style_weights = [expr.get("count", 1) for expr in learnt_style_expressions]
|
||||
selected_style = weighted_sample(learnt_style_expressions, style_weights, style_num)
|
||||
if style_exprs:
|
||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||
selected_style = weighted_sample(style_exprs, style_weights, style_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
if learnt_grammar_expressions:
|
||||
grammar_weights = [expr.get("count", 1) for expr in learnt_grammar_expressions]
|
||||
selected_grammar = weighted_sample(learnt_grammar_expressions, grammar_weights, grammar_num)
|
||||
if grammar_exprs:
|
||||
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
|
||||
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
|
||||
else:
|
||||
selected_grammar = []
|
||||
|
||||
return selected_style, selected_grammar
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按文件分组后一次性写入"""
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
|
||||
updates_by_file = {}
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id = expr.get("source_id")
|
||||
if not source_id:
|
||||
logger.warning(f"表达方式缺少source_id,无法更新: {expr}")
|
||||
expr_type = expr.get("type", "style")
|
||||
situation = expr.get("situation")
|
||||
style = expr.get("style")
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
|
||||
file_path = ""
|
||||
if source_id == "personality":
|
||||
file_path = os.path.join("data", "expression", "personality", "expressions.json")
|
||||
else:
|
||||
chat_id = source_id
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "style":
|
||||
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||
elif expr_type == "grammar":
|
||||
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||
|
||||
if file_path:
|
||||
if file_path not in updates_by_file:
|
||||
updates_by_file[file_path] = []
|
||||
updates_by_file[file_path].append(expr)
|
||||
|
||||
for file_path, updates in updates_by_file.items():
|
||||
if not os.path.exists(file_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
all_expressions = json.load(f)
|
||||
|
||||
# Create a dictionary for quick lookup
|
||||
expr_map = {(e.get("situation"), e.get("style")): e for e in all_expressions}
|
||||
|
||||
# Update counts in memory
|
||||
for expr_to_update in updates:
|
||||
key = (expr_to_update.get("situation"), expr_to_update.get("style"))
|
||||
if key in expr_map:
|
||||
expr_in_map = expr_map[key]
|
||||
current_count = expr_in_map.get("count", 1)
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_in_map["count"] = new_count
|
||||
expr_in_map["last_active_time"] = time.time()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in {file_path}"
|
||||
)
|
||||
|
||||
# Save the updated list once for this file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_expressions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for (chat_id, expr_type, situation, style), expr in updates_by_key.items():
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == expr_type) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user