feat:更好的配置文件更新,表达方式迁移到数据库

This commit is contained in:
SengokuCola
2025-07-16 18:13:02 +08:00
parent e533fec8f6
commit 5c97bcf083
11 changed files with 428 additions and 244 deletions

View File

@@ -11,6 +11,7 @@ from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .expression_learner import get_expression_learner
from src.common.database.database_model import Expression
logger = get_logger("expression_selector")
@@ -84,88 +85,77 @@ class ExpressionSelector:
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
# sourcery skip: extract-duplicate-method, move-assign
(
learnt_style_expressions,
learnt_grammar_expressions,
) = self.expression_learner.get_expression_by_chat_id(chat_id)
# 直接数据库查询
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
style_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style"
} for expr in style_query
]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar"
} for expr in grammar_query
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
if learnt_style_expressions:
style_weights = [expr.get("count", 1) for expr in learnt_style_expressions]
selected_style = weighted_sample(learnt_style_expressions, style_weights, style_num)
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, style_num)
else:
selected_style = []
if learnt_grammar_expressions:
grammar_weights = [expr.get("count", 1) for expr in learnt_grammar_expressions]
selected_grammar = weighted_sample(learnt_grammar_expressions, grammar_weights, grammar_num)
if grammar_exprs:
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
else:
selected_grammar = []
return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
"""对一批表达方式更新count值文件分组后一次性写入"""
"""对一批表达方式更新count值chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
updates_by_file = {}
updates_by_key = {}
for expr in expressions_to_update:
source_id = expr.get("source_id")
if not source_id:
logger.warning(f"表达方式缺少source_id无法更新: {expr}")
expr_type = expr.get("type", "style")
situation = expr.get("situation")
style = expr.get("style")
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
file_path = ""
if source_id == "personality":
file_path = os.path.join("data", "expression", "personality", "expressions.json")
else:
chat_id = source_id
expr_type = expr.get("type", "style")
if expr_type == "style":
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
elif expr_type == "grammar":
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
if file_path:
if file_path not in updates_by_file:
updates_by_file[file_path] = []
updates_by_file[file_path].append(expr)
for file_path, updates in updates_by_file.items():
if not os.path.exists(file_path):
continue
try:
with open(file_path, "r", encoding="utf-8") as f:
all_expressions = json.load(f)
# Create a dictionary for quick lookup
expr_map = {(e.get("situation"), e.get("style")): e for e in all_expressions}
# Update counts in memory
for expr_to_update in updates:
key = (expr_to_update.get("situation"), expr_to_update.get("style"))
if key in expr_map:
expr_in_map = expr_map[key]
current_count = expr_in_map.get("count", 1)
new_count = min(current_count + increment, 5.0)
expr_in_map["count"] = new_count
expr_in_map["last_active_time"] = time.time()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in {file_path}"
)
# Save the updated list once for this file
with open(file_path, "w", encoding="utf-8") as f:
json.dump(all_expressions, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for (chat_id, expr_type, situation, style), expr in updates_by_key.items():
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == expr_type) &
(Expression.situation == situation) &
(Expression.style == style)
)
if query.exists():
expr_obj = query.get()
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
async def select_suitable_expressions_llm(
self,