数据库重构
This commit is contained in:
@@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -22,7 +22,6 @@ DECAY_DAYS = 30 # 30天衰减到0.01
|
||||
DECAY_MIN = 0.01 # 最小衰减值
|
||||
|
||||
logger = get_logger("expressor")
|
||||
session = get_session()
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
@@ -204,7 +203,8 @@ class ExpressionLearner:
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 直接从数据库查询
|
||||
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
|
||||
with get_db_session() as session:
|
||||
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
|
||||
for expr in style_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
@@ -247,8 +247,9 @@ class ExpressionLearner:
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
"""
|
||||
try:
|
||||
# 获取所有表达方式
|
||||
all_expressions = session.execute(select(Expression)).scalars()
|
||||
with get_db_session() as session:
|
||||
# 获取所有表达方式
|
||||
all_expressions = session.execute(select(Expression)).scalars()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
@@ -265,19 +266,19 @@ class ExpressionLearner:
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
session.delete(expr)
|
||||
session.commit()
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
updated_count += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
@@ -355,43 +356,46 @@ class ExpressionLearner:
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = session.execute(select(Expression).where(
|
||||
with get_db_session() as session:
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
if query:
|
||||
expr_obj = query
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
else:
|
||||
new_expression = Expression(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
)
|
||||
session.add(new_expression)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
session.execute(select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())).scalars()
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
session.delete(expr)
|
||||
session.commit()
|
||||
return learnt_expressions
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
else:
|
||||
new_expression = Expression(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
)
|
||||
session.add(new_expression)
|
||||
session.commit()
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
session.execute(select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())).scalars()
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
session.delete(expr)
|
||||
session.commit()
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
@@ -574,8 +578,8 @@ class ExpressionLearnerManager:
|
||||
continue
|
||||
|
||||
# 查重:同chat_id+type+situation+style
|
||||
|
||||
query = session.execute(select(Expression).where(
|
||||
with get_db_session() as session:
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
@@ -596,6 +600,8 @@ class ExpressionLearnerManager:
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
session.add(new_expression)
|
||||
session.commit()
|
||||
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -627,21 +633,22 @@ class ExpressionLearnerManager:
|
||||
使用last_active_time作为create_date的默认值
|
||||
"""
|
||||
try:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
|
||||
updated_count = 0
|
||||
with get_db_session() as session:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
|
||||
updated_count = 0
|
||||
|
||||
for expr in old_expressions:
|
||||
# 使用last_active_time作为create_date
|
||||
expr.create_date = expr.last_active_time
|
||||
updated_count += 1
|
||||
for expr in old_expressions:
|
||||
# 使用last_active_time作为create_date
|
||||
expr.create_date = expr.last_active_time
|
||||
updated_count += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user