数据库重构

This commit is contained in:
雅诺狐
2025-08-16 23:43:45 +08:00
committed by Windpicker-owo
parent 6bd4170c90
commit 875e02d42f
21 changed files with 841 additions and 1034 deletions

View File

@@ -12,8 +12,7 @@ from src.common.logger import get_logger
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_session
session = get_session()
from src.common.database.sqlalchemy_database_api import get_db_session
logger = get_logger("expression_selector")
@@ -132,14 +131,14 @@ class ExpressionSelector:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
with get_db_session() as session:
# 优化一次性查询所有相关chat_id的表达方式
style_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
))
grammar_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
))
style_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
))
grammar_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
))
style_exprs = [
{
@@ -176,7 +175,13 @@ class ExpressionSelector:
selected_style = weighted_sample(style_exprs, style_weights, total_num)
else:
selected_style = []
return selected_style
if grammar_exprs:
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
else:
selected_grammar = []
return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
@@ -195,7 +200,8 @@ class ExpressionSelector:
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = session.execute(select(Expression).where(
with get_db_session() as session:
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
@@ -207,10 +213,11 @@ class ExpressionSelector:
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
session.commit()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
session.commit()
async def select_suitable_expressions_llm(
self,