typing fix
This commit is contained in:
@@ -107,11 +107,12 @@ class ExpressionLearner:
|
||||
last_active_time = expr.get("last_active_time", time.time())
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == type_str) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style_val)
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
@@ -125,7 +126,7 @@ class ExpressionLearner:
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str
|
||||
type=type_str,
|
||||
)
|
||||
logger.info(f"已迁移 {expr_file} 到数据库")
|
||||
except Exception as e:
|
||||
@@ -149,24 +150,28 @@ class ExpressionLearner:
|
||||
# 直接从数据库查询
|
||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
for expr in style_query:
|
||||
learnt_style_expressions.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style"
|
||||
})
|
||||
learnt_style_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style",
|
||||
}
|
||||
)
|
||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||
for expr in grammar_query:
|
||||
learnt_grammar_expressions.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar"
|
||||
})
|
||||
learnt_grammar_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar",
|
||||
}
|
||||
)
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
def is_similar(self, s1: str, s2: str) -> bool:
|
||||
@@ -213,14 +218,16 @@ class ExpressionLearner:
|
||||
logger.error(f"全局衰减{type}表达方式失败: {e}")
|
||||
continue
|
||||
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
|
||||
# 学习新的表达方式(这里会进行局部衰减)
|
||||
for _ in range(3):
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
||||
learnt_style = await self.learn_and_store(type="style", num=25)
|
||||
if not learnt_style:
|
||||
return [], []
|
||||
|
||||
for _ in range(1):
|
||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
||||
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
|
||||
if not learnt_grammar:
|
||||
return [], []
|
||||
|
||||
@@ -321,10 +328,10 @@ class ExpressionLearner:
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == type) &
|
||||
(Expression.situation == new_expr["situation"]) &
|
||||
(Expression.style == new_expr["style"])
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
@@ -342,13 +349,17 @@ class ExpressionLearner:
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type
|
||||
type=type,
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
|
||||
exprs = list(
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
return learnt_expressions
|
||||
|
||||
|
||||
Reference in New Issue
Block a user